← Назад к вопросам

Какие архитектуры для сегментации изображений знаете?

3.0 Senior🔥 81 комментариев
#Глубокое обучение#Машинное обучение

Комментарии (1)

🐱
claude-haiku-4.5PrepBro AI2 апр. 2026 г.(ред.)

Ответ сгенерирован нейросетью и может содержать ошибки

Какие архитектуры для сегментации изображений знаете?

Сегментация изображений — это задача классификации каждого пикселя (или региона) в изображении. Существует несколько основных архитектур, которые доминируют в этой области. Давайте разберём самые важные.

1. FCN (Fully Convolutional Networks)

FCN была первой архитектурой, полностью основанной на свёртках для пиксельного предсказания (2014).

Ключевая идея

Вместо fully-connected слоёв, которые теряют пространственную информацию, FCN использует только свёртки и upsampling операции. Это позволяет работать с изображениями произвольного размера.

import torch
import torch.nn as nn
import torch.nn.functional as F

class FCN(nn.Module):
    def __init__(self, num_classes=21):
        super(FCN, self).__init__()
        
        # Encoder (VGG16-like)
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)  # 1/2
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)  # 1/4
        )
        
        # Decoder: upsampling слои
        self.score_pool2 = nn.Conv2d(128, num_classes, 1)  # 1x1 conv для классификации
        self.upsample2x = nn.ConvTranspose2d(
            num_classes, num_classes, 4, stride=2, padding=1
        )
        self.upsample8x = nn.ConvTranspose2d(
            num_classes, num_classes, 16, stride=8, padding=4
        )
    
    def forward(self, x):
        # Encoder
        pool1 = self.conv1(x)  # 1/2
        pool2 = self.conv2(pool1)  # 1/4
        
        # Decoder
        score2 = self.score_pool2(pool2)
        upsample2 = self.upsample2x(score2)  # 1/2
        
        # Finalize: upsampling in 8x
        out = self.upsample8x(upsample2)  # 1/1
        
        # Crop to match input size
        return out[:, :, :x.shape[2], :x.shape[3]]

Особенности:

  • Простая архитектура
  • Skip connections (использует информацию с разных уровней)
  • Хорошо работает на PASCAL VOC

2. U-Net

U-Net — самая популярная архитектура для сегментации, особенно в медицинской визуализации (2015).

Ключевая идея

Симметричная encoder-decoder архитектура с skip connections. Encoder сжимает изображение, decoder его увеличивает, а skip connections передают информацию с разных уровней.

class UNet(nn.Module):
    def __init__(self, num_classes=1, num_channels=3):
        super(UNet, self).__init__()
        
        # Encoder
        self.enc1 = self.conv_block(num_channels, 64)
        self.pool1 = nn.MaxPool2d(2, 2)
        
        self.enc2 = self.conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(2, 2)
        
        self.enc3 = self.conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(2, 2)
        
        # Bottleneck
        self.bottleneck = self.conv_block(256, 512)
        
        # Decoder с skip connections
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = self.conv_block(512, 256)  # 256+256 из skip
        
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = self.conv_block(256, 128)  # 128+128 из skip
        
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = self.conv_block(128, 64)  # 64+64 из skip
        
        # Output
        self.final = nn.Conv2d(64, num_classes, 1)
    
    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        x = self.pool1(enc1)
        
        enc2 = self.enc2(x)
        x = self.pool2(enc2)
        
        enc3 = self.enc3(x)
        x = self.pool3(enc3)
        
        # Bottleneck
        x = self.bottleneck(x)
        
        # Decoder с skip connections
        x = self.up3(x)
        x = torch.cat([x, enc3], dim=1)  # Skip connection
        x = self.dec3(x)
        
        x = self.up2(x)
        x = torch.cat([x, enc2], dim=1)  # Skip connection
        x = self.dec2(x)
        
        x = self.up1(x)
        x = torch.cat([x, enc1], dim=1)  # Skip connection
        x = self.dec1(x)
        
        # Output
        x = self.final(x)
        
        return x

Особенности:

  • Симметричная архитектура
  • Skip connections передают детали
  • Работает хорошо на маленьких датасетах
  • Стандарт для медицинской визуализации

3. DeepLab

DeepLab — использует atrous convolution (dilated convolution) для увеличения receptive field без потери разрешения.

Ключевая идея

Atrous convolution позволяет увеличить восприятие сети без уменьшения размера картинки. ASPP (Atrous Spatial Pyramid Pooling) объединяет признаки на разных масштабах.

class AtrousConv(nn.Module):
    def __init__(self, in_channels, out_channels, dilation):
        super(AtrousConv, self).__init__()
        self.atrous = nn.Conv2d(
            in_channels, out_channels, 3,
            dilation=dilation,
            padding=dilation
        )
    
    def forward(self, x):
        return self.atrous(x)

class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels=256):
        super(ASPP, self).__init__()
        
        # Разные dilation rates
        self.conv1x1 = nn.Conv2d(in_channels, out_channels, 1)
        self.atrous3 = AtrousConv(in_channels, out_channels, 3)
        self.atrous6 = AtrousConv(in_channels, out_channels, 6)
        self.atrous12 = AtrousConv(in_channels, out_channels, 12)
        
        # Image pooling
        self.image_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1),
            nn.ReLU()
        )
        
        self.project = nn.Conv2d(out_channels * 5, out_channels, 1)
    
    def forward(self, x):
        # Применяем разные receptive fields
        conv1x1 = self.conv1x1(x)
        atrous3 = self.atrous3(x)
        atrous6 = self.atrous6(x)
        atrous12 = self.atrous12(x)
        
        # Image pooling
        pool = self.image_pool(x)
        pool = F.interpolate(pool, size=x.shape[2:], mode='bilinear')
        
        # Конкатенируем
        x = torch.cat(
            [conv1x1, atrous3, atrous6, atrous12, pool],
            dim=1
        )
        
        # Проект
        return self.project(x)

class DeepLabV3(nn.Module):
    def __init__(self, num_classes=21):
        super(DeepLabV3, self).__init__()
        
        # Backbone (ResNet)
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        
        # ASPP
        self.aspp = ASPP(64, 256)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Conv2d(256, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, num_classes, 1)
        )
    
    def forward(self, x):
        input_shape = x.shape[2:]
        
        # Backbone
        x = self.backbone(x)
        
        # ASPP
        x = self.aspp(x)
        
        # Decoder
        x = self.decoder(x)
        
        # Upsampling to original size
        x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
        
        return x

Особенности:

  • Atrous convolution для большого receptive field
  • ASPP для многомасштабных признаков
  • Не теряет разрешение как обычные CNN
  • Отличные результаты на PASCAL VOC, Cityscapes

4. Mask R-CNN

Mask R-CNN — расширение Faster R-CNN для instance segmentation (2017).

Ключевая идея

Вначале находит bounding boxes объектов (detection), затем для каждого предсказывает маску пикселей.

# Использование предтренированного Mask R-CNN
from torchvision.models import detection
import torchvision.transforms as T
import torch

# Загрузка предобученной модели
model = detection.maskrcnn_resnet50_fpn(pretrained=True)
model.eval()

# Подготовка изображения
image = T.functional.pil_to_tensor(Image.open('image.jpg'))
image = image.float() / 255.0

# Inference
with torch.no_grad():
    predictions = model([image])

# Результаты
boxes = predictions[0]['boxes']
masks = predictions[0]['masks']
scores = predictions[0]['scores']
labels = predictions[0]['labels']

# Фильтруем по confidence threshold
threshold = 0.5
valid = scores > threshold
boxes = boxes[valid]
masks = masks[valid]
labels = labels[valid]

Особенности:

  • Instance segmentation (разные объекты разные маски)
  • Использует RPN для detection
  • Очень точная на COCO
  • Требует больше вычислений

5. SegFormer

SegFormer — современная архитектура на основе Vision Transformers (2021).

from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
from PIL import Image
import torch

# Загрузка модели
processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-cityscapes-1024-1024")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-cityscapes-1024-1024")

# Подготовка
image = Image.open('image.jpg')
inputs = processor(images=image, return_tensors="pt")

# Inference
with torch.no_grad():
    outputs = model(**inputs)
predictions = outputs.logits.argmax(dim=1)[0]

# Visualize
import matplotlib.pyplot as plt
plt.imshow(predictions.numpy())
plt.colorbar()
plt.show()

Особенности:

  • Трансформеры вместо CNN
  • Лучше захватывает глобальный контекст
  • Очень хорошие результаты
  • Требует больше памяти

Сравнение архитектур

АрхитектураСкоростьТочностьПрименениеСложность
FCNБыстроХорошоBaselineНизкая
U-NetБыстроОтличноМедицинаНизкая
DeepLabМедленноОтличноСценыСредняя
Mask R-CNNМедленноОтличноInstance segВысокая
SegFormerБыстроОтличноSOTAСредняя

Loss functions для сегментации

import torch.nn as nn

# Cross Entropy (стандартная)
loss_fn = nn.CrossEntropyLoss()

# Dice Loss (для дисбаланса классов)
class DiceLoss(nn.Module):
    def forward(self, pred, target, smooth=1.0):
        pred = torch.sigmoid(pred)
        intersection = (pred * target).sum()
        dice = (2 * intersection + smooth) / (pred.sum() + target.sum() + smooth)
        return 1 - dice

# Focal Loss (для очень несбалансированных данных)
loss_fn = nn.FocalLoss(alpha=0.25, gamma=2.0)

# Комбинированная
def combined_loss(pred, target):
    ce_loss = nn.CrossEntropyLoss()(pred, target)
    dice_loss = DiceLoss()(pred, target)
    return ce_loss + dice_loss

Обучение

model = UNet(num_classes=num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    model.train()
    for images, masks in train_loader:
        images = images.to(device)
        masks = masks.to(device)
        
        # Forward
        outputs = model(images)
        loss = loss_fn(outputs, masks)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Evaluation
    model.eval()
    with torch.no_grad():
        iou = calculate_iou(model, val_loader)
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}, IoU: {iou:.4f}")

Вывод: выбор архитектуры зависит от задачи. U-Net для медицины, DeepLab для сцен, Mask R-CNN для instance segmentation, SegFormer для SOTA результатов.

Какие архитектуры для сегментации изображений знаете? | PrepBro