ray-train
Ray Train orchestrates distributed machine learning training across single machines to multi-node clusters with automatic GPU allocation, fault tolerance, and checkpointing. Use it when scaling PyTorch, TensorFlow, or HuggingFace models beyond single-GPU capacity or running distributed hyperparameter tuning experiments with Ray Tune without rewriting core training logic.
git clone --depth 1 https://github.com/Orchestra-Research/AI-Research-SKILLs /tmp/ray-train && cp -r /tmp/ray-train/08-distributed-training/ray-train ~/.claude/skills/ray-trainSKILL.md
# Ray Train - Distributed Training Orchestration
## Quick start
Ray Train scales machine learning training from single GPU to multi-node clusters with minimal code changes.
**Installation**:
```bash
pip install -U "ray[train]"
```
**Basic PyTorch training** (single node):
```python
import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
import torch
import torch.nn as nn
# Define training function
def train_func(config):
# Your normal PyTorch code
model = nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# Prepare for distributed (Ray handles device placement)
model = train.torch.prepare_model(model)
for epoch in range(10):
# Your training loop
output = model(torch.randn(32, 10))
loss = output.sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Report metrics (logged automatically)
train.report({"loss": loss.item(), "epoch": epoch})
# Run distributed training
trainer = TorchTrainer(
train_func,
scaling_config=ScalingConfig(
num_workers=4, # 4 GPUs/workers
use_gpu=True
)
)
result = trainer.fit()
print(f"Final loss: {result.metrics['loss']}")
```
**That's it!** Ray handles:
- Distributed coordination
- GPU allocation
- Fault tolerance
- Checkpointing
- Metric aggregation
## Common workflows
### Workflow 1: Scale existing PyTorch code
**Original single-GPU code**:
```python
model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(epochs):
for batch in dataloader:
loss = model(batch)
loss.backward()
optimizer.step()
```
**Ray Train version** (scales to multi-GPU/multi-node):
```python
from ray.train.torch import TorchTrainer
from ray import train
def train_func(config):
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
# Prepare for distributed (automatic device placement)
model = train.torch.prepare_model(model)
dataloader = train.torch.prepare_data_loader(dataloader)
for epoch in range(epochs):
for batch in dataloader:
loss = model(batch)
loss.backward()
optimizer.step()
# Report metrics
train.report({"loss": loss.item()})
# Scale to 8 GPUs
trainer = TorchTrainer(
train_func,
scaling_config=ScalingConfig(num_workers=8, use_gpu=True)
)
trainer.fit()
```
**Benefits**: Same code runs on 1 GPU or 1000 GPUs
### Workflow 2: HuggingFace Transformers integration
```python
from ray.train.huggingface import TransformersTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
def train_func(config):
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Training arguments (HuggingFace API)
training_args = TrainingArguments(
output_dir="./output",
num_train_epochs=3,
per_device_train_batch_size=8,
learning_rate=2e-5,
)
# Ray automatically handles distributed training
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
)
trainer.train()
# Scale to multi-node (2 nodes × 8 GPUs = 16 workers)
trainer = TransformersTrainer(
train_func,
scaling_config=ScalingConfig(
num_workers=16,
use_gpu=True,
resources_per_worker={"GPU": 1}
)
)
result = trainer.fit()
```
### Workflow 3: Hyperparameter tuning with Ray Tune
```python
from ray import tune
from ray.train.torch import TorchTrainer
from ray.tune.schedulers import ASHAScheduler
def train_func(config):
# Use hyperparameters from config
lr = config["lr"]
batch_size = config["batch_size"]
model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
model = train.torch.prepare_model(model)
for epoch in range(10):
# Training loop
loss = train_epoch(model, optimizer, batch_size)
train.report({"loss": loss, "epoch": epoch})
# Define search space
param_space = {
"lr": tune.loguniform(1e-5, 1e-2),
"batch_size": tune.choice([16, 32, 64, 128])
}
# Run 20 trials with early stopping
tuner = tune.Tuner(
TorchTrainer(
train_func,
scaling_config=ScalingConfig(num_workers=4, use_gpu=True)
),
param_space=param_space,
tune_config=tune.TuneConfig(
num_samples=20,
scheduler=ASHAScheduler(metric="loss", mode="min")
)
)
results = tuner.fit()
best = results.get_best_result(metric="loss", mode="min")
print(f"Best hyperparameters: {best.config}")
```
**Result**: Distributed hyperparameter search across cluster
### Workflow 4: Checkpointing and fault tolerance
```python
from ray import train
from ray.train import Checkpoint
def train_func(config):
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
# Try to resume from checkpoint
checkpoint = train.get_checkpoint()
if checkpoint:
with checkpoint.as_directory() as checkpoint_dir:
state = torch.load(f"{checkpoint_dir}/model.pt")
model.load_state_dict(state["model"])
optimizer.load_state_dict(state["optimizer"])
start_epoch = state["epoch"]
else:
start_epoch = 0
model = train.torch.prepare_model(model)
for epoch in range(start_epoch, 100):
loss = train_epoch(model, optimizer)
# Save checkpoint every 10 epochs
if epoch % 10 == 0:
checkpoint = Checkpoint.from_directory(
train.get_context().get_trial_dir()
)
torch.save({
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch
}, checkpoint.path / "model.pt")
trainOrchestrates end-to-end autonomous AI research projects using a two-loop architecture. The inner loop runs rapid experiment iterations with clear optimization targets. The outer loop synthesizes results, identifies patterns, and steers research direction. Routes to domain-specific skills for execution, supports continuous agent operation via Claude Code /loop and OpenClaw heartbeat, and produces research presentations and papers. Use when starting a research project, running autonomous experiments, or managing a multi-hypothesis research effort.
Implements and trains LLMs using Lightning AI's LitGPT with 20+ pretrained architectures (Llama, Gemma, Phi, Qwen, Mistral). Use when need clean model implementations, educational understanding of architectures, or production fine-tuning with LoRA/QLoRA. Single-file implementations, no abstraction layers.
State-space model with O(n) complexity vs Transformers' O(n²). 5× faster inference, million-token sequences, no KV cache. Selective SSM with hardware-aware design. Mamba-1 (d_state=16) and Mamba-2 (d_state=128, multi-head). Models 130M-2.8B on HuggingFace.
Educational GPT implementation in ~300 lines. Reproduces GPT-2 (124M) on OpenWebText. Clean, hackable code for learning transformers. By Andrej Karpathy. Perfect for understanding GPT architecture from scratch. Train on Shakespeare (CPU) or OpenWebText (multi-GPU).
RNN+Transformer hybrid with O(n) inference. Linear time, infinite context, no KV cache. Train like GPT (parallel), infer like RNN (sequential). Linux Foundation AI project. Production at Windows, Office, NeMo. RWKV-7 (March 2025). Models up to 14B parameters.
Provides PyTorch-native distributed LLM pretraining using torchtitan with 4D parallelism (FSDP2, TP, PP, CP). Use when pretraining Llama 3.1, DeepSeek V3, or custom models at scale from 8 to 512+ GPUs with Float8, torch.compile, and distributed checkpointing.
Fast tokenizers optimized for research and production. Rust-based implementation tokenizes 1GB in <20 seconds. Supports BPE, WordPiece, and Unigram algorithms. Train custom vocabularies, track alignments, handle padding/truncation. Integrates seamlessly with transformers. Use when you need high-performance tokenization or custom tokenizer training.
Language-independent tokenizer treating text as raw Unicode. Supports BPE and Unigram algorithms. Fast (50k sentences/sec), lightweight (6MB memory), deterministic vocabulary. Used by T5, ALBERT, XLNet, mBART. Train on raw text without pre-tokenization. Use when you need multilingual support, CJK languages, or reproducible tokenization.