Transformer

from noxton.nn import (
    TransformerEncoderLayer,
    TransformerDecoderLayer,
    TransformerEncoder,
    TransformerDecoder,
    Transformer,
    VisionTransformer,
)

TransformerEncoderLayer

Single encoder layer: multi-head self-attention + feed-forward MLP, each followed by layer normalisation and a residual connection.

Constructor

TransformerEncoderLayer(
    d_model: int,
    nhead: int,
    dim_feedforward: int = 2048,
    dropout: float = 0.1,
    activation = jax.nn.relu,
    layer_norm_eps: float = 1e-5,
    norm_first: bool = False,
    bias: bool = True,
    inference: bool = False,
    *,
    key: PRNGKeyArray,
    dtype = None,
)

__call__

layer(
    src: Array,                 # (seq_len, d_model)
    src_mask = None,
    src_key_padding_mask = None,
    is_causal: bool = False,
    key = None,
) -> Array                      # (seq_len, d_model)

TransformerDecoderLayer

Single decoder layer: self-attention + cross-attention + MLP, each with layer norm and residual.

Constructor

TransformerDecoderLayer(
    d_model: int,
    nhead: int,
    dim_feedforward: int = 2048,
    dropout: float = 0.1,
    activation = jax.nn.relu,
    layer_norm_eps: float = 1e-5,
    norm_first: bool = False,
    bias: bool = True,
    inference: bool = False,
    *,
    key: PRNGKeyArray,
    dtype = None,
)

__call__

layer(
    tgt: Array,                 # (tgt_len, d_model)
    memory: Array,              # (src_len, d_model)
    tgt_mask = None,
    memory_mask = None,
    tgt_key_padding_mask = None,
    memory_key_padding_mask = None,
    tgt_is_causal: bool = False,
    memory_is_causal: bool = False,
    key = None,
) -> Array                      # (tgt_len, d_model)

TransformerEncoder

Stack of TransformerEncoderLayers with optional final LayerNorm.

Constructor

TransformerEncoder(
    encoder_layer: TransformerEncoderLayer,
    num_layers: int,
    norm = None,
)

__call__

encoder(
    src: Array,
    mask = None,
    src_key_padding_mask = None,
    is_causal: bool = False,
    key = None,
) -> Array

TransformerDecoder

Stack of TransformerDecoderLayers with optional final LayerNorm.

Constructor

TransformerDecoder(
    decoder_layer: TransformerDecoderLayer,
    num_layers: int,
    norm = None,
)

Transformer

Full encoder-decoder architecture.

Constructor

Transformer(
    d_model: int = 512,
    nhead: int = 8,
    num_encoder_layers: int = 6,
    num_decoder_layers: int = 6,
    dim_feedforward: int = 2048,
    dropout: float = 0.1,
    activation = jax.nn.relu,
    layer_norm_eps: float = 1e-5,
    norm_first: bool = False,
    bias: bool = True,
    inference: bool = False,
    *,
    key: PRNGKeyArray,
    dtype = None,
)

__call__

transformer(
    src: Array,                 # (src_len, d_model)
    tgt: Array,                 # (tgt_len, d_model)
    src_mask = None,
    tgt_mask = None,
    memory_mask = None,
    src_key_padding_mask = None,
    tgt_key_padding_mask = None,
    memory_key_padding_mask = None,
    src_is_causal: bool = False,
    tgt_is_causal: bool = False,
    memory_is_causal: bool = False,
    key = None,
) -> Array                      # (tgt_len, d_model)

VisionTransformer

Vision Transformer (ViT) with patch embedding and CLS token. Divides the input image into patches, linearly embeds each patch, prepends a learnable [CLS] token, adds positional embeddings, and passes the sequence through a stack of ResidualAttentionBlocks.

Constructor

VisionTransformer(
    image_size: int,
    patch_size: int,
    width: int,
    layers: int,
    heads: int,
    output_dim: int,
    *,
    key: PRNGKeyArray,
    dtype = None,
)
Parameter Description
image_size Input image resolution (square).
patch_size Size of each patch (square).
width Transformer hidden dimension.
layers Number of transformer blocks.
heads Number of attention heads.
output_dim Dimension of the projected output (CLS token after projection).

__call__

vit(x: Array) -> Array   # (C, H, W) -> (output_dim,)

Example

import jax
import jax.numpy as jnp
from noxton.nn import VisionTransformer

key = jax.random.PRNGKey(0)
vit = VisionTransformer(
    image_size=224,
    patch_size=16,
    width=768,
    layers=12,
    heads=12,
    output_dim=512,
    key=key,
)

image = jax.random.normal(key, (3, 224, 224))
features = vit(image)
# features.shape -> (512,)