ESM (Protein Language Models)¶
from noxton.models import ESM3, ESMC
Noxton includes two protein language model architectures from EvolutionaryScale:
- ESMC — a compact, efficient protein sequence model
- ESM3 — a multimodal model that jointly reasons over protein sequence, structure, secondary structure (SS8), solvent accessibility (SASA), and function annotations
Both models require the optional huggingface-hub dependency and will download weights from HuggingFace Hub on first use.
Note
Install the hub extra for ESM support:
bash
pip install "noxton[hub]"
ESMC¶
A Transformer-based protein sequence model. Given a tokenised amino-acid sequence it returns per-residue logits over the vocabulary.
from_pretrained¶
ESMC.from_pretrained(
model: Literal["esmc_300m", "esmc_600m"],
key: PRNGKeyArray | None = None,
dtype = None,
) -> ESMC
Weights are downloaded from HuggingFace Hub and cached locally.
model |
Parameters | HuggingFace repo |
|---|---|---|
esmc_300m |
300M | EvolutionaryScale/esmc-300m-2024-12 |
esmc_600m |
600M | EvolutionaryScale/esmc-600m-2024-12 |
__call__¶
esmc(
sequence_tokens: Array, # (L,) int32
sequence_id: Array | None = None, # (L,) bool attention mask (True = real token)
) -> tuple[Array, Array, Array]
# (sequence_logits, last_hidden_state, all_hidden_states)
# sequence_logits: (L, 64)
# last_hidden_state: (L, d_model)
# all_hidden_states: (n_layers, L, d_model)
Vocabulary¶
SEQUENCE_VOCAB = ["<cls>", "<pad>", "<eos>", "<unk>",
"L", "A", "G", "V", "S", "E", "R", "T",
"I", "D", "P", "K", "Q", "N", "F", "Y",
"M", "H", "W", "C", "X", "B", "U", "Z",
"O", ".", "-", "|", "<mask>"]
Special token indices: BOS=0, PAD=1, EOS=2, MASK=32.
Example¶
import jax
import jax.numpy as jnp
from noxton.models import ESMC
model = ESMC.from_pretrained("esmc_300m", dtype=jnp.float16)
# Encode a short protein sequence (BOS + residues + EOS)
# Vocabulary indices: BOS=0, L=4, A=5, G=6, EOS=2
tokens = jnp.array([0, 4, 5, 6, 4, 5, 2], dtype=jnp.int32)
logits, embedding, hiddens = model(tokens)
# logits.shape -> (7, 64)
# embedding.shape -> (7, 960) [for esmc_300m]
ESM3¶
A multimodal protein language model that reasons jointly over sequence, structure, secondary structure, solvent accessibility, and function annotations (Hayes et al., 2024).
from_pretrained¶
ESM3.from_pretrained(
model: Literal["esm3_open"] = "esm3_open",
key: PRNGKeyArray | None = None,
dtype = None,
) -> ESM3
model |
Parameters | HuggingFace repo |
|---|---|---|
esm3_open |
1.4B | EvolutionaryScale/esm3-sm-open-v1 |
__call__¶
All inputs are optional. Any None input is replaced with an appropriate mask or padding token.
esm3(
sequence_tokens: Array | None = None, # (L,) int32
structure_tokens: Array | None = None, # (L,) int32
ss8_tokens: Array | None = None, # (L,) int32
sasa_tokens: Array | None = None, # (L,) int32
function_tokens: Array | None = None, # (L, 8) int32
residue_annotation_tokens: Array | None = None, # (L, 16) int32
average_plddt: Array | None = None, # scalar
per_res_plddt: Array | None = None, # (L,)
structure_coords: Array | None = None, # (L, 3, 3) N/CA/C coords
chain_id: Array | None = None, # (L,) int32
sequence_id: Array | None = None, # (L,) bool attention mask
) -> tuple[Array, Array, Array, Array, Array, Array, Array]
# (sequence_logits, structure_logits, secondary_structure_logits,
# sasa_logits, function_logits, residue_logits, embedding)
Example: sequence-only forward pass¶
import jax
import jax.numpy as jnp
from noxton.models import ESM3
model = ESM3.from_pretrained("esm3_open")
# Sequence tokens only — all other modalities are masked
tokens = jnp.array([0, 4, 5, 6, 4, 5, 2], dtype=jnp.int32)
outputs = model(sequence_tokens=tokens)
seq_logits = outputs[0]
# seq_logits.shape -> (7, vocab_size)