Swin Transformer¶
from noxton.models import SwinTransformer
Swin Transformer (Liu et al., 2021) introduces a hierarchical vision backbone using shifted window attention. Windows partition the feature map so attention is computed locally, then shifted between layers to allow cross-window connections. SwinV2 (Liu et al., 2022) adds continuous relative position bias and other stabilisation improvements.
Variants¶
Swin V1¶
model |
Params | Input size | Top-1 |
|---|---|---|---|
swin_t |
28.3M | 224 | 81.5% |
swin_s |
49.6M | 224 | 83.2% |
swin_b |
87.8M | 224 | 83.5% |
Swin V2¶
model |
Params | Input size | Top-1 |
|---|---|---|---|
swin_v2_t |
28.4M | 256 | 82.0% |
swin_v2_s |
49.7M | 256 | 83.7% |
swin_v2_b |
87.9M | 256 | 84.6% |
Note
V1 and V2 models use different input resolutions. Use 232/224 crop for V1 and 260/256 for V2. See preprocessing below.
from_pretrained¶
SwinTransformer.from_pretrained(
model: str,
weights: str,
key: PRNGKeyArray,
dtype = None,
) -> tuple[SwinTransformer, eqx.nn.State]
Available weights¶
weights |
Top-1 |
|---|---|
swin_t_IMAGENET1K_V1 |
81.5% |
swin_s_IMAGENET1K_V1 |
83.2% |
swin_b_IMAGENET1K_V1 |
83.5% |
swin_v2_t_IMAGENET1K_V1 |
82.0% |
swin_v2_s_IMAGENET1K_V1 |
83.7% |
swin_v2_b_IMAGENET1K_V1 |
84.6% |
__call__¶
model(
x: Array, # (3, H, W)
state: eqx.nn.State,
key: PRNGKeyArray | None = None,
) -> tuple[Array, eqx.nn.State] # (1000,), state
Example¶
import jax
import jax.numpy as jnp
import equinox as eqx
from noxton.models import SwinTransformer
# V1 — input 224
model_v1, state_v1 = SwinTransformer.from_pretrained(
model="swin_t",
weights="swin_t_IMAGENET1K_V1",
key=jax.random.key(0),
dtype=jnp.float32,
)
model_v1, state_v1 = eqx.nn.inference_mode((model_v1, state_v1))
images_v1 = jnp.zeros((2, 3, 224, 224))
logits_v1, _ = eqx.filter_vmap(
model_v1, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
)(images_v1, state_v1)
# V2 — input 256
model_v2, state_v2 = SwinTransformer.from_pretrained(
model="swin_v2_t",
weights="swin_v2_t_IMAGENET1K_V1",
key=jax.random.key(0),
dtype=jnp.float32,
)
model_v2, state_v2 = eqx.nn.inference_mode((model_v2, state_v2))
images_v2 = jnp.zeros((2, 3, 256, 256))
logits_v2, _ = eqx.filter_vmap(
model_v2, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
)(images_v2, state_v2)
Input preprocessing¶
Swin V1 (224px)
from torchvision import transforms
preprocess = transforms.Compose([
transforms.Resize(232),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
Swin V2 (256px)
from torchvision import transforms
preprocess = transforms.Compose([
transforms.Resize(260),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])