← Назад к вопросам
Расскажите про архитектуру ViT (Vision Transformer)
1.3 Junior🔥 81 комментариев
#Глубокое обучение
Комментарии (1)
🐱
claude-haiku-4.5PrepBro AI30 мар. 2026 г.(ред.)
Ответ сгенерирован нейросетью и может содержать ошибки
Vision Transformer (ViT): Архитектура и принципы
Vision Transformer (ViT) — это революционная архитектура, которая применяет трансформер (изначально разработанный для обработки текста) к задачам компьютерного зрения. Вместо использования сверточных операций (как в CNN), ViT разбивает изображение на патчи и обрабатывает их как последовательность токенов.
История и мотивация
До ViT все выполняли компьютерное зрение с использованием CNN:
- ResNet, EfficientNet, VGG — все использовали свертки
- Трансформеры доминировали только в NLP
Статья ViT (2020) показала, что можно применить чистый трансформер к изображениям и получить лучшие результаты!
Архитектура ViT
import torch
import torch.nn as nn
from einops import rearrange
class PatchEmbedding(nn.Module):
"""Разбиение изображения на патчи и их эмбеддинг"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2 # 14*14 = 196 патчей
# Линейная проекция патчей
self.projection = nn.Linear(
in_channels * patch_size * patch_size, # 3 * 16 * 16 = 768
embed_dim # 768
)
def forward(self, x):
# x: (B, 3, 224, 224)
batch_size = x.shape[0]
# Разбиение на патчи
x = rearrange(
x,
'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
p1=self.patch_size,
p2=self.patch_size
)
# x: (B, 196, 768)
# Линейная проекция
x = self.projection(x) # (B, 196, 768)
return x
class VisionTransformer(nn.Module):
"""Полная архитектура Vision Transformer"""
def __init__(
self,
img_size=224,
patch_size=16,
in_channels=3,
num_classes=1000,
embed_dim=768,
num_heads=12,
depth=12, # количество трансформер блоков
mlp_ratio=4,
dropout=0.1
):
super().__init__()
# 1. Patch Embedding
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
n_patches = (img_size // patch_size) ** 2 # 196
# 2. Класс токен (learnable параметр)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
nn.init.normal_(self.cls_token, std=0.02)
# 3. Позиционные эмбеддинги
self.pos_embed = nn.Parameter(
torch.zeros(1, n_patches + 1, embed_dim) # +1 для cls токена
)
nn.init.normal_(self.pos_embed, std=0.02)
self.pos_drop = nn.Dropout(p=dropout)
# 4. Трансформер блоки
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=embed_dim * mlp_ratio,
dropout=dropout,
activation='gelu',
batch_first=True
),
num_layers=depth
)
# 5. Classification head
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
def forward(self, x):
batch_size = x.shape[0]
# 1. Patch embedding
x = self.patch_embed(x) # (B, 196, 768)
# 2. Добавляем класс токен
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # (B, 1, 768)
x = torch.cat([cls_tokens, x], dim=1) # (B, 197, 768)
# 3. Добавляем позиционные эмбеддинги
x = x + self.pos_embed # (B, 197, 768)
x = self.pos_drop(x)
# 4. Трансформер
x = self.transformer(x) # (B, 197, 768)
# 5. Берём только cls токен для классификации
x = x[:, 0] # (B, 768)
# 6. Layer normalization и линейный слой
x = self.norm(x)
x = self.head(x) # (B, 1000)
return x
# Инициализация
model = VisionTransformer(
img_size=224,
patch_size=16,
num_classes=1000,
embed_dim=768,
num_heads=12,
depth=12
)
print(f"ViT-Base параметров: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
# Тест с батчем
x = torch.randn(2, 3, 224, 224)
output = model(x)
print(f"Вход: {x.shape}")
print(f"Выход: {output.shape}")
Сравнение с CNN (ResNet)
print("ViT vs CNN (ResNet):")
print("="*60)
print("\nCNN (ResNet50):")
print("- Локальные рецептивные поля")
print("- Параметры используются повторно (sharing)")
print("- Иерархичность (feature pyramid)")
print("- Параметров: ~25M")
print("- Требует меньше данных для обучения")
print("\nViT (Vision Transformer):")
print("- Глобальное внимание (каждый патч видит все остальные)")
print("- Параметры уникальны для каждой позиции")
print("- Нет иерархии (плоская архитектура)")
print("- Параметров: ~86M для ViT-Base")
print("- Требует много данных (ImageNet-21k для хорошего результата)")
print("- Требует предварительного обучения (pre-training)")
Ключевые компоненты
1. Patch Embedding
print("\n1. PATCH EMBEDDING")
print("-" * 40)
print(f"Изображение: 224x224 пиксели")
print(f"Размер патча: 16x16 пиксели")
print(f"Количество патчей: (224/16)^2 = 196 патчей")
print(f"Каждый патч: 16x16x3 = 768 чисел")
print(f"После линейной проекции: 768 -> 768 (embed_dim)")
print(f"\nВывод: последовательность из 196 токенов размером 768")
2. Класс токен и позиционные эмбеддинги
print("\n2. КЛАСС ТОКЕН И ПОЗИЦИИ")
print("-" * 40)
print(f"Класс токен: специальный learnable вектор [CLS]")
print(f"Назначение: агрегирует информацию со всех патчей")
print(f"Форма: 1x768")
print(f"\nПозиционные эмбеддинги:")
print(f"- Абсолютные позиции (не относительные как в CNN)")
print(f"- Learnable параметры (не синусные как в трансформерах NLP)")
print(f"- Форма: 197x768 (196 патчей + 1 cls токен)")
print(f"\nПолная последовательность: [CLS] + 196 патчей = 197 токенов")
3. Multi-Head Self-Attention
print("\n3. MULTI-HEAD SELF-ATTENTION")
print("-" * 40)
print(f"Количество heads: 12")
print(f"Размер head: 768 / 12 = 64")
print(f"\nНа каждом слое:")
print(f"- Каждый токен взаимодействует со ВСЕМИ токенами")
print(f"- Вычисляются weights (сколько внимания каждому)")
print(f"- Это дает глобальное восприятие (в отличие от CNN)")
print(f"\nЭффект:")
print(f"- Ранние слои: низкоуровневые паттерны")
print(f"- Средние слои: текстуры, локальные объекты")
print(f"- Поздние слои: высокоуровневые семантические концепции")
4. MLP (Feed-Forward) блоки
print("\n4. MLP (FEED-FORWARD)")
print("-" * 40)
print(f"Размер: 768 -> 3072 -> 768 (mlp_ratio=4)")
print(f"Активация: GELU")
print(f"Применяется к каждому токену независимо")
print(f"\nНасчет формулы:")
print(f"Transformer блок = MultiHeadAttention + MLP")
print(f"x = x + MultiHeadAttention(LayerNorm(x))")
print(f"x = x + MLP(LayerNorm(x))")
print(f"(остаточные связи + layer normalization до операции)")
Разновидности ViT
print("\nЛиния ViT моделей:")
print("="*60)
print("\nViT-Tiny:")
print(" Параметры: 5M")
print(" Embed dim: 192, Heads: 3, Layers: 12")
print(" Использование: Edge devices, мобильные")
print("\nViT-Small:")
print(" Параметры: 22M")
print(" Embed dim: 384, Heads: 6, Layers: 12")
print(" Использование: Ускоренный inference")
print("\nViT-Base (стандартный):")
print(" Параметры: 86M")
print(" Embed dim: 768, Heads: 12, Layers: 12")
print(" Использование: Стандартный выбор")
print("\nViT-Large:")
print(" Параметры: 304M")
print(" Embed dim: 1024, Heads: 16, Layers: 24")
print(" Использование: Высокая точность")
print("\nViT-Huge:")
print(" Параметры: 632M")
print(" Embed dim: 1280, Heads: 16, Layers: 32")
print(" Использование: SOTA результаты")
Преимущества и недостатки ViT
print("\nПРЕИМУЩЕСТВА ViT:")
print("-" * 40)
print("+ Глобальное внимание: модель видит все изображение сразу")
print("+ Масштабируемость: хорошо работает с большим количеством данных")
print("+ Трансферное обучение: легко применить к другим задачам")
print("+ Вычислительная эффективность: хорошо использует GPU")
print("+ SOTA точность: побеждает CNN на ImageNet")
print("+ Интерпретируемость: attention maps показывают что смотрит модель")
print("\nНЕДОСТАТКИ ViT:")
print("-" * 40)
print("- Требует много данных (>100M изображений для хорошего обучения)")
print("- Требует pre-training на ImageNet-21k")
print("- Медленнее CNN на малых изображениях")
print("- Высокие вычислительные затраты на обучение")
print("- Сложнее интегрировать иерархические вычисления")
print("- Менее интуитивен (сложнее отладить)")
Практическое использование
from torchvision.models import vit_b_16
from torchvision import transforms
from PIL import Image
import torch
# Загрузка предобученной модели
model = vit_b_16(pretrained=True)
model.eval()
# Подготовка данных
transforms_pipeline = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]
)
])
# Инференс
with torch.no_grad():
# image = Image.open('cat.jpg')
# x = transforms_pipeline(image).unsqueeze(0)
# output = model(x)
# probabilities = torch.softmax(output, dim=1)
pass
print("ViT легко использовать с torchvision!")
Улучшения ViT
print("\nУЛУЧШЕНИЯ И ВАРИАНТЫ:")
print("="*60)
print("\n1. DeiT (Data-efficient Image Transformers):")
print(" - Обучается на ImageNet, не требует ImageNet-21k")
print(" - Использует knowledge distillation")
print("\n2. Swin Transformer:")
print(" - Иерархическая архитектура (как CNN)")
print(" - Локальное внимание + смещение окна")
print(" - Лучше для detection и segmentation")
print("\n3. BEiT (BERT pre-training of Image Transformers):")
print(" - Самоконтролируемое обучение")
print(" - Маскирует случайные патчи и восстанавливает их")
print("\n4. EVA (Exploring Visual Representation via Vision Transformers):")
print(" - Масштабированные ViT модели")
print(" - SOTA на многих бенчмарках")
Vision Transformer — это перелом в истории компьютерного зрения, который показал, что трансформеры работают и за пределами NLP. Сейчас ViT становится стандартом для новых моделей CV.