第6章:機械学習フレームワーク
更新日:2025年12月9日
1. PyTorch vs JAX
PyTorchとJAXは、現代の深層学習研究における2大フレームワークである。Table 1に両者の特徴を比較する。
Table 1. PyTorch vs JAX
| 観点 | PyTorch | JAX |
|---|---|---|
| 開発元 | Meta (Facebook) | |
| 計算グラフ | 動的(Define-by-Run) | 関数変換ベース |
| 自動微分 | autograd | grad(関数変換) |
| JITコンパイル | torch.compile(2.0+) | jax.jit(コア機能) |
| エコシステム | 非常に豊富 | 成長中(Flax, Optax等) |
| デバッグ | 容易(Eager実行) | やや難(関数型) |
| 主な用途 | 汎用、プロダクション | 研究、TPU、大規模学習 |
1.1 PyTorchの基本:
import torch
import torch.nn as nn
import torch.optim as optim
# デバイス設定
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# モデル定義
class MLP(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x)
# インスタンス化
model = MLP(784, 256, 10).to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
# 学習ステップ
def train_step(model, batch, optimizer, criterion):
model.train()
x, y = batch
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
logits = model(x)
loss = criterion(logits, y)
loss.backward()
optimizer.step()
return loss.item()
1.2 JAXの基本:関数型プログラミングのパラダイムを採用[1]。
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from flax import linen as nn
from flax.training import train_state
import optax
# JAXはステートレス:パラメータを明示的に管理
class MLP(nn.Module):
hidden_dim: int
output_dim: int
@nn.compact
def __call__(self, x, training: bool = True):
x = nn.Dense(self.hidden_dim)(x)
x = nn.relu(x)
x = nn.Dropout(0.1, deterministic=not training)(x)
x = nn.Dense(self.hidden_dim)(x)
x = nn.relu(x)
x = nn.Dropout(0.1, deterministic=not training)(x)
x = nn.Dense(self.output_dim)(x)
return x
# 初期化
model = MLP(hidden_dim=256, output_dim=10)
key = jax.random.PRNGKey(0)
params = model.init(key, jnp.ones((1, 784)))
# Optaxによるオプティマイザ
tx = optax.adamw(learning_rate=1e-3, weight_decay=0.01)
# TrainState: パラメータとオプティマイザ状態を管理
state = train_state.TrainState.create(
apply_fn=model.apply,
params=params,
tx=tx,
)
# JITコンパイルされた学習ステップ
@jit
def train_step(state, batch):
def loss_fn(params):
x, y = batch
logits = state.apply_fn(params, x, training=True, rngs={'dropout': jax.random.PRNGKey(0)})
return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
loss, grads = jax.value_and_grad(loss_fn)(state.params)
state = state.apply_gradients(grads=grads)
return state, loss
Fig. 1に両フレームワークの設計思想の違いを示す。
2. 学習ループ設計
効率的で保守性の高い学習ループの設計パターンを解説する。
2.1 データローダー:PyTorchのDataLoaderは並列データ読み込みを提供。
from torch.utils.data import Dataset, DataLoader
from typing import Tuple
import numpy as np
class CustomDataset(Dataset):
def __init__(self, data: np.ndarray, labels: np.ndarray):
self.data = torch.FloatTensor(data)
self.labels = torch.LongTensor(labels)
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
return self.data[idx], self.labels[idx]
# DataLoader設定
train_loader = DataLoader(
dataset=train_dataset,
batch_size=64,
shuffle=True,
num_workers=4, # 並列ワーカー数
pin_memory=True, # GPU転送を高速化
drop_last=True, # 最後の不完全バッチを除外
persistent_workers=True, # ワーカーを維持(エポック間)
)
2.2 学習ループのベストプラクティス:
import torch
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
def train_epoch(
model: nn.Module,
train_loader: DataLoader,
optimizer: optim.Optimizer,
criterion: nn.Module,
device: torch.device,
scaler: GradScaler | None = None, # Mixed Precision用
) -> dict:
model.train()
total_loss = 0.0
correct = 0
total = 0
pbar = tqdm(train_loader, desc='Training')
for batch_idx, (x, y) in enumerate(pbar):
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
optimizer.zero_grad(set_to_none=True) # set_to_none=Trueでメモリ効率化
# Mixed Precision Training
if scaler is not None:
with autocast(dtype=torch.float16):
logits = model(x)
loss = criterion(logits, y)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
else:
logits = model(x)
loss = criterion(logits, y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
# メトリクス計算
total_loss += loss.item()
pred = logits.argmax(dim=1)
correct += (pred == y).sum().item()
total += y.size(0)
pbar.set_postfix({
'loss': f'{loss.item():.4f}',
'acc': f'{correct/total:.4f}'
})
return {
'loss': total_loss / len(train_loader),
'accuracy': correct / total
}
2.3 PyTorch 2.0 torch.compile:JITコンパイルによる高速化[2]。
# torch.compileによる最適化
model = MLP(784, 256, 10).to(device)
# コンパイル(初回実行時にコンパイルされる)
compiled_model = torch.compile(
model,
mode='reduce-overhead', # 'default', 'reduce-overhead', 'max-autotune'
fullgraph=True, # グラフ全体をコンパイル
)
# 使用方法は通常と同じ
output = compiled_model(input_tensor)
# 典型的な高速化: 1.5x - 2x(モデルとハードウェアに依存)
3. 分散学習
3.1 DDP(DistributedDataParallel)
DDPはデータ並列学習の標準的な手法である[3]。各GPUが同じモデルのコピーを持ち、異なるデータバッチで学習する。
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import os
def setup(rank: int, world_size: int):
"""分散環境の初期化"""
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group('nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def train_ddp(rank: int, world_size: int):
setup(rank, world_size)
# モデルをDDPでラップ
model = MLP(784, 256, 10).to(rank)
model = DDP(model, device_ids=[rank])
# DistributedSamplerでデータを分割
train_sampler = DistributedSampler(
train_dataset,
num_replicas=world_size,
rank=rank,
shuffle=True,
)
train_loader = DataLoader(
train_dataset,
batch_size=64,
sampler=train_sampler,
num_workers=4,
pin_memory=True,
)
optimizer = optim.AdamW(model.parameters(), lr=1e-3 * world_size) # Linear scaling
for epoch in range(num_epochs):
train_sampler.set_epoch(epoch) # シャッフルのために必要
train_epoch(model, train_loader, optimizer, criterion, rank)
cleanup()
# 実行
if __name__ == '__main__':
world_size = torch.cuda.device_count()
torch.multiprocessing.spawn(train_ddp, args=(world_size,), nprocs=world_size)
Fig. 2にDDPの動作を示す。
3.2 FSDP(Fully Sharded Data Parallel)
FSDPは大規模モデル向けの分散学習手法である。モデルパラメータ、勾配、オプティマイザ状態をGPU間で分割する[4]。
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools
# Mixed Precision設定
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)
# 自動ラップポリシー(Transformerレイヤーごとにシャード)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerEncoderLayer},
)
# FSDPでラップ
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mp_policy,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
)
# 以降は通常のDDPと同様に学習
Table 2. DDP vs FSDP
| 観点 | DDP | FSDP |
|---|---|---|
| メモリ効率 | 低(全パラメータを各GPUに保持) | 高(パラメータを分割) |
| 通信オーバーヘッド | 低 | 中〜高 |
| 実装複雑度 | 低 | 中 |
| 適用モデルサイズ | 〜数十億パラメータ | 数百億パラメータ以上 |
4. 実験管理
機械学習プロジェクトでは、多数の実験を追跡・比較する必要がある。
4.1 Weights & Biases(wandb):クラウドベースの実験管理プラットフォーム[5]。
import wandb
# 初期化
wandb.init(
project='my-ml-project',
name='experiment-001',
config={
'learning_rate': 1e-3,
'batch_size': 64,
'hidden_dim': 256,
'epochs': 100,
'optimizer': 'AdamW',
'weight_decay': 0.01,
}
)
# 学習ループ内でログ
for epoch in range(num_epochs):
train_metrics = train_epoch(model, train_loader, optimizer, criterion, device)
val_metrics = evaluate(model, val_loader, criterion, device)
wandb.log({
'epoch': epoch,
'train/loss': train_metrics['loss'],
'train/accuracy': train_metrics['accuracy'],
'val/loss': val_metrics['loss'],
'val/accuracy': val_metrics['accuracy'],
'learning_rate': optimizer.param_groups[0]['lr'],
})
# モデルのチェックポイント保存
if val_metrics['accuracy'] > best_accuracy:
best_accuracy = val_metrics['accuracy']
wandb.save('best_model.pt')
torch.save(model.state_dict(), 'best_model.pt')
# 終了
wandb.finish()
4.2 MLflow:オープンソースの実験管理・モデル管理プラットフォーム。
import mlflow
# 実験の設定
mlflow.set_experiment('my-ml-project')
with mlflow.start_run(run_name='experiment-001'):
# パラメータのログ
mlflow.log_params({
'learning_rate': 1e-3,
'batch_size': 64,
'hidden_dim': 256,
})
for epoch in range(num_epochs):
train_metrics = train_epoch(...)
# メトリクスのログ
mlflow.log_metrics({
'train_loss': train_metrics['loss'],
'train_accuracy': train_metrics['accuracy'],
}, step=epoch)
# モデルの保存
mlflow.pytorch.log_model(model, 'model')
# アーティファクトの保存
mlflow.log_artifact('config.yaml')
Table 3. 実験管理ツールの比較
| ツール | ホスティング | 特徴 |
|---|---|---|
| Weights & Biases | クラウド / セルフホスト | UI優秀、チーム機能充実 |
| MLflow | セルフホスト | OSS、モデルレジストリ |
| TensorBoard | ローカル | シンプル、TF/PyTorch対応 |
| Neptune | クラウド | メタデータ管理に強い |
| Comet | クラウド / セルフホスト | コード差分追跡 |
5. 再現性確保
機械学習の再現性は研究・運用の両面で重要である。
5.1 乱数シードの固定:
import random
import numpy as np
import torch
def set_seed(seed: int = 42):
"""全ての乱数シードを固定"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# 決定論的アルゴリズムの使用(性能低下の可能性あり)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# PyTorch 2.0+
torch.use_deterministic_algorithms(True)
# 注意: 一部の操作は決定論的バージョンが存在しない
# 環境変数で警告を有効化
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
5.2 環境の固定:
# uv.lock / poetry.lock による依存関係の固定
# Dockerによる環境の完全な再現
# Dockerfile
FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
WORKDIR /app
COPY pyproject.toml uv.lock ./
RUN pip install uv && uv sync --frozen
COPY . .
CMD ["python", "train.py"]
5.3 チェックポイントの保存:学習状態の完全な保存と復元。
def save_checkpoint(
model: nn.Module,
optimizer: optim.Optimizer,
scheduler,
epoch: int,
best_metric: float,
path: str,
):
"""チェックポイントの保存"""
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
'best_metric': best_metric,
'rng_state': {
'python': random.getstate(),
'numpy': np.random.get_state(),
'torch': torch.get_rng_state(),
'cuda': torch.cuda.get_rng_state_all(),
},
}
torch.save(checkpoint, path)
def load_checkpoint(
path: str,
model: nn.Module,
optimizer: optim.Optimizer,
scheduler=None,
):
"""チェックポイントの復元"""
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if scheduler and checkpoint['scheduler_state_dict']:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
# 乱数状態の復元
rng_state = checkpoint['rng_state']
random.setstate(rng_state['python'])
np.random.set_state(rng_state['numpy'])
torch.set_rng_state(rng_state['torch'])
torch.cuda.set_rng_state_all(rng_state['cuda'])
return checkpoint['epoch'], checkpoint['best_metric']
5.4 再現性チェックリスト:
# 再現性確保のためのチェックリスト
"""
□ 乱数シード固定(Python, NumPy, PyTorch, CUDA)
□ 決定論的アルゴリズム有効化
□ 依存関係のロックファイル(uv.lock, poetry.lock)
□ Dockerイメージのタグ固定
□ データのバージョン管理(DVC等)
□ コードのバージョン管理(Git commit hash記録)
□ ハイパーパラメータの記録
□ 学習・評価スクリプトの記録
□ ハードウェア情報の記録(GPU型番、CUDA版等)
"""
References
[1] JAX, "JAX: Autograd and XLA," jax.readthedocs.io, 2024.
[2] PyTorch, "torch.compile Tutorial," pytorch.org/tutorials, 2024.
[3] PyTorch, "DistributedDataParallel," pytorch.org/docs, 2024.
[4] PyTorch, "FSDP Tutorial," pytorch.org/tutorials, 2024.
[5] Weights & Biases, "W&B Documentation," docs.wandb.ai, 2024.
本コンテンツは2025年12月時点の情報に基づいて作成されている。各フレームワークのAPIは活発に開発が進んでおり、変更される可能性がある。最新の情報は公式ドキュメントを参照されたい。