import torch
from torch.utils.data import ConcatDataset
import pandas as pd
import numpy as np
import pickle
import os
from Bio import pairwise2
from math import isnan
from tqdm import tqdm
from dataclasses import dataclass
from typing import Optional
from .utils import get_pdb,parse_pdb
import lmdb
import glob
from spurs import utils
log = utils.get_logger(__name__)
import json
from .LMDBDataset import LMDBDataset
from .batch import CoordBatchConverter
from .data_utils import Alphabet
from .fireport import FireProtDataset
from .ddgbench import ddgBenchDataset
from .ddggeo import ddgGeo
from .domainome import domainome
ALPAHBET = 'ACDEFGHIKLMNPQRSTVWYX'
from joblib import Parallel, delayed
from collections import defaultdict
import math
import random
[docs]class MegaScaleDataset(torch.utils.data.Dataset):
"""A dataset class for handling protein stability data from Tsuboyama et al. 2023.
This dataset contains a large collection of protein stability measurements, including
single-site mutations and their effects on protein stability (ddG values).
The dataset supports train/val/test splits and includes structure information from AlphaFold.
Args:
reduce (str, optional): Reduction strategy for the dataset. Defaults to ''.
split (str, optional): Dataset split ('train', 'val', 'test'). Defaults to 'train'.
single_mut (bool, optional): Whether to handle only single mutations. Defaults to False.
mut_seq (bool, optional): Whether to include mutated sequences. Defaults to False.
std_ratio (float, optional): Will be removed in the future
loss_ratio (float, optional): Will be removed in the future
train_ratio (float, optional): Will be removed in the future
"""
[docs] def __init__(self,
reduce: str = '',
split: str = 'train',
single_mut: bool = False,
mut_seq: bool=False,
std_ratio: float=0.75,
loss_ratio: float=1.,
train_ratio: float=0.05,
):
self.split = split
current_dir = os.path.dirname(os.path.abspath(__file__))
root_path = os.path.join(current_dir,'../../../')
fname = os.path.join(root_path,'data/dataset/megascale/Tsuboyama2023_Dataset2_Dataset3_20230416.csv')
df = pd.read_csv(fname, usecols=["ddG_ML", "mut_type", "WT_name", "aa_seq", "dG_ML"])
# remove unreliable data and more complicated mutations
df = df.loc[df.ddG_ML != '-', :].reset_index(drop=True)
df = df.loc[~df.mut_type.str.contains("ins") & ~df.mut_type.str.contains("del") & ~df.mut_type.str.contains(":"), :].reset_index(drop=True)
self.df = df
if self.split!='test':
mmseq_wt_search = os.path.join(root_path,'data/dataset/megascale/mmseq_mut_search_0.25.m8')
ret = []
with open(mmseq_wt_search, 'r') as f:
for line in f.readlines():
second_column_value = int(line.split("\t")[1])
ret.append(second_column_value)
# we dont want the rows in the ret
previous_len = len(df)
df = df.loc[~df.index.isin(ret), :].reset_index(drop=True)
cur_len = len(df)
log.info(f"removed {previous_len - cur_len} rows from the dataset")
# load splits produced by mmseqs clustering
with open( os.path.join(root_path,'data/dataset/megascale/mega_splits.pkl'), 'rb') as f:
splits = pickle.load(f) # this is a dict with keys train/val/test and items holding FULL PDB names for a given split
self.split_wt_names = {
"val": [],
"test": [],
"train": [],
"train_s669": [],
"all": [],
"cv_train_0": [],
"cv_train_1": [],
"cv_train_2": [],
"cv_train_3": [],
"cv_train_4": [],
"cv_val_0": [],
"cv_val_1": [],
"cv_val_2": [],
"cv_val_3": [],
"cv_val_4": [],
"cv_test_0": [],
"cv_test_1": [],
"cv_test_2": [],
"cv_test_3": [],
"cv_test_4": [],
}
self.wt_seqs = {}
self.mut_rows = {}
if self.split == 'all':
all_names = np.concatenate([splits['train'],splits['val'],splits['test']])
self.split_wt_names[self.split] = all_names
else:
if reduce == 'prot' and split == 'train':
n_prots_reduced = 58
self.split_wt_names[self.split] = np.random.choice(splits["train"], n_prots_reduced)
else:
self.split_wt_names[self.split] = splits[self.split]
self.wt_names = self.split_wt_names[self.split]
removed_wt_names = []
for wt_name in tqdm(self.wt_names):
wt_rows = df.query('WT_name == @wt_name and mut_type == "wt"').reset_index(drop=True)
self.mut_rows[wt_name] = df.query('WT_name == @wt_name and mut_type != "wt"').reset_index(drop=True)
if type(reduce) is float and self.split == 'train':
self.mut_rows[wt_name] = self.mut_rows[wt_name].sample(frac=float(reduce), replace=False)
if len(wt_rows) == 0:
# log.info(f'remove {wt_name}')
removed_wt_names.append(wt_name)
else:
self.wt_seqs[wt_name] = wt_rows.aa_seq[0]
previous_len = len(self.wt_names)
self.wt_names = list(set(self.wt_names) - set(removed_wt_names))
cur_len = len(self.wt_names)
log.info(f"removed {previous_len - cur_len} wt names from the dataset")
structure_path = os.path.join(root_path,'data/dataset/megascale/AlphaFold_model_PDBs/')
structure_path_json = os.path.join(structure_path,"../parsed_structure.json")
self.structure_path = structure_path
parse_pdb(structure_path,structure_path_json)
log.info("loading structure dataset")
with open(structure_path_json, 'r') as file:
self.json_dataset = json.load(file)
self.mut_seq = mut_seq
self.single_mut = single_mut
self.std_ratio = std_ratio
self.loss_ratio = loss_ratio
self.train_ratio = train_ratio
[docs] def cal_index2mt(self, index):
return self.index2mt[index]
def __len__(self):
return len(self.wt_names)
def _get_wt_item(self, index):
# wt_name, mut_seq, wt_seq = self.cal_index2mt(index)
wt_name = self.wt_names[index]
wt_seq = self.wt_seqs[wt_name]
mut_data = self.mut_rows[wt_name]
wt_name = wt_name.split(".pdb")[0].replace("|",":")
pdb = self.json_dataset[wt_name]
protein = get_pdb(pdb,wt_seq,wt_name)
if self.mut_seq:
protein['S'] = [protein['S']]
dataset_name = []
for i in range(len(mut_data)):
mut_seq = mut_data.iloc[i]
if self.mut_seq:
pdb["seq"] = mut_seq.aa_seq
mut_protein = get_pdb(pdb,mut_seq.aa_seq,wt_name)
protein['S'].append(mut_protein['S'])
if "ins" in mut_seq.mut_type or "del" in mut_seq.mut_type or ":" in mut_seq.mut_type:
return None
assert len(mut_seq.aa_seq) == len(wt_seq)
wt = mut_seq.mut_type[0]
mut = mut_seq.mut_type[-1]
mut_id = int(mut_seq.mut_type[1:-1]) - 1
assert wt_seq[mut_id] == wt
assert mut_seq.aa_seq[mut_id] == mut
if mut_seq.ddG_ML == '-':
return None
ddG = -torch.tensor([float(mut_seq.ddG_ML)], dtype=torch.float32)
wt_onehot = torch.zeros((21))
wt_onehot[ALPAHBET.index(wt)] = 1
mt_onehot = torch.zeros((21))
mt_onehot[ALPAHBET.index(mut)] = 1
append_tensor = torch.cat([wt_onehot,mt_onehot])
append_tensor = append_tensor.float()
# split = fing_split_in_proteingym(mut_seq.aa_seq)
protein['mut_ids'].append(mut_id)
protein['ddG'].append(ddG)
protein['append_tensors'].append(append_tensor)
protein['mut_seq'].append(mut_seq.aa_seq)
# dataset_name.append('megascale'+str(split))
protein['ddG'] = torch.stack(protein['ddG']).to(protein['X'].device,non_blocking=True)
protein['append_tensors'] = torch.stack(protein['append_tensors'])
# protein['dataset'] = dataset_name
protein['dataset'] = 'megascale'
protein['pdb_path'] = self.structure_path
# print(protein['X'].shape,protein['mask'].shape,protein['chain_M'].shape,protein['chain_M_chain_M_pos'].shape,protein['residue_idx'].shape,protein['chain_encoding_all'].shape,protein['randn_1'].shape)
if self.mut_seq:
protein['S'] = torch.cat(protein['S'],dim=0).clone()
protein['X'] = protein['X'].expand(len(protein['S']),-1,-1,-1).clone()
protein['mask'] = protein['mask'].expand(len(protein['S']),-1).clone()
protein['chain_M'] = protein['chain_M'].expand(len(protein['S']),-1).clone()
protein['chain_M_chain_M_pos'] = protein['chain_M_chain_M_pos'].expand(len(protein['S']),-1).clone()
protein['residue_idx'] = protein['residue_idx'].expand(len(protein['S']),-1).clone()
protein['chain_encoding_all'] = protein['chain_encoding_all'].expand(len(protein['S']),-1).clone()
protein['randn_1'] = protein['randn_1'].expand(len(protein['S']),-1).clone()
protein['std_ratio'] = self.std_ratio
protein['loss_ratio'] = self.loss_ratio
return protein
def __getitem__(self, index):
return self._get_wt_item(index)
[docs] def collect_func(self, batch):
return batch[0]
[docs]class Featurizer(object):
[docs] def __init__(self, alphabet: Alphabet,
to_pifold_format=False,
coord_nan_to_zero=True,
atoms=('N', 'CA', 'C', 'O'),
single_mut = False,
mut_seq= False
):
self.alphabet = alphabet
self.batcher = CoordBatchConverter(
alphabet=alphabet,
coord_pad_inf=alphabet.add_special_tokens,
to_pifold_format=to_pifold_format,
coord_nan_to_zero=coord_nan_to_zero
)
self.single_mut = single_mut
self.atoms = atoms
self.cache = defaultdict(lambda: -1)
self.mut_seq = mut_seq
def __call__(self, raw_batch: dict):
if not self.single_mut:
raw_batch = raw_batch[0]
if not self.mut_seq:
seqs = [raw_batch['seq']]
coords = [np.stack([raw_batch['coords'][atom] for atom in self.atoms], 1)]
else:
seqs = [raw_batch['seq']]+raw_batch['mut_seq']
coords = [np.stack([raw_batch['coords'][atom] for atom in self.atoms], 1)]*len(seqs)
coords, confidence, strs, tokens, lengths, coord_mask = self.batcher.from_lists(
coords_list=coords, confidence_list=None, seq_list=seqs
)
if not self.mut_seq:
raw_batch['tokens'] = tokens
raw_batch['mut_tokens'] = None
else:
raw_batch['tokens'] = tokens
raw_batch['mut_tokens'] = None
if True:
ddg = raw_batch['ddG']
raw_batch['ddG'] = ddg
raw_batch['ddG'] = raw_batch['ddG'].reshape(-1)
return raw_batch
for protein in raw_batch:
name = protein['name']
if isinstance(self.cache[name], int):
seqs = [protein['seq']]
coords = [np.stack([protein['coords'][atom] for atom in self.atoms], 1)]
coords, confidence, strs, tokens, lengths, coord_mask = self.batcher.from_lists(
coords_list=coords, confidence_list=None, seq_list=seqs
)
self.cache[name] = {
'tokens': tokens,
'mut_tokens': None
}
else:
tokens = self.cache[name]['tokens']
mut_tokens = self.cache[name]['mut_tokens']
protein['tokens'] = tokens
protein['mut_tokens'] = None
ddg = torch.stack([protein['ddG'] for protein in raw_batch])
return {
'raw_batch': raw_batch,
'mut_ids': [protein['mut_ids'] for protein in raw_batch],
'append_tensors' : torch.stack([protein['append_tensors'] for protein in raw_batch]),
'ddG': ddg,
'name': [protein['name']+protein['chain_ids'] for protein in raw_batch],
'dataset': [protein['dataset'] for protein in raw_batch],
}
[docs]class MegaScaleTestDatasets(torch.utils.data.Dataset):
"""A comprehensive test dataset combining multiple protein stability benchmarks.
This dataset aggregates various test sets including:
- MegaScale test set
- FireProt homologue-free set
- SSYM direct and inverse sets
- S669 dataset
- ddg datasets (S461, S783, S8754, S2648)
- dtm datasets (S571, S4346)
"""
[docs] def __init__(self,
):
current_dir = os.path.dirname(os.path.abspath(__file__))
root_path = os.path.join(current_dir,'../../../')
self.megascale = MegaScaleDataset(
reduce = '',
split = 'test',
)
self.fireport = FireProtDataset(split = 'homologue-free')
self.ssym_dir = ddgBenchDataset(
pdb_dir = os.path.join(root_path,'data/dataset/ssym/pdb'),
csv_fname= os.path.join(root_path,'data/dataset/ssym/ssym-5fold_clean_dir.csv'),
dataset_name='ssym_dir')
self.ssym_inv = ddgBenchDataset(
pdb_dir = os.path.join(root_path,'data/dataset/ssym/pdb'),
csv_fname= os.path.join(root_path,'data/dataset/ssym/ssym-5fold_clean_inv.csv'),
dataset_name='ssym_inv')
self.s669 = ddgBenchDataset(
pdb_dir = os.path.join(root_path,'data/dataset/S669/pdb'),
csv_fname= os.path.join(root_path,'data/dataset/S669/s669_clean_dir.csv'),
dataset_name='S669')
self.S461 = ddgGeo(
pdb_dir=os.path.join(root_path, 'data/dataset/geostab_data/ddG_cleaned/S461'),
csv_fname=os.path.join(root_path, 'data/dataset/geostab_data/ddG_cleaned/S461.csv'),
dataset_name='S461'
)
self.S783 = ddgGeo(
pdb_dir=os.path.join(root_path, 'data/dataset/geostab_data/ddG_cleaned/S783'),
csv_fname=os.path.join(root_path, 'data/dataset/geostab_data/ddG_cleaned/S783.csv'),
dataset_name='S783'
)
self.S8754 = ddgGeo(
pdb_dir=os.path.join(root_path, 'data/dataset/geostab_data/ddG_cleaned/S8754'),
csv_fname=os.path.join(root_path, 'data/dataset/geostab_data/ddG_cleaned/S8754.csv'),
dataset_name='S8754'
)
self.S2648 = ddgGeo(
pdb_dir=os.path.join(root_path, 'data/dataset/geostab_data/ddG_cleaned/S2648'),
csv_fname=os.path.join(root_path, 'data/dataset/geostab_data/ddG_cleaned/S2648.csv'),
dataset_name='S2648',
stage='test'
)
self.S571 = ddgGeo(
pdb_dir=os.path.join(root_path, 'data/dataset/geostab_data/dTm_cleaned/S571'),
csv_fname=os.path.join(root_path, 'data/dataset/geostab_data/dTm_cleaned/S571.csv'),
dataset_name='S571'
)
self.S4346 = ddgGeo(
pdb_dir=os.path.join(root_path, 'data/dataset/geostab_data/dTm_cleaned/S4346'),
csv_fname=os.path.join(root_path, 'data/dataset/geostab_data/dTm_cleaned/S4346.csv'),
dataset_name='S4346'
)
def __len__(self):
return self.megascale.__len__() + self.fireport.__len__() + self.ssym_dir.__len__() + self.ssym_inv.__len__() + self.s669.__len__() + self.S461.__len__() + self.S783.__len__() + self.S571.__len__() + self.S8754.__len__() + self.S2648.__len__() + self.S4346.__len__()
def __getitem__(self, index):
if index < self.megascale.__len__():
return self.megascale.__getitem__(index)
index -= self.megascale.__len__()
if index < self.fireport.__len__():
return self.fireport.__getitem__(index)
index -= self.fireport.__len__()
if index < self.ssym_dir.__len__():
return self.ssym_dir.__getitem__(index)
index -= self.ssym_dir.__len__()
if index < self.ssym_inv.__len__():
return self.ssym_inv.__getitem__(index)
index -= self.ssym_inv.__len__()
if index < self.s669.__len__():
return self.s669.__getitem__(index)
index -= self.s669.__len__()
if index < self.S461.__len__():
return self.S461.__getitem__(index)
index -= self.S461.__len__()
if index < self.S783.__len__():
return self.S783.__getitem__(index)
index -= self.S783.__len__()
if index < self.S571.__len__():
return self.S571.__getitem__(index)
index -= self.S571.__len__()
if index < self.S8754.__len__():
return self.S8754.__getitem__(index)
index -= self.S8754.__len__()
if index < self.S2648.__len__():
return self.S2648.__getitem__(index)
index -= self.S2648.__len__()
if index < self.S4346.__len__():
return self.S4346.__getitem__(index)