Examples

All example scripts are in the examples/ directory. Each script loads a pretrained model, preprocesses cat.jpg, runs inference, and prints top-5 ImageNet predictions.


ResNet inference

import os
import jax
import jax.numpy as jnp
import equinox as eqx
import requests
from PIL import Image
from torchvision import transforms
from noxton.models import ResNet


def get_imagenet_labels():
    url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
    response = requests.get(url)
    return [line.strip() for line in response.text.splitlines()]


def preprocess_image(image_path):
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    image = Image.open(image_path)
    return jnp.array(preprocess(image).unsqueeze(0).numpy())


# Load model
model, state = ResNet.from_pretrained(
    model="resnet50",
    weights="resnet50_IMAGENET1K_V2",
    key=jax.random.key(0),
    dtype=jnp.float16,
)
model, state = eqx.nn.inference_mode((model, state))

# Preprocess and run
x = preprocess_image("examples/cat.jpg").astype(jnp.float16)
logits, _ = eqx.filter_vmap(
    model, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
)(x, state)

probs = jax.nn.softmax(logits[0])
labels = get_imagenet_labels()
top5 = jnp.argsort(probs)[-5:][::-1]
for i, idx in enumerate(top5):
    print(f"{i+1}. {labels[int(idx)]}: {float(probs[idx]):.4f}")

CLIP zero-shot classification

import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx
from PIL import Image
from noxton.models import CLIP
from noxton.models.clip import clip_tokenize


def preprocess_image(path, n_px=224):
    img = Image.open(path).convert("RGB").resize((n_px, n_px), Image.BICUBIC)
    x = jnp.array(np.array(img), dtype=jnp.float32) / 255.0
    x = jnp.transpose(x, (2, 0, 1))  # HWC -> CHW
    mean = jnp.array([0.48145466, 0.4578275,  0.40821073]).reshape(3, 1, 1)
    std  = jnp.array([0.26862954, 0.26130258, 0.27577711]).reshape(3, 1, 1)
    return ((x - mean) / std).astype(jnp.float16)


image  = preprocess_image("examples/cat.jpg")
text   = clip_tokenize(["a photo of a human", "a photo of a cat", "a photo of a dog"])

clip, state = CLIP.from_pretrained(model="ViT-B/32", dtype=jnp.float16)
clip, state = eqx.nn.inference_mode((clip, state))

logits_per_image, _, state = eqx.filter_vmap(
    clip,
    in_axes=(None, 0, None),
    out_axes=(0, 0, None),
    axis_name="batch",
)(image, text, state)

probs = jax.nn.softmax(logits_per_image)
for label, p in zip(["human", "cat", "dog"], probs):
    print(f"{label}: {p * 100:.1f}%")

EfficientNet inference

import functools
import jax
import jax.numpy as jnp
import equinox as eqx
import torchvision.transforms as transforms
from PIL import Image
from torchvision.models import EfficientNet_B0_Weights
from noxton.models import EfficientNet


model, state = EfficientNet.from_pretrained(
    "efficientnet_b0",
    weights="efficientnet_b0_IMAGENET1K_V1",
    dtype=jnp.float16,
)
model, state = eqx.nn.inference_mode((model, state))

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])
img = Image.open("examples/cat.jpg")
x = jnp.array(preprocess(img).unsqueeze(0).numpy(), dtype=jnp.float16)

key = jax.random.key(0)
model_fn = functools.partial(model, key=key)
logits, _ = eqx.filter_vmap(
    model_fn, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
)(x, state)

categories = EfficientNet_B0_Weights.IMAGENET1K_V1.meta["categories"]
probs = jax.nn.softmax(logits[0])
_, top5 = jax.lax.top_k(probs, 5)
for i, idx in enumerate(top5):
    print(f"{i+1}. {categories[idx]} ({probs[idx] * 100:.2f}%)")

Swin Transformer inference

import jax
import jax.numpy as jnp
import equinox as eqx
from torchvision import transforms
from PIL import Image
from noxton.models import SwinTransformer

# V1 uses 224px input
preprocess_v1 = transforms.Compose([
    transforms.Resize(232),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# V2 uses 256px input
preprocess_v2 = transforms.Compose([
    transforms.Resize(260),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

for model_name, weights, preprocess in [
    ("swin_t",    "swin_t_IMAGENET1K_V1",    preprocess_v1),
    ("swin_v2_t", "swin_v2_t_IMAGENET1K_V1", preprocess_v2),
]:
    model, state = SwinTransformer.from_pretrained(
        model=model_name, weights=weights,
        key=jax.random.key(0), dtype=jnp.float32,
    )
    model, state = eqx.nn.inference_mode((model, state))

    img = Image.open("examples/cat.jpg")
    x = jnp.array(preprocess(img).unsqueeze(0).numpy())

    logits, _ = eqx.filter_vmap(
        model, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
    )(x, state)

    probs = jax.nn.softmax(logits[0])
    top1  = int(jnp.argmax(probs))
    print(f"{model_name}: top-1 class {top1} ({float(probs[top1]):.4f})")