Embedding¶
from noxton.nn import EmbeddingWithPadding, EmbeddingBag
EmbeddingWithPadding¶
Embedding table that zeros out embeddings for a configurable padding index. Wraps eqx.nn.Embedding and multiplies every looked-up embedding by a binary mask that is 0 wherever the input equals padding_idx.
Constructor¶
EmbeddingWithPadding(
num_embeddings: int,
embedding_dim: int,
padding_idx: int = 0,
*,
key: PRNGKeyArray,
dtype = None,
)
__call__¶
emb(x: Array) -> Array # (seq_len,) -> (seq_len, embedding_dim)
Example¶
import jax
import jax.numpy as jnp
from noxton.nn import EmbeddingWithPadding
key = jax.random.PRNGKey(0)
emb = EmbeddingWithPadding(num_embeddings=10, embedding_dim=4, key=key)
ids = jnp.array([0, 1, 2, 0]) # 0 is padding
out = emb(ids)
# out.shape -> (4, 4)
# out[0] and out[3] are zero vectors
EmbeddingBag¶
Sums a bag of token embeddings into a single vector. Looks up each index using EmbeddingWithPadding (padding tokens contribute zero), then reduces by summing. Analogous to torch.nn.EmbeddingBag(mode="sum").
Constructor¶
EmbeddingBag(
num_embeddings: int,
embedding_dim: int,
padding_idx: int = 0,
*,
key: PRNGKeyArray,
dtype = None,
)
__call__¶
bag(x: Array) -> Array # (bag_size,) -> (embedding_dim,)
Example¶
import jax
import jax.numpy as jnp
from noxton.nn import EmbeddingBag
key = jax.random.PRNGKey(0)
bag = EmbeddingBag(num_embeddings=10, embedding_dim=4, key=key)
ids = jnp.array([1, 3, 5])
out = bag(ids)
# out.shape -> (4,)