Skip to main content
ClaudeWave
Skill9.6k estrellas del repoactualizado 1mo ago

pytorch-fsdp2

This skill teaches how to integrate PyTorch FSDP2 (fully_shard) into training scripts for distributed parameter sharding across multiple GPUs. Use it when model parameters, gradients, and optimizer state exceed single-GPU memory, or when you need DTensor-based sharding with DeviceMesh for composing data parallelism with tensor parallelism, requiring proper initialization, bottom-up sharding application, optimizer creation after sharding, and distributed checkpointing via DCP.

Instalar en Claude Code
Copiar
git clone --depth 1 https://github.com/Orchestra-Research/AI-Research-SKILLs /tmp/pytorch-fsdp2 && cp -r /tmp/pytorch-fsdp2/08-distributed-training/pytorch-fsdp2 ~/.claude/skills/pytorch-fsdp2
Después abre una sesión nueva de Claude Code; el skill carga automáticamente.

SKILL.md

# Skill: Use PyTorch FSDP2 (`fully_shard`) correctly in a training script

This skill teaches a coding agent how to **add PyTorch FSDP2** to a training loop with correct initialization, sharding, mixed precision/offload configuration, and checkpointing.

> FSDP2 in PyTorch is exposed primarily via `torch.distributed.fsdp.fully_shard` and the `FSDPModule` methods it adds in-place to modules. See: `references/pytorch_fully_shard_api.md`, `references/pytorch_fsdp2_tutorial.md`.

---

## When to use this skill

Use FSDP2 when:
- Your model **doesn’t fit** on one GPU (parameters + gradients + optimizer state).
- You want an eager-mode sharding approach that is **DTensor-based per-parameter sharding** (more inspectable, simpler sharded state dicts) than FSDP1.  
- You may later compose DP with **Tensor Parallel** using **DeviceMesh**.

Avoid (or be careful) if:
- You need strict backwards-compatible checkpoints across PyTorch versions (DCP warns against this).
- You’re forced onto older PyTorch versions without the FSDP2 stack.

## Alternatives (when FSDP2 is not the best fit)

- **DistributedDataParallel (DDP)**: Use the standard data-parallel wrapper when you want classic distributed data parallel training.
- **FullyShardedDataParallel (FSDP1)**: Use the original FSDP wrapper for parameter sharding across data-parallel workers.

Reference: `references/pytorch_ddp_notes.md`, `references/pytorch_fsdp1_api.md`.

---

## Contract the agent must follow

1. **Launch with `torchrun`** and set the CUDA device per process (usually via `LOCAL_RANK`).  
2. **Apply `fully_shard()` bottom-up**, i.e., shard submodules (e.g., Transformer blocks) before the root module.  
3. **Call `model(input)`**, not `model.forward(input)`, so the FSDP2 hooks run (unless you explicitly `unshard()` or register the forward method).  
4. **Create the optimizer after sharding** and make sure it is built on the **DTensor parameters** (post-`fully_shard`).  
5. **Checkpoint using Distributed Checkpoint (DCP)** or the distributed-state-dict helpers, not naïve `torch.save(model.state_dict())` unless you deliberately gather to full tensors.

(Each of these rules is directly described in the official API docs/tutorial; see references.)

---

## Step-by-step procedure

### 0) Version & environment sanity
- Prefer a recent stable PyTorch where the docs show FSDP2 and DCP updated recently.
- Use `torchrun --nproc_per_node <gpus_per_node> ...` and ensure `RANK`, `WORLD_SIZE`, `LOCAL_RANK` are visible.

Reference: `references/pytorch_fsdp2_tutorial.md` (launch commands and setup), `references/pytorch_fully_shard_api.md` (user contract).

---

### 1) Initialize distributed and set device
Minimal, correct pattern:
- `dist.init_process_group(backend="nccl")`
- `torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))`
- Optionally create a `DeviceMesh` to describe the data-parallel group(s)

Reference: `references/pytorch_device_mesh_tutorial.md` (why DeviceMesh exists & how it manages process groups).

---

### 2) Build model on meta device (recommended for very large models)
For big models, initialize on `meta`, apply sharding, then materialize weights on GPU:
- `with torch.device("meta"): model = ...`
- apply `fully_shard(...)` on submodules, then `fully_shard(model)`
- `model.to_empty(device="cuda")`
- `model.reset_parameters()` (or your init routine)

Reference: `references/pytorch_fsdp2_tutorial.md` (migration guide shows this flow explicitly).

---

### 3) Apply `fully_shard()` bottom-up (wrapping policy = “apply where needed”)
**Do not** only call `fully_shard` on the topmost module.

Recommended sharding pattern for transformer-like models:
- iterate modules, `if isinstance(m, TransformerBlock): fully_shard(m, ...)`
- then `fully_shard(model, ...)`

Why:
- `fully_shard` forms “parameter groups” for collective efficiency and excludes params already grouped by earlier calls. Bottom-up gives better overlap and lower peak memory.

Reference: `references/pytorch_fully_shard_api.md` (bottom-up requirement and why).

---

### 4) Configure `reshard_after_forward` for memory/perf trade-offs
Default behavior:
- `None` means `True` for non-root modules and `False` for root modules (good default).

Heuristics:
- If you’re memory-bound: keep defaults or force `True` on many blocks.
- If you’re throughput-bound and can afford memory: consider keeping unsharded params longer (root often `False`).
- Advanced: use an `int` to reshard to a smaller mesh after forward (e.g., intra-node) if it’s a meaningful divisor.

Reference: `references/pytorch_fully_shard_api.md` (full semantics).

---

### 5) Mixed precision & offload (optional but common)
FSDP2 uses:
- `mp_policy=MixedPrecisionPolicy(param_dtype=..., reduce_dtype=..., output_dtype=..., cast_forward_inputs=...)`
- `offload_policy=CPUOffloadPolicy()` if you want CPU offload

Rules of thumb:
- Start with BF16 parameters/reductions on H100/A100-class GPUs (if numerically stable for your model).
- Keep `reduce_dtype` aligned with your gradient reduction expectations.
- If you use CPU offload, budget for PCIe/NVLink traffic and runtime overhead.

Reference: `references/pytorch_fully_shard_api.md` (MixedPrecisionPolicy / OffloadPolicy classes).

---

### 6) Optimizer, gradient clipping, accumulation
- Create the optimizer **after** sharding so it holds DTensor params.
- If you need gradient accumulation / no_sync:
  - use the FSDP2 mechanism (`set_requires_gradient_sync`) instead of FSDP1’s `no_sync()`.

Gradient clipping:
- Use the approach shown in the FSDP2 tutorial (“Gradient Clipping and Optimizer with DTensor”), because parameters/gradients are DTensors.

Reference: `references/pytorch_fsdp2_tutorial.md`.

---

### 7) Checkpointing: prefer DCP or distributed state dict helpers
Two recommended approaches:

**A) Distributed Checkpoint (DCP) — best default**
- DCP saves/loads from multiple ranks in parallel and supports load-time resharding.
- DCP produces **multiple files** (often at least one per r
autoresearchSkill

Orchestrates end-to-end autonomous AI research projects using a two-loop architecture. The inner loop runs rapid experiment iterations with clear optimization targets. The outer loop synthesizes results, identifies patterns, and steers research direction. Routes to domain-specific skills for execution, supports continuous agent operation via Claude Code /loop and OpenClaw heartbeat, and produces research presentations and papers. Use when starting a research project, running autonomous experiments, or managing a multi-hypothesis research effort.

implementing-llms-litgptSkill

Implements and trains LLMs using Lightning AI's LitGPT with 20+ pretrained architectures (Llama, Gemma, Phi, Qwen, Mistral). Use when need clean model implementations, educational understanding of architectures, or production fine-tuning with LoRA/QLoRA. Single-file implementations, no abstraction layers.

mamba-architectureSkill

State-space model with O(n) complexity vs Transformers' O(n²). 5× faster inference, million-token sequences, no KV cache. Selective SSM with hardware-aware design. Mamba-1 (d_state=16) and Mamba-2 (d_state=128, multi-head). Models 130M-2.8B on HuggingFace.

nanogptSkill

Educational GPT implementation in ~300 lines. Reproduces GPT-2 (124M) on OpenWebText. Clean, hackable code for learning transformers. By Andrej Karpathy. Perfect for understanding GPT architecture from scratch. Train on Shakespeare (CPU) or OpenWebText (multi-GPU).

rwkv-architectureSkill

RNN+Transformer hybrid with O(n) inference. Linear time, infinite context, no KV cache. Train like GPT (parallel), infer like RNN (sequential). Linux Foundation AI project. Production at Windows, Office, NeMo. RWKV-7 (March 2025). Models up to 14B parameters.

distributed-llm-pretraining-torchtitanSkill

Provides PyTorch-native distributed LLM pretraining using torchtitan with 4D parallelism (FSDP2, TP, PP, CP). Use when pretraining Llama 3.1, DeepSeek V3, or custom models at scale from 8 to 512+ GPUs with Float8, torch.compile, and distributed checkpointing.

huggingface-tokenizersSkill

Fast tokenizers optimized for research and production. Rust-based implementation tokenizes 1GB in <20 seconds. Supports BPE, WordPiece, and Unigram algorithms. Train custom vocabularies, track alignments, handle padding/truncation. Integrates seamlessly with transformers. Use when you need high-performance tokenization or custom tokenizer training.

sentencepieceSkill

Language-independent tokenizer treating text as raw Unicode. Supports BPE and Unigram algorithms. Fast (50k sentences/sec), lightweight (6MB memory), deterministic vocabulary. Used by T5, ALBERT, XLNet, mBART. Train on raw text without pre-tokenization. Use when you need multilingual support, CJK languages, or reproducible tokenization.