Какие архитектуры для сегментации изображений знаете?
Комментарии (1)
Ответ сгенерирован нейросетью и может содержать ошибки
Какие архитектуры для сегментации изображений знаете?
Сегментация изображений — это задача классификации каждого пикселя (или региона) в изображении. Существует несколько основных архитектур, которые доминируют в этой области. Давайте разберём самые важные.
1. FCN (Fully Convolutional Networks)
FCN была первой архитектурой, полностью основанной на свёртках для пиксельного предсказания (2014).
Ключевая идея
Вместо fully-connected слоёв, которые теряют пространственную информацию, FCN использует только свёртки и upsampling операции. Это позволяет работать с изображениями произвольного размера.
import torch
import torch.nn as nn
import torch.nn.functional as F
class FCN(nn.Module):
def __init__(self, num_classes=21):
super(FCN, self).__init__()
# Encoder (VGG16-like)
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2) # 1/2
)
self.conv2 = nn.Sequential(
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2) # 1/4
)
# Decoder: upsampling слои
self.score_pool2 = nn.Conv2d(128, num_classes, 1) # 1x1 conv для классификации
self.upsample2x = nn.ConvTranspose2d(
num_classes, num_classes, 4, stride=2, padding=1
)
self.upsample8x = nn.ConvTranspose2d(
num_classes, num_classes, 16, stride=8, padding=4
)
def forward(self, x):
# Encoder
pool1 = self.conv1(x) # 1/2
pool2 = self.conv2(pool1) # 1/4
# Decoder
score2 = self.score_pool2(pool2)
upsample2 = self.upsample2x(score2) # 1/2
# Finalize: upsampling in 8x
out = self.upsample8x(upsample2) # 1/1
# Crop to match input size
return out[:, :, :x.shape[2], :x.shape[3]]
Особенности:
- Простая архитектура
- Skip connections (использует информацию с разных уровней)
- Хорошо работает на PASCAL VOC
2. U-Net
U-Net — самая популярная архитектура для сегментации, особенно в медицинской визуализации (2015).
Ключевая идея
Симметричная encoder-decoder архитектура с skip connections. Encoder сжимает изображение, decoder его увеличивает, а skip connections передают информацию с разных уровней.
class UNet(nn.Module):
def __init__(self, num_classes=1, num_channels=3):
super(UNet, self).__init__()
# Encoder
self.enc1 = self.conv_block(num_channels, 64)
self.pool1 = nn.MaxPool2d(2, 2)
self.enc2 = self.conv_block(64, 128)
self.pool2 = nn.MaxPool2d(2, 2)
self.enc3 = self.conv_block(128, 256)
self.pool3 = nn.MaxPool2d(2, 2)
# Bottleneck
self.bottleneck = self.conv_block(256, 512)
# Decoder с skip connections
self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.dec3 = self.conv_block(512, 256) # 256+256 из skip
self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.dec2 = self.conv_block(256, 128) # 128+128 из skip
self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.dec1 = self.conv_block(128, 64) # 64+64 из skip
# Output
self.final = nn.Conv2d(64, num_classes, 1)
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
# Encoder
enc1 = self.enc1(x)
x = self.pool1(enc1)
enc2 = self.enc2(x)
x = self.pool2(enc2)
enc3 = self.enc3(x)
x = self.pool3(enc3)
# Bottleneck
x = self.bottleneck(x)
# Decoder с skip connections
x = self.up3(x)
x = torch.cat([x, enc3], dim=1) # Skip connection
x = self.dec3(x)
x = self.up2(x)
x = torch.cat([x, enc2], dim=1) # Skip connection
x = self.dec2(x)
x = self.up1(x)
x = torch.cat([x, enc1], dim=1) # Skip connection
x = self.dec1(x)
# Output
x = self.final(x)
return x
Особенности:
- Симметричная архитектура
- Skip connections передают детали
- Работает хорошо на маленьких датасетах
- Стандарт для медицинской визуализации
3. DeepLab
DeepLab — использует atrous convolution (dilated convolution) для увеличения receptive field без потери разрешения.
Ключевая идея
Atrous convolution позволяет увеличить восприятие сети без уменьшения размера картинки. ASPP (Atrous Spatial Pyramid Pooling) объединяет признаки на разных масштабах.
class AtrousConv(nn.Module):
def __init__(self, in_channels, out_channels, dilation):
super(AtrousConv, self).__init__()
self.atrous = nn.Conv2d(
in_channels, out_channels, 3,
dilation=dilation,
padding=dilation
)
def forward(self, x):
return self.atrous(x)
class ASPP(nn.Module):
def __init__(self, in_channels, out_channels=256):
super(ASPP, self).__init__()
# Разные dilation rates
self.conv1x1 = nn.Conv2d(in_channels, out_channels, 1)
self.atrous3 = AtrousConv(in_channels, out_channels, 3)
self.atrous6 = AtrousConv(in_channels, out_channels, 6)
self.atrous12 = AtrousConv(in_channels, out_channels, 12)
# Image pooling
self.image_pool = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1),
nn.ReLU()
)
self.project = nn.Conv2d(out_channels * 5, out_channels, 1)
def forward(self, x):
# Применяем разные receptive fields
conv1x1 = self.conv1x1(x)
atrous3 = self.atrous3(x)
atrous6 = self.atrous6(x)
atrous12 = self.atrous12(x)
# Image pooling
pool = self.image_pool(x)
pool = F.interpolate(pool, size=x.shape[2:], mode='bilinear')
# Конкатенируем
x = torch.cat(
[conv1x1, atrous3, atrous6, atrous12, pool],
dim=1
)
# Проект
return self.project(x)
class DeepLabV3(nn.Module):
def __init__(self, num_classes=21):
super(DeepLabV3, self).__init__()
# Backbone (ResNet)
self.backbone = nn.Sequential(
nn.Conv2d(3, 64, 7, stride=2, padding=3),
nn.BatchNorm2d(64),
nn.ReLU()
)
# ASPP
self.aspp = ASPP(64, 256)
# Decoder
self.decoder = nn.Sequential(
nn.Conv2d(256, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, num_classes, 1)
)
def forward(self, x):
input_shape = x.shape[2:]
# Backbone
x = self.backbone(x)
# ASPP
x = self.aspp(x)
# Decoder
x = self.decoder(x)
# Upsampling to original size
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
return x
Особенности:
- Atrous convolution для большого receptive field
- ASPP для многомасштабных признаков
- Не теряет разрешение как обычные CNN
- Отличные результаты на PASCAL VOC, Cityscapes
4. Mask R-CNN
Mask R-CNN — расширение Faster R-CNN для instance segmentation (2017).
Ключевая идея
Вначале находит bounding boxes объектов (detection), затем для каждого предсказывает маску пикселей.
# Использование предтренированного Mask R-CNN
from torchvision.models import detection
import torchvision.transforms as T
import torch
# Загрузка предобученной модели
model = detection.maskrcnn_resnet50_fpn(pretrained=True)
model.eval()
# Подготовка изображения
image = T.functional.pil_to_tensor(Image.open('image.jpg'))
image = image.float() / 255.0
# Inference
with torch.no_grad():
predictions = model([image])
# Результаты
boxes = predictions[0]['boxes']
masks = predictions[0]['masks']
scores = predictions[0]['scores']
labels = predictions[0]['labels']
# Фильтруем по confidence threshold
threshold = 0.5
valid = scores > threshold
boxes = boxes[valid]
masks = masks[valid]
labels = labels[valid]
Особенности:
- Instance segmentation (разные объекты разные маски)
- Использует RPN для detection
- Очень точная на COCO
- Требует больше вычислений
5. SegFormer
SegFormer — современная архитектура на основе Vision Transformers (2021).
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
from PIL import Image
import torch
# Загрузка модели
processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-cityscapes-1024-1024")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-cityscapes-1024-1024")
# Подготовка
image = Image.open('image.jpg')
inputs = processor(images=image, return_tensors="pt")
# Inference
with torch.no_grad():
outputs = model(**inputs)
predictions = outputs.logits.argmax(dim=1)[0]
# Visualize
import matplotlib.pyplot as plt
plt.imshow(predictions.numpy())
plt.colorbar()
plt.show()
Особенности:
- Трансформеры вместо CNN
- Лучше захватывает глобальный контекст
- Очень хорошие результаты
- Требует больше памяти
Сравнение архитектур
| Архитектура | Скорость | Точность | Применение | Сложность |
|---|---|---|---|---|
| FCN | Быстро | Хорошо | Baseline | Низкая |
| U-Net | Быстро | Отлично | Медицина | Низкая |
| DeepLab | Медленно | Отлично | Сцены | Средняя |
| Mask R-CNN | Медленно | Отлично | Instance seg | Высокая |
| SegFormer | Быстро | Отлично | SOTA | Средняя |
Loss functions для сегментации
import torch.nn as nn
# Cross Entropy (стандартная)
loss_fn = nn.CrossEntropyLoss()
# Dice Loss (для дисбаланса классов)
class DiceLoss(nn.Module):
def forward(self, pred, target, smooth=1.0):
pred = torch.sigmoid(pred)
intersection = (pred * target).sum()
dice = (2 * intersection + smooth) / (pred.sum() + target.sum() + smooth)
return 1 - dice
# Focal Loss (для очень несбалансированных данных)
loss_fn = nn.FocalLoss(alpha=0.25, gamma=2.0)
# Комбинированная
def combined_loss(pred, target):
ce_loss = nn.CrossEntropyLoss()(pred, target)
dice_loss = DiceLoss()(pred, target)
return ce_loss + dice_loss
Обучение
model = UNet(num_classes=num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(num_epochs):
model.train()
for images, masks in train_loader:
images = images.to(device)
masks = masks.to(device)
# Forward
outputs = model(images)
loss = loss_fn(outputs, masks)
# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Evaluation
model.eval()
with torch.no_grad():
iou = calculate_iou(model, val_loader)
print(f"Epoch {epoch}, Loss: {loss.item():.4f}, IoU: {iou:.4f}")
Вывод: выбор архитектуры зависит от задачи. U-Net для медицины, DeepLab для сцен, Mask R-CNN для instance segmentation, SegFormer для SOTA результатов.