Noxton¶
Noxton is a JAX/Equinox library of neural network building blocks and pretrained models. It provides:
noxton.nn— low-level layers (attention, normalization, convolution, Mamba, Transformer, …) that compose naturally with Equinoxnoxton.models— high-level vision and language models withfrom_pretrained()loaders that download and convert PyTorch weights on demand
All modules are pure JAX — no hidden state, no magic globals. Models are eqx.Module subclasses so they work with jax.jit, jax.vmap, jax.grad, and the rest of the JAX ecosystem out of the box.
Available models¶
| Model | Variants | Task |
|---|---|---|
| AlexNet | alexnet | ImageNet classification |
| ResNet | resnet18/34/50/101/152, resnext50/101, wide_resnet50/101 | ImageNet classification |
| ConvNeXt | tiny, small, base, large | ImageNet classification |
| EfficientNet | B0–B7, V2-S/M/L | ImageNet classification |
| Swin Transformer | swin_t/s/b, swin_v2_t/s/b | ImageNet classification |
| CLIP | RN50, RN101, ViT-B/32, ViT-B/16, ViT-L/14 | Zero-shot image–text |
| ESM | ESM3, ESMC | Protein language modelling |
Quick example¶
import jax
import jax.numpy as jnp
import equinox as eqx
from noxton.models import ResNet
# Load ResNet-50 with pretrained ImageNet V2 weights
model, state = ResNet.from_pretrained(
model="resnet50",
weights="resnet50_IMAGENET1K_V2",
key=jax.random.key(0),
dtype=jnp.float32,
)
# Switch to inference mode (disables BatchNorm updates, dropout)
model, state = eqx.nn.inference_mode((model, state))
# Batch inference via filter_vmap
images = jax.random.normal(jax.random.key(1), (4, 3, 224, 224))
logits, _ = eqx.filter_vmap(
model, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
)(images, state)
probs = jax.nn.softmax(logits, axis=-1) # (4, 1000)
Design principles¶
Functional first. Every layer is an eqx.Module. Stateful operations (BatchNorm running statistics, dropout keys) flow through explicit arguments — there are no hidden side effects.
Composable. All noxton.nn primitives accept a key: PRNGKeyArray constructor argument and a dtype argument. They can be freely mixed with standard Equinox and JAX primitives.
Pretrained weights. Weights are downloaded once to ~/.noxton/ and converted from PyTorch .pth/.pt files using statedict2pytree. Once converted, PyTorch is not required at inference time.