Training nanoGPT on Slurm with a Nix-Pinned Environment
A researcher prototypes a model on her MacBook Pro. She uses current-stable PyTorch, a Python interpreter, some standard Python training libraries, a few native libs, and a Conda environment.
This is a Python workload, so her single biggest concern isn’t, “Can I validate my assumption with a small local experiment?”; it’s: “Will all these Python dependencies actually run on my MacBook?” She’s using Conda, so everything should work. Holding her breath, she types a command and kicks off the run. Huzzah! The model trains … and there’s signal! Elated, she pushes her code and opens a PR.
It’s at precisely this point that anything can happen. Because what works locally doesn’t automatically work everywhere else. On the GPU cluster, as soon as the EKS pod comes up, things go sideways. PyTorch wasn’t compiled against the cloud GPU’s CUDA stack. A native extension tries to load libstdc++ using a path that doesn’t exist. The loader expects to read from local disk instead of S3. The job fails.
Fear, Loathing, and ML/AI Handoff
Feel familiar? People who work with ML/AI live this everyday, sometimes several times a day. Maybe a job sails through prototyping on a GPU cluster … only to founder during training on Slurm:
RuntimeError: CUDA error: CUDA-capable device(s) is/are busy or unavailableOr maybe it fails in eval. CI. MLOps. Staging. Maybe it transits all of these before failing in production. The point is the PTSD: The nagging anxiety that it’s going to fail, inexplicably, somewhere downstream.
This article describes a pattern for creating reproducible runtime environments for ML/AI using declared, graph-backed environments based on Nix and Flox. The same Nix and Flox environments work on Linux and macOS, x86 or ARM, NVIDIA CUDA or Apple Metal/MLX. They travel from model training in local dev to checkpoint validation in eval. They run as-is, pulling in exactly the same dependencies, in CI, MLOps, and production.
The pattern looks like this:
- Teams define GPU-accelerated PyTorch, JaX, TensorRT, etc. as Nix or Flox runtime environments;
- ML/AI researchers define project-specific Nix or Flox environments on top of the appropriate runtime;
- Researchers run Nix or Flox ML/AI stacks on their MacBooks, prototype on NVIDIA DGX nodes, train models on Slurm. Apple Metal/MLX and NVIDIA CUDA get GPU-accelerated libraries;
- MLOps teams use Nix or Flox environments when evaluating + packaging checkpoints for production;
- Platform teams maintain just one environment for Slurm (training) and Kubernetes (prod).
The upshot is that a single GPU-accelerated environment transits the ML/AI software lifecycle without accumulating stage-specific runtime barnacles. Teams can compose modular Nix and Flox environments to create ML/AI stacks bundling CUDA or Metal/MLX for GPU support; PyTorch, JaX, or TensorRT for training or inferencing; project-specific native and Python dependencies; and the code, data pipelines, and tools required to train, package, and ship ML/AImodels. This pattern reduces debugging cycles and gives orgs a safe, atomic way to promote new releases, or to roll back (if necessary) to known-good ones.
A reusable, cross-platform PyTorch runtime
This article uses a PyTorch inferencing stack as its baseline example. But the same pattern works with JaX, TensorRT, and other ML frameworks. It works with model-serving runtimes, distributed training frameworks, batch inference jobs, EDA pipelines, eval harnesses, and RAG/embedding pipelines, too.
Creating a cross-platform, GPU-accelerated PyTorch runtime is straightforward with both Nix and Flox. Each is “declarative” in the sense that teams declare the versions of packages they want to be available in an environment; from there, each tool’s resolving machinery figures out how to make these coexist.
With Nix and Flox, then, a build recipe or runtime environment encodes named inputs, sources, patches, build instructions, toolchains, target systems, and environment variables as derivations; realizing these derivations produces store objects under /nix/store. With Nix and Flox, reproducibility is a function of the declared, resolve dependency graph, lock state, derivation, and closure of the realized store.
With Nix
The Nix equivalent of a cross-platform PyTorch runtime looks like:
{
description = "Python 3.13 + PyTorch runtime base shared by dev, training, eval, CI, and container outputs";
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
};
outputs = { nixpkgs, ... }:
let
lib = nixpkgs.lib;
systems = [
"x86_64-linux"
"aarch64-linux"
"x86_64-darwin"
"aarch64-darwin"
];
# CUDA is intentionally opt-in. Add systems here only after verifying
# that torchWithCuda evaluates and builds for that system/pinned nixpkgs.
cudaSystems = [
"x86_64-linux"
"aarch64-linux"
];
forAllSystems = f: lib.genAttrs systems f;
pkgsFor = system: import nixpkgs {
inherit system;
config.allowUnfree = true;
};
isCudaSystem = system: builtins.elem system cudaSystems;
torchFor = system: ps:
if isCudaSystem system
then ps.torchWithCuda
else ps.torch;
# Shared runtime layer.
basePkgs = system: ps: [
(torchFor system ps)
];
# Training layer extends the shared runtime layer.
trainingPkgs = system: ps: basePkgs system ps ++ [
ps.tensorboard
ps.wandb
];
# Eval layer extends the shared runtime layer.
evalPkgs = system: ps: basePkgs system ps ++ [
# Add eval-specific deps here.
# Example:
# ps.mlflow
];
# Dev layer extends the shared runtime layer.
devPkgs = system: ps: basePkgs system ps ++ [
ps.jupyter
ps.ipython
];
in
{
packages = forAllSystems (system:
let
pkgs = pkgsFor system;
pythonWith = layerFn:
pkgs.python313.withPackages (layerFn system);
runtime = pythonWith basePkgs;
training = pythonWith trainingPkgs;
evalEnv = pythonWith evalPkgs;
in
{
runtime = runtime;
training = training;
eval = evalEnv;
default = runtime;
} // lib.optionalAttrs (isCudaSystem system) {
container = pkgs.dockerTools.buildLayeredImage {
name = "pytorch-runtime";
tag = "latest";
contents = [
runtime
pkgs.cacert
pkgs.iana-etc
];
config = {
Cmd = [ "${runtime}/bin/python" ];
Env = [
"PATH=${runtime}/bin"
"SSL_CERT_FILE=${pkgs.cacert}/etc/ssl/certs/ca-bundle.crt"
"NVIDIA_VISIBLE_DEVICES=all"
"NVIDIA_DRIVER_CAPABILITIES=compute,utility"
];
};
};
}
);
devShells = forAllSystems (system:
let
pkgs = pkgsFor system;
pythonWith = layerFn:
pkgs.python313.withPackages (layerFn system);
runtime = pythonWith basePkgs;
devEnv = pythonWith devPkgs;
in
{
default = pkgs.mkShell {
packages = [ devEnv ];
shellHook = ''
python - <<'PY'
import torch
print("torch:", torch.__version__)
print("cuda available:", torch.cuda.is_available())
if torch.cuda.is_available():
print("cuda device:", torch.cuda.get_device_name(0))
if hasattr(torch.backends, "mps"):
print("mps available:", torch.backends.mps.is_available())
PY
'';
};
ci = pkgs.mkShell {
packages = [ runtime ];
};
}
);
};
}This flake defines a cross-platform Python 3.13 + PyTorch base and composes it into context-specific outputs for runtime, training, eval, dev shell, CI shell, and an OCI container image. This gives platform teams a reproducible, cross-platform PyTorch base that works from local development → production.
Nix has excellent container tooling. The flake above defines an OCI image with dockerTools.buildLayeredImage. This tells Nix to build the image from the runtime’s closure, then to write the container’s default command and environment variables into the image metadata. For this example, we define the OCI image as part of the project repo’s Nix flake, but release teams can (and do) maintain their own downstream flakes. These consume the project flake’s output and emit an OCI image.
This second pattern lets application teams expose a runtime closure as the complete set of dependencies for their application. CI can test against the same closure, and downstream release flakes can package it into an OCI image, then tag, sign, scan, and publish it.
With Flox
The Flox equivalent of this environment is less verbose:
[install]
python.pkg-path = "python3"
python.version = "3.13.12"
python.pkg-group = "runtime"
python.priority = 5
python.outputs = "all"
# CUDA-accelerated PyTorch for Linux (x86_64 and aarch64)
torch.pkg-path = "flox-cuda/python3Packages.torch"
torch.pkg-group = "torch"
torch.systems = ["x86_64-linux", "aarch64-linux"]
# MPS-accelerated PyTorch for macOS (Apple Silicon and Intel)
mps-torch.pkg-path = "python313Packages.torch"
mps-torch.systems = ["x86_64-darwin", "aarch64-darwin"]
mps-torch.priority = 6
mps-torch.pkg-group = "torch"
mps-torch.outputs = "all"That’s it. The Nix flake explicitly defines a series of lifecycle roles (viz., a dev shell, a runtime package, a CI shell, a container image) that the Flox model abstracts. To take one example, the Flox manifest declares packages, versions, supported systems, environment variables, services, and build recipes; it doesn’t, however, declare entrypoints for specific roles. It doesn’t need to. A Flox environment exposes dev tooling and libraries for local dev by default; activating it with the --mode run flag, or declaring this in the Flox manifest, restricts access to dev tooling and libraries. (run mode is the default on Kubernetes.)
Flox environments don’t need to declare the lifecycle options (e.g., default command, environment variables, etc.) for OCI images, either: the flox containerize command does this automatically.
Flox primitives like package groups, priorities, systems filters, and outputs abstract common Nix patterns; they don’t map one-for-one to a single Nix primitive. For example, package groups abstract the work of getting packages that resolve against different historical nixpkgs commits to coexist in the same environment. The packages defined in the manifest above get isolated into Flox package groups so that it’s easy to version and manage them: You can define specific versions without worrying about conflicts with other packages.
Compose a Training Stack
The PyTorch runtime is the foundation. One or more downstream environments can easily consume it.
Teams might compose this environment with other environments designed for:
- CUDA development. Nix or Flox environments that define
nvcc,cudart,cublas,cudnn,nccl, and other CUDA-specific dependencies. Essential on CUDA, skipped on macOS. - Model training. An environment declaring Python packages / native libraries used to train models.
- Building + packaging. Linux gets
gcc, macOS getsclang; all getcmake+ other tools. - CUDA Profiling / performance. NVIDIA’s Nsight Systems and Nsight Compute; PyTorch profiler workflows; CPU and memory profilers; kernel-level performance analysis.
- Offline eval. Eval harnesses, metrics libraries, dataset clients, tokenizers, etc.
- Model packaging. Model export, conversion, quantization, artifact packaging, and metadata tools.
- Model serving. Platforms like llamacpp, VLLM, Nvidia Triton, SGLang, Ollama.
Each is a Nix or Flox environment that can use the base PyTorch runtime (along with others) as an input.
With Nix
The model-training flake consumes the foundational PyTorch runtime flake in inputs. It specifies cuda-dev and build flakes as extra inputs. Linux/NVIDIA CUDA pulls in CUDA dev packages and the GCC stack; users on macOS/Metal (or MLX) skip CUDA and pull in clang, along with other essential deps. The flake below declares a model training Nix dev shell. It declares uv and activates a project-local Python virtual environment, pulling in dependencies that run against the PyTorch runtime.
{
description = "ML training environment";
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
build-env = {
url = "github:flox/ml-ai-lifecycle?dir=build-env";
inputs.nixpkgs.follows = "nixpkgs";
};
cuda-dev-essentials = {
url = "github:flox/ml-ai-lifecycle?dir=cuda-dev-essentials";
inputs.nixpkgs.follows = "nixpkgs";
};
pytorch-runtime = {
url = "github:flox/ml-ai-lifecycle?dir=pytorch-runtime";
inputs.nixpkgs.follows = "nixpkgs";
};
};
outputs =
{ nixpkgs
, build-env
, cuda-dev-essentials
, pytorch-runtime
, ...
}:
let
lib = nixpkgs.lib;
systems = [
"x86_64-linux"
"aarch64-linux"
"x86_64-darwin"
"aarch64-darwin"
];
linuxSystems = [
"x86_64-linux"
"aarch64-linux"
];
forAllSystems = f:
lib.genAttrs systems f;
pkgsFor = system:
import nixpkgs {
inherit system;
config = {
allowUnfree = true;
cudaSupport = isLinux system;
};
};
isLinux = system:
builtins.elem system linuxSystems;
getPackage = flake: system: name:
if builtins.hasAttr "packages" flake
&& builtins.hasAttr system flake.packages
&& builtins.hasAttr name flake.packages.${system}
then
flake.packages.${system}.${name}
else
throw "Expected flake to expose packages.${system}.${name}";
in
{
devShells = forAllSystems (system:
let
pkgs = pkgsFor system;
buildTools =
getPackage build-env system "default";
pytorchRuntime =
getPackage pytorch-runtime system "runtime";
cudaTools =
lib.optionals (isLinux system) [
(getPackage cuda-dev-essentials system "default")
];
in
{
default = pkgs.mkShell {
packages =
[
pkgs.uv
buildTools
pytorchRuntime
]
++ cudaTools;
# Use the Python interpreter from the PyTorch runtime layer so the
# venv can inherit torch and related runtime packages.
PYTHON_FOR_VENV = "${pytorchRuntime}/bin/python3";
shellHook = ''
export SSL_CERT_FILE=${pkgs.cacert}/etc/ssl/certs/ca-bundle.crt
export NIX_SSL_CERT_FILE=${pkgs.cacert}/etc/ssl/certs/ca-bundle.crt
''
+ lib.optionalString (isLinux system) (
let cudaPkgs = pkgs.cudaPackages_12_9; in ''
export LD_LIBRARY_PATH=${pkgs.gcc-unwrapped.lib}/lib:${cudaPkgs.cuda_cudart}/lib:${cudaPkgs.libcublas}/lib:${cudaPkgs.cudnn}/lib:${cudaPkgs.cuda_cupti}/lib:${cudaPkgs.nccl}/lib:${cudaPkgs.libcutensor}/lib''${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}
export CPATH=${cudaPkgs.cuda_cudart}/include:${cudaPkgs.libcublas}/include:${cudaPkgs.cudnn}/include:${cudaPkgs.nccl}/include:${cudaPkgs.libcutensor}/include''${CPATH:+:$CPATH}
export LIBRARY_PATH=${cudaPkgs.cuda_cudart}/lib:${cudaPkgs.libcublas}/lib:${cudaPkgs.cudnn}/lib:${cudaPkgs.cuda_cupti}/lib:${cudaPkgs.nccl}/lib:${cudaPkgs.libcutensor}/lib''${LIBRARY_PATH:+:$LIBRARY_PATH}
export CUDA_PATH=${cudaPkgs.cuda_nvcc}
export CUDA_HOME=$CUDA_PATH
'')
+ ''
export ML_TRAINING_CACHE="$PWD/.cache/ml-training"
export ML_TRAINING_VENV="$ML_TRAINING_CACHE/venv"
export UV_CACHE_DIR="$ML_TRAINING_CACHE/uv"
export PIP_CACHE_DIR="$ML_TRAINING_CACHE/pip"
mkdir -p "$ML_TRAINING_CACHE" "$UV_CACHE_DIR" "$PIP_CACHE_DIR"
ml_training_setup() {
set -euo pipefail
venv="$ML_TRAINING_VENV"
if [ ! -d "$venv" ]; then
uv venv "$venv" \
--python "$PYTHON_FOR_VENV" \
--system-site-packages
fi
if [ -f "$venv/bin/activate" ]; then
. "$venv/bin/activate"
fi
if [ ! -f "$ML_TRAINING_CACHE/.training_deps_installed" ]; then
uv pip install --python "$venv/bin/python" --quiet \
numpy datasets tokenizers transformers accelerate \
safetensors tensorboard scikit-learn tqdm pyyaml \
fastapi uvicorn
touch "$ML_TRAINING_CACHE/.training_deps_installed"
fi
}
ml_training_setup
'';
};
});
};
}With Flox
The model-training Flox environment composes the foundational PyTorch runtime with separate build and CUDA development environments declared using Flox [include] s. On Linux/NVIDIA CUDA, the included flox-labs/cuda-dev-essentials environment pulls in CUDA development dependencies, while the flox-labs/build-env environment provides GCC. macOS platforms (using Metal or MLX) skip CUDA and pull in clang from flox-labs/build-env, plus other platform-appropriate dependencies.
On both macOS and Linux, the manifest [hook] installs uv, defines cache-backed virtual environment and package-cache paths under $FLOX_ENV_CACHE, activates the project-local Python venv, and installs the model-training Python dependencies on top of the shared PyTorch runtime:
schema-version = "1.12.0"
[install]
uv.pkg-path = "uv"
[vars]
ML_TRAINING_VENV = "$FLOX_ENV_CACHE/venv"
UV_CACHE_DIR = "$FLOX_ENV_CACHE/uv"
PIP_CACHE_DIR = "$FLOX_ENV_CACHE/pip"
[hook]
on-activate = '''
ml_training_setup() {
venv="$FLOX_ENV_CACHE/venv"
if [ ! -d "$venv" ]; then
uv venv "$venv" --python python3
fi
if [ -f "$venv/bin/activate" ]; then
source "$venv/bin/activate"
fi
if [ ! -f "$FLOX_ENV_CACHE/.training_deps_installed" ]; then
uv pip install --python "$venv/bin/python" --quiet \
numpy datasets tokenizers transformers accelerate \
safetensors tensorboard scikit-learn tqdm pyyaml \
fastapi uvicorn
touch "$FLOX_ENV_CACHE/.training_deps_installed"
fi
}
ml_training_setup
'''
[include]
environments = [
{ remote = "flox-labs/build-env" },
{ remote = "flox-labs/cuda-dev-essentials" },
{ remote = "flox-labs/pytorch-runtime" }
]The [include] section is specific to Flox. It fulfills a function similar to the Nix flake’s top-level inputs set in that it declares the external inputs (in this case, other Flox manifests) that the composed environment will consume as dependencies. The difference has to do with the unit of composition: Nix fetches and locks the flake [inputs], then passes them to the [outputs] function. The flake author decides how to use those inputs to produce packages, dev shells, apps, containers, or other outputs. By contrast, the [include] section in the Flox manifest references other manifests, which Flox determines how to merge into the composing manifest. This eliminates the requirement to author and maintain wiring to compose these inputs, at the cost of slightly less control than a Nix flake.
The Nix flake expects to consume its flake inputs from GitHub; this Flox manifest composes remote FloxHub environments. It’s a minor difference, but probably worth calling out.
The end result
With both Nix and Flox, an ML/AI researcher using an M5 MacBook Pro gets the following stack:
| Environment | What resolves |
|---|---|
| `pytorch-runtime` | Python 3.13 + MPS-accelerated PyTorch |
| `cuda-dev-essentials` | *Nothing*: all packages restricted to Linux |
| `build-env` | clang, cmake, coreutils, GNU userland |
| `model-training` | uv, venv, numpy, transformers, etc. |While an ML/AI researcher working with CUDA (locally or on a GPU cluster) gets:
| Environment | What resolves |
|---|---|
| `pytorch-runtime` | Python 3.13 + CUDA-accelerated PyTorch |
| `cuda-dev-essentials` | nvcc, cudart, cublas, cudnn, nccl, profiler, debugger |
| `build-env` | gcc, gcc-unwrapped, glibc, cmake |
| `model-training` | uv, venv, numpy, transformers, etc. |Researchers can prototype or trail locally, emitting a *checkpoint.pt and a runtime.json, then upload or copy them to an artifact registry or model store. The runtime.json travels with the model.
Batch Model Training
Slurm by itself doesn't address the core challenge of getting every node in a cluster to run against the same ML/AI runtime, with the same dependency graph and the same versions of CUDA, Python, and other finicky dependencies. In the wild, users rely on module load commands to assemble the right CUDA, Python, compilers, libraries, and other dependencies. But module definitions are notorious for drifting across login and compute nodes: the command module load cuda/12.8 python/3.12 can (and often does) resolve differently across a cluster. Containerized ML workflows on Slurm often use HPC container runtimes like Singularity or Apptainer; these improve reproducibility, but must be configured for each cluster’s runtime setup, security policy, GPU/MPI settings, and Slurm conventions.
Alternatives like Conda have drawbacks too: Conda environments that bake in a large number of CUDA and Python dependencies can take a long time to resolve. Teams cannot copy a working Conda environment to a new prefix (i.e., path) and expect it to keep working. Unless every node sees the same shared environment path, teams typically need to add a separate packaging or distribution step.
So modules, HPC container runtimes, and Conda add operational layers on top of the ML/AI workload.
nanoGPT on Slurm
A composed Nix or Flox ML stack runs as-is on Slurm clusters, with the proviso that Nix or Flox are available on each node. You can run this stack from a single shared environment (accessible cluster-wide via NFS), or independently on each GPU node. No matter how you do it, Nix and Flox ensure each node gets the same packages and the same runtime environment, with the same env vars and secrets.
The following sub-sections show how this works with an example nanoGPT training job. Both the Nix and Flox environments consume two or more input environments (cross-platform PyTorch; Linux-only CUDA dev; cross-platform Python / general-purpose dev) to compose a single unified ML stack environment.
Allowing for the tutorial-specific config.sh wrapper (see below), Nix and Flox drop into the normal operating model for HPC systems. Submit jobs from the login node with sbatch. Slurm schedules the data prep, training, sampling, and eval jobs using standard sbatch –dependency job rules. But each job script runs its workload with Nix or Flox, so every node in the cluster gets the same pinned runtime.
Getting started
First clone the repo, then change into the nanogpt-slurm directory and edit the config.sh script.
This defines a run_in_env helper that dispatches to Flox or Nix based on what’s set in ENV_MANAGER. It makes it possible for the same Slurm scripts to run with either Flox or Nix. In a real-world deployment, you wouldn’t need this; rather, you’d pick Nix or Flox and call it directly.
Note: Clone this repo into a filesystem that’s visible across both the login node and the Slurm compute nodes, such as a shared NFS or GPFS mount. The job scripts source config.sh at runtime, so config.sh and any training code referenced by the scripts must be available on the compute node when the job starts.
If your cluster does not provide a shared filesystem, stage the repo onto the compute node yourself, either by cloning it as part of the job script; copying it to node-local scratch; or using sbcast. Another viable pattern is to use the Nix or Flox environment to provide tools like git, Python, CUDA, and PyTorch, and then to create an activation hook that clones a pinned revision of the training repo into $SLURM_TMPDIR before running the training command.
# Environment manager: "flox" or "nix"
ENV_MANAGER="${ENV_MANAGER:-flox}"
# Flox: FloxHub environment path (e.g., "youruser/nanogpt-slurm")
# Push your environment first: flox auth login && flox push
NANOGPT_FLOX_ENV="${NANOGPT_FLOX_ENV:-youruser/nanogpt-slurm}"
# Nix: Flake reference (GitHub repo, branch, pinned rev, or local path)
NANOGPT_FLAKE="${NANOGPT_FLAKE:-github:flox/ml-ai-lifecycle?dir=model-training}"
# Helper: run a command inside the chosen environment.
# Usage: run_in_env bash -c '...'
run_in_env() {
if [ "$ENV_MANAGER" = "nix" ]; then
nix develop "$NANOGPT_FLAKE" --command "$@"
else
flox activate -r "$NANOGPT_FLOX_ENV" -- "$@"
fi
}In this repo, every job script sources config.sh and calls:
source config.sh
run_in_env bash -c '...'If you plan to use Nix
- Edit
config.shand switch to Nix mode:
ENV_MANAGER="nix"
NANOGPT_FLAKE="github:flox/ml-ai-lifecycle?dir=model-training"The flake reference can be any valid flake URL:
| Form | Example |
|------|---------|
| GitHub repo (default) | `github:flox/ml-ai-lifecycle?dir=model-training ` |
| Specific branch | `github:flox/ml-ai-lifecycle/main?dir=model-training` |
| Pinned revision | `github:flox/ml-ai-lifecycle/<commit-sha>?dir=model-training` |
| Local path | `path:/shared/nfs/ml-ai-lifecycle?dir=model-training` |- Verify on the login node:
nix develop "$NANOGPT_FLAKE" --command python3 -c "import torch; print(torch.cuda.is_available())"Note: Consider configuring Nix binary substitution before running Slurm jobs at scale! Without a binary cache or shared Nix store, each GPU node must build CUDA dependencies the first time it runs the job. For PyTorch and other ML stacks, this can take an extremely long time: up to several hours.
Alternatively, if you use the Nix package manager, you can pull pre-built, pre-patched CUDA-accelerated packages from Flox’s binary cache. Just add Flox as an extra substituter in nix.conf, like so:
extra-substituters = https://cache.flox.dev
extra-trusted-public-keys = flox-cache-public-1:7F4OyH7ZCnFhcze3fJdfyXYLQw/aV7GEed86nQ7IsOs=If you plan to use Flox
-
Edit config.sh to set your FloxHub username
-
Run
flox pushto publish the environment to FloxHub. This way the Slurm GPU node pulls it dynamically at runtime. -
Verify on the login node:
flox activate -r flox-labs/model-training -- python3 -c "import torch; print(torch.cuda.is_available())"Compute nodes pull the environment by name at job start. No further setup needed on each node.
Running Slurm Jobs with Nix and/or Flox
Once config.sh is configured for either Nix or Flox, submitting Slurm jobs looks the same.
Each script in ./jobs/ is a standard Slurm batch script (with #SBATCH headers for resources like GPUs, CPUs, time limits). When a compute node runs the job, it does:
source config.sh # get Nix flake ref
nix develop "$NANOGPT_FLAKE" --command \ # realize/enter the Nix dev shell
python3 train.py ... # run nanoGPT inside that env
OR
source config.sh # get FloxHub env name
flox activate -r "$NANOGPT_FLOX_ENV" -- \ # pull env from FloxHub, activate it
python3 train.py ... # run nanoGPT inside that envThe first job set runs a Shakespeare smoke test to validate the setup. Paste this on the login node:
PREP=$(sbatch --parsable jobs/shakespeare-prep.sh) # CPU: tokenize data
TRAIN=$(sbatch --parsable --dependency=afterok:$PREP jobs/shakespeare-train.sh) # GPU: train
sbatch --dependency=afterok:$TRAIN jobs/shakespeare-sample.sh # GPU: generate textThis trains a small (~10M parameter) model and usually finishes in 10 minutes or fewer:
GPT-2 124M pipeline
GPT-2 training uses the same Nix and/or Flox dependencies, just with much longer-running jobs. This workflow requires us to make a decision about which dataset we want to use:
Train on GPU (≈10-20 hours / ≈100GB disk)
TRAIN=$(sbatch --parsable --dependency=afterok:$PREP jobs/gpt2-train.sh)Before submitting, check batch_size and gradient_accumulation_steps in jobs/gpt2-train.sh against your GPU's VRAM. The defaults target 32 GB (A100, RTX 5090):
| VRAM | GPUs | `batch_size` | `gradient_accumulation_steps` |
|------|------|:------------:|:-----------------------------:|
| 40-80 GB | A100, H100 | 32 | 16 |
| 32 GB | RTX 5090 | 16 | 32 |
| 24 GB | RTX 4090, A5000 | 12 | 43 |
| 16 GB | RTX 4080, T4 | 8 | 64 |
| 8-12 GB | RTX 3080, RTX 4060 | 4 | 128 |The two values should satisfy batch_size * gradient_accumulation_steps * 1024 ≈ 524,288. (Some rows round slightly due to integer constraints). A smaller batch_size means you need more gradient accumulation steps to compensate, so training takes longer but uses less VRAM.
Once you've tweaked those, submit the whole pipeline at once on the login node:
# Pick one dataset:
PREP=$(sbatch --parsable jobs/openwebtext-prep.sh) # Option A: OpenWebText (~2-4 hrs, ~20 GB)
# PREP=$(sbatch --parsable jobs/fineweb-prep.sh) # Option B: FineWeb-Edu 10BT (~4-6 hrs, ~40 GB)
# Train
TRAIN=$(sbatch --parsable --dependency=afterok:$PREP jobs/gpt2-train.sh)
# Inference (both wait for training, then run)
sbatch --dependency=afterok:$TRAIN jobs/gpt2-sample.sh
sbatch --dependency=afterok:$TRAIN jobs/gpt2-eval.shAllowing for the tutorial-specific config.sh wrapper, Nix and Flox drop into the normal operating model for shared HPC systems. Submit jobs from the login node with sbatch. Slurm schedules the data prep, training, sampling, and eval jobs using standard sbatch –dependency job rules. But each job script runs its workload with Nix or Flox, so every node in the cluster gets the same pinned runtime.
Promote models like software packages
Working with declared, graph-backed technologies like Nix and Flox gives teams a straightforward path from training → eval → CI → prod. They can reuse modular Nix or Flox environments as inputs for other environments, composing them to create ML/AI stacks. After batch training, teams can make use of this same pattern to compose modular environments for evaluation, benchmarking, checkpoint packaging, and release gating. Similarly, platform teams can compose their own Nix or Flox environments for staging, serving, canaries, observability, and production rollout. Everybody starts with the Nix/Flox PyTorch and/or CUDA diagnostic environments and compose them with their own use-case specific environments.
But because Nix and Flox are reproducible build systems they make it staightforward to package, publish, and pull software, too. (Full disclosure: Flox inherits this virtuous behavior from Nix.) So when training needs to hand off to eval, it can use Nix or Flox to package the checkpoint, model code, tokenizer assets, metadata, and runtime inputs before publishing it to a binary cache or to its private Flox Catalog.
But why do this? Because with declared, graph-backed technologies, all of the transitive dependencies that software needs in order to run get packaged along with it. So instead of saying “Here’s checkpoint.pt; have fun tracking down the right Python, PyTorch, CUDA, tokenizer, and so on!”, ML/AI researchers package and publish their model artifacts with the dependency graphs needed to use them.
Eval declares these packages in their own environments, reuses the PyTorch runtime, and gets everything it needs to score the model. CI pulls eval’s scored model package, runs its release gates, and publishes an approved artifact. MLOps registers that artifact, attaches release metadata, and promotes it. Platform teams can declare the package as an input to a container build, or use Nix/Flox to generate OCI images.
Declared technologies like Nix or Flox are not in any sense a panacea for ML work. But this pattern replaces the most error-prone part of the ML/AI lifecycle: the handoff. The result is a more regular promotion cycle: package the artifact, publish it, declare it downstream, and promote by reference.


