Lighting Enhancement Aids Reconstruction of Colonoscopic Surfaces

Paper-reading, polyp

Posted by Seasons on June 11, 2022

Title page

image-20220611004312905

会议:Accepted at IPMI 2021

年份:2021

github链接https://github.com/zhangybzbo/ColonLight

pdf链接https://arxiv.org/pdf/2103.10310.pdf

Summary

  1. The lighting inconsistency of colonoscopy videos can cause a key component of the colonoscopic reconstruction system
  2. 这篇文章找到了一种光线校正的方法,来调节邻近视频帧之间的光线密度分布
  3. 该方法使用RNN网络来无监督调整gamma值,实现了实时
  4. 该方法显著增加了肠镜表面重建系统的重建成功率与重建质量

Workflow

Methods

1. SLAM mechanism

  • Simultaneous Localization And Mapping (SLAM), one of the most successful methods for 3D reconstruction.

  • SLAM is an algorithm that can achieve real-time dense reconstruction from a sequence of monocular images

  • SLAM has a localization component and a mapping component; the two components operate cooperatively. The localization (tracking) component predicts the camera poses from each incoming image frame. Based on the visual clues extracted from the images, the mapping component optimizes especially the pose predictions but also the keypoints’ depth estimates

  • 目标函数:

The lighting problem in colonoscopic surface reconstruction

  • 肠镜的特殊之处:the point light is moving with the camera and can change rapidly due to motion and occlusion
  • SLAM的额外亮度优化不能很好地处理肠镜点光源的问题,such as contrast difference (bright regions become brighter and dark regions become darker).

In this work we apply an adaptive intensity mapping to enhance the colonoscopy frame sequence with the help of an RNN network。

实现细节

1. RNN → gamma值调整(实时考虑)

image-20220613001540355

2. RNN Network

image-20220613001735318

3. 训练方式:无监督

  • 使用输入图像的前两帧与整个视频序列(共10帧)的前两帧作为参考

  • 损失函数:结构相似性测量

    image-20220613002652827

  • In training when computing the Lssim, we mask out the pixels with input intensity larger than 0.7

Result-show

image-20220613003033627

image-20220613003235718

APE:Absolute pose error

image-20220613142145164

image-20220613004424435

启发和思考

  1. 可否借鉴,用于息肉分割的阈值后处理
  2. 可否用于前处理,通过调整gamma值与其他值,结合域自适应,使得输入的肠镜图像能够快速归一化

代码注释

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import torch.nn as nn
import torch.nn.functional as F

'''modified from https://github.com/ndrplz/ConvLSTM_pytorch/'''
class ConvLSTMCell(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size, bias=True, forget_bias=1.0):
        """
        Initialize ConvLSTM cell.
        Parameters
        ----------
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias
        self.forget_bias = forget_bias # TODO: add or not

        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)

        torch.nn.init.xavier_normal_(self.conv.weight)

    def forward(self, input_tensor, cur_state):
        '''
        :param input_tensor: B x input_dim x height x width
        :param cur_state: B x 2hidden_dim x height x width
        :return: h_next, [h_next, c_next]
        '''
        h_cur, c_cur = torch.split(cur_state, self.hidden_dim, dim=1)
        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis

        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f + self.forget_bias)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, torch.cat([h_next, c_next], dim=1)

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))

class RNNmodule(nn.Module):
    def __init__(self, img_channel, hidden_channel, adj_channel, kernel_size, pixelwise):
        super(RNNmodule, self).__init__()

        self.img_channel = img_channel
        self.adj_channel = adj_channel
        self.pixelwise = pixelwise

        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.in_cnv = nn.Conv2d(img_channel, hidden_channel, kernel_size, stride=1, padding=self.padding)
        self.bn = nn.BatchNorm2d(hidden_channel, momentum=0.01) # TODO
        self.RNNcell = ConvLSTMCell(hidden_channel, hidden_channel, kernel_size)
        self.rnn_cnv = nn.Conv2d(hidden_channel, hidden_channel, kernel_size, stride=1, padding=self.padding)
        if pixelwise:
            self.out_cnv = nn.Conv2d(hidden_channel, adj_channel, kernel_size, stride=1, padding=self.padding)
        else:
            self.out_linear = nn.Linear(hidden_channel, 1)
        # TODO: or kernel size=1
        # TODO: or multi-adjustment using rnn

        self.init_param()

    def init_param(self):
        torch.nn.init.xavier_normal_(self.in_cnv.weight)
        torch.nn.init.xavier_normal_(self.rnn_cnv.weight)
        if self.pixelwise:
            torch.nn.init.xavier_normal_(self.out_cnv.weight)
        self.bn.weight.data.normal_(1.0, 0.02)
        self.bn.bias.data.fill_(0)

    def forward(self, inputs, last_hidden):
        b, c, h, w = inputs.size()
        assert c == self.img_channel

        if last_hidden == None:
            last_hidden = torch.cat(self.RNNcell.init_hidden(b, (h, w)), dim=1)

        conv = self.in_cnv(inputs)
        conv = self.bn(conv)
        conv = F.relu(conv)

        rnn, new_hidden = self.RNNcell(conv, last_hidden)
        rnn = self.rnn_cnv(rnn)
        # rnn = self.bn(rnn)
        rnn = F.relu(rnn)

        if self.pixelwise:
            adjs = F.relu(self.out_cnv(rnn))
            adj_t = torch.split(adjs, 1, dim=1)
            x = inputs
            for i in range(self.adj_channel):
                x = torch.pow(x, 1/adj_t[i]) # adj_t[i] B x 1 x h x w
        else:
            adjs = torch.mean(rnn, dim=[-1, -2])
            adjs = F.relu(self.out_linear(adjs)) # B x 1
            adjs = adjs.view(-1, 1, 1, 1)
            x = torch.pow(inputs, 1/adjs)
        # x_gray = torch.mean(x, dim=1, keepdim=True)
        # xp = adjs[:, 0].view(-1, 1, 1, 1)
        # yp = adjs[:, 1].view(-1, 1, 1, 1)
        # x = (x_gray<=xp) * yp/xp * x + (x_gray>xp) * ((1-yp)/(1-xp) * x + (yp-xp)/(1-xp))

        return x, adjs,