Attention

from noxton.nn import MultiheadAttention, SqueezeExcitation

MultiheadAttention

A 1-to-1 JAX/Equinox port of torch.nn.MultiheadAttention. Splits queries, keys and values into num_heads independent attention heads, computes scaled dot-product attention for each head in parallel, and projects the concatenated outputs back to embed_dim.

Note

This implementation is intentionally API-compatible with torch.nn.MultiheadAttention. Unless you specifically need that compatibility, prefer eqx.nn.MultiheadAttention, which is more idiomatic for JAX.

Constructor

MultiheadAttention(
    embed_dim: int,
    num_heads: int,
    dropout: float = 0.0,
    bias: bool = True,
    add_bias_kv: bool = False,
    add_zero_attn: bool = False,
    kdim: int | None = None,
    vdim: int | None = None,
    inference: bool = False,
    *,
    key: PRNGKeyArray,
    dtype = None,
)
Parameter Description
embed_dim Total dimensionality of the model (query) embeddings.
num_heads Number of attention heads. embed_dim must be divisible by num_heads.
dropout Dropout probability on attention weights during training. Default 0.0.
bias Add learnable bias to input and output projections. Default True.
add_bias_kv Append learnable bias vectors to key and value sequences. Default False.
add_zero_attn Append a batch of zeros to key and value sequences. Default False.
kdim Dimensionality of key inputs. Defaults to embed_dim.
vdim Dimensionality of value inputs. Defaults to embed_dim.
inference Disable dropout (eval mode). Default False.
key JAX PRNG key for parameter initialisation.
dtype Parameter dtype. Defaults to project default.

__call__

mha(
    query: Array,           # (tgt_len, embed_dim)
    key: Array,             # (src_len, kdim)
    value: Array,           # (src_len, vdim)
    key_padding_mask = None,
    need_weights: bool = True,
    attn_mask = None,
    average_attn_weights: bool = True,
    is_causal: bool = False,
    dropout_key = None,
) -> tuple[Array, Array | None]

Returns (attn_output, attn_weights) where attn_output has shape (tgt_len, embed_dim) and attn_weights is None when need_weights=False, shape (tgt_len, src_len) when averaged, or (num_heads, tgt_len, src_len) per-head.

Example

import jax
import jax.numpy as jnp
from noxton.nn import MultiheadAttention

key = jax.random.PRNGKey(0)
mha = MultiheadAttention(embed_dim=64, num_heads=4, key=key)

q = jax.random.normal(key, (10, 64))   # (seq_len, embed_dim)
out, weights = mha(q, q, q)            # self-attention
# out.shape    -> (10, 64)
# weights.shape -> (10, 10)

# Causal (decoder) attention
out, _ = mha(q, q, q, is_causal=True, need_weights=False)

SqueezeExcitation

Squeeze-and-Excitation channel-attention block (Hu et al., 2018). Globally pools the spatial dimensions of a feature map, passes the result through two 1×1 convolutions to produce a per-channel scale vector, and multiplies it back into the input.

scale = sigmoid(fc2(relu(fc1(avgpool(x)))))
output = scale * x

Constructor

SqueezeExcitation(
    input_channels: int,
    squeeze_channels: int,
    *,
    key: PRNGKeyArray,
    dtype = None,
)
Parameter Description
input_channels Number of channels in the input feature map.
squeeze_channels Bottleneck width (typically input_channels // reduction_ratio).

__call__

se(
    x: Array,                            # (C, H, W)
    activation = jax.nn.relu,
    scale_activation = jax.nn.sigmoid,
) -> Array                               # (C, H, W)

Example

import jax
from noxton.nn import SqueezeExcitation

key = jax.random.PRNGKey(0)
se = SqueezeExcitation(input_channels=64, squeeze_channels=16, key=key)

x = jax.random.normal(key, (64, 28, 28))  # (C, H, W)
out = se(x)
# out.shape -> (64, 28, 28)