NN Layers — Overview¶
noxton.nn provides JAX/Equinox implementations of common neural network building blocks. All classes are eqx.Module subclasses and accept key: PRNGKeyArray and dtype constructor arguments.
Module list¶
| Module | Class(es) | Description |
|---|---|---|
| Attention | MultiheadAttention, SqueezeExcitation |
Multi-head scaled dot-product attention; SE channel recalibration |
| Normalization | BatchNorm, LayerNorm, LocalResponseNormalization |
Batch, layer, and LRN normalization |
| Convolution | ConvNormActivation |
Conv → Norm → Activation block |
| Embedding | EmbeddingWithPadding, EmbeddingBag |
Embedding tables with padding support |
| Regularization | StochasticDepth |
DropPath / stochastic depth |
| Linear | BatchedLinear |
Linear layer with arbitrary batch dimensions |
| State Space (Mamba) | SelectiveStateSpaceModel, MambaBlock, ResidualBlock, Mamba |
Mamba sequence model components |
| Transformer | TransformerEncoderLayer, TransformerDecoderLayer, TransformerEncoder, TransformerDecoder, Transformer, VisionTransformer |
Transformer encoder/decoder stack; ViT |
Importing¶
from noxton.nn import (
MultiheadAttention,
SqueezeExcitation,
ConvNormActivation,
BatchNorm,
LayerNorm,
LocalResponseNormalization,
EmbeddingWithPadding,
EmbeddingBag,
StochasticDepth,
BatchedLinear,
SelectiveStateSpace,
SelectiveStateSpaceModel,
MambaBlock,
Mamba,
TransformerEncoder,
TransformerDecoder,
Transformer,
VisionTransformer,
)
Abstract base classes¶
Two abstract base classes define the interface for normalization layers used internally by ConvNormActivation:
AbstractNorm— stateless normalisation (e.g.LayerNorm).__call__(x) -> x.AbstractNormStateful— stateful normalisation (e.g.BatchNorm).__call__(x, state) -> (x, state).
Custom norm layers can extend either base class to be compatible with ConvNormActivation.