Source code for spurs.functional_site_annotation

"""
This module implements functionality for identifying and analyzing functional sites in proteins
using SPURS model predictions and sigmoid-based regression analysis. It provides tools for
normalizing mutation effect scores, fitting sigmoid functions to the data, and visualizing
the results to identify functionally important positions in proteins.
"""

from spurs.inference import get_SPURS, parse_pdb
# ~ 10s
import torch
import pandas as pd
import numpy as np
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import numpy as np
import seaborn as sns
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
from scipy.optimize import minimize
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.colors as mcolors
import matplotlib as mpl

[docs]def normalize(data, lower=-4.36, upper=1.21, new_min=0, new_max=1): """ Normalize mutation effect scores to a specified range. Args: data: Raw mutation effect scores lower: Lower bound for clipping (default: -4.36, 0.1 percentile of ddG values) upper: Upper bound for clipping (default: 1.21, 99.9 percentile of ddG values) new_min: Minimum value in normalized range (default: 0) new_max: Maximum value in normalized range (default: 1) Returns: Normalized and clipped data scaled to [new_min, new_max] """ clipped_data = np.clip(data, lower, upper) norm_data = (clipped_data - lower) / (upper - lower) norm_data = norm_data * (new_max - new_min) + new_min return norm_data
[docs]def sigmoid_function(x, xmid, scal): """ Compute sigmoid function values for given parameters. Args: x: Input values xmid: Midpoint parameter of the sigmoid scal: Scale parameter controlling steepness Returns: Sigmoid function values """ return 1 / (1 + np.exp(-(x - xmid) / scal))
# Define a custom estimator for sigmoid fitting
[docs]class SigmoidRegressor(BaseEstimator, RegressorMixin): """ Custom scikit-learn compatible estimator for fitting sigmoid functions to mutation effect data. Used to identify functional sites by finding positions where mutations have non-linear effects. The sigmoid function is fit using weighted least squares optimization with bounds derived from empirical protein stability data. """
[docs] def __init__(self, xmid_initial=0.0, scal_initial=1.0): """ Initialize sigmoid regressor with starting parameters. Args: xmid_initial: Initial guess for sigmoid midpoint scal_initial: Initial guess for sigmoid scale factor """ self.xmid_initial = xmid_initial self.scal_initial = scal_initial
[docs] def fit(self, X, y, sample_weight=None): """ Fit sigmoid function to data using weighted optimization. Args: X: Input features y: Target values sample_weight: Sample weights for weighted least squares Returns: self: Fitted estimator """ X, y = check_X_y(X, y, ensure_2d=False) # Initial guess for the parameters initial_guess = [self.xmid_initial, self.scal_initial] def weighted_loss(params, X, y, sample_weight): """Calculate weighted squared error loss""" xmid, scal = params y_pred = sigmoid_function(X, *params) residuals = y - y_pred return np.sum(sample_weight**2 * (residuals ** 2)) # Optimize the parameters bounds = [(-20, -5), (None, np.e)] # weight, bonds from https://github.com/lehner-lab/domainome/blob/main/09_esm1v_residuals.Rmd#L112 # Optimize the parameters with bounds result = minimize(weighted_loss, initial_guess, args=(X, y, sample_weight), bounds=bounds) self.xmid_, self.scal_ = result.x return self
[docs] def predict(self, X): """ Predict using fitted sigmoid function. Args: X: Input features Returns: Predicted values """ check_is_fitted(self, ['xmid_', 'scal_']) X = check_array(X, ensure_2d=False) return sigmoid_function(X, self.xmid_, self.scal_)
[docs]def get_sigmoid_results(mask_results,ddg): """ Fit sigmoid function to mutation effect data for functional site identification. Args: mask_results: Mutation effect predictions from ESM ddg: Predicted stability measurements Returns: tuple: (normalized_ddg, mask_results, None, sigmoid_parameters) """ X = mask_results.flatten() y = -normalize(ddg).flatten() assert len(X) == len(y) X_clean = X y_clean = y max_f_apca = np.nanmax(y) min_f_apca = np.nanmin(y) weights = max_f_apca - min_f_apca - (y_clean + 1) sigmoid_regressor = SigmoidRegressor(xmid_initial=np.nanmedian(X_clean), scal_initial=0.6) print(X_clean.shape,y_clean.shape,weights.shape) sigmoid_regressor.fit(X_clean.numpy(), y_clean.numpy(), sample_weight=weights.numpy()) xmid = sigmoid_regressor.xmid_ scal = sigmoid_regressor.scal_ popt = [xmid,scal] return y, X, None,popt
[docs]def inference_wt_seq(sequence: str, indices: list, batch_converter, model, device: torch.device, alphabet): """ Perform inference on wild-type sequence using the ESM model. Args: sequence: Protein sequence to analyze indices: List of positions to analyze batch_converter: ESM batch converter model: ESM model device: Computation device alphabet: ESM alphabet for token mapping Returns: Logits for each position and possible amino acid substitution """ data = [("sequence", sequence)] batch_labels, batch_strs, batch_tokens = batch_converter(data) batch_tokens = batch_tokens.to(device) with torch.no_grad(): results = model(batch_tokens, repr_layers=[33], return_contacts=False) logits = results["logits"] logits_original = logits[0, indices, :] l = 'ACDEFGHIKLMNPQRSTVWY' logist = [logits_original[:,alphabet.tok_to_idx[i]] for i in l] return torch.stack(logist).T
[docs]def get_wt_aa_logit_differences(sequence: str, mut_indices: list, batch_converter, model, device: torch.device, alphabet, shift: int = 1): """ Calculate differences between wild-type and mutant amino acid logits. This function helps identify positions where mutations would have the strongest effect by comparing model predictions for all possible amino acid substitutions against the wild-type amino acid at each position. Args: sequence: Protein sequence mut_indices: Positions to analyze batch_converter: ESM batch converter model: ESM model device: Computation device alphabet: ESM alphabet shift: Index adjustment (default: 1) Returns: Tensor of logit differences between wild-type and all possible mutations """ mask_results = inference_wt_seq(sequence,[i for i in mut_indices],batch_converter,model,device,alphabet) l = 'ACDEFGHIKLMNPQRSTVWY' aa_indcies = [l.index(sequence[i-shift]) for i in mut_indices] # to one hot one_hot = torch.zeros(len(mut_indices),20).to(mask_results.device) one_hot[range(len(mut_indices)),aa_indcies] = 1 wt_logits = (one_hot * mask_results).sum(-1).reshape(-1,1) # print(wt_logits.shape,mask_results.shape) mask_results = mask_results - wt_logits return mask_results
[docs]def plot_sigmoid_results(result,shift=1,vcenter = 0,highlight_positions =[]): """ Visualize functional site prediction results using a scatter plot. Creates a plot showing normalized mutation effects across protein positions, with optional highlighting of specific positions of interest. The color scheme uses a diverging colormap centered at vcenter. Args: result: Tuple containing (normalized_ddg, mask_results, None, sigmoid_parameters) shift: Position numbering offset (default: 1) vcenter: Center value for color normalization (default: 0) highlight_positions: List of positions to highlight with markers Returns: normalized_data: Z-score normalized mutation effects """ arr = (result[0].numpy()-sigmoid_function(result[1].numpy(), *result[-1])).reshape(-1,20).sum(-1) # normalize the data data = arr mean = np.nanmean(data) std_dev = np.nanstd(data) z_scores = (data - mean) / std_dev min_z = np.nanmin(z_scores) max_z = np.nanmax(z_scores) normalized_data = 2 * (z_scores - min_z) / (max_z - min_z) - 1 x = np.arange(shift, shift + len(normalized_data)) y = normalized_data color_strength = normalized_data vmin = min(color_strength) vmax = max(color_strength) norm = mcolors.TwoSlopeNorm(vmin=vmin, vmax=vmax, vcenter=vcenter) cmap = plt.cm.get_cmap('RdBu_r') new_cmap = mpl.colors.LinearSegmentedColormap.from_list('truncated_cmap', cmap(np.linspace(0.1, 0.9, 256))) cmap = new_cmap plt.figure(figsize=(16, 4)) scatter = plt.scatter(x, y, c=color_strength, cmap=cmap, norm=norm, s=100) cbar = plt.colorbar(scatter) color_strength = normalized_data for pos in highlight_positions: plt.scatter(pos, y[pos - shift], color='black', marker='x', s=120, linewidth=1.5) plt.ylim(-1.1, 1.1) plt.show() return normalized_data