Skip to main content
ClaudeWave
Skill199 estrellas del repoactualizado 16d ago

pymc-bayesian-modeling

Bayesian modeling with PyMC 5: priors, likelihood, NUTS/ADVI sampling, diagnostics (R-hat, ESS), LOO/WAIC comparison, prediction. Hierarchical, logistic, GP variants; predictive checks.

Instalar en Claude Code
Copiar
git clone --depth 1 https://github.com/jaechang-hits/SciAgent-Skills /tmp/pymc-bayesian-modeling && cp -r /tmp/pymc-bayesian-modeling/skills/biostatistics/pymc-bayesian-modeling ~/.claude/skills/pymc-bayesian-modeling
Después abre una sesión nueva de Claude Code; el skill carga automáticamente.

SKILL.md

# PyMC Bayesian Modeling

## Overview

PyMC is a Python library for Bayesian statistical modeling and probabilistic programming. It provides an expressive syntax for defining probabilistic models and efficient inference via MCMC (NUTS) and variational methods (ADVI). This skill covers the full Bayesian modeling cycle from model specification through diagnostics, comparison, and prediction.

## When to Use

- Estimating parameters with full uncertainty quantification (credible intervals, not just point estimates)
- Fitting hierarchical/multilevel models to grouped or nested data
- Performing prior and posterior predictive checks to validate model assumptions
- Comparing candidate models using information criteria (LOO-CV, WAIC)
- Building regression models (linear, logistic, Poisson) in a Bayesian framework
- Handling missing data or measurement error as latent parameters
- Modeling time series with autoregressive or random walk priors
- Generating posterior predictions for new observations with uncertainty bounds
- Use **Stan/PyStan** instead for compiled, more scalable Bayesian inference on large models; use **statsmodels** for frequentist statistical tests

## Prerequisites

- **Python packages**: `pymc >= 5.0`, `arviz`, `numpy`, `matplotlib`
- **Data**: NumPy arrays or pandas DataFrames with numeric columns
- **Environment**: CPU sufficient for most models; GPU via JAX backend for large models

```bash
pip install pymc arviz numpy matplotlib
# Optional: JAX backend for GPU acceleration
pip install pymc[jax]
```

## Quick Start

```python
import pymc as pm
import arviz as az
import numpy as np

# Simulate data
np.random.seed(42)
X = np.random.randn(100)
y = 2.5 + 1.3 * X + np.random.randn(100) * 0.5

# Build and fit model
with pm.Model() as model:
    alpha = pm.Normal("alpha", mu=0, sigma=5)
    beta = pm.Normal("beta", mu=0, sigma=5)
    sigma = pm.HalfNormal("sigma", sigma=1)
    mu = alpha + beta * X
    y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y)
    idata = pm.sample(1000, tune=1000, chains=4, random_seed=42)

print(az.summary(idata, var_names=["alpha", "beta", "sigma"]))
# Expected: alpha ~ 2.5, beta ~ 1.3, sigma ~ 0.5
```

## Workflow

### Step 1: Prepare Data

Standardize continuous predictors for better sampling efficiency. Use named coordinates for readable models and ArviZ integration.

```python
import pymc as pm
import arviz as az
import numpy as np

# Load data
X = np.random.randn(200, 3)  # 200 obs, 3 predictors
y = X @ np.array([1.0, -0.5, 0.3]) + np.random.randn(200) * 0.8

# Standardize predictors
X_mean, X_std = X.mean(axis=0), X.std(axis=0)
X_scaled = (X - X_mean) / X_std

# Define coordinates for named dimensions
coords = {
    "predictors": ["var1", "var2", "var3"],
    "obs_id": np.arange(len(y)),
}
print(f"Data shape: X={X_scaled.shape}, y={y.shape}")
```

### Step 2: Define Model and Set Priors

Specify the model structure inside a `pm.Model()` context. Use weakly informative priors, `dims` for named dimensions, and `HalfNormal` or `Exponential` for scale parameters.

```python
with pm.Model(coords=coords) as model:
    # Priors — weakly informative, not flat
    alpha = pm.Normal("alpha", mu=0, sigma=1)
    beta = pm.Normal("beta", mu=0, sigma=1, dims="predictors")
    sigma = pm.HalfNormal("sigma", sigma=1)

    # Linear predictor
    mu = alpha + pm.math.dot(X_scaled, beta)

    # Likelihood
    y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y, dims="obs_id")

# Inspect model variables
print(model.basic_RVs)  # Lists: [alpha, beta, sigma, y_obs]
```

### Step 3: Prior Predictive Check

Validate that priors produce plausible data ranges before fitting. Adjust priors if simulated data is unreasonable.

```python
with model:
    prior_pred = pm.sample_prior_predictive(samples=1000, random_seed=42)

# Check prior-implied data range
prior_y = prior_pred.prior_predictive["y_obs"].values.flatten()
print(f"Prior predictive range: [{prior_y.min():.1f}, {prior_y.max():.1f}]")
print(f"Observed data range:    [{y.min():.1f}, {y.max():.1f}]")

az.plot_ppc(prior_pred, group="prior", num_pp_samples=100)
```

### Step 4: Sample Posterior (MCMC)

Run NUTS sampling with multiple chains. Include `log_likelihood=True` if you plan model comparison later.

```python
with model:
    idata = pm.sample(
        draws=2000,
        tune=1000,
        chains=4,
        target_accept=0.9,
        random_seed=42,
        idata_kwargs={"log_likelihood": True},
    )

print(f"Posterior shape: {idata.posterior['beta'].shape}")
# Expected: (4 chains, 2000 draws, 3 predictors)
```

### Step 5: Diagnose Sampling

Check convergence before interpreting results. All three diagnostics (R-hat, ESS, divergences) must pass.

```python
# Summary with convergence diagnostics
summary = az.summary(idata, var_names=["alpha", "beta", "sigma"])
print(summary[["mean", "sd", "hdi_3%", "hdi_97%", "r_hat", "ess_bulk"]])

# R-hat convergence check
bad_rhat = summary[summary["r_hat"] > 1.01]
if len(bad_rhat) > 0:
    print(f"WARNING: {len(bad_rhat)} parameters with R-hat > 1.01")
    print(bad_rhat[["r_hat"]])

# Effective sample size check
low_ess = summary[summary["ess_bulk"] < 400]
if len(low_ess) > 0:
    print(f"WARNING: {len(low_ess)} parameters with ESS < 400")

# Divergence check
n_div = idata.sample_stats.diverging.sum().item()
total = len(idata.posterior.draw) * len(idata.posterior.chain)
print(f"Divergences: {n_div}/{total} ({n_div / total * 100:.2f}%)")

# Visual diagnostics — trace plots and rank plots
az.plot_trace(idata, var_names=["alpha", "beta", "sigma"])
az.plot_rank(idata, var_names=["alpha", "beta", "sigma"])
```

### Step 6: Posterior Predictive Check

Validate model fit by comparing simulated data from the posterior to observed data.

```python
with model:
    pm.sample_posterior_predictive(idata, extend_inferencedata=True, random_seed=42)

az.plot_ppc(idata, num_pp_samples=100)
# Blue = observed data, grey = posterior simulations
# Systematic deviations indicate
sciagent-skill-creatorSkill

|

opentrons-integrationSkill

Opentrons Protocol API v2 for OT-2/Flex: Python protocols for pipetting, serial dilutions, PCR, plate replication; control thermocycler, heater-shaker, magnetic, temperature modules. Use pylabrobot for multi-vendor.

plotly-interactive-visualizationSkill

Interactive visualization with Plotly. 40+ chart types (scatter, line, heatmap, 3D, geographic) with hover, zoom, pan. Two APIs: Plotly Express (DataFrame) and Graph Objects (fine control). For static publication figures use matplotlib; for statistical grammar use seaborn.

seaborn-statistical-visualizationSkill

Statistical visualization on matplotlib + pandas. Distributions (histplot, kdeplot, violin, box), relational (scatter, line), categorical, regression, correlation heatmaps. Auto aggregation/CIs. Use plotly for interactive; matplotlib for low-level.

single-cell-annotationSkill

Best practices for single-cell RNA-seq cell type annotation including marker-based, reference-based, and automated classification approaches.

scikit-survival-analysisSkill

Time-to-event modeling with scikit-survival: Cox PH (elastic net), Random Survival Forests, Boosting, SVMs for censored data. C-index, Brier, time-dependent AUC; Kaplan-Meier, Nelson-Aalen, competing risks. Pipeline/GridSearchCV compatible. Use statsmodels for frequentist, pymc for Bayesian, lifelines for parametric.

statistical-analysisSkill

>-

statsmodels-statistical-modelingSkill

Python statistical modeling: regression (OLS, WLS, GLM), discrete (Logit, Poisson, NegBin), time series (ARIMA, SARIMAX, VAR), with rigorous inference, diagnostics, and hypothesis tests. Use scikit-learn for ML; statistical-analysis for test choice.