API Reference
This page provides detailed documentation for the SPURS package. SPURS offers APIs for protein stability prediction and functional site identification.
Getting Started
Model Loading and Inference
- spurs.inference.get_SPURS(ckpt_path: str, device: str = 'cpu')[source]
Load a SPURS model from a local checkpoint file.
This function loads a SPURS model and its configuration from a local checkpoint directory. The checkpoint should contain both the model weights and the configuration file.
- Parameters
ckpt_path (str) – Path to the checkpoint directory containing .hydra/config.yaml and checkpoints/best.ckpt
device (str, optional) – Device to load the model on. Defaults to ‘cuda’ if available, else ‘cpu’.
- Returns
A tuple containing: - model (torch.nn.Module): The loaded SPURS model - cfg (OmegaConf): The model configuration
- Return type
tuple
Example
>>> model, cfg = get_SPURS("path/to/checkpoint")
- spurs.inference.get_SPURS_from_hub(repo_id: str = 'cyclization9/SPURS', device: str = 'cpu')[source]
Load a pre-trained SPURS model directly from the Hugging Face model hub.
This function downloads and loads a pre-trained SPURS model and its configuration from the Hugging Face model hub. This is the recommended way to get started with SPURS.
- Parameters
repo_id (str, optional) – Hugging Face model repository ID. Defaults to “cyclization9/SPURS”.
device (str, optional) – Device to load the model on. Defaults to ‘cuda’ if available, else ‘cpu’.
- Returns
- A tuple containing:
model (torch.nn.Module): The loaded SPURS model
cfg (OmegaConf): The model configuration
- Return type
tuple
Example
>>> model, cfg = get_SPURS_from_hub()
- spurs.inference.get_SPURS_multi_from_hub(repo_id: str = 'cyclization9/SPURS', device: str = 'cpu')[source]
Load a pre-trained SPURS multi-mutation model from the Hugging Face model hub.
This function downloads and loads a pre-trained SPURS model specialized for predicting the effects of multiple mutations simultaneously.
- Parameters
repo_id (str, optional) – Hugging Face model repository ID. Defaults to “cyclization9/SPURS”.
device (str, optional) – Device to load the model on. Defaults to ‘cuda’ if available, else ‘cpu’.
- Returns
- A tuple containing:
model (torch.nn.Module): The loaded SPURS multi-mutation model
cfg (OmegaConf): The model configuration
- Return type
tuple
Example
>>> model, cfg = get_SPURS_multi_from_hub()
- spurs.inference.parse_pdb(pdb_path: str, pdb_name: str, chain: str, cfg, device: str = 'cpu')[source]
Parse a PDB file and prepare it for SPURS model input.
This function processes a PDB file and converts it into the format required by SPURS models. It handles coordinate extraction, feature generation, and device placement.
- Parameters
pdb_path (str) – Path to the PDB file
pdb_name (str) – Name of the protein
chain (str) – Chain identifier to analyze
cfg – Model configuration containing alphabet settings
device (str, optional) – Device to place the tensors on. Defaults to ‘cuda’ if available, else ‘cpu’.
- Returns
- A dictionary containing processed PDB data including:
coordinates
sequence information
features required by the model
All tensors are placed on the specified device.
- Return type
dict
- spurs.inference.parse_pdb_for_mutation(mut_info_list)[source]
Parse mutation information into tensor format required by model.
This function converts a list of mutation specifications into the tensor format required by SPURS models for mutation analysis.
- Parameters
mut_info_list (List[List[str]]) – List of lists containing mutation strings. Each inner list represents a set of mutations to analyze together. Format: [[‘V2C’,’P3T’], [‘W1A’,’V2Y’]] where each mutation string is formatted as ‘OriginalAAPositionNewAA’.
- Returns
- A tuple containing:
mut_ids (torch.Tensor): Tensor of mutation positions (0-indexed)
append_tensors (torch.Tensor): Tensor of amino acid indices for wild-type and mutant residues
- Return type
tuple
Example
>>> mut_ids, append_tensors = parse_pdb_for_mutation([['V2C', 'P3T']]) >>> print(mut_ids) # Shows positions 1,2 (0-indexed) >>> print(append_tensors) # Shows amino acid indices for V->C, P->T
Basic usage:
from spurs.inference import get_SPURS_from_hub, parse_pdb
# Load pre-trained model
model, cfg = get_SPURS_from_hub()
# Parse PDB file
pdb = parse_pdb(
pdb_path='path/to/protein.pdb',
pdb_name='PROTEIN_NAME',
chain='A',
cfg=cfg
)
# Make predictions
ddg = model(pdb, return_logist=True)
Functional Site Identification
This module implements functionality for identifying and analyzing functional sites in proteins using SPURS model predictions and sigmoid-based regression analysis. It provides tools for normalizing mutation effect scores, fitting sigmoid functions to the data, and visualizing the results to identify functionally important positions in proteins.
- spurs.functional_site_annotation.normalize(data, lower=- 4.36, upper=1.21, new_min=0, new_max=1)[source]
Normalize mutation effect scores to a specified range.
- Parameters
data – Raw mutation effect scores
lower – Lower bound for clipping (default: -4.36, 0.1 percentile of ddG values)
upper – Upper bound for clipping (default: 1.21, 99.9 percentile of ddG values)
new_min – Minimum value in normalized range (default: 0)
new_max – Maximum value in normalized range (default: 1)
- Returns
Normalized and clipped data scaled to [new_min, new_max]
- spurs.functional_site_annotation.sigmoid_function(x, xmid, scal)[source]
Compute sigmoid function values for given parameters.
- Parameters
x – Input values
xmid – Midpoint parameter of the sigmoid
scal – Scale parameter controlling steepness
- Returns
Sigmoid function values
- class spurs.functional_site_annotation.SigmoidRegressor(*args: Any, **kwargs: Any)[source]
Bases:
sklearn.base.BaseEstimator,sklearn.base.RegressorMixinCustom scikit-learn compatible estimator for fitting sigmoid functions to mutation effect data. Used to identify functional sites by finding positions where mutations have non-linear effects.
The sigmoid function is fit using weighted least squares optimization with bounds derived from empirical protein stability data.
- __init__(xmid_initial=0.0, scal_initial=1.0)[source]
Initialize sigmoid regressor with starting parameters.
- Parameters
xmid_initial – Initial guess for sigmoid midpoint
scal_initial – Initial guess for sigmoid scale factor
- spurs.functional_site_annotation.get_sigmoid_results(mask_results, ddg)[source]
Fit sigmoid function to mutation effect data for functional site identification.
- Parameters
mask_results – Mutation effect predictions from ESM
ddg – Predicted stability measurements
- Returns
(normalized_ddg, mask_results, None, sigmoid_parameters)
- Return type
tuple
- spurs.functional_site_annotation.inference_wt_seq(sequence: str, indices: list, batch_converter, model, device: torch.device, alphabet)[source]
Perform inference on wild-type sequence using the ESM model.
- Parameters
sequence – Protein sequence to analyze
indices – List of positions to analyze
batch_converter – ESM batch converter
model – ESM model
device – Computation device
alphabet – ESM alphabet for token mapping
- Returns
Logits for each position and possible amino acid substitution
- spurs.functional_site_annotation.get_wt_aa_logit_differences(sequence: str, mut_indices: list, batch_converter, model, device: torch.device, alphabet, shift: int = 1)[source]
Calculate differences between wild-type and mutant amino acid logits.
This function helps identify positions where mutations would have the strongest effect by comparing model predictions for all possible amino acid substitutions against the wild-type amino acid at each position.
- Parameters
sequence – Protein sequence
mut_indices – Positions to analyze
batch_converter – ESM batch converter
model – ESM model
device – Computation device
alphabet – ESM alphabet
shift – Index adjustment (default: 1)
- Returns
Tensor of logit differences between wild-type and all possible mutations
- spurs.functional_site_annotation.plot_sigmoid_results(result, shift=1, vcenter=0, highlight_positions=[])[source]
Visualize functional site prediction results using a scatter plot.
Creates a plot showing normalized mutation effects across protein positions, with optional highlighting of specific positions of interest. The color scheme uses a diverging colormap centered at vcenter.
- Parameters
result – Tuple containing (normalized_ddg, mask_results, None, sigmoid_parameters)
shift – Position numbering offset (default: 1)
vcenter – Center value for color normalization (default: 0)
highlight_positions – List of positions to highlight with markers
- Returns
Z-score normalized mutation effects
- Return type
normalized_data
Basic usage:
from spurs.functional_site_annotation import get_wt_aa_logit_differences
import torch
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sequence = "YOUR_PROTEIN_SEQUENCE"
mut_indices = list(range(1, len(sequence) + 1))
# Get predictions
results = get_wt_aa_logit_differences(
sequence,
mut_indices,
batch_converter,
model,
device,
alphabet
)
Model Architecture
SPURS Models
- class spurs.models.stability.spurs.SPURSConfig(encoder: spurs.models.stability.protein_mpnn.ProteinMPNNConfig = ProteinMPNNConfig(d_model=128, d_node_feats=128, d_edge_feats=128, k_neighbors=48, augment_eps=0.0, n_enc_layers=3, dropout=0.1, tune=False, use_input_decoding_order=False, n_vocab=22, n_dec_layers=3, random_decoding_order=True, nar=True, crf=False, use_esm_alphabet=False), adapter_layer_indices: List = <factory>, separate_loss: bool = True, name: str = 'esm2_t33_650M_UR50D', dropout: float = 0.1, mlp: spurs.models.stability.mlp.MLPConfig = MLPConfig(num_layers=3, input_dim=2602, hidden_dim=512, output_dim=1, dropout=0.1, ckpt_path='', append_tensors=True, flat_dim=-1))[source]
Bases:
object- encoder: spurs.models.stability.protein_mpnn.ProteinMPNNConfig = ProteinMPNNConfig(d_model=128, d_node_feats=128, d_edge_feats=128, k_neighbors=48, augment_eps=0.0, n_enc_layers=3, dropout=0.1, tune=False, use_input_decoding_order=False, n_vocab=22, n_dec_layers=3, random_decoding_order=True, nar=True, crf=False, use_esm_alphabet=False)
- adapter_layer_indices: List
- separate_loss: bool = True
- name: str = 'esm2_t33_650M_UR50D'
- dropout: float = 0.1
- mlp: spurs.models.stability.mlp.MLPConfig = MLPConfig(num_layers=3, input_dim=2602, hidden_dim=512, output_dim=1, dropout=0.1, ckpt_path='', append_tensors=True, flat_dim=-1)
- __init__(encoder: spurs.models.stability.protein_mpnn.ProteinMPNNConfig = ProteinMPNNConfig(d_model=128, d_node_feats=128, d_edge_feats=128, k_neighbors=48, augment_eps=0.0, n_enc_layers=3, dropout=0.1, tune=False, use_input_decoding_order=False, n_vocab=22, n_dec_layers=3, random_decoding_order=True, nar=True, crf=False, use_esm_alphabet=False), adapter_layer_indices: List = <factory>, separate_loss: bool = True, name: str = 'esm2_t33_650M_UR50D', dropout: float = 0.1, mlp: spurs.models.stability.mlp.MLPConfig = MLPConfig(num_layers=3, input_dim=2602, hidden_dim=512, output_dim=1, dropout=0.1, ckpt_path='', append_tensors=True, flat_dim=-1)) None
- class spurs.models.stability.spurs.SPURS(cfg)[source]
Bases:
spurs.models.stability.basemodel.BaseModelSPURS (Structure-based Protein Understanding and Recognition System) model for protein stability prediction.
This model combines protein structure information (from ProteinMPNN) and sequence information (from ESM2) to predict protein stability changes. The architecture consists of three main components:
Encoder (ProteinMPNN): Processes protein structure information
Decoder (ESM2): Processes sequence information with structural prior
MLP: Final stability prediction layer
The model uses a structural adapter to effectively combine structural and sequence information, allowing for more accurate stability predictions.
- Parameters
cfg (SPURSConfig) – Configuration object containing model parameters - encoder: ProteinMPNN configuration - adapter_layer_indices: List of ESM2 layer indices to adapt - name: ESM2 model name - dropout: Dropout rate - mlp: MLP configuration
- __init__(cfg) None[source]
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(batch, **kwargs)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- forward_encoder(batch)[source]
Forward pass through the encoder (ProteinMPNN) component of the SPURS model.
This function processes the input protein structure data (X, S, mask, chain_M, residue_idx, chain_encoding_all, randn_1) and returns the encoded features from the ProteinMPNN encoder.
- Parameters
batch (dict) – Input batch containing protein structure data - X: Protein structure coordinates - S: Protein structure mask - mask: Mask indicating valid positions - chain_M: Chain mask - residue_idx: Residue indices - chain_encoding_all: Chain encoding
- Returns
Encoded features from the ProteinMPNN encoder
- Return type
torch.Tensor
- training: bool
ProteinMPNN Model
- class spurs.models.stability.mpnn.MPNN(cfg)[source]
Bases:
spurs.models.stability.basemodel.BaseModelProteinMPNN-based model for protein stability prediction.
This model uses the ProteinMPNN architecture to process protein structure information and predict stability changes. Key features:
Structure-aware message passing neural network
Support for both fixed and trainable encoders
MLP-based regression head for stability predictions
- Parameters
cfg (MPNN) – Configuration object containing: - encoder: ProteinMPNN configuration - name: Model name - dropout: Dropout rate - mlp: MLP configuration
- __init__(cfg) None[source]
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- training: bool
- forward(batch, **kwargs)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
ESM Model
- class spurs.models.stability.esm.ESM(cfg)[source]
Bases:
spurs.models.stability.basemodel.BaseModelESM-based model for protein stability prediction.
This model leverages the ESM2 language model to understand protein sequences and predict stability changes. Features include:
Pre-trained ESM2 language model for sequence encoding
MLP-based regression head for stability predictions
- Parameters
cfg (ESM) – Configuration object containing: - name: ESM2 model name (e.g. ‘esm2_t33_650M_UR50D’) - dropout: Dropout rate - mlp: MLP configuration
- __init__(cfg) None[source]
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- training: bool
- forward(batch, **kwargs)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Transfer Model
- class spurs.models.stability.org_transfer_model.TransferModelConfig(encoder: spurs.models.stability.protein_mpnn.ProteinMPNNConfig = ProteinMPNNConfig(d_model=128, d_node_feats=128, d_edge_feats=128, k_neighbors=48, augment_eps=0.0, n_enc_layers=3, dropout=0.1, tune=False, use_input_decoding_order=False, n_vocab=22, n_dec_layers=3, random_decoding_order=True, nar=True, crf=False, use_esm_alphabet=False))[source]
Bases:
object- encoder: spurs.models.stability.protein_mpnn.ProteinMPNNConfig = ProteinMPNNConfig(d_model=128, d_node_feats=128, d_edge_feats=128, k_neighbors=48, augment_eps=0.0, n_enc_layers=3, dropout=0.1, tune=False, use_input_decoding_order=False, n_vocab=22, n_dec_layers=3, random_decoding_order=True, nar=True, crf=False, use_esm_alphabet=False)
- __init__(encoder: spurs.models.stability.protein_mpnn.ProteinMPNNConfig = ProteinMPNNConfig(d_model=128, d_node_feats=128, d_edge_feats=128, k_neighbors=48, augment_eps=0.0, n_enc_layers=3, dropout=0.1, tune=False, use_input_decoding_order=False, n_vocab=22, n_dec_layers=3, random_decoding_order=True, nar=True, crf=False, use_esm_alphabet=False)) None
- class spurs.models.stability.org_transfer_model.TransferModel(cfg)[source]
Bases:
spurs.models.stability.basemodel.BaseModelThermoMPNN https://github.com/Kuhlman-Lab/ThermoMPNN
This model combines ProteinMPNN’s structure embeddings with a light attention mechanism and MLP layers for stability prediction. Features include:
Pre-trained ProteinMPNN encoder for structure understanding
Light attention mechanism for sequence context
MLP layers for final prediction
Support for mutation effect prediction
- Parameters
cfg (TransferModelConfig) – Configuration object containing: - encoder: ProteinMPNN configuration
- __init__(cfg)[source]
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(batch, tied_feat=True)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
Data Modules
- class spurs.datamodules.datasets.megascale.MegaScaleDataset(reduce: str = '', split: str = 'train', single_mut: bool = False, mut_seq: bool = False, std_ratio: float = 0.75, loss_ratio: float = 1.0, train_ratio: float = 0.05)[source]
Bases:
torch.utils.data.dataset.DatasetA dataset class for handling protein stability data from Tsuboyama et al. 2023.
This dataset contains a large collection of protein stability measurements, including single-site mutations and their effects on protein stability (ddG values). The dataset supports train/val/test splits and includes structure information from AlphaFold.
- Parameters
reduce (str, optional) – Reduction strategy for the dataset. Defaults to ‘’.
split (str, optional) – Dataset split (‘train’, ‘val’, ‘test’). Defaults to ‘train’.
single_mut (bool, optional) – Whether to handle only single mutations. Defaults to False.
mut_seq (bool, optional) – Whether to include mutated sequences. Defaults to False.
std_ratio (float, optional) – Will be removed in the future
loss_ratio (float, optional) – Will be removed in the future
train_ratio (float, optional) – Will be removed in the future
- class spurs.datamodules.datasets.megascale.Featurizer(alphabet: spurs.datamodules.datasets.data_utils.Alphabet, to_pifold_format=False, coord_nan_to_zero=True, atoms=('N', 'CA', 'C', 'O'), single_mut=False, mut_seq=False)[source]
Bases:
object
- class spurs.datamodules.datasets.megascale.MegaScaleTestDatasets[source]
Bases:
torch.utils.data.dataset.DatasetA comprehensive test dataset combining multiple protein stability benchmarks.
This dataset aggregates various test sets including: - MegaScale test set - FireProt homologue-free set - SSYM direct and inverse sets - S669 dataset - ddg datasets (S461, S783, S8754, S2648) - dtm datasets (S571, S4346)
- class spurs.datamodules.datasets.megascale_multi.MegaScaleDoubleDataset(reduce: str = '', split: str = 'train')[source]
Bases:
torch.utils.data.dataset.DatasetA dataset class for handling double mutations from the MegaScale dataset.
This dataset specifically handles double mutations from Tsuboyama et al. 2023, where each data point contains two mutations and their combined effect on protein stability. The dataset supports train/val/test splits and includes structure information from AlphaFold.
- Parameters
reduce (str, optional) – Reduction strategy for the dataset. Defaults to ‘’.
split (str, optional) – Dataset split (‘train’, ‘val’, ‘test’). Defaults to ‘train’.
- class spurs.datamodules.datasets.megascale_multi.Featurizer(alphabet: spurs.datamodules.datasets.data_utils.Alphabet, to_pifold_format=False, coord_nan_to_zero=True, atoms=('N', 'CA', 'C', 'O'), single_mut=False, mut_seq=False)[source]
Bases:
object
- class spurs.datamodules.datasets.domainome.domainome(pdb_dir, csv_fname, dataset_name, stage='full', mut_seq=False, train_size=1)[source]
Bases:
torch.utils.data.dataset.DatasetA dataset class for handling domain-level protein stability data.
This dataset provides domain-specific stability measurements and is used for analyzing and predicting stability changes at the protein domain level. It includes domain-level annotations and corresponding stability measurements.
- Parameters
pdb_dir (str) – Directory containing PDB structure files.
csv_fname (str) – Path to the CSV file containing domain mutation data.
dataset_name (str) – Name of the dataset.
stage (str, optional) – Dataset stage (‘full’, ‘train’, ‘test’). Defaults to ‘full’.
mut_seq (bool, optional) – Whether to include mutated sequences. Defaults to False.
train_size (float, optional) – Training data size ratio. Defaults to 1.
- class spurs.datamodules.datasets.ddgbench.ddgBenchDataset(pdb_dir, csv_fname, dataset_name)[source]
Bases:
torch.utils.data.dataset.DatasetA dataset class for handling standard protein stability benchmark datasets.
This class is used for loading and processing several benchmark datasets: - SSYM-dir: Direct SSYM dataset for stability measurements - SSYM-inv: Inverse SSYM dataset for stability measurements - S669: A curated set of 669 stability measurements
- Parameters
pdb_dir (str) – Directory containing PDB structure files.
csv_fname (str) – Path to the CSV file containing mutation data.
dataset_name (str) – Name of the dataset (‘ssym_dir’, ‘ssym_inv’, or ‘S669’).
- class spurs.datamodules.datasets.ddggeo.ddgGeo(pdb_dir, csv_fname, dataset_name, stage='full', mut_seq=False, train_size=1)[source]
Bases:
torch.utils.data.dataset.DatasetA dataset class for handling geometric-aware protein stability datasets. https://github.com/Gonglab-THU/SPIRED-Fitness
This class handles multiple datasets that incorporate geometric features: - S461: 461 mutations with structural information - S783: 783 mutations with geometric features - S8754: Large-scale dataset with 8,754 mutations - S2648: Test set with 2,648 mutations - S571: Temperature-based stability dataset - S4346: Extended temperature mutation dataset
- Parameters
pdb_dir (str) – Directory containing PDB structure files.
csv_fname (str) – Path to the CSV file containing mutation data.
dataset_name (str) – Name of the dataset (e.g., ‘S461’, ‘S783’, etc.).
stage (str, optional) – Dataset stage (‘full’, ‘train’, ‘test’). Defaults to ‘full’.
mut_seq (bool, optional) – Whether to include mutated sequences. Defaults to False.
train_size (float, optional) – Training data size ratio. Defaults to 1.
- class spurs.datamodules.datasets.fireport.FireProtDataset(split)[source]
Bases:
torch.utils.data.dataset.Dataset
Note
ddgBenchDataset is used for ssym-dir, ssym-inv, and S669 datasets
ddgGeo is used for S461, S783, S8754, S2648, S571, and S4346 datasets