第6章:機械学習フレームワーク

更新日:2025年12月9日

本章では、Pythonにおける機械学習・深層学習フレームワークを解説する。PyTorchとJAXの設計思想と使い分け、効率的な学習ループの設計パターン、DDP/FSDPによる分散学習、Weights & Biases等による実験管理、再現性を確保するためのベストプラクティスについて学ぶ。フレームワークの特性を理解することで、研究から本番運用まで適切な技術選択が可能になる。

1. PyTorch vs JAX

PyTorchとJAXは、現代の深層学習研究における2大フレームワークである。Table 1に両者の特徴を比較する。

Table 1. PyTorch vs JAX

観点 PyTorch JAX
開発元 Meta (Facebook) Google
計算グラフ 動的(Define-by-Run) 関数変換ベース
自動微分 autograd grad(関数変換)
JITコンパイル torch.compile(2.0+) jax.jit(コア機能)
エコシステム 非常に豊富 成長中(Flax, Optax等)
デバッグ 容易(Eager実行) やや難(関数型)
主な用途 汎用、プロダクション 研究、TPU、大規模学習

1.1 PyTorchの基本

import torch
import torch.nn as nn
import torch.optim as optim

# デバイス設定
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# モデル定義
class MLP(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)

# インスタンス化
model = MLP(784, 256, 10).to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

# 学習ステップ
def train_step(model, batch, optimizer, criterion):
    model.train()
    x, y = batch
    x, y = x.to(device), y.to(device)

    optimizer.zero_grad()
    logits = model(x)
    loss = criterion(logits, y)
    loss.backward()
    optimizer.step()

    return loss.item()

1.2 JAXの基本:関数型プログラミングのパラダイムを採用[1]。

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from flax import linen as nn
from flax.training import train_state
import optax

# JAXはステートレス:パラメータを明示的に管理
class MLP(nn.Module):
    hidden_dim: int
    output_dim: int

    @nn.compact
    def __call__(self, x, training: bool = True):
        x = nn.Dense(self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dropout(0.1, deterministic=not training)(x)
        x = nn.Dense(self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dropout(0.1, deterministic=not training)(x)
        x = nn.Dense(self.output_dim)(x)
        return x

# 初期化
model = MLP(hidden_dim=256, output_dim=10)
key = jax.random.PRNGKey(0)
params = model.init(key, jnp.ones((1, 784)))

# Optaxによるオプティマイザ
tx = optax.adamw(learning_rate=1e-3, weight_decay=0.01)

# TrainState: パラメータとオプティマイザ状態を管理
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=tx,
)

# JITコンパイルされた学習ステップ
@jit
def train_step(state, batch):
    def loss_fn(params):
        x, y = batch
        logits = state.apply_fn(params, x, training=True, rngs={'dropout': jax.random.PRNGKey(0)})
        return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()

    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

Fig. 1に両フレームワークの設計思想の違いを示す。

2. 学習ループ設計

効率的で保守性の高い学習ループの設計パターンを解説する。

2.1 データローダー:PyTorchのDataLoaderは並列データ読み込みを提供。

from torch.utils.data import Dataset, DataLoader
from typing import Tuple
import numpy as np

class CustomDataset(Dataset):
    def __init__(self, data: np.ndarray, labels: np.ndarray):
        self.data = torch.FloatTensor(data)
        self.labels = torch.LongTensor(labels)

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.data[idx], self.labels[idx]

# DataLoader設定
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,        # 並列ワーカー数
    pin_memory=True,      # GPU転送を高速化
    drop_last=True,       # 最後の不完全バッチを除外
    persistent_workers=True,  # ワーカーを維持(エポック間)
)

2.2 学習ループのベストプラクティス

import torch
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm

def train_epoch(
    model: nn.Module,
    train_loader: DataLoader,
    optimizer: optim.Optimizer,
    criterion: nn.Module,
    device: torch.device,
    scaler: GradScaler | None = None,  # Mixed Precision用
) -> dict:
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc='Training')
    for batch_idx, (x, y) in enumerate(pbar):
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)  # set_to_none=Trueでメモリ効率化

        # Mixed Precision Training
        if scaler is not None:
            with autocast(dtype=torch.float16):
                logits = model(x)
                loss = criterion(logits, y)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

        # メトリクス計算
        total_loss += loss.item()
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{correct/total:.4f}'
        })

    return {
        'loss': total_loss / len(train_loader),
        'accuracy': correct / total
    }

2.3 PyTorch 2.0 torch.compile:JITコンパイルによる高速化[2]。

# torch.compileによる最適化
model = MLP(784, 256, 10).to(device)

# コンパイル(初回実行時にコンパイルされる)
compiled_model = torch.compile(
    model,
    mode='reduce-overhead',  # 'default', 'reduce-overhead', 'max-autotune'
    fullgraph=True,          # グラフ全体をコンパイル
)

# 使用方法は通常と同じ
output = compiled_model(input_tensor)

# 典型的な高速化: 1.5x - 2x(モデルとハードウェアに依存)

3. 分散学習

3.1 DDP(DistributedDataParallel)

DDPはデータ並列学習の標準的な手法である[3]。各GPUが同じモデルのコピーを持ち、異なるデータバッチで学習する。

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import os

def setup(rank: int, world_size: int):
    """分散環境の初期化"""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group('nccl', rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup():
    dist.destroy_process_group()

def train_ddp(rank: int, world_size: int):
    setup(rank, world_size)

    # モデルをDDPでラップ
    model = MLP(784, 256, 10).to(rank)
    model = DDP(model, device_ids=[rank])

    # DistributedSamplerでデータを分割
    train_sampler = DistributedSampler(
        train_dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True,
    )
    train_loader = DataLoader(
        train_dataset,
        batch_size=64,
        sampler=train_sampler,
        num_workers=4,
        pin_memory=True,
    )

    optimizer = optim.AdamW(model.parameters(), lr=1e-3 * world_size)  # Linear scaling

    for epoch in range(num_epochs):
        train_sampler.set_epoch(epoch)  # シャッフルのために必要
        train_epoch(model, train_loader, optimizer, criterion, rank)

    cleanup()

# 実行
if __name__ == '__main__':
    world_size = torch.cuda.device_count()
    torch.multiprocessing.spawn(train_ddp, args=(world_size,), nprocs=world_size)

Fig. 2にDDPの動作を示す。

3.2 FSDP(Fully Sharded Data Parallel)

FSDPは大規模モデル向けの分散学習手法である。モデルパラメータ、勾配、オプティマイザ状態をGPU間で分割する[4]。

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    ShardingStrategy,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools

# Mixed Precision設定
mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16,
)

# 自動ラップポリシー(Transformerレイヤーごとにシャード)
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={TransformerEncoderLayer},
)

# FSDPでラップ
model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    mixed_precision=mp_policy,
    auto_wrap_policy=auto_wrap_policy,
    device_id=torch.cuda.current_device(),
)

# 以降は通常のDDPと同様に学習

Table 2. DDP vs FSDP

観点 DDP FSDP
メモリ効率 低(全パラメータを各GPUに保持) 高(パラメータを分割)
通信オーバーヘッド 中〜高
実装複雑度
適用モデルサイズ 〜数十億パラメータ 数百億パラメータ以上

4. 実験管理

機械学習プロジェクトでは、多数の実験を追跡・比較する必要がある。

4.1 Weights & Biases(wandb):クラウドベースの実験管理プラットフォーム[5]。

import wandb

# 初期化
wandb.init(
    project='my-ml-project',
    name='experiment-001',
    config={
        'learning_rate': 1e-3,
        'batch_size': 64,
        'hidden_dim': 256,
        'epochs': 100,
        'optimizer': 'AdamW',
        'weight_decay': 0.01,
    }
)

# 学習ループ内でログ
for epoch in range(num_epochs):
    train_metrics = train_epoch(model, train_loader, optimizer, criterion, device)
    val_metrics = evaluate(model, val_loader, criterion, device)

    wandb.log({
        'epoch': epoch,
        'train/loss': train_metrics['loss'],
        'train/accuracy': train_metrics['accuracy'],
        'val/loss': val_metrics['loss'],
        'val/accuracy': val_metrics['accuracy'],
        'learning_rate': optimizer.param_groups[0]['lr'],
    })

    # モデルのチェックポイント保存
    if val_metrics['accuracy'] > best_accuracy:
        best_accuracy = val_metrics['accuracy']
        wandb.save('best_model.pt')
        torch.save(model.state_dict(), 'best_model.pt')

# 終了
wandb.finish()

4.2 MLflow:オープンソースの実験管理・モデル管理プラットフォーム。

import mlflow

# 実験の設定
mlflow.set_experiment('my-ml-project')

with mlflow.start_run(run_name='experiment-001'):
    # パラメータのログ
    mlflow.log_params({
        'learning_rate': 1e-3,
        'batch_size': 64,
        'hidden_dim': 256,
    })

    for epoch in range(num_epochs):
        train_metrics = train_epoch(...)

        # メトリクスのログ
        mlflow.log_metrics({
            'train_loss': train_metrics['loss'],
            'train_accuracy': train_metrics['accuracy'],
        }, step=epoch)

    # モデルの保存
    mlflow.pytorch.log_model(model, 'model')

    # アーティファクトの保存
    mlflow.log_artifact('config.yaml')

Table 3. 実験管理ツールの比較

ツール ホスティング 特徴
Weights & Biases クラウド / セルフホスト UI優秀、チーム機能充実
MLflow セルフホスト OSS、モデルレジストリ
TensorBoard ローカル シンプル、TF/PyTorch対応
Neptune クラウド メタデータ管理に強い
Comet クラウド / セルフホスト コード差分追跡

5. 再現性確保

機械学習の再現性は研究・運用の両面で重要である。

5.1 乱数シードの固定

import random
import numpy as np
import torch

def set_seed(seed: int = 42):
    """全ての乱数シードを固定"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # 決定論的アルゴリズムの使用(性能低下の可能性あり)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # PyTorch 2.0+
    torch.use_deterministic_algorithms(True)

# 注意: 一部の操作は決定論的バージョンが存在しない
# 環境変数で警告を有効化
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

5.2 環境の固定

# uv.lock / poetry.lock による依存関係の固定

# Dockerによる環境の完全な再現
# Dockerfile
FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime

WORKDIR /app
COPY pyproject.toml uv.lock ./
RUN pip install uv && uv sync --frozen

COPY . .
CMD ["python", "train.py"]

5.3 チェックポイントの保存:学習状態の完全な保存と復元。

def save_checkpoint(
    model: nn.Module,
    optimizer: optim.Optimizer,
    scheduler,
    epoch: int,
    best_metric: float,
    path: str,
):
    """チェックポイントの保存"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'best_metric': best_metric,
        'rng_state': {
            'python': random.getstate(),
            'numpy': np.random.get_state(),
            'torch': torch.get_rng_state(),
            'cuda': torch.cuda.get_rng_state_all(),
        },
    }
    torch.save(checkpoint, path)

def load_checkpoint(
    path: str,
    model: nn.Module,
    optimizer: optim.Optimizer,
    scheduler=None,
):
    """チェックポイントの復元"""
    checkpoint = torch.load(path)

    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if scheduler and checkpoint['scheduler_state_dict']:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    # 乱数状態の復元
    rng_state = checkpoint['rng_state']
    random.setstate(rng_state['python'])
    np.random.set_state(rng_state['numpy'])
    torch.set_rng_state(rng_state['torch'])
    torch.cuda.set_rng_state_all(rng_state['cuda'])

    return checkpoint['epoch'], checkpoint['best_metric']

5.4 再現性チェックリスト

# 再現性確保のためのチェックリスト
"""
□ 乱数シード固定(Python, NumPy, PyTorch, CUDA)
□ 決定論的アルゴリズム有効化
□ 依存関係のロックファイル(uv.lock, poetry.lock)
□ Dockerイメージのタグ固定
□ データのバージョン管理(DVC等)
□ コードのバージョン管理(Git commit hash記録)
□ ハイパーパラメータの記録
□ 学習・評価スクリプトの記録
□ ハードウェア情報の記録(GPU型番、CUDA版等)
"""

References

[1] JAX, "JAX: Autograd and XLA," jax.readthedocs.io, 2024.

[2] PyTorch, "torch.compile Tutorial," pytorch.org/tutorials, 2024.

[3] PyTorch, "DistributedDataParallel," pytorch.org/docs, 2024.

[4] PyTorch, "FSDP Tutorial," pytorch.org/tutorials, 2024.

[5] Weights & Biases, "W&B Documentation," docs.wandb.ai, 2024.

免責事項
本コンテンツは2025年12月時点の情報に基づいて作成されている。各フレームワークのAPIは活発に開発が進んでおり、変更される可能性がある。最新の情報は公式ドキュメントを参照されたい。

← 前章:データ処理次章:LLM開発 →