Linear¶
from noxton.nn import BatchedLinear
BatchedLinear¶
Linear layer that natively handles arbitrarily-batched inputs without requiring an explicit outer vmap. Accepts inputs with any number of leading batch dimensions and reshapes internally. Supports both real and complex dtypes.
Weights and biases are initialised with a uniform distribution in [-1/√in_features, 1/√in_features].
Constructor¶
BatchedLinear(
in_features: int | Literal["scalar"],
out_features: int | Literal["scalar"],
use_bias: bool = True,
dtype = None,
*,
key: PRNGKeyArray,
)
| Parameter | Description |
|---|---|
in_features |
Size of the last input dimension, or "scalar". |
out_features |
Size of the last output dimension, or "scalar". |
use_bias |
Add a learnable bias vector. Default True. |
dtype |
Parameter dtype. Supports complex dtypes. Default: project default. |
__call__¶
linear(
x: Array, # (*batch_dims, in_features)
key = None,
) -> Array # (*batch_dims, out_features)
Example¶
import jax
import jax.numpy as jnp
from noxton.nn import BatchedLinear
key = jax.random.PRNGKey(0)
linear = BatchedLinear(in_features=8, out_features=4, key=key)
# Arbitrary batch dimensions
x = jax.random.normal(key, (3, 5, 8))
out = linear(x)
# out.shape -> (3, 5, 4)
# Works on plain 1-D vectors too
out1d = linear(jnp.ones(8))
# out1d.shape -> (4,)