← Назад к вопросам

Что такое contrastive learning?

2.7 Senior🔥 121 комментариев
#Глубокое обучение#Машинное обучение

Комментарии (1)

🐱
claude-haiku-4.5PrepBro AI2 апр. 2026 г.(ред.)

Ответ сгенерирован нейросетью и может содержать ошибки

Что такое contrastive learning?

Contrastive Learning - это парадигма самообучения (self-supervised learning), которая учит модель различать похожие и непохожие примеры. Основная идея: модель должна научиться представлять похожие примеры близко в пространстве признаков, а непохожие - далеко друг от друга. Это очень мощный подход для обучения без размеченных данных.

Основная концепция

Вместо предсказания меток класса, contrastive learning работает с парами примеров:

  • Positive pair (положительная пара): два похожих примера (например, два аугментированных варианта одного изображения)
  • Negative pair (отрицательная пара): два непохожих примера

Цель: минимизировать расстояние между положительными парами и максимизировать расстояние между отрицательными.

Метрика сходства: cosine similarity

import numpy as np
import torch
import torch.nn.functional as F

def cosine_similarity(x, y):
    """
    Вычисляет косинусное сходство между двумя векторами
    cosine_sim = (x . y) / (||x|| * ||y||)
    Результат от -1 до 1, где 1 = идентичны
    """
    return np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))

# Пример
x = np.array([1, 0, 0])  # Единичный вектор
y = np.array([1, 0, 0])  # Идентичный
z = np.array([0, 1, 0])  # Ортогональный

print(f"sim(x, y) = {cosine_similarity(x, y):.3f}")  # 1.0
print(f"sim(x, z) = {cosine_similarity(x, z):.3f}")  # 0.0

# В PyTorch
x_torch = torch.tensor([1.0, 0.0, 0.0])
y_torch = torch.tensor([1.0, 0.0, 0.0])
z_torch = torch.tensor([0.0, 1.0, 0.0])

print("PyTorch:")
print(f"sim(x, y) = {F.cosine_similarity(x_torch.unsqueeze(0), y_torch.unsqueeze(0)).item():.3f}")
print(f"sim(x, z) = {F.cosine_similarity(x_torch.unsqueeze(0), z_torch.unsqueeze(0)).item():.3f}")

Функция потерь: Contrastive Loss (Triplet Loss)

import torch
import torch.nn.functional as F

def contrastive_loss(anchor, positive, negative, margin=1.0):
    """
    Triplet Loss: придвигает positive ближе к anchor,
    а negative отодвигает на расстояние margin
    
    L = max(0, d(anchor, positive) - d(anchor, negative) + margin)
    """
    pos_dist = F.pairwise_distance(anchor, positive, p=2)
    neg_dist = F.pairwise_distance(anchor, negative, p=2)
    
    loss = torch.relu(pos_dist - neg_dist + margin)
    return loss.mean()

# Пример
anchor = torch.randn(32, 128)  # batch_size=32, embedding_dim=128
positive = torch.randn(32, 128)
negative = torch.randn(32, 128)

loss = contrastive_loss(anchor, positive, negative)
print(f"Triplet Loss: {loss.item():.4f}")

NT-Xent Loss (InfoNCE)

Самая популярная функция потерь для contrastive learning. Используется в SimCLR, MoCo и других.

import torch
import torch.nn.functional as F

def nt_xent_loss(z_i, z_j, temperature=0.5, batch_size=32):
    """
    Normalized Temperature-scaled Cross Entropy Loss (NT-Xent)
    Основная формула SimCLR
    """
    # Объединяем репрезентации: [2*batch_size, embedding_dim]
    representations = torch.cat([z_i, z_j], dim=0)
    
    # Вычисляем матрицу косинусного сходства
    # similarity_matrix shape: [2*batch_size, 2*batch_size]
    similarity_matrix = F.cosine_similarity(
        representations.unsqueeze(1),
        representations.unsqueeze(0),
        dim=2
    )
    
    # Нормализуем на температуру
    similarity_matrix = similarity_matrix / temperature
    
    # Создаем матрицу положительных пар
    # Пара (i, batch_size + i) и (batch_size + i, i) положительны
    pos_mask = torch.eye(2 * batch_size, dtype=torch.bool)
    pos_mask[:batch_size, batch_size:] = torch.eye(batch_size, dtype=torch.bool)
    pos_mask[batch_size:, :batch_size] = torch.eye(batch_size, dtype=torch.bool)
    
    # Диагональные элементы (i с самим собой) исключаем
    pos_mask.fill_diagonal_(False)
    
    # Применяем softmax и берем логарифм
    logits = similarity_matrix
    labels = torch.arange(2 * batch_size, device=z_i.device)
    
    # Смещаем индексы так, чтобы положительная пара была на диагонали
    logits[~pos_mask] = -float('inf')
    
    loss = F.cross_entropy(logits / temperature, labels)
    return loss

SimCLR - популярный framework

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

class SimCLR(nn.Module):
    def __init__(self, encoder, projection_dim=128):
        super().__init__()
        self.encoder = encoder  # ResNet50 или другая архитектура
        
        # Projection head
        feat_dim = encoder.fc.in_features
        self.projection = nn.Sequential(
            nn.Linear(feat_dim, feat_dim),
            nn.ReLU(),
            nn.Linear(feat_dim, projection_dim)
        )
    
    def forward(self, x):
        # Получаем features из энкодера
        features = self.encoder(x)
        # Проходим через projection head
        return self.projection(features)

# Data augmentation
augment = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

class ContrastiveDataset(CIFAR10):
    def __getitem__(self, index):
        x, y = super().__getitem__(index)
        # Создаем две аугментации одного примера
        x1 = augment(x)
        x2 = augment(x)
        return x1, x2

# Тренировка
dataset = ContrastiveDataset('./data', download=True, transform=augment)
loader = DataLoader(dataset, batch_size=256, shuffle=True)

model = SimCLR(encoder=resnet50(), projection_dim=128)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(100):
    for batch_idx, (x1, x2) in enumerate(loader):
        # Forward pass
        z1 = model(x1)
        z2 = model(x2)
        
        # Loss
        loss = nt_xent_loss(z1, z2, temperature=0.5)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch_idx % 10 == 0:
            print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")

Применения Contrastive Learning

1. Computer Vision

# Обучение на неразмеченных изображениях
# SimCLR: 69.3% на ImageNet с одним слоем классификатора
# vs. 25.2% у random features

2. Natural Language Processing

# Contrastive learning для текстовых представлений
# CLIP: связывает образы с текстом через contrastive learning
# Позволяет делать zero-shot классификацию

3. Метрическое обучение

# Когда нужно измерять сходство между объектами
# Face recognition, similarity search, recommendation systems

Вариации и улучшения

МетодОсобенностьГод
SimCLRПростая, но эффективная2020
MoCoMomentum контраст, очередь негативов2020
BYOLНе требует негативных примеров2020
SwAVКластеризация + контраст2020
CLIPVision + Language2021

Ключевые преимущества

  • Не требует разметки: обучается на неразмеченных данных
  • Эффективна: меньше потребляет памяти, чем другие self-supervised методы
  • Масштабируемость: хорошо работает на больших датасетах
  • Переносимость: выученные представления хорошо переносятся на downstream tasks

Выводы

Contrastive learning революционизировал самообучение и показал, что модели могут выучить мощные представления без размеченных данных. Это особенно важно для областей, где разметка дорогая или недоступна.