Getting Started¶
Installation¶
Install noxton from PyPI:
pip install noxton
Optional dependencies¶
| Extra | When to install |
|---|---|
noxton[hub] |
Required to load ESM models (fetches weights from HuggingFace Hub) |
torch |
Required only during the first from_pretrained() call to convert PyTorch weights. Not needed at inference time after the first run. |
Requirements¶
- Python ≥ 3.13
- JAX ≥ 0.4 (CPU, CUDA, or Metal backend)
- Equinox
Basic usage¶
Loading a pretrained model¶
All pretrained models share the same from_pretrained() interface:
import jax.numpy as jnp
from noxton.models import ResNet
model, state = ResNet.from_pretrained(
model="resnet50",
weights="resnet50_IMAGENET1K_V2",
key=jax.random.key(0),
dtype=jnp.float32,
)
Weights are downloaded once to ~/.noxton/ and reused on subsequent calls.
Inference mode¶
Call eqx.nn.inference_mode to freeze BatchNorm running statistics and disable dropout:
import equinox as eqx
model, state = eqx.nn.inference_mode((model, state))
Batched inference¶
Use eqx.filter_vmap to run a model over a batch. The axis_name="batch" argument is required for models that contain BatchNorm:
import jax
import jax.numpy as jnp
import equinox as eqx
images = jnp.zeros((8, 3, 224, 224)) # batch of 8
logits, _ = eqx.filter_vmap(
model,
in_axes=(0, None), # vmap over images, broadcast state
out_axes=(0, None), # stack logits, pass state through
axis_name="batch",
)(images, state)
probs = jax.nn.softmax(logits, axis=-1) # (8, 1000)
Training¶
For training you typically want to:
- Keep
inference=False(the default) so BatchNorm tracks running statistics. - Split parameters from state using
eqx.partition/eqx.filter. - Pass a
keyargument to layers that use dropout.
import optax
optimizer = optax.adamw(1e-3)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
@eqx.filter_jit
def train_step(model, state, opt_state, images, labels):
def loss_fn(model):
logits, new_state = eqx.filter_vmap(
model, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
)(images, state)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss, new_state
(loss, new_state), grads = eqx.filter_value_and_grad(loss_fn, has_aux=True)(model)
updates, opt_state = optimizer.update(grads, opt_state, eqx.filter(model, eqx.is_array))
model = eqx.apply_updates(model, updates)
return model, new_state, opt_state, loss
Weight cache¶
Downloaded weights are stored at:
~/.noxton/
pytorch_weights/ # raw .pth / .pt files from the internet
models/ # converted JAX checkpoints (per dtype)
Delete these directories to force a fresh download and conversion.
dtype support¶
All models accept a dtype argument:
# half-precision inference (faster on GPU/TPU)
model, state = ResNet.from_pretrained("resnet50", "resnet50_IMAGENET1K_V2",
key=jax.random.key(0), dtype=jnp.float16)
Supported dtypes: jnp.float32, jnp.float16, jnp.bfloat16.