State Space (Mamba)¶
from noxton.nn import SelectiveStateSpaceModel, MambaBlock, Mamba
Noxton implements the Mamba architecture from Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Gu & Dao, 2023). The hierarchy is:
Mamba ← full language model (embedding + blocks + LM head)
└─ ResidualBlock ← RMSNorm + MambaBlock + residual
└─ MambaBlock ← conv + SelectiveStateSpaceModel + gating
└─ SelectiveStateSpaceModel ← the core SSM recurrence
A lower-level SelectiveStateSpace class in noxton.nn exposes the SSM with an additional output projection.
SelectiveStateSpaceModel¶
The core SSM recurrence without an output projection. Implements input-dependent state-space dynamics:
- Project each token to
(dt_rank + 2 * d_state)to obtaindelta,B,C. - Project
deltafromdt_rank→d_innerand applysoftplus. - Run the selective scan recurrence.
Constructor¶
SelectiveStateSpaceModel(
d_inner: int,
dt_rank: int,
d_state: int,
use_input_proj_bias: bool = False,
use_delta_proj_bias: bool = False,
*,
key: PRNGKeyArray,
dtype = None,
)
__call__¶
ssm(x: Array) -> Array # (seq_len, d_inner) -> (seq_len, d_inner)
MambaBlock¶
Complete Mamba block with 1-D depthwise convolution, SSM, and output gating.
Constructor¶
MambaBlock(
d_model: int,
d_inner: int,
dt_rank: int,
d_state: int,
d_conv: int,
*,
key: PRNGKeyArray,
dtype = None,
)
__call__¶
block(x: Array) -> Array # (seq_len, d_model) -> (seq_len, d_model)
Mamba¶
Full Mamba language model: token embedding → stack of ResidualBlocks → LM head.
Constructor¶
Mamba(
vocab_size: int,
d_model: int,
d_state: int = 16,
d_conv: int = 4,
expand: int = 2,
n_layers: int = 4,
*,
key: PRNGKeyArray,
dtype = None,
)
| Parameter | Description |
|---|---|
vocab_size |
Vocabulary size for token embedding and LM head. |
d_model |
Model dimensionality. |
d_state |
Latent state size N. Default 16. |
d_conv |
Depthwise convolution kernel size. Default 4. |
expand |
Expansion factor d_inner = expand * d_model. Default 2. |
n_layers |
Number of residual Mamba blocks. Default 4. |
__call__¶
model(x: Array) -> Array # (seq_len,) int tokens -> (seq_len, vocab_size) logits
Example¶
import jax
import jax.numpy as jnp
from noxton.nn import Mamba
key = jax.random.PRNGKey(0)
model = Mamba(
vocab_size=256,
d_model=128,
n_layers=4,
key=key,
)
tokens = jnp.array([1, 5, 42, 7, 100]) # (seq_len,)
logits = model(tokens)
# logits.shape -> (5, 256)