SEMANTIC IMAGE SEGMENTATION WITH DEEP CONVOLUTIONAL NETS AND FULLY CONNECTED CRFS论文解读
V2链接:
https://blog.csdn.net/weixin_44543648/article/details/122599976
V3链接:
https://blog.csdn.net/weixin_44543648/article/details/122829741
论文地址:
https://arxiv.org/pdf/1412.7062v3.pdf
代码地址:
https://github.com/wangleihitcs/DeepLab-V1-PyTorch
主要内容:
主要解决DCNN用于图像分割中存在的两个问题:
下采样
:最大池和下采样重复组合会导致的信号分辨率降低
空间上的“不敏感性”(不变性)
:分类器获得以对象为中心的决策需要对空间转换的不变性,这固有地限制了DCNN模型的空间准确性。
主要采用两个方法来解决这一问题:
-
采用hole算法解决下采样的问题,提高接收域:即设置conv2d中的dilation,如图所示:
-
使用全连接条件随机场(CRF),提高了模型捕获细节的能力。一般用于test阶段,不用于train阶段。
公式:
代码:
CRF代码:
import numpy as np
import pydensecrf.densecrf as dcrf
import pydensecrf.utils as utils
class DenseCRF(object):
def __init__(self, iter_max, pos_w, pos_xy_std, bi_w, bi_xy_std, bi_rgb_std):
self.iter_max = iter_max
self.pos_w = pos_w
self.pos_xy_std = pos_xy_std
self.bi_w = bi_w
self.bi_xy_std = bi_xy_std
self.bi_rgb_std = bi_rgb_std
def __call__(self, image, probmap):
C, H, W = probmap.shape
U = utils.unary_from_softmax(probmap)
U = np.ascontiguousarray(U)
image = np.ascontiguousarray(image)#内存不连续的图像转换为内存连续的图像
d = dcrf.DenseCRF2D(W, H, C)
d.setUnaryEnergy(U)
d.addPairwiseGaussian(sxy=self.pos_xy_std, compat=self.pos_w)
d.addPairwiseBilateral(
sxy=self.bi_xy_std, srgb=self.bi_rgb_std, rgbim=image, compat=self.bi_w
)
Q = d.inference(self.iter_max)
Q = np.array(Q).reshape((C, H, W))
return Q
网络代码:
import torch
import torch.nn as nn
from torchvision import models
class VGG16_LargeFOV(nn.Module):
def __init__(self, num_classes=21, input_size=321, split='train', init_weights=True):
super(VGG16_LargeFOV, self).__init__()
self.input_size = input_size
self.split = split
self.features = nn.Sequential(
### conv1_1 conv1_2 maxpooling
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
### conv2_1 conv2_2 maxpooling
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
### conv3_1 conv3_2 conv3_3 maxpooling
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
### conv4_1 conv4_2 conv4_3 maxpooling(stride=1)
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
### conv5_1 conv5_2 conv5_3 (dilated convolution dilation=2, padding=2)
### maxpooling(stride=1)
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=2, dilation=2),
nn.ReLU(True),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=2, dilation=2),
nn.ReLU(True),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=2, dilation=2),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
### average pooling
nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
### fc6 relu6 drop6
nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=12, dilation=12),
nn.ReLU(True),
nn.Dropout2d(0.5),
### fc7 relu7 drop7 (kernel_size=1, padding=0)
nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
nn.ReLU(True),
nn.Dropout2d(0.5),
### fc8
nn.Conv2d(1024, num_classes, kernel_size=1, stride=1, padding=0)
)
if init_weights:
self._initialize_weights()
def forward(self, x):
output = self.features(x)
if self.split == 'test':
output = nn.functional.interpolate(output, size=(self.input_size, self.input_size), mode='bilinear', align_corners=True)
return output
def _initialize_weights(self):
for m in self.named_modules():
if isinstance(m[1], nn.Conv2d):
if m[0] == 'features.38':
nn.init.normal_(m[1].weight.data, mean=0, std=0.01)
nn.init.constant_(m[1].bias.data, 0.0)
if __name__ == "__main__":
model = VGG16_LargeFOV()
x = torch.ones([2, 3, 321, 321])
y = model(x)
print(y.shape)
版权声明:本文为weixin_44543648原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。