matlab-train-network
The matlab-train-network skill enables training, evaluation, and deployment of neural networks in MATLAB using modern APIs like trainnet and dlnetwork. Use this skill when a user requests neural network training for classification or regression tasks, fine-tuning pretrained models, running inference on trained networks, or migrating legacy deep learning code to current MATLAB standards.
git clone --depth 1 https://github.com/matlab/matlab-agentic-toolkit /tmp/matlab-train-network && cp -r /tmp/matlab-train-network/skills-catalog/ai-and-statistics/matlab-train-network ~/.claude/skills/matlab-train-networkSKILL.md
# matlab-train-network
Train, evaluate, and deploy neural networks in MATLAB using the
recommended `dlnetwork`-based API (`trainnet`, `dlnetwork`, `minibatchpredict`,
`scores2label`, `testnet`, `imagePretrainedNetwork`, `fitcnet`, `fitrnet`).
## When to Use
Activate this skill when a user asks to:
- Train any neural network (classifier, regression, multi-output, LSTM, CNN, etc.)
- Fine-tune or use a pretrained model for transfer learning
- Evaluate a trained network on test data
- Run inference / predict with a trained network
- Export a trained network to Simulink
- Migrate existing legacy deep learning code (trainNetwork, patternnet, fitnet,
narxnet, gensim) to recommended APIs
- Create a "pattern recognition network", "function fitting network", "NARX
network", or any task historically associated with the Neural Network Toolbox
shallow nets API
## When NOT to Use
- Importing/exporting models (importNetworkFromPyTorch, exportONNXNetwork)
- Data loading and preprocessing (imageDatastore, transforms, augmentation)
- Network architecture design decisions (choosing CNN vs LSTM vs transformer)
- Reinforcement learning workflows (use Reinforcement Learning Toolbox)
- Object detection (use specialized detector training functions in Computer Vision Toolbox)
## Decision: fitrnet/fitcnet or trainnet
- **Tabular data?** → `fitrnet` (regression) or `fitcnet` (classification)
- **Tabular data, but need a non-LBFGS solver or a non-MSE/cross-entropy loss?** → `trainnet`
- **Everything else?** → `trainnet`
```matlab
% Classification
mdl = fitcnet(XTrain, TTrain, LayerSizes=20);
[labels,score] = predict(mdl,XTest);
L = loss(mdl, XTest, TTest);
% Regression
mdl = fitrnet(XTrain, TTrain, LayerSizes=[20 20]);
YTest = predict(mdl,XTest);
L = loss(mdl, XTest, TTest);
```
- From R2024b, `fitrnet` supports multi-response variables.
- From R2025a, for custom architectures beyond `LayerSizes`, `Activations`, `LayerWeightsInitializer`, and `LayerBiasesInitializer`, pass a `dlnetwork` via the `Network` name-value argument.
---
## Conventions
### Training with trainnet + dlnetwork
#### Data formats
`trainnet` expects data in specific orientations by default:
| Input layer | Expected data shape |
|-------------|-------------------|
| `featureInputLayer(C)` | observations×channels (e.g., 150×4) |
| `imageInputLayer([H W C])` | H×W×C×observations (e.g., 28×28×1×5000) |
| `sequenceInputLayer(C)` | timesteps×channels×observations, or an observations×1 cell array where each element is a timesteps×channels time series |
If your data has a different layout, use `InputDataFormats` and/or
`TargetDataFormats` in `trainingOptions` instead of transposing the data manually.
The format string describes your data's current layout — one letter per
dimension, not the desired layout. MATLAB handles the remapping internally.
For cell arrays, add `"B"` (batch) to the format string — e.g.,
`InputDataFormats="CTB"` for cells of C×T matrices. Do not specify these
options when data already matches the input layer's default.
#### What trainnet supports
Use `trainnet` and `dlnetwork` for all Deep Learning Toolbox training. This includes:
- Standard classification and regression
- Transfer learning
- Multi-input or multi-output networks
- Custom loss functions (pass a function handle to `trainnet`)
- Custom loss function backward passes via `DifferentiableFunction`
- Custom metrics (string, function handle, or `deep.Metric` subclass)
- Custom stopping criteria via `OutputFcn` in `trainingOptions`
- Custom layers
**Only** use a custom training loop (`dlfeval`/`dlgradient`/update functions) when a
customization is impossible via `trainingOptions` — for example, a custom weight
update rule. Note that `trainingOptions` supports L-BFGS (R2023b+) and
Levenberg-Marquardt `"lm"` (R2024b+).
### NEVER use these legacy APIs
If the user has existing code using these APIs, migrate it to the recommended
replacement and briefly explain which APIs were replaced and what the modern
equivalents are. If the user asks for a legacy API by name, acknowledge their
request and explain that the function has been replaced with a recommended
alternative before providing the solution.
| Legacy API | Recommended replacement |
|-----------|-------------------|
| `trainNetwork` | `trainnet` |
| `patternnet` | `fitcnet` (preferred), or `dlnetwork` + `trainnet` |
| `fitnet` | `fitrnet` (preferred), or `dlnetwork` + `trainnet` |
| `feedforwardnet` | `dlnetwork` + `trainnet` |
| `narxnet`, `timedelaynet` | `nlarx` (preferred), or `dlnetwork` + `trainnet` |
| `train()` (shallow `network` object) | `trainnet` |
| `classify` | `minibatchpredict` + `scores2label` |
| `activations` | `minibatchpredict(net,data,Outputs=layer)` |
| `predictAndUpdateState`, `classifyAndUpdateState` | `[Y, state] = predict(net,X); net.State = state;` |
| `classificationLayer` | Not required — use `trainnet` with `"crossentropy"` as the loss |
| `regressionLayer` | Not required — use `trainnet` with `"mse"` as the loss |
| `DAGNetwork`, `SeriesNetwork`, `layerGraph` | `dlnetwork` — supports `addLayers` and `connectLayers` for multi-branch architectures, anything `layerGraph` can do, `dlnetwork` can do directly |
| `resnet18`, `googlenet`, `squeezenet`, etc. (pretrained network functions that return `DAGNetwork`) | `imagePretrainedNetwork("resnet18", ...)` — returns a `dlnetwork` |
| Manually converting network scores to labels (e.g., `[~,idx] = max(scores)`) | `scores2label` |
| `plotconfusion` | `confusionchart` |
| `gensim` | `exportNetworkToSimulink` (preferred), or Predict block |
| `preparets` | `nlarx` (preferred, handles delays internally), or `dlnetwork` with `sequenceInputLayer(C, MinLength=numDelays)` + `convolution1dLayer(numDelays, ..., Padding="causal")` |
| `closeloop` | `forecast` (preferred, with `nlarx`), or iterative `predict` loop feeding previous predictions back as input |
### Inference — use minibatchpredict (or predict)
- `predict` on a `dlnetworImport recorded driving sensor data (GPS, camera, lidar, actor tracks, lanes) into scenariobuilder.* objects (GPSData, CameraData, LidarData, ActorTrackData, Trajectory, laneData) and run preprocessing — synchronize, offset correction, crop, normalizeTimestamps, convertTimestamps. Also: compute actor tracks from lidar when no annotations exist, attach camera/lidar mounting + intrinsics, export to MAT/workspace/timetable/script. Use for raw driving dataset files (KITTI, nuScenes, Waymo, Pandaset, ROS/ROS2 bags, .mat, .csv, .mp4) or driving/vehicle/sensor logs that need wrapping. drivingLogAnalyzer (DLA) is OPT-IN ONLY — invoke only on explicit user request ('DLA', 'open in DLA', 'inspect/explore/analyze the recording') or reported sensor problem (sync drift, timestamp mismatch, overlay misalignment). NEVER auto-launch DLA after wrapping (Rule 0). For 'build scenario / export to RoadRunner / drivingScenario / OpenSCENARIO / Unreal / simulate', hand off to matlab-scenario-builder.
Generate driving scenes, scenarios, road surfaces, and 3D content from already-wrapped scenariobuilder.* sensor data (GPS, camera, lidar, actor tracks) using Scenario Builder for Automated Driving Toolbox. Use to BUILD, EXPORT, or AUGMENT a virtual scenario/scene/map: ego or actor trajectories, trajectory smoothing, OpenCRG road-surface extraction, 3D asset generation, static-object placement, point-cloud georeferencing + elevation, lane-based ego localization, sensor-fusion tracking, scenario-event extraction (cut-ins, hard brakes, near-misses, ADAS disengagements), or export to RoadRunner, drivingScenario, OpenDRIVE, OpenCRG, OpenSCENARIO, or Unreal Engine. Also: log-to-scenario, scenario harvesting, accident/near-miss reconstruction, SOTIF (ISO 21448) and ISO 26262 scenario coverage, USGS-aerial-lidar scene augmentation, traffic-sign placement from camera+lidar logs. NOT for raw-data import or multi-sensor sync/crop/offset/timestamp normalization — route those to matlab-driving-data-importer.
>
>
>
>
Build, modify, and diagram SimBiology models — API reference, helper functions, and layout patterns. Use when constructing or editing models programmatically or visually.
Fit SimBiology model parameters to data — fitproblem, population NLME, virtual patients, and NCA. Use when asked to fit, estimate, calibrate, or compute PK metrics.