AlexNet

from noxton.models import AlexNet

AlexNet is the landmark 8-layer CNN that won ImageNet 2012 (Krizhevsky et al., 2012). The architecture is:

Conv(11×11, s=4) → ReLU → MaxPool
Conv(5×5)        → ReLU → MaxPool
Conv(3×3)        → ReLU
Conv(3×3)        → ReLU
Conv(3×3)        → ReLU → MaxPool
AdaptiveAvgPool(6×6)
Dropout → Linear(9216, 4096) → ReLU
Dropout → Linear(4096, 4096) → ReLU
Linear(4096, n_classes)

Local Response Normalisation (LRN) is applied after the first and second pooling layers.


Constructor

AlexNet(
    *,
    n_classes: int,
    key: PRNGKeyArray,
    inference: bool = False,
    dtype = None,
)
Parameter Description
n_classes Number of output classes. Use 1000 for ImageNet.
key JAX PRNG key for parameter initialisation.
inference Disable dropout. Default False.
dtype Parameter dtype. Default: project default (float32).

from_pretrained

AlexNet.from_pretrained(
    weights: str = "alexnet_IMAGENET1K_V1",
    key: PRNGKeyArray = ...,
    dtype = None,
) -> tuple[AlexNet, eqx.nn.State]
weights Dataset Top-1
alexnet_IMAGENET1K_V1 ImageNet-1K 56.5%

Weights are downloaded from PyTorch Hub and cached at ~/.noxton/.


__call__

model(
    x: Array,                # (3, H, W) — single image, no batch dim
    state: eqx.nn.State,
    key: PRNGKeyArray | None = None,
) -> tuple[Array, eqx.nn.State]   # (n_classes,), state

Example

import os
import jax
import jax.numpy as jnp
import equinox as eqx
from PIL import Image
from torchvision import transforms
from noxton.models import AlexNet

# Load pretrained model
model, state = AlexNet.from_pretrained(
    weights="alexnet_IMAGENET1K_V1",
    key=jax.random.key(0),
    dtype=jnp.float32,
)
model, state = eqx.nn.inference_mode((model, state))

# Preprocess
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])
image = Image.open("cat.jpg")
x = jnp.array(preprocess(image).unsqueeze(0).numpy())   # (1, 3, 224, 224)

# Batch inference
logits, _ = eqx.filter_vmap(
    model, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
)(x, state)

probs = jax.nn.softmax(logits[0])
top5 = jnp.argsort(probs)[-5:][::-1]