1. 分散訓練の必要性
1.1 メモリ要件
LLM訓練に必要なGPUメモリ(FP16/BF16):
| 項目 |
計算式 |
70Bモデル例 |
| パラメータ |
2B(FP16) |
140 GB |
| 勾配 |
2B(FP16) |
140 GB |
| オプティマイザ状態(Adam) |
8B(FP32 m, v, パラメータコピー) |
560 GB |
| 活性化値 |
バッチ・シーケンス依存 |
〜100+ GB |
| 合計 |
- |
〜1 TB |
H100 80GB × 1では到底不足。分散が必須。
1.2 計算時間
GPT-3(175B)の訓練:
- 約300Bトークン
- 推定3.64E23 FLOP
- 単一A100(312 TFLOP/s)で約34年
- 1,024 GPUで約2週間
1.3 分散訓練の目標
- メモリ分散:単一GPUに収まらないモデルの訓練
- 計算並列化:訓練時間の短縮
- スループット最大化:GPU利用効率の向上
- スケーラビリティ:GPUを増やすと線形に高速化
2. データ並列(Data Parallelism)
2.1 基本概念
各GPUに同じモデルを複製し、異なるデータで訓練:
- 各GPUがミニバッチの一部を処理
- 勾配を計算
- 勾配をAllReduceで集約・平均
- 各GPUでパラメータを更新
2.2 AllReduce通信
全GPUの勾配を集約し、結果を全GPUに配布:
- Ring AllReduce:帯域効率的、通信量 2(N-1)/N × データサイズ
- Tree AllReduce:レイテンシ効率的
- NCCL:NVIDIAの最適化通信ライブラリ
2.3 PyTorch DDP
DistributedDataParallel:PyTorchの標準的なデータ並列実装。
- 各プロセスが1GPUを担当
- 勾配計算とAllReduceをオーバーラップ
- Gradient Bucketingによる効率化
2.4 データ並列の限界
- メモリ複製:各GPUに全パラメータが必要
- 大モデルに不向き:単一GPUに収まるモデルサイズが上限
- 通信オーバーヘッド:GPU数増加でAllReduceコストが増大
3. モデル並列(Model Parallelism)
3.1 テンソル並列(Tensor Parallelism)
各層のテンソルをGPU間で分割:
例:Linear層の分割
- 列分割:出力次元を分割 → AllGatherで結合
- 行分割:入力次元を分割 → ReduceScatterで集約
Megatron-LM方式:
- Attention:Query/Key/Valueを列分割
- FFN:最初の線形層を列分割、2番目を行分割
- 各Transformer層で2回のAllReduce
3.2 テンソル並列の特徴
- 利点:メモリを分割、活性化値も分散
- 欠点:頻繁な通信(各層で同期)
- 適用範囲:同一ノード内のGPU間(高帯域接続が必要)
3.3 シーケンス並列
シーケンス次元での分割(Ring Attention等):
- 長いシーケンスの活性化値メモリを削減
- LayerNorm、Dropout等の演算をシーケンス分割
4. パイプライン並列(Pipeline Parallelism)
4.1 基本概念
モデルを層のグループ(ステージ)に分割し、各GPUに配置:
- GPU 0:層 1-10
- GPU 1:層 11-20
- GPU 2:層 21-30
- ...
4.2 パイプラインバブル問題
単純な実装では、GPUがアイドルになる時間(バブル)が発生:
- 前段のGPUが計算中、後段はアイドル
- バブル比率 = (P-1) / M(P:パイプラインステージ数、M:マイクロバッチ数)
4.3 マイクロバッチング
ミニバッチを小さなマイクロバッチに分割:
- 複数のマイクロバッチを順次投入
- パイプラインを「満たす」ことでバブルを削減
- マイクロバッチ数を増やすほど効率向上
4.4 1F1B(One Forward One Backward)
Forwardとbackwardを交互に実行するスケジュール:
- メモリ使用量を削減(活性化値を早期に解放)
- PipeDream、GPipe等で採用
4.5 パイプライン並列の特徴
- 利点:通信量が少ない(ステージ間のみ)
- 利点:ノード間の接続でも効率的
- 欠点:バブルによる効率低下
- 欠点:実装が複雑
5. ZeRO最適化
5.1 ZeRO(Zero Redundancy Optimizer)
ZeRO(Rajbhandari et al., 2020):メモリ冗長性を排除。
データ並列の各GPUが持つ冗長なデータを分散。
5.2 ZeROの3段階
| 段階 |
分散対象 |
メモリ削減 |
| ZeRO-1 |
オプティマイザ状態 |
4倍 |
| ZeRO-2 |
+ 勾配 |
8倍 |
| ZeRO-3 |
+ パラメータ |
N倍(N=GPU数) |
5.3 ZeRO-3の動作
- Forward時:必要なパラメータをAllGatherで収集
- 計算後:パラメータを破棄
- Backward時:再度AllGatherでパラメータ収集
- 勾配計算後:ReduceScatterで勾配を分散
- 更新:各GPUが担当パラメータのみを更新
5.4 ZeRO-Offload / ZeRO-Infinity
- ZeRO-Offload:CPUメモリにオフロード
- ZeRO-Infinity:NVMeにもオフロード
- GPUメモリを超えるモデルの訓練が可能
5.5 FSDP(Fully Sharded Data Parallel)
PyTorchのZeRO-3相当の実装:
- FairScaleからPyTorch本体に統合
- ZeRO-3と同様のSharding戦略
- PyTorchエコシステムとのシームレスな統合
6. 3D並列
6.1 並列化の組み合わせ
大規模訓練では複数の並列化を組み合わせ:
- データ並列(DP):バッチを分割
- テンソル並列(TP):層内を分割
- パイプライン並列(PP):層間を分割
6.2 典型的な構成
例:1024 GPU、175Bモデル
- TP = 8(同一ノード内の8 GPU)
- PP = 16(16ステージ)
- DP = 8(8つのデータ並列レプリカ)
- 合計:8 × 16 × 8 = 1024 GPU
6.3 Expert並列(MoEモデル)
Mixture of Expertsモデルでの追加の並列化:
- Expert(専門家ネットワーク)をGPU間に分散
- All-to-All通信でトークンをルーティング
- 4D並列(DP + TP + PP + EP)
6.4 設計のトレードオフ
| 並列化 |
通信パターン |
適したスコープ |
| テンソル並列 |
頻繁、高帯域必要 |
ノード内(NVLink) |
| パイプライン並列 |
まれ、小データ |
ノード間も可 |
| データ並列 |
定期的、大データ |
全体 |
7. フレームワーク
7.1 DeepSpeed
DeepSpeed(Microsoft):
- ZeRO最適化の開発元
- ZeRO-1/2/3、ZeRO-Offload、ZeRO-Infinity
- 混合精度、勾配累積の統合
- Hugging Face Transformersと統合
7.2 Megatron-LM
Megatron-LM(NVIDIA):
- テンソル並列の開発元
- 3D並列のサポート
- 高性能な実装
- Megatron-DeepSpeed:両者の統合
7.3 PyTorch FSDP
Fully Sharded Data Parallel:
- PyTorch 1.11+で標準搭載
- ZeRO-3相当の機能
- PyTorchネイティブ、移行が容易
7.4 JAX / Pax
JAX(Google):
- pjit/shardによる自動並列化
- XLA最適化
- TPU向けに最適化
7.5 選択の指針
| 要件 |
推奨 |
| Hugging Faceモデルを簡単に |
DeepSpeed + Accelerate |
| PyTorchネイティブで |
FSDP |
| 最大性能(NVIDIA GPU) |
Megatron-LM / Megatron-DeepSpeed |
| TPU使用 |
JAX / Pax |
8. 参考文献
主要論文
- Rajbhandari et al. (2020). "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models" SC
- Shoeybi et al. (2019). "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism" arXiv
- Narayanan et al. (2021). "Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM" SC
- Huang et al. (2019). "GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism" NeurIPS
- Zhao et al. (2023). "PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel" VLDB
ドキュメント