U-Net : Image segmentation | U-Net 구조 이해하기
1. U-Net
- 이미지 세그멘테이션에 사용되는 딥러닝 아키텍처
- 인코더-디코더 기반의 모델
- 충분한 양의 라벨이 없는 작은 데이터셋에서도 효과적으로 작동
2. U-Net의 구조
Encoder 역할의 Contracting path
Decoder 역할의 Expansive path
Skip connection
U-Net은 위와 같이 세 가지 파트로 나눌 수 있다.
Encoder (Contracting Path)
입력 이미지의 특징을 추출할 수 있도록
채널 수를 늘리면서 차원을 축소하는 단계
1. convolution 연산
: 일반적으로 3x3 크기의 커널을 사용,
필터의 수는 해당 레이어에서 추출할 특징의 수를 나타낸다.
2. pooling 연산
: 풀링 연산을 통해 down sampling을 수행
주로 Max Pooling이 사용되며, 2x2 크기의 stride 2 풀링 윈도우가 자주 사용된다.
파란색 박스를 확장해보면 위와 같은 과정을 거치게 된다.
3x3 conv연산, Batch Normalization, ReLU 함수를 거치게 되는 것이다.
import torch
import torch.nn as nn
class Convblock(nn.Module):
def __init__(self, input_channel, output_channel, kernel=3, stride=1, padding=1):
super().__init__()
self.convblock = nn.Sequential(
nn.Conv2d(input_channel, output_channel, kernel, stride, padding),
nn.BatchNorm2d(output_channel),
nn.ReLU(inplace=True),
nn.Conv2d(output_channel, output_channel, kernel),
nn.ReLU(inplace=True),
)
def forward(self, x):
x = self.convblock(x)
return x
class UNetEncoder(nn.Module):
def __init__(self, in_channels, out_channels):
super(UNetEncoder, self).__init__()
self.conv1 = Convblock(in_channels, 64)
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = Convblock(64, 128)
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.maxpool1(x1)
x3 = self.conv2(x2)
x4 = self.maxpool2(x3)
return x4
# 인코더 객체 생성
in_channels = 3
out_channels = 64
encoder = UNetEncoder(in_channels, out_channels)
# 임의의 입력 이미지 생성
input_image = torch.randn(1, 3, 256, 256)
# 인코더에 입력 이미지 전달하여 출력 확인
output_feature_map = encoder(input_image)
print("Encoder Output Shape:", output_feature_map.shape)
Bottleneck
네트워크의 중간에 위치하며
인코더의 다운 샘플링 과정에서 얻은 고수준 특징을 잘 요약하고, 디코더로 전달하기 위해 사용된다.
1x1 convolution 연산을 수행하는데
1. 채널 간 상호작용의 강조 : 채널 간의 상호작용을 학습하고 강조
2. 채널 수 조절 : 채널 수를 적절히 조절하고 특징의 추상화 수준을 조절
3. 고수준 특징의 요약 : 인코더에서 추출된 고수준 특징을 1x1 컨볼루션을 통해 요약하고, 이를 디코더로 전달하여 세그멘테이션을 수행하는 데 활용
Decoder (Expansive Path)
저차원으로 인코딩된 정보만 이용하여 채널의 수를 줄이고
차원을 늘려서 고차원 이미지를 복원하는 단계
업샘플링 연산을 통해 고수준 특징을 복원하고,
skip connection을 통해 인코더에서 가져온 저수준 특징과 결합
다시 리마인드시키자면,
인코딩에서 차원 축소 단계를 거치면서 이미지 객체에 대한 공간적인 정보가 손실됨
디코딩 단계에서도 축소된 이미지 정보만을 이용하므로 손실된 정보를 회복하지 못함
skip connection : 인코딩의 각 layer에서 얻은 특징을 디코딩의 레이어에 connection
1. transposed convolution
: 디코더에서는 업샘플링 연산을 사용하여 추상화된 feature map을 확대한다.
(transposed convolution 연산에 대해서는 다음에 자세히 다루도록 하겠다,,)
2. skip connection
: 스킵 연결을 통해 인코더에서 가져온 저수준 특징을 현재 층의 특징과 결합한다.
-> 이때 Concatenation 또는 Addition 연산을 사용하여 스킵 연결을 수행
3. convolution layer
: 업샘플링된 고수준 특징과 스킵 연결된 저수준 특징을 결합한 후,
컨볼루션 연산을 사용하여 새로운 특징을 추출한다.
-> 다양한 채널을 가진 특징을 학습한다.
class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, skip_channels=None):
super(DecoderBlock, self).__init__()
# 업샘플링을 위한 ConvTranspose2d
self.upconv = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
# 스킵 연결을 위한 Convolution
if skip_channels is not None:
self.skip_conv = nn.Conv2d(skip_channels, in_channels // 2, kernel_size=1)
# 컨볼루션 연산
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
# 활성화 함수 (ReLU)
self.relu = nn.ReLU(inplace=True)
def forward(self, x, skip_connection=None):
x = self.upconv(x)
# 스킵 연결 추가
if skip_connection is not None:
x = torch.cat([x, self.skip_conv(skip_connection)], dim=1)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
return x
in_channels = 128 # 업샘플링 된 특징의 채널 수
out_channels = 64 # 현재 블록에서 생성될 특징의 채널 수
skip_channels = 64 # 스킵 연결에서 가져온 특징의 채널 수
decoder_block = DecoderBlock(in_channels, out_channels, skip_channels)
# 임의의 업샘플링된 특징과 스킵 연결된 특징을 생성
upsampled_feature = torch.randn(1, in_channels, 16, 16)
skip_connection_feature = torch.randn(1, skip_channels, 32, 32)
# 디코더 블록에 특징 전달
output_feature = decoder_block(upsampled_feature, skip_connection_feature)
# 출력 특징의 크기 확인
print("Output Feature Shape:", output_feature.shape)