Source code for spurs.datamodules.datasets.ddgbench

import torch
from torch.utils.data import ConcatDataset
import pandas as pd
import numpy as np
import pickle
import os
from Bio import pairwise2
from math import isnan
from tqdm import tqdm
from dataclasses import dataclass
from typing import Optional
from .utils import fermi_transform,tied_featurize,get_pdb,parse_pdb,alt_parse_PDB
import lmdb
import glob
from spurs import utils
log = utils.get_logger(__name__)
import json
from .LMDBDataset import LMDBDataset
from .batch import CoordBatchConverter
from .data_utils import Alphabet
from collections import defaultdict
ALPAHBET = 'ACDEFGHIKLMNPQRSTVWYX'

[docs]class ddgBenchDataset(torch.utils.data.Dataset): """A dataset class for handling standard protein stability benchmark datasets. This class is used for loading and processing several benchmark datasets: - SSYM-dir: Direct SSYM dataset for stability measurements - SSYM-inv: Inverse SSYM dataset for stability measurements - S669: A curated set of 669 stability measurements Args: pdb_dir (str): Directory containing PDB structure files. csv_fname (str): Path to the CSV file containing mutation data. dataset_name (str): Name of the dataset ('ssym_dir', 'ssym_inv', or 'S669'). """
[docs] def __init__(self, pdb_dir, csv_fname, dataset_name): self.pdb_dir = pdb_dir df = pd.read_csv(csv_fname) self.df = df self.dataset_name = dataset_name self.wt_seqs = {} self.mut_rows = {} self.wt_names = df.PDB.unique() for wt_name in self.wt_names: wt_name_query = wt_name wt_name = wt_name[:-1] self.mut_rows[wt_name] = df.query('PDB == @wt_name_query').reset_index(drop=True) if 'S669' not in self.pdb_dir: self.wt_seqs[wt_name] = self.mut_rows[wt_name].SEQ[0] self.structure_path = pdb_dir self.json_dataset = defaultdict(lambda: defaultdict(lambda: -1))
def __len__(self): return len(self.wt_names)
[docs] def __getitem__(self, index): """Batch retrieval fxn - each batch is a single protein""" wt_name = self.wt_names[index] chain = [wt_name[-1]] wt_name = wt_name.split(".pdb")[0][:-1] mut_data = self.mut_rows[wt_name] # modified PDB parser returns list of residue IDs so we can align them easier if isinstance(self.json_dataset[wt_name][chain[0]],int): pdb = alt_parse_PDB(os.path.join(self.pdb_dir,wt_name+".pdb"),chain) self.json_dataset[wt_name][chain[0]] = pdb pdb = self.json_dataset[wt_name][chain[0]] resn_list = pdb[0]["resn_list"] protein = get_pdb(pdb[0], wt_name, wt_name, check_assert=False) for i, row in mut_data.iterrows(): mut_info = row.MUT wtAA, mutAA = mut_info[0], mut_info[-1] try: pos = mut_info[1:-1] pdb_idx = resn_list.index(pos) except ValueError: # skip positions with insertion codes for now - hard to parse continue try: assert pdb[0]['seq'][pdb_idx] == wtAA except AssertionError: # contingency for mis-alignments # if gaps are present, add these to idx (+10 to get any around the mutation site, kinda a hack) if 'S669' in self.pdb_dir: gaps = [g for g in pdb[0]['seq'] if g == '-'] else: gaps = [g for g in pdb[0]['seq'][:pdb_idx + 10] if g == '-'] if len(gaps) > 0: pdb_idx += len(gaps) else: pdb_idx += 1 if pdb_idx is None: continue assert pdb[0]['seq'][pdb_idx] == wtAA wt = wtAA mut = mutAA ddG = None if row.DDG is None or isnan(row.DDG) else torch.tensor([row.DDG * -1.], dtype=torch.float32) wt_onehot = torch.zeros((21)) wt_onehot[ALPAHBET.index(wt)] = 1 mt_onehot = torch.zeros((21)) mt_onehot[ALPAHBET.index(mut)] = 1 append_tensor = torch.cat([wt_onehot,mt_onehot]) append_tensor = append_tensor.float() protein['mut_ids'].append(pdb_idx) protein['ddG'].append(ddG) protein['append_tensors'].append(append_tensor) protein['ddG'] = torch.stack(protein['ddG']) protein['append_tensors'] = torch.stack(protein['append_tensors']) protein['pdb_path'] = self.structure_path protein['dataset'] = self.dataset_name return protein