Source code for spurs.models.stability.spurs

from dataclasses import dataclass, field
from typing import List

import torch
from spurs.models import register_model
from spurs.models.stability.basemodel import BaseModel
from spurs.models.stability.protein_mpnn import ProteinMPNNConfig

from spurs.models.stability.modules.esm2 import ESM2
from spurs import utils
from spurs.models.stability.org_transfer_model import get_protein_mpnn
import torch.nn.functional as F
from spurs.models.stability.modules.esm2_adapter import ESM2WithStructuralAdatper
import torch.nn as nn
log = utils.get_logger(__name__)
from .mlp import MLP, MLPConfig


[docs]@dataclass class SPURSConfig: encoder: ProteinMPNNConfig = field(default=ProteinMPNNConfig()) adapter_layer_indices: List = field(default_factory=lambda: [-1, ]) separate_loss: bool = True name: str = 'esm2_t33_650M_UR50D' dropout: float = 0.1 mlp: MLPConfig = field(default=MLPConfig())
[docs]@register_model('spurs') class SPURS(BaseModel): """ SPURS (Structure-based Protein Understanding and Recognition System) model for protein stability prediction. This model combines protein structure information (from ProteinMPNN) and sequence information (from ESM2) to predict protein stability changes. The architecture consists of three main components: 1. Encoder (ProteinMPNN): Processes protein structure information 2. Decoder (ESM2): Processes sequence information with structural prior 3. MLP: Final stability prediction layer The model uses a structural adapter to effectively combine structural and sequence information, allowing for more accurate stability predictions. Args: cfg (SPURSConfig): Configuration object containing model parameters - encoder: ProteinMPNN configuration - adapter_layer_indices: List of ESM2 layer indices to adapt - name: ESM2 model name - dropout: Dropout rate - mlp: MLP configuration """ _default_cfg = SPURSConfig()
[docs] def __init__(self, cfg) -> None: super().__init__(cfg) self.tune = cfg.encoder.tune self.use_input_decoding_order = cfg.encoder.use_input_decoding_order self.encoder = get_protein_mpnn(tune=cfg.encoder.tune) self.cfg.encoder.d_model = self.cfg.mlp.input_dim self.decoder = ESM2WithStructuralAdatper.from_pretrained(args=self.cfg, name=self.cfg.name) self.input_dim = self.cfg.mlp.input_dim self.cfg.mlp.input_dim = self.cfg.mlp.input_dim+1280 if self.cfg.mlp.flat_dim < 0 else self.cfg.mlp.input_dim+self.cfg.mlp.flat_dim self.mlp = MLP(self.cfg.mlp) if self.cfg.mlp.flat_dim > 0: self.flat_layers = nn.Linear(1280, self.cfg.mlp.flat_dim) self.dp = nn.Dropout(self.cfg.dropout) self.padding_idx = self.decoder.padding_idx self.mask_idx = self.decoder.mask_idx self.cls_idx = self.decoder.cls_idx self.eos_idx = self.decoder.eos_idx
[docs] def forward(self, batch, **kwargs): if not self.tune: with torch.no_grad(): batch['feats'] = self.forward_encoder(batch) else: batch['feats'] = self.forward_encoder(batch) batch['feats'] = batch['feats'][:,:,:self.input_dim] encoder_out = {'feats':F.pad(batch['feats'], (0, 0, 1, 1))} init_pred = batch['tokens'] decoder_out = self.decoder( tokens=init_pred, encoder_out=encoder_out, ) representation = decoder_out['representations'][-1] if self.cfg.mlp.flat_dim > 0: flat_representation = self.flat_layers(representation) flat_representation = self.dp(flat_representation) flat_representation = F.gelu(flat_representation) representation = flat_representation representation = torch.cat([representation, encoder_out['feats']], dim=-1) if 'return_logist' in kwargs and kwargs['return_logist']: wt=batch['seq'] shifted_mut_ids = torch.repeat_interleave(torch.arange(1, 1+len(wt)), 20) muted_id_representation = representation[:, shifted_mut_ids.long()] batch['muted_id_representation'] = muted_id_representation pre_output = self.mlp(batch) ddg_out = pre_output.squeeze() mt_aa = torch.arange(20).repeat(len(wt)) ALPHABET = 'ACDEFGHIKLMNPQRSTVWY' wt_aa = torch.repeat_interleave(torch.tensor([ALPHABET.index(s) for s in wt]), 20) num_classes = 21 mt_aa_one_hot = F.one_hot(mt_aa, num_classes=num_classes).to(representation.device) wt_aa_one_hot = F.one_hot(wt_aa, num_classes=num_classes).to(representation.device) ddg_out_aa = (ddg_out*mt_aa_one_hot).sum(-1) ddg_out_wt_aa = (ddg_out*wt_aa_one_hot).sum(-1) ddg = ddg_out_aa - ddg_out_wt_aa return ddg.reshape(-1,20) batch['mut_ids'] = batch['mut_ids'] if isinstance(batch['mut_ids'], torch.Tensor) else torch.tensor(batch['mut_ids']) shifed_mut_ids = batch['mut_ids'].to(representation.device)+1 muted_id_representation = representation[:, shifed_mut_ids.long()] # [B, H] batch['muted_id_representation'] = muted_id_representation pre_output = self.mlp(batch) ddg_out = pre_output.squeeze() ddg_out_aa = (ddg_out*batch['append_tensors'][:,21:]).sum(-1) ddg_out_wt_aa = (ddg_out*batch['append_tensors'][:,:21]).sum(-1) ddg = ddg_out_aa - ddg_out_wt_aa ddg[torch.isnan(ddg)] = 10000 return ddg
[docs] def forward_encoder(self,batch): """ Forward pass through the encoder (ProteinMPNN) component of the SPURS model. This function processes the input protein structure data (X, S, mask, chain_M, residue_idx, chain_encoding_all, randn_1) and returns the encoded features from the ProteinMPNN encoder. Args: batch (dict): Input batch containing protein structure data - X: Protein structure coordinates - S: Protein structure mask - mask: Mask indicating valid positions - chain_M: Chain mask - residue_idx: Residue indices - chain_encoding_all: Chain encoding Returns: torch.Tensor: Encoded features from the ProteinMPNN encoder """ X = batch['X'] S = batch['S'] mask = batch['mask'] chain_M = batch['chain_M'] chain_M_chain_M_pos = batch['chain_M_chain_M_pos'] residue_idx = batch['residue_idx'] chain_encoding_all = batch['chain_encoding_all'] randn_1 = batch['randn_1'] all_mpnn_hid, mpnn_embed, _ = self.encoder(X, S, mask, chain_M, residue_idx, chain_encoding_all, None,self.use_input_decoding_order) all_mpnn_hid = torch.cat([all_mpnn_hid[0],mpnn_embed,all_mpnn_hid[1],all_mpnn_hid[2]],dim=-1) return all_mpnn_hid