効率的訓練技術
効率的訓練技術
深層学習の効率的訓練技術。混合精度訓練、勾配チェックポイント、FlashAttention、量子化訓練、コンパイル最適化。より少ないリソースでより速く訓練する方法。
最終更新:2025年11月
1. 効率化の概要
1.1 効率化の3つの軸
- メモリ効率:同じGPUでより大きなモデル/バッチ
- 計算効率:同じ計算をより速く
- 通信効率:分散訓練のオーバーヘッド削減
1.2 ボトルネックの理解
GPU訓練の主なボトルネック:
- メモリ帯域:HBMとSM間のデータ転送
- 計算能力:FLOPの実行速度
- メモリ容量:GPUメモリの上限
多くの演算は計算律速ではなくメモリ律速。
1.3 Arithmetic Intensity
計算量 / メモリアクセス量の比率:
- 高い:行列乗算(計算律速)
- 低い:LayerNorm、Softmax、Element-wise(メモリ律速)
効率化はボトルネックに応じた最適化が必要。
2. 混合精度訓練
2.1 数値精度の種類
| 形式 | ビット数 | 範囲 | 用途 |
|---|---|---|---|
| FP32 | 32 | 広い | 従来の標準 |
| FP16 | 16 | 狭い | 高速訓練 |
| BF16 | 16 | FP32同等 | LLM訓練の標準 |
| TF32 | 19 | FP32同等 | Ampere以降のGPU |
| FP8 | 8 | 狭い | Hopper GPU(H100) |
2.2 BF16(Brain Floating Point)
Googleが開発した16ビット形式:
- 指数部8ビット:FP32と同じ範囲
- 仮数部7ビット:精度は低い
- 利点:アンダーフロー/オーバーフローが起きにくい
- LLM訓練の事実上の標準
2.3 混合精度の実装
AMP(Automatic Mixed Precision):
- Forward/Backwardは低精度(FP16/BF16)
- パラメータ更新はFP32(Master Weight)
- Loss Scaling:勾配のアンダーフロー防止
Loss Scalingの仕組み:
- Lossを大きな値でスケール(例:2^16)
- Backwardで勾配も同様にスケール
- 更新前にスケールを戻す
- Dynamic Loss Scaling:オーバーフロー時に自動調整
2.4 効果
- メモリ:約半減(FP32 → FP16/BF16)
- 速度:1.5〜2倍(Tensor Core活用)
- 精度:ほぼ維持(適切な実装で)
3. 勾配チェックポイント
3.1 活性化値のメモリ問題
Backwardには中間活性化値が必要:
- すべての層の活性化値を保持 → 大量のメモリ消費
- バッチサイズ、シーケンス長に比例
- 長いシーケンスでは活性化値がパラメータより大きくなる
3.2 勾配チェックポイントの原理
Gradient Checkpointing(Activation Checkpointing):
- Forward時に一部の活性化値のみ保存
- Backward時に必要な活性化値を再計算
- メモリ ↔ 計算のトレードオフ
3.3 チェックポイント戦略
| 戦略 | メモリ | 計算 | 説明 |
|---|---|---|---|
| チェックポイントなし | O(N) | 1x | 全活性化値を保持 |
| 全層チェックポイント | O(√N) | 〜1.33x | 各層でチェックポイント |
| 選択的チェックポイント | 調整可能 | 調整可能 | 特定層のみ |
3.4 実装
PyTorch:
from torch.utils.checkpoint import checkpoint
# 通常のforward
output = layer(input)
# チェックポイント付きforward
output = checkpoint(layer, input, use_reentrant=False)
Transformerでの適用:
- 各Transformer層をチェックポイント単位に
- Hugging Face:
gradient_checkpointing_enable()
4. FlashAttention
4.1 標準Attentionの問題
Self-Attentionの計算:
- QK^Tを計算 → O(N²d)のメモリ
- Softmaxを適用
- Vと乗算
N(シーケンス長)が大きいと、注意行列のメモリがボトルネック。
4.2 FlashAttentionの革新
FlashAttention(Dao et al., 2022):
- タイリング:注意行列をブロック単位で計算
- カーネル融合:QK^T、Softmax、×Vを1カーネルに
- オンラインSoftmax:ブロックごとに漸進的に計算
- メモリI/O削減:HBMアクセスを最小化
4.3 効果
- メモリ:O(N²) → O(N)(注意行列を保持しない)
- 速度:2〜4倍高速(メモリI/O削減)
- 長シーケンス:より長いシーケンスが処理可能
4.4 FlashAttention-2 / FlashAttention-3
FlashAttention-2(2023):
- 並列化の改善(シーケンス次元)
- Warp間のワークパーティショニング最適化
- さらに2倍程度の高速化
FlashAttention-3(2024):
- Hopper GPU(H100)の機能活用
- FP8サポート
- 非同期処理、Warp specialization
4.5 実装
# PyTorch 2.0+(SDPAにFlashAttention統合)
import torch.nn.functional as F
output = F.scaled_dot_product_attention(
query, key, value,
attn_mask=None,
dropout_p=0.0,
is_causal=True # Causal maskの効率的実装
)
- flash-attnライブラリ:スタンドアロン実装
- xformers:Meta提供、Memory-efficient attention
- PyTorch SDPA:自動バックエンド選択
5. 量子化訓練
5.1 量子化の概要
より低いビット精度での表現:
- FP32/FP16 → INT8/INT4
- メモリ削減、計算高速化
- 推論では広く使用、訓練は発展途上
5.2 量子化の種類
| 種類 | タイミング | 特徴 |
|---|---|---|
| Post-Training Quantization (PTQ) | 訓練後 | 簡単、精度低下の可能性 |
| Quantization-Aware Training (QAT) | 訓練中 | 高精度、コスト大 |
| 低精度訓練 | 訓練全体 | FP8、INT8訓練 |
5.3 FP8訓練
H100 GPUでサポートされる8ビット浮動小数点:
- E4M3:4ビット指数、3ビット仮数(Forward向け)
- E5M2:5ビット指数、2ビット仮数(Backward向け)
- BF16比で〜2倍の高速化
5.4 QLoRA
QLoRA(Dettmers et al., 2023):
- ベースモデルを4ビット量子化(NF4)
- LoRAアダプタをFP16/BF16で追加訓練
- 65Bモデルを単一48GB GPUでファインチューニング可能
5.5 量子化手法
- GPTQ:Post-training、レイヤーごとの最適化
- AWQ:Activation-aware、重要な重みを保護
- bitsandbytes:8ビット/4ビットの効率的実装
6. コンパイル最適化
6.1 torch.compile
PyTorch 2.0で導入されたJITコンパイラ:
import torch
model = MyModel()
model = torch.compile(model) # コンパイル
# 以降は通常通り使用
output = model(input)
6.2 コンパイルの効果
- カーネル融合:複数の演算を1カーネルに
- メモリI/O削減:中間テンソルの書き込み削減
- グラフ最適化:不要な演算の削除
- 効果:1.5〜2倍の高速化(モデルによる)
6.3 バックエンド
- TorchInductor:デフォルト、GPU向けTritonコード生成
- cudagraphs:CUDAグラフによるカーネル起動オーバーヘッド削減
- ONNX Runtime:推論向け最適化
6.4 Triton
Triton(OpenAI):
- Pythonライクな言語でGPUカーネルを記述
- CUDAより簡単、高い生産性
- FlashAttentionもTritonで実装可能
- torch.compileのバックエンドとしても使用
6.5 XLA
XLA(Accelerated Linear Algebra):
- TensorFlow、JAXのコンパイラ
- PyTorch/XLA:PyTorchでXLAを使用
- TPU向けに最適化
7. その他の技術
7.1 勾配累積
メモリに収まるバッチサイズで複数回Forward/Backwardし、勾配を累積:
for i, batch in enumerate(dataloader):
loss = model(batch) / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
実効バッチサイズ = ミニバッチ × 累積ステップ × GPU数
7.2 Speculative Decoding
推論時の高速化技術:
- 小さなドラフトモデルで複数トークンを投機的に生成
- 大きなターゲットモデルで検証
- バッチ処理の効率を活用
- 2〜3倍の高速化
7.3 KVキャッシュ最適化
推論時のKey-Valueキャッシュ:
- PagedAttention(vLLM):メモリの断片化防止
- Continuous Batching:動的バッチ管理
- MLA(DeepSeek):KVキャッシュの低ランク圧縮
7.4 データローダー最適化
- プリフェッチ:GPUが計算中に次のバッチを準備
- ピン止めメモリ:CPU→GPU転送の高速化
- 非同期データロード:複数ワーカーで並列読み込み
7.5 組み合わせの指針
| 目的 | 推奨技術 |
|---|---|
| メモリ不足 | 勾配チェックポイント + 混合精度 + 勾配累積 |
| 訓練高速化 | FlashAttention + torch.compile + 混合精度 |
| 長シーケンス | FlashAttention + 勾配チェックポイント |
| ファインチューニング(低コスト) | QLoRA + FlashAttention |
8. 参考文献
主要論文
- Micikevicius et al. (2018). "Mixed Precision Training" ICLR
- Chen et al. (2016). "Training Deep Nets with Sublinear Memory Cost" arXiv(勾配チェックポイント)
- Dao et al. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" NeurIPS
- Dao (2023). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning"
- Dettmers et al. (2023). "QLoRA: Efficient Finetuning of Quantized LLMs" NeurIPS
- Ansel et al. (2024). "PyTorch 2: Faster Machine Learning Through Dynamic Python Bytecode Transformation and Graph Compilation"
ライブラリ・ドキュメント
- FlashAttention:GitHub
- bitsandbytes:GitHub
- Triton:公式サイト
- PyTorch AMP:PyTorch Docs