Convolution¶
from noxton.nn import ConvNormActivation
ConvNormActivation¶
Composable Conv → Normalization → Activation block. Stacks a single eqx.nn.Conv layer, an optional normalisation layer, and an optional activation function into one Equinox StatefulLayer.
When padding is None, same-padding is computed automatically from kernel_size and dilation so that spatial dimensions are preserved (assuming stride=1). When use_bias is None it defaults to True only when no norm_layer is provided (normalisation makes biases redundant).
Constructor¶
ConvNormActivation(
num_spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: int | Sequence[int] = 3,
stride: int | Sequence[int] = 1,
padding = None,
groups: int = 1,
activation_layer = jax.nn.relu,
dilation: int | Sequence[int] = 1,
use_bias: bool | None = None,
*,
norm_layer: Callable | None,
key: PRNGKeyArray,
dtype,
)
| Parameter | Description |
|---|---|
num_spatial_dims |
Number of spatial dimensions, e.g. 2 for images. |
in_channels |
Number of input channels. |
out_channels |
Number of output channels (filters). |
kernel_size |
Convolution kernel size. Default 3. |
stride |
Convolution stride. Default 1. |
padding |
Explicit padding or None for auto same-padding. |
groups |
Number of blocked connections (grouped convolution). Default 1. |
activation_layer |
Activation callable or None to skip. Default jax.nn.relu. |
dilation |
Kernel dilation. Default 1. |
use_bias |
Bias in convolution. None = True iff no norm layer. |
norm_layer |
Zero-argument factory returning an AbstractNorm/AbstractNormStateful, or None. |
__call__¶
block(
x: Array, # (in_channels, *spatial_dims)
state: eqx.nn.State,
key = None,
) -> tuple[Array, eqx.nn.State]
Example¶
import jax
import jax.numpy as jnp
import equinox as eqx
from functools import partial
from noxton.nn import ConvNormActivation, BatchNorm
key = jax.random.PRNGKey(0)
norm_factory = partial(BatchNorm, size=32, axis_name="batch")
block = ConvNormActivation(
num_spatial_dims=2,
in_channels=16,
out_channels=32,
kernel_size=3,
norm_layer=norm_factory,
key=key,
dtype=jnp.float32,
)
state = eqx.nn.State(block)
x = jax.random.normal(key, (4, 16, 28, 28)) # (batch, C, H, W)
out, state = jax.vmap(
block, axis_name="batch", in_axes=(0, None), out_axes=(0, None)
)(x, state)
# out.shape -> (4, 32, 28, 28)