ml-training-recipes
ml-training-recipes provides production-grade PyTorch training patterns for neural networks across domains including LLMs, vision, diffusion, medical imaging, and biomedical discovery. It includes reference guides for architecture selection by data type and scale, optimizer configuration (AdamW, Muon), learning rate scheduling, mixed precision training, debugging techniques, and scaling laws. Use this skill when implementing training loops, selecting model architectures, troubleshooting training failures like loss spikes or out-of-memory errors, or optimizing computational efficiency across different problem domains.
git clone --depth 1 https://github.com/Orchestra-Research/AI-Research-SKILLs /tmp/ml-training-recipes && cp -r /tmp/ml-training-recipes/10-optimization/ml-training-recipes ~/.claude/skills/ml-training-recipesSKILL.md
# ML Training Recipes
Battle-tested patterns for PyTorch training across domains. Drawn from production codebases
(Karpathy's autoresearch/nanochat, torchvision, HuggingFace) and modern training practice.
## Reference files (read when needed)
- `references/architecture.md` — Transformer/LLM architecture code patterns, weight init
- `references/optimizers.md` — Muon, AdamW hybrid, per-group LR, compiled optimizer steps
- `references/domain-specific.md` — Vision, diffusion, contrastive, distributed, checkpointing, data loading
- `references/scaling-and-selection.md` — Scaling laws, compute budget tables, decision trees, DGX Spark
- `references/biomedical.md` — Drug discovery, protein models, medical imaging, genomics, clinical NLP
- `references/experiment-loop.md` — Autonomous experiment loop (autoresearch keep/discard/revert)
---
## Architecture Selection
Pick the right model by **data type** and **data scale**:
| Data Type | < 10K samples | 10K-100K | > 100K |
|-----------|--------------|----------|--------|
| **Images** | Pretrained CNN + fine-tune | Fine-tune ViT or CNN | ViT from scratch |
| **Text (gen)** | Few-shot prompting | Fine-tune GPT/LLaMA (LoRA) | Pretrain from scratch |
| **Tabular** | XGBoost/LightGBM | Still XGBoost | Neural viable |
| **Audio** | Pretrained Whisper | Fine-tune AST | Train from scratch |
| **Molecules** | Pretrained GNN | Fine-tune molecular LM | Train GNN from scratch |
| **Proteins** | ESM-2 embeddings + head | Fine-tune ESM-2 | Train protein LM |
| **Medical img** | Pretrained CNN | nnU-Net (auto-config) | Swin-UNETR / MedSAM |
**Key principle**: architecture matters less than training recipe at equal compute. A well-tuned
ResNet beats a poorly-tuned ViT (ref: "ResNet Strikes Back", Wightman 2021).
For biomedical domains, see `references/biomedical.md`.
For sequence model selection and compute planning, see `references/scaling-and-selection.md`.
---
## Scaling Laws
### Chinchilla rule (Hoffmann et al., 2022)
Compute-optimal training: **~20 tokens per parameter**.
| Model Size | Compute-Optimal | Inference-Optimal (100×) |
|-----------|----------------|--------------------------|
| 125M | 2.5B tokens | 12.5B tokens |
| 1B | 20B tokens | 100B tokens |
| 7B | 140B tokens | 700B tokens |
**FLOPs ≈ 6 × N × D** (N=params, D=tokens). Data repetition limit: ~4 epochs before diminishing returns.
---
## Training Loop
```python
import gc, time, torch
torch.manual_seed(42)
torch.set_float32_matmul_precision("high") # TF32 on Ampere+
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
grad_accum_steps = total_batch_size // (batch_size * seq_len)
step = 0
while not done:
t0 = time.time()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
loss = model(x, y)
(loss / grad_accum_steps).backward()
x, y = next(train_loader)
update_lr(optimizer, progress)
optimizer.step()
model.zero_grad(set_to_none=True) # frees memory vs zeroing
if loss.item() > 100: # fast-fail on divergence
print("FAIL: loss exploded"); exit(1)
torch.cuda.synchronize()
if step == 0:
gc.collect(); gc.freeze(); gc.disable() # avoid ~500ms GC stalls
step += 1
```
### Key principles
- **Gradient clipping**: `clip_grad_norm_(params, 1.0)` — near-universal for Transformers.
Exception: Muon optimizer normalizes updates via orthogonalization, so clipping is optional.
- **Tensor Core alignment**: batch size, hidden dims should be multiples of 8 (bf16) or 64 (A100).
- **Time-based budgets** make experiments comparable across hardware.
- **`cudnn.benchmark = True`** for fixed-size vision inputs.
---
## Optimizer Configuration
Modern LLM training uses different optimizers per parameter group:
| Parameter Type | Optimizer | LR (base) | Weight Decay |
|---------------|-----------|-----------|--------------|
| 2D weight matrices | Muon | 0.04 | 0.2 |
| Token embeddings | AdamW | 0.6 × scale | 0.0 |
| Unembedding (lm_head) | AdamW | 0.004 × scale | 0.0 |
| Per-layer scalars | AdamW | 0.005 × scale | 0.0 |
**LR scaling by dimension**: `lr * (d_model / 768)^(-0.5)` — keeps dynamics stable across sizes.
### Rules of thumb
- Embeddings need higher LR (sparse updates). Never weight-decay embeddings.
- Weight decay scheduling: linearly decay WD to 0 over training.
- AdamW defaults: β1=0.9, β2=0.95, eps=1e-10 (not default 1e-8 — prevents stale updates in bf16).
For Muon details (polar express orthogonalization, NorMuon), see `references/optimizers.md`.
---
## Learning Rate Scheduling
### Time-based (autoresearch style)
```python
def get_lr_multiplier(progress): # progress = elapsed_time / time_budget
if progress < warmup_ratio:
return progress / warmup_ratio
elif progress < 1.0 - warmdown_ratio:
return 1.0
else:
cooldown = (1.0 - progress) / warmdown_ratio
return cooldown + (1 - cooldown) * final_lr_frac
```
### Cosine decay
```python
def get_lr(step, total_steps, max_lr, min_lr, warmup_steps):
if step < warmup_steps:
return max_lr * step / warmup_steps
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
```
**WSD (Warmup-Stable-Decay)**: gaining traction — easier to resume training mid-run.
### Guidance
- **Warmup**: 1-5% of training. Zero warmup valid with Muon (autoresearch uses `WARMUP_RATIO=0.0`).
- **Warmdown**: 30-50% of training in LR decay. Matters more than warmup for final quality.
- **Final LR**: 0 or ~10% of peak. Zero is simpler.
---
## Mixed Precision & Compilation
```python
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" # before torch import
import torch
torch.set_float32_matmul_precision("high")
aOrchestrates end-to-end autonomous AI research projects using a two-loop architecture. The inner loop runs rapid experiment iterations with clear optimization targets. The outer loop synthesizes results, identifies patterns, and steers research direction. Routes to domain-specific skills for execution, supports continuous agent operation via Claude Code /loop and OpenClaw heartbeat, and produces research presentations and papers. Use when starting a research project, running autonomous experiments, or managing a multi-hypothesis research effort.
Implements and trains LLMs using Lightning AI's LitGPT with 20+ pretrained architectures (Llama, Gemma, Phi, Qwen, Mistral). Use when need clean model implementations, educational understanding of architectures, or production fine-tuning with LoRA/QLoRA. Single-file implementations, no abstraction layers.
State-space model with O(n) complexity vs Transformers' O(n²). 5× faster inference, million-token sequences, no KV cache. Selective SSM with hardware-aware design. Mamba-1 (d_state=16) and Mamba-2 (d_state=128, multi-head). Models 130M-2.8B on HuggingFace.
Educational GPT implementation in ~300 lines. Reproduces GPT-2 (124M) on OpenWebText. Clean, hackable code for learning transformers. By Andrej Karpathy. Perfect for understanding GPT architecture from scratch. Train on Shakespeare (CPU) or OpenWebText (multi-GPU).
RNN+Transformer hybrid with O(n) inference. Linear time, infinite context, no KV cache. Train like GPT (parallel), infer like RNN (sequential). Linux Foundation AI project. Production at Windows, Office, NeMo. RWKV-7 (March 2025). Models up to 14B parameters.
Provides PyTorch-native distributed LLM pretraining using torchtitan with 4D parallelism (FSDP2, TP, PP, CP). Use when pretraining Llama 3.1, DeepSeek V3, or custom models at scale from 8 to 512+ GPUs with Float8, torch.compile, and distributed checkpointing.
Fast tokenizers optimized for research and production. Rust-based implementation tokenizes 1GB in <20 seconds. Supports BPE, WordPiece, and Unigram algorithms. Train custom vocabularies, track alignments, handle padding/truncation. Integrates seamlessly with transformers. Use when you need high-performance tokenization or custom tokenizer training.
Language-independent tokenizer treating text as raw Unicode. Supports BPE and Unigram algorithms. Fast (50k sentences/sec), lightweight (6MB memory), deterministic vocabulary. Used by T5, ALBERT, XLNet, mBART. Train on raw text without pre-tokenization. Use when you need multilingual support, CJK languages, or reproducible tokenization.