CLIP¶
from noxton.models import CLIP
from noxton.models.clip import clip_tokenize
CLIP (Radford et al., 2021) trains an image encoder and a text encoder jointly with a contrastive objective, enabling zero-shot image classification and image–text similarity scoring. Noxton includes the original OpenAI pretrained weights.
Variants¶
model |
Image encoder | Image size | Top-1 (zero-shot) |
|---|---|---|---|
RN50 |
ModifiedResNet-50 | 224 | 59.8% |
RN101 |
ModifiedResNet-101 | 224 | 62.3% |
RN50x4 |
ModifiedResNet-50×4 | 288 | 66.2% |
RN50x16 |
ModifiedResNet-50×16 | 384 | 70.8% |
RN50x64 |
ModifiedResNet-50×64 | 448 | 75.3% |
ViT-B/32 |
ViT-Base / patch 32 | 224 | 63.3% |
ViT-B/16 |
ViT-Base / patch 16 | 224 | 68.3% |
ViT-L/14 |
ViT-Large / patch 14 | 224 | 75.5% |
ViT-L/14@336px |
ViT-Large / patch 14 | 336 | 76.6% |
from_pretrained¶
CLIP.from_pretrained(
model: str,
dtype = None,
) -> tuple[CLIP, eqx.nn.State]
Weights are the original OpenAI releases downloaded from Azure CDN and cached at ~/.noxton/. No key argument is needed (weights are fully determined by the pretrained checkpoint).
__call__¶
clip(
image: Array, # (3, H, W) — single image
text: Array, # (77,) int32 token ids — single caption
state: eqx.nn.State,
key: PRNGKeyArray | None = None,
) -> tuple[Array, Array, eqx.nn.State]
# (logits_per_image,), (logits_per_text,), state
Use eqx.filter_vmap over text to score one image against multiple captions (see example).
encode_image¶
clip.encode_image(
image: Array, # (3, H, W)
state: eqx.nn.State,
) -> tuple[Array, eqx.nn.State] # (embed_dim,), state
encode_text¶
clip.encode_text(
text: Array, # (77,)
state: eqx.nn.State,
) -> tuple[Array, eqx.nn.State] # (embed_dim,), state
Tokenization¶
from noxton.models.clip import clip_tokenize
tokens = clip_tokenize(["a photo of a cat", "a photo of a dog"])
# tokens.shape -> (2, 77), dtype int32
# sequences are padded / truncated to 77 tokens
Example: zero-shot classification¶
import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx
from PIL import Image
from noxton.models import CLIP
from noxton.models.clip import clip_tokenize
# --- Load model ---
clip, state = CLIP.from_pretrained(model="ViT-B/32", dtype=jnp.float16)
clip, state = eqx.nn.inference_mode((clip, state))
# --- Prepare inputs ---
def preprocess_image(path, n_px=224):
img = Image.open(path).convert("RGB")
img = img.resize((n_px, n_px), Image.BICUBIC)
x = jnp.array(np.array(img), dtype=jnp.float16) / 255.0
x = jnp.transpose(x, (2, 0, 1)) # HWC -> CHW
mean = jnp.array([0.48145466, 0.4578275, 0.40821073]).reshape(3, 1, 1)
std = jnp.array([0.26862954, 0.26130258, 0.27577711]).reshape(3, 1, 1)
return (x - mean) / std
image = preprocess_image("cat.jpg")
text = clip_tokenize(["a photo of a human", "a photo of a cat", "a photo of a dog"])
# --- Run model: vmap over text candidates ---
logits_per_image, logits_per_text, state = eqx.filter_vmap(
clip,
in_axes=(None, 0, None), # broadcast image, map over text
out_axes=(0, 0, None),
axis_name="batch",
)(image, text, state)
probs = jax.nn.softmax(logits_per_image)
for label, p in zip(["human", "cat", "dog"], probs):
print(f"{label}: {p * 100:.1f}%")