# PyTorch 2.0+(SDPAにFlashAttention統合)
import torch.nn.functional as F
output = F.scaled_dot_product_attention(
query, key, value,
attn_mask=None,
dropout_p=0.0,
is_causal=True # Causal maskの効率的実装
)
flash-attnライブラリ:スタンドアロン実装
xformers:Meta提供、Memory-efficient attention
PyTorch SDPA:自動バックエンド選択
5. 量子化訓練
5.1 量子化の概要
より低いビット精度での表現:
FP32/FP16 → INT8/INT4
メモリ削減、計算高速化
推論では広く使用、訓練は発展途上
5.2 量子化の種類
種類
タイミング
特徴
Post-Training Quantization (PTQ)
訓練後
簡単、精度低下の可能性
Quantization-Aware Training (QAT)
訓練中
高精度、コスト大
低精度訓練
訓練全体
FP8、INT8訓練
5.3 FP8訓練
H100 GPUでサポートされる8ビット浮動小数点:
E4M3:4ビット指数、3ビット仮数(Forward向け)
E5M2:5ビット指数、2ビット仮数(Backward向け)
BF16比で〜2倍の高速化
5.4 QLoRA
QLoRA(Dettmers et al., 2023):
ベースモデルを4ビット量子化(NF4)
LoRAアダプタをFP16/BF16で追加訓練
65Bモデルを単一48GB GPUでファインチューニング可能
5.5 量子化手法
GPTQ:Post-training、レイヤーごとの最適化
AWQ:Activation-aware、重要な重みを保護
bitsandbytes:8ビット/4ビットの効率的実装
6. コンパイル最適化
6.1 torch.compile
PyTorch 2.0で導入されたJITコンパイラ:
import torch
model = MyModel()
model = torch.compile(model) # コンパイル
# 以降は通常通り使用
output = model(input)
6.2 コンパイルの効果
カーネル融合:複数の演算を1カーネルに
メモリI/O削減:中間テンソルの書き込み削減
グラフ最適化:不要な演算の削除
効果:1.5〜2倍の高速化(モデルによる)
6.3 バックエンド
TorchInductor:デフォルト、GPU向けTritonコード生成
cudagraphs:CUDAグラフによるカーネル起動オーバーヘッド削減
ONNX Runtime:推論向け最適化
6.4 Triton
Triton(OpenAI):
Pythonライクな言語でGPUカーネルを記述
CUDAより簡単、高い生産性
FlashAttentionもTritonで実装可能
torch.compileのバックエンドとしても使用
6.5 XLA
XLA(Accelerated Linear Algebra):
TensorFlow、JAXのコンパイラ
PyTorch/XLA:PyTorchでXLAを使用
TPU向けに最適化
7. その他の技術
7.1 勾配累積
メモリに収まるバッチサイズで複数回Forward/Backwardし、勾配を累積:
for i, batch in enumerate(dataloader):
loss = model(batch) / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
実効バッチサイズ = ミニバッチ × 累積ステップ × GPU数
7.2 Speculative Decoding
推論時の高速化技術:
小さなドラフトモデルで複数トークンを投機的に生成
大きなターゲットモデルで検証
バッチ処理の効率を活用
2〜3倍の高速化
7.3 KVキャッシュ最適化
推論時のKey-Valueキャッシュ:
PagedAttention(vLLM):メモリの断片化防止
Continuous Batching:動的バッチ管理
MLA(DeepSeek):KVキャッシュの低ランク圧縮
7.4 データローダー最適化
プリフェッチ:GPUが計算中に次のバッチを準備
ピン止めメモリ:CPU→GPU転送の高速化
非同期データロード:複数ワーカーで並列読み込み
7.5 組み合わせの指針
目的
推奨技術
メモリ不足
勾配チェックポイント + 混合精度 + 勾配累積
訓練高速化
FlashAttention + torch.compile + 混合精度
長シーケンス
FlashAttention + 勾配チェックポイント
ファインチューニング(低コスト)
QLoRA + FlashAttention
8. 参考文献
主要論文
Micikevicius et al. (2018). "Mixed Precision Training" ICLR
Chen et al. (2016). "Training Deep Nets with Sublinear Memory Cost" arXiv(勾配チェックポイント)
Dao et al. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" NeurIPS
Dao (2023). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning"
Dettmers et al. (2023). "QLoRA: Efficient Finetuning of Quantized LLMs" NeurIPS
Ansel et al. (2024). "PyTorch 2: Faster Machine Learning Through Dynamic Python Bytecode Transformation and Graph Compilation"