← Назад к вопросам

Какой граф вычислений в PyTorch?

2.3 Middle🔥 131 комментариев
#Python#Глубокое обучение

Комментарии (1)

🐱
claude-haiku-4.5PrepBro AI30 мар. 2026 г.(ред.)

Ответ сгенерирован нейросетью и может содержать ошибки

Граф вычислений в PyTorch

Пиру вычислений (Computational Graph) — один из ключевых концепций PyTorch, который позволяет автоматически вычислять градиенты для обратного распространения.

Динамический граф вычислений

В отличие от TensorFlow 1.x (статический граф), PyTorch использует динамический граф вычислений, который строится на лету во время прямого прохода.

Основные характеристики:

  • Построение во время выполнения: граф создаётся во время forward pass, а не предопределяется
  • Гибкость: возможны условные операции, циклы и динамические структуры
  • Отладка: легче отлаживать, так как можно использовать обычные инструменты Python
  • Эффективность: граф существует только до начала backward pass

Как работает граф

import torch

# Создаём тензоры с require_grad=True
x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([3.0], requires_grad=True)

# Forward pass: строится граф вычислений
z = x**2 + y * 2  # z = 4 + 6 = 10

# В этот момент граф выглядит так:
# z ← y*2 (умножение)
# z ← x**2 (возведение в степень)
# Каждая операция содержит функцию для обратного прохода

# Backward pass: вычисляем градиенты
z.backward()

print(x.grad)  # dz/dx = 2*x = 4.0
print(y.grad)  # dz/dy = 2

Структура узла графа

Каждый тензор, созданный с requires_grad=True, содержит:

  1. data — сам тензор с данными
  2. grad — накопленные градиенты
  3. grad_fn — функция для обратного прохода (откуда тензор взялся)
  4. requires_grad — флаг, нужно ли вычислять градиент
  5. is_leaf — листовой узел (исходная переменная) или промежуточный результат
x = torch.tensor([2.0], requires_grad=True)
z = x**2 + x

print(z.grad_fn)  # <PowBackward0 object>
print(x.grad_fn)  # None (листовой узел)
print(z.is_leaf)  # False
print(x.is_leaf)  # True

Листовые и промежуточные узлы

Листовые узлы (Leaf nodes):

  • Исходные переменные (параметры модели, входные данные)
  • grad_fn = None
  • Могут сохранять градиенты (если requires_grad=True)

Промежуточные узлы (Intermediate nodes):

  • Результаты операций
  • grad_fn указывает на операцию, которая их создала
  • По умолчанию их retain_graph=False, т.е. они удаляются после backward()

Кроме backward pass

После того как вызывается backward(), граф автоматически удаляется (по умолчанию).

x = torch.tensor([2.0], requires_grad=True)
y = x**2

y.backward()
print(x.grad)  # 4.0

# Граф удалён, нельзя вызвать backward ещё раз
y.backward()  # RuntimeError: leaf variable has been moved into the graph interior

Сохранение графа:

x = torch.tensor([2.0], requires_grad=True)
y = x**2

y.backward(retain_graph=True)
print(x.grad)  # 4.0

y.backward()  # Работает, так как граф сохранён
print(x.grad)  # 8.0 (градиент накопился)

Отключение отслеживания графа

Для инференса (без gradients):

with torch.no_grad():
    y = model(x)  # Граф не строится, экономим память

# Или
y = model(x).detach()  # Отсоедиеяем от графа

Отключение для отдельного тензора:

x = torch.tensor([2.0], requires_grad=True)
x_no_grad = x.detach()  # Новый тензор без отслеживания

Пример с нейронной сетью

import torch
import torch.nn as nn

model = nn.Linear(10, 1)
x = torch.randn(1, 10)
y_true = torch.tensor([1.0])

# Forward pass
y_pred = model(x)
loss = (y_pred - y_true)**2

# Граф выглядит:
# loss → (y_pred - y_true)^2
# y_pred → Linear(weight, bias, x)
# x → исходные данные
# weight, bias → параметры модели (листовые узлы)

# Backward pass
loss.backward()

# Теперь model.weight.grad и model.bias.grad содержат градиенты
with torch.no_grad():
    model.weight -= 0.01 * model.weight.grad
    model.bias -= 0.01 * model.bias.grad

Основные преимущества динамического графа

  1. Удобство отладки — можно использовать print(), pdb, IDE
  2. Гибкость — условные операции, переменные циклы
  3. Простота — код выглядит как обычный Python
  4. Эффективность памяти — граф удаляется после backward()
  5. Поддержка рекуррентных сетей — естественная работа с циклами

Это делает PyTorch идеальным для исследований и быстрого прототипирования, где нужна гибкость и простота разработки.