Что такое contrastive learning?
Комментарии (1)
Ответ сгенерирован нейросетью и может содержать ошибки
Что такое contrastive learning?
Contrastive Learning - это парадигма самообучения (self-supervised learning), которая учит модель различать похожие и непохожие примеры. Основная идея: модель должна научиться представлять похожие примеры близко в пространстве признаков, а непохожие - далеко друг от друга. Это очень мощный подход для обучения без размеченных данных.
Основная концепция
Вместо предсказания меток класса, contrastive learning работает с парами примеров:
- Positive pair (положительная пара): два похожих примера (например, два аугментированных варианта одного изображения)
- Negative pair (отрицательная пара): два непохожих примера
Цель: минимизировать расстояние между положительными парами и максимизировать расстояние между отрицательными.
Метрика сходства: cosine similarity
import numpy as np
import torch
import torch.nn.functional as F
def cosine_similarity(x, y):
"""
Вычисляет косинусное сходство между двумя векторами
cosine_sim = (x . y) / (||x|| * ||y||)
Результат от -1 до 1, где 1 = идентичны
"""
return np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))
# Пример
x = np.array([1, 0, 0]) # Единичный вектор
y = np.array([1, 0, 0]) # Идентичный
z = np.array([0, 1, 0]) # Ортогональный
print(f"sim(x, y) = {cosine_similarity(x, y):.3f}") # 1.0
print(f"sim(x, z) = {cosine_similarity(x, z):.3f}") # 0.0
# В PyTorch
x_torch = torch.tensor([1.0, 0.0, 0.0])
y_torch = torch.tensor([1.0, 0.0, 0.0])
z_torch = torch.tensor([0.0, 1.0, 0.0])
print("PyTorch:")
print(f"sim(x, y) = {F.cosine_similarity(x_torch.unsqueeze(0), y_torch.unsqueeze(0)).item():.3f}")
print(f"sim(x, z) = {F.cosine_similarity(x_torch.unsqueeze(0), z_torch.unsqueeze(0)).item():.3f}")
Функция потерь: Contrastive Loss (Triplet Loss)
import torch
import torch.nn.functional as F
def contrastive_loss(anchor, positive, negative, margin=1.0):
"""
Triplet Loss: придвигает positive ближе к anchor,
а negative отодвигает на расстояние margin
L = max(0, d(anchor, positive) - d(anchor, negative) + margin)
"""
pos_dist = F.pairwise_distance(anchor, positive, p=2)
neg_dist = F.pairwise_distance(anchor, negative, p=2)
loss = torch.relu(pos_dist - neg_dist + margin)
return loss.mean()
# Пример
anchor = torch.randn(32, 128) # batch_size=32, embedding_dim=128
positive = torch.randn(32, 128)
negative = torch.randn(32, 128)
loss = contrastive_loss(anchor, positive, negative)
print(f"Triplet Loss: {loss.item():.4f}")
NT-Xent Loss (InfoNCE)
Самая популярная функция потерь для contrastive learning. Используется в SimCLR, MoCo и других.
import torch
import torch.nn.functional as F
def nt_xent_loss(z_i, z_j, temperature=0.5, batch_size=32):
"""
Normalized Temperature-scaled Cross Entropy Loss (NT-Xent)
Основная формула SimCLR
"""
# Объединяем репрезентации: [2*batch_size, embedding_dim]
representations = torch.cat([z_i, z_j], dim=0)
# Вычисляем матрицу косинусного сходства
# similarity_matrix shape: [2*batch_size, 2*batch_size]
similarity_matrix = F.cosine_similarity(
representations.unsqueeze(1),
representations.unsqueeze(0),
dim=2
)
# Нормализуем на температуру
similarity_matrix = similarity_matrix / temperature
# Создаем матрицу положительных пар
# Пара (i, batch_size + i) и (batch_size + i, i) положительны
pos_mask = torch.eye(2 * batch_size, dtype=torch.bool)
pos_mask[:batch_size, batch_size:] = torch.eye(batch_size, dtype=torch.bool)
pos_mask[batch_size:, :batch_size] = torch.eye(batch_size, dtype=torch.bool)
# Диагональные элементы (i с самим собой) исключаем
pos_mask.fill_diagonal_(False)
# Применяем softmax и берем логарифм
logits = similarity_matrix
labels = torch.arange(2 * batch_size, device=z_i.device)
# Смещаем индексы так, чтобы положительная пара была на диагонали
logits[~pos_mask] = -float('inf')
loss = F.cross_entropy(logits / temperature, labels)
return loss
SimCLR - популярный framework
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
class SimCLR(nn.Module):
def __init__(self, encoder, projection_dim=128):
super().__init__()
self.encoder = encoder # ResNet50 или другая архитектура
# Projection head
feat_dim = encoder.fc.in_features
self.projection = nn.Sequential(
nn.Linear(feat_dim, feat_dim),
nn.ReLU(),
nn.Linear(feat_dim, projection_dim)
)
def forward(self, x):
# Получаем features из энкодера
features = self.encoder(x)
# Проходим через projection head
return self.projection(features)
# Data augmentation
augment = transforms.Compose([
transforms.RandomResizedCrop(32),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
class ContrastiveDataset(CIFAR10):
def __getitem__(self, index):
x, y = super().__getitem__(index)
# Создаем две аугментации одного примера
x1 = augment(x)
x2 = augment(x)
return x1, x2
# Тренировка
dataset = ContrastiveDataset('./data', download=True, transform=augment)
loader = DataLoader(dataset, batch_size=256, shuffle=True)
model = SimCLR(encoder=resnet50(), projection_dim=128)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(100):
for batch_idx, (x1, x2) in enumerate(loader):
# Forward pass
z1 = model(x1)
z2 = model(x2)
# Loss
loss = nt_xent_loss(z1, z2, temperature=0.5)
# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_idx % 10 == 0:
print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
Применения Contrastive Learning
1. Computer Vision
# Обучение на неразмеченных изображениях
# SimCLR: 69.3% на ImageNet с одним слоем классификатора
# vs. 25.2% у random features
2. Natural Language Processing
# Contrastive learning для текстовых представлений
# CLIP: связывает образы с текстом через contrastive learning
# Позволяет делать zero-shot классификацию
3. Метрическое обучение
# Когда нужно измерять сходство между объектами
# Face recognition, similarity search, recommendation systems
Вариации и улучшения
| Метод | Особенность | Год |
|---|---|---|
| SimCLR | Простая, но эффективная | 2020 |
| MoCo | Momentum контраст, очередь негативов | 2020 |
| BYOL | Не требует негативных примеров | 2020 |
| SwAV | Кластеризация + контраст | 2020 |
| CLIP | Vision + Language | 2021 |
Ключевые преимущества
- Не требует разметки: обучается на неразмеченных данных
- Эффективна: меньше потребляет памяти, чем другие self-supervised методы
- Масштабируемость: хорошо работает на больших датасетах
- Переносимость: выученные представления хорошо переносятся на downstream tasks
Выводы
Contrastive learning революционизировал самообучение и показал, что модели могут выучить мощные представления без размеченных данных. Это особенно важно для областей, где разметка дорогая или недоступна.