Normalization

from noxton.nn import BatchNorm, LayerNorm, LocalResponseNormalization

BatchNorm

Batch normalisation for use inside jax.vmap / jax.lax.pmean. Normalises each channel across the batch using statistics gathered via jax.lax.pmean over the named axis_name. Running mean and variance are maintained in an Equinox State object.

At inference time (inference=True) stored running statistics are used instead of batch statistics.

Constructor

BatchNorm(
    size: int,
    axis_name: Hashable | Sequence[Hashable],
    eps: float = 1e-5,
    momentum: float = 0.1,
    affine: bool = True,
    inference: bool = False,
    dtype = None,
)
Parameter Description
size Number of channels to normalise.
axis_name The vmap/pmap axis name used for cross-batch aggregation.
eps Numerical stability constant. Default 1e-5.
momentum EMA factor for running statistics. Default 0.1.
affine Learn per-channel scale and shift. Default True.
inference Use running statistics. Default False.

__call__

bn(
    x: Array,          # (size,) or (size, *spatial_dims)
    state: eqx.nn.State,
    key = None,
) -> tuple[Array, eqx.nn.State]

Must be called inside a jax.vmap scope that maps over axis_name.

Example

import jax
import jax.numpy as jnp
import equinox as eqx
from noxton.nn import BatchNorm

bn = BatchNorm(size=16, axis_name="batch")
state = eqx.nn.State(bn)

x = jax.random.normal(jax.random.PRNGKey(0), (4, 16))  # batch of 4

out, state = jax.vmap(
    bn, axis_name="batch", in_axes=(0, None), out_axes=(0, None)
)(x, state)
# out.shape -> (4, 16)

LayerNorm

Layer normalisation over the last one or more dimensions. Normalises by subtracting the mean and dividing by the standard deviation, then optionally applies learnable affine parameters weight (scale) and bias (shift). Computation is performed at float32 precision minimum and cast back to the input dtype.

Constructor

LayerNorm(
    shape: int | Sequence[int],
    eps: float = 1e-5,
    use_weight: bool = True,
    use_bias: bool = True,
    dtype = None,
)

__call__

ln(x: Array, key=None) -> Array

Example

import jax.numpy as jnp
from noxton.nn import LayerNorm

ln = LayerNorm(shape=64)
x = jnp.ones((10, 64))
out = ln(x)
# out.shape -> (10, 64)

# 2-D normalisation
ln2d = LayerNorm(shape=(28, 28))
x2d = jnp.ones((16, 28, 28))
out2d = ln2d(x2d)
# out2d.shape -> (16, 28, 28)

LocalResponseNormalization

Local Response Normalisation (LRN) across adjacent channels, as used in AlexNet (Krizhevsky et al., 2012):

b[c,h,w] = x[c,h,w] / (k + alpha * sum_j x[j,h,w]^2)^beta

where the sum runs over at most n channels centred on c.

Constructor

LocalResponseNormalization(
    k: int = 2,
    n: int = 5,
    alpha: float = 1e-4,
    beta: float = 0.75,
)

__call__

lrn(x: Array) -> Array   # (C, H, W) -> (C, H, W)

Example

import jax.numpy as jnp
from noxton.nn import LocalResponseNormalization

lrn = LocalResponseNormalization()
x = jnp.ones((8, 14, 14))
out = lrn(x)
# out.shape -> (8, 14, 14)