Spaces:
Running
on
T4
Running
on
T4
##################################################################### | |
############# PROTEIN SEQUENCE DIFFUSION SAMPLER #################### | |
##################################################################### | |
import sys, os, subprocess, pickle, time, json | |
script_dir = os.path.dirname(os.path.realpath(__file__)) | |
sys.path = sys.path + [script_dir+'/../model/'] + [script_dir+'/'] | |
import shutil | |
import glob | |
import torch | |
import numpy as np | |
import copy | |
import json | |
import matplotlib.pyplot as plt | |
from torch import nn | |
import math | |
import re | |
import pickle | |
import pandas as pd | |
import random | |
from copy import deepcopy | |
import time | |
from collections import namedtuple | |
import math | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from RoseTTAFoldModel import RoseTTAFoldModule | |
from util import * | |
from inpainting_util import * | |
from kinematics import get_init_xyz, xyz_to_t2d | |
import parsers_inference as parsers | |
import diff_utils | |
import pickle | |
import pdb | |
from utils.calc_dssp import annotate_sse | |
from potentials import POTENTIALS | |
from diffusion import GaussianDiffusion_SEQDIFF | |
MODEL_PARAM ={ | |
"n_extra_block" : 4, | |
"n_main_block" : 32, | |
"n_ref_block" : 4, | |
"d_msa" : 256, | |
"d_msa_full" : 64, | |
"d_pair" : 128, | |
"d_templ" : 64, | |
"n_head_msa" : 8, | |
"n_head_pair" : 4, | |
"n_head_templ" : 4, | |
"d_hidden" : 32, | |
"d_hidden_templ" : 32, | |
"p_drop" : 0.0 | |
} | |
SE3_PARAMS = { | |
"num_layers_full" : 1, | |
"num_layers_topk" : 1, | |
"num_channels" : 32, | |
"num_degrees" : 2, | |
"l0_in_features_full": 8, | |
"l0_in_features_topk" : 64, | |
"l0_out_features_full": 8, | |
"l0_out_features_topk" : 64, | |
"l1_in_features": 3, | |
"l1_out_features": 2, | |
"num_edge_features_full": 32, | |
"num_edge_features_topk": 64, | |
"div": 4, | |
"n_heads": 4 | |
} | |
SE3_param_full = {} | |
SE3_param_topk = {} | |
for param, value in SE3_PARAMS.items(): | |
if "full" in param: | |
SE3_param_full[param[:-5]] = value | |
elif "topk" in param: | |
SE3_param_topk[param[:-5]] = value | |
else: # common arguments | |
SE3_param_full[param] = value | |
SE3_param_topk[param] = value | |
MODEL_PARAM['SE3_param_full'] = SE3_param_full | |
MODEL_PARAM['SE3_param_topk'] = SE3_param_topk | |
DEFAULT_CKPT = '/home/jgershon/models/SEQDIFF_221219_equalTASKS_nostrSELFCOND_mod30.pt' #this is the one with good sequences | |
LOOP_CHECKPOINT = '/home/jgershon/models/SEQDIFF_221202_AB_NOSTATE_fromBASE_mod30.pt' | |
t1d_29_CKPT = '/home/jgershon/models/SEQDIFF_230205_dssp_hotspots_25mask_EQtasks_mod30.pt' | |
class SEQDIFF_sampler: | |
''' | |
MODULAR SAMPLER FOR SEQUENCE DIFFUSION | |
- the goal for modularizing this code is to make it as | |
easy as possible to edit and mix functions around | |
- in the base implementation here this can handle the standard | |
inference mode with default passes through the model, different | |
forms of partial diffusion, and linear symmetry | |
''' | |
def __init__(self, args=None): | |
''' | |
set args and DEVICE as well as other default params | |
''' | |
self.args = args | |
self.DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
self.conversion = 'ARNDCQEGHILKMFPSTWYVX-' | |
self.dssp_dict = {'X':3,'H':0,'E':1,'L':2} | |
self.MODEL_PARAM = MODEL_PARAM | |
self.SE3_PARAMS = SE3_PARAMS | |
self.SE3_param_full = SE3_param_full | |
self.SE3_param_topk = SE3_param_topk | |
self.use_potentials = False | |
self.reset_design_num() | |
def set_args(self, args): | |
''' | |
set new arguments if iterating through dictionary of multiple arguments | |
# NOTE : args pertaining to the model will not be considered as this is | |
used to sample more efficiently without having to reload model for | |
different sets of args | |
''' | |
self.args = args | |
self.diffuser_init() | |
if self.args['potentials'] not in ['', None]: | |
self.potential_init() | |
def reset_design_num(self): | |
''' | |
reset design num to 0 | |
''' | |
self.design_num = 0 | |
def diffuser_init(self): | |
''' | |
set up diffuser object of GaussianDiffusion_SEQDIFF | |
''' | |
self.diffuser = GaussianDiffusion_SEQDIFF(T=self.args['T'], | |
schedule=self.args['noise_schedule'], | |
sample_distribution=self.args['sample_distribution'], | |
sample_distribution_gmm_means=self.args['sample_distribution_gmm_means'], | |
sample_distribution_gmm_variances=self.args['sample_distribution_gmm_variances'], | |
) | |
self.betas = self.diffuser.betas | |
self.alphas = 1-self.betas | |
self.alphas_cumprod = np.cumprod(self.alphas, axis=0) | |
def make_hotspot_features(self): | |
''' | |
set up hotspot features | |
''' | |
# initialize hotspot features to all 0s | |
self.features['hotspot_feat'] = torch.zeros(self.features['L']) | |
# if hotspots exist in args then make hotspot features | |
if self.args['hotspots'] != None: | |
self.features['hotspots'] = [(x[0],int(x[1:])) for x in self.args['hotspots'].split(',')] | |
for n,x in enumerate(self.features['mappings']['complex_con_ref_pdb_idx']): | |
if x in self.features['hotspots']: | |
self.features['hotspot_feat'][self.features['mappings']['complex_con_hal_idx0'][n]] = 1.0 | |
def make_dssp_features(self): | |
''' | |
set up dssp features | |
''' | |
assert not ((self.args['secondary_structure'] != None) and (self.args['dssp_pdb'] != None)), \ | |
f'You are attempting to provide both dssp_pdb and/or secondary_secondary structure, please choose one or the other' | |
# initialize with all zeros | |
self.features['dssp_feat'] = torch.zeros(self.features['L'],4) | |
if self.args['secondary_structure'] != None: | |
self.features['secondary_structure'] = [self.dssp_dict[x.upper()] for x in self.args['secondary_structure']] | |
assert len(self.features['secondary_structure']*self.features['sym'])+self.features['cap']*2 == self.features['L'], \ | |
f'You have specified a secondary structure string that does not match your design length' | |
self.features['dssp_feat'] = torch.nn.functional.one_hot( | |
torch.tensor(self.features['cap_dssp']+self.features['secondary_structure']*self.features['sym']+self.features['cap_dssp']), | |
num_classes=4) | |
elif self.args['dssp_pdb'] != None: | |
dssp_xyz = torch.from_numpy(parsers.parse_pdb(self.args['dssp_pdb'])['xyz'][:,:,:]) | |
dssp_pdb = annotate_sse(np.array(dssp_xyz[:,1,:].squeeze()), percentage_mask=0) | |
#we assume binder is chain A | |
self.features['dssp_feat'][:dssp_pdb.shape[0]] = dssp_pdb | |
elif (self.args['helix_bias'] + self.args['strand_bias'] + self.args['loop_bias']) > 0.0: | |
tmp_mask = torch.rand(self.features['L']) < self.args['helix_bias'] | |
self.features['dssp_feat'][tmp_mask,0] = 1.0 | |
tmp_mask = torch.rand(self.features['L']) < self.args['strand_bias'] | |
self.features['dssp_feat'][tmp_mask,1] = 1.0 | |
tmp_mask = torch.rand(self.features['L']) < self.args['loop_bias'] | |
self.features['dssp_feat'][tmp_mask,2] = 1.0 | |
#contigs get mask label | |
self.features['dssp_feat'][self.features['mask_str'][0],3] = 1.0 | |
#anything not labeled gets mask label | |
mask_index = torch.where(torch.sum(self.features['dssp_feat'], dim=1) == 0)[0] | |
self.features['dssp_feat'][mask_index,3] = 1.0 | |
def model_init(self): | |
''' | |
get model set up and choose checkpoint | |
''' | |
if self.args['checkpoint'] == None: | |
self.args['checkpoint'] = DEFAULT_CKPT | |
self.MODEL_PARAM['d_t1d'] = self.args['d_t1d'] | |
# decide based on input args what checkpoint to load | |
if self.args['hotspots'] != None or self.args['secondary_structure'] != None \ | |
or (self.args['helix_bias'] + self.args['strand_bias'] + self.args['loop_bias']) > 0 \ | |
or self.args['dssp_pdb'] != None and self.args['checkpoint'] == DEFAULT_CKPT: | |
self.MODEL_PARAM['d_t1d'] = 29 | |
print('You are using features only compatible with a newer model, switching checkpoint...') | |
self.args['checkpoint'] = t1d_29_CKPT | |
elif self.args['loop_design'] and self.args['checkpoint'] == DEFAULT_CKPT: | |
print('Switched to loop design checkpoint') | |
self.args['checkpoint'] = LOOP_CHECKPOINT | |
# check to make sure checkpoint chosen exists | |
if not os.path.exists(self.args['checkpoint']): | |
print('WARNING: couldn\'t find checkpoint') | |
self.ckpt = torch.load(self.args['checkpoint'], map_location=self.DEVICE) | |
# check to see if [loader_param, model_param, loss_param] is in checkpoint | |
# if so then you are using v2 of inference with t2d bug fixed | |
self.v2_mode = False | |
if 'model_param' in self.ckpt.keys(): | |
print('You are running a new v2 model switching into v2 inference mode') | |
self.v2_mode = True | |
for k in self.MODEL_PARAM.keys(): | |
if k in self.ckpt['model_param'].keys(): | |
self.MODEL_PARAM[k] = self.ckpt['model_param'][k] | |
else: | |
print(f'no match for {k} in loaded model params') | |
# make model and load checkpoint | |
print('Loading model checkpoint...') | |
self.model = RoseTTAFoldModule(**self.MODEL_PARAM).to(self.DEVICE) | |
model_state = self.ckpt['model_state_dict'] | |
self.model.load_state_dict(model_state, strict=False) | |
self.model.eval() | |
print('Successfully loaded model checkpoint') | |
def feature_init(self): | |
''' | |
featurize pdb and contigs and choose type of diffusion | |
''' | |
# initialize features dictionary for all example features | |
self.features = {} | |
# set up params | |
self.loader_params = {'MAXCYCLE':self.args['n_cycle'],'TEMPERATURE':self.args['temperature'], 'DISTANCE':self.args['min_decoding_distance']} | |
# symmetry | |
self.features['sym'] = self.args['symmetry'] | |
self.features['cap'] = self.args['symmetry_cap'] | |
self.features['cap_dssp'] = [self.dssp_dict[x.upper()] for x in 'H'*self.features['cap']] | |
if self.features['sym'] > 1: | |
print(f"Input sequence symmetry {self.features['sym']}") | |
assert (self.args['contigs'] in [('0'),(0),['0'],[0]] ) ^ (self.args['sequence'] in ['',None]),\ | |
f'You are specifying contigs ({self.args["contigs"]}) and sequence ({self.args["sequence"]}) (or neither), please specify one or the other' | |
# initialize trb dictionary | |
self.features['trb_d'] = {} | |
if self.args['pdb'] == None and self.args['sequence'] not in ['', None]: | |
print('Preparing sequence input') | |
allowable_aas = [x for x in self.conversion[:-1]] | |
for x in self.args['sequence']: assert x in allowable_aas, f'Amino Acid {x} is undefinded, please only use standart 20 AAs' | |
self.features['seq'] = torch.tensor([self.conversion.index(x) for x in self.args['sequence']]) | |
self.features['xyz_t'] = torch.full((1,1,len(self.args['sequence']),27,3), np.nan) | |
self.features['mask_str'] = torch.zeros(len(self.args['sequence'])).long()[None,:].bool() | |
#added check for if in partial diffusion mode will mask | |
if self.args['sampling_temp'] == 1.0: | |
self.features['mask_seq'] = torch.tensor([0 if x == 'X' else 1 for x in self.args['sequence']]).long()[None,:].bool() | |
else: | |
self.features['mask_seq'] = torch.zeros(len(self.args['sequence'])).long()[None,:].bool() | |
self.features['blank_mask'] = torch.ones(self.features['mask_str'].size()[-1])[None,:].bool() | |
self.features['idx_pdb'] = torch.tensor([i for i in range(len(self.args['sequence']))])[None,:] | |
conf_1d = torch.ones_like(self.features['seq']) | |
conf_1d[~self.features['mask_str'][0]] = 0 | |
self.features['seq_hot'], self.features['msa'], \ | |
self.features['msa_hot'], self.features['msa_extra_hot'], _ = MSAFeaturize_fixbb(self.features['seq'][None,:],self.loader_params) | |
self.features['t1d'] = TemplFeaturizeFixbb(self.features['seq'], conf_1d=conf_1d)[None,None,:] | |
self.features['seq_hot'] = self.features['seq_hot'].unsqueeze(dim=0) | |
self.features['msa'] = self.features['msa'].unsqueeze(dim=0) | |
self.features['msa_hot'] = self.features['msa_hot'].unsqueeze(dim=0) | |
self.features['msa_extra_hot'] = self.features['msa_extra_hot'].unsqueeze(dim=0) | |
self.max_t = int(self.args['T']*self.args['sampling_temp']) | |
self.features['pdb_idx'] = [('A',i+1) for i in range(len(self.args['sequence']))] | |
self.features['trb_d']['inpaint_str'] = self.features['mask_str'][0] | |
self.features['trb_d']['inpaint_seq'] = self.features['mask_seq'][0] | |
else: | |
assert not (self.args['pdb'] == None and self.args['sampling_temp'] != 1.0),\ | |
f'You must specify a pdb if attempting to use contigs with partial diffusion, else partially diffuse sequence input' | |
if self.args['pdb'] == None: | |
self.features['parsed_pdb'] = {'seq':np.zeros((1),'int64'), | |
'xyz':np.zeros((1,27,3),'float32'), | |
'idx':np.zeros((1),'int64'), | |
'mask':np.zeros((1,27), bool), | |
'pdb_idx':['A',1]} | |
else: | |
# parse input pdb | |
self.features['parsed_pdb'] = parsers.parse_pdb(self.args['pdb']) | |
# generate contig map | |
self.features['rm'] = ContigMap(self.features['parsed_pdb'], self.args['contigs'], | |
self.args['inpaint_seq'], self.args['inpaint_str'], | |
self.args['length'], self.args['ref_idx'], | |
self.args['hal_idx'], self.args['idx_rf'], | |
self.args['inpaint_seq_tensor'], self.args['inpaint_str_tensor']) | |
self.features['mappings'] = get_mappings(self.features['rm']) | |
self.features['pdb_idx'] = self.features['rm'].hal | |
### PREPARE FEATURES DEPENDING ON TYPE OF ARGUMENTS SPECIFIED ### | |
# FULL DIFFUSION MODE | |
if self.args['trb'] == None and self.args['sampling_temp'] == 1.0: | |
# process contigs and generate masks | |
self.features['mask_str'] = torch.from_numpy(self.features['rm'].inpaint_str)[None,:] | |
self.features['mask_seq'] = torch.from_numpy(self.features['rm'].inpaint_seq)[None,:] | |
self.features['blank_mask'] = torch.ones(self.features['mask_str'].size()[-1])[None,:].bool() | |
seq_input = torch.from_numpy(self.features['parsed_pdb']['seq']) | |
xyz_input = torch.from_numpy(self.features['parsed_pdb']['xyz'][:,:,:]) | |
self.features['xyz_t'] = torch.full((1,1,len(self.features['rm'].ref),27,3), np.nan) | |
self.features['xyz_t'][:,:,self.features['rm'].hal_idx0,:14,:] = xyz_input[self.features['rm'].ref_idx0,:14,:][None, None,...] | |
self.features['seq'] = torch.full((1,len(self.features['rm'].ref)),20).squeeze() | |
self.features['seq'][self.features['rm'].hal_idx0] = seq_input[self.features['rm'].ref_idx0] | |
# template confidence | |
conf_1d = torch.ones_like(self.features['seq'])*float(self.args['tmpl_conf']) | |
conf_1d[~self.features['mask_str'][0]] = 0 # zero confidence for places where structure is masked | |
seq_masktok = torch.where(self.features['seq'] == 20, 21, self.features['seq']) | |
# Get sequence and MSA input features | |
self.features['seq_hot'], self.features['msa'], \ | |
self.features['msa_hot'], self.features['msa_extra_hot'], _ = MSAFeaturize_fixbb(seq_masktok[None,:],self.loader_params) | |
self.features['t1d'] = TemplFeaturizeFixbb(self.features['seq'], conf_1d=conf_1d)[None,None,:] | |
self.features['idx_pdb'] = torch.from_numpy(np.array(self.features['rm'].rf)).int()[None,:] | |
self.features['seq_hot'] = self.features['seq_hot'].unsqueeze(dim=0) | |
self.features['msa'] = self.features['msa'].unsqueeze(dim=0) | |
self.features['msa_hot'] = self.features['msa_hot'].unsqueeze(dim=0) | |
self.features['msa_extra_hot'] = self.features['msa_extra_hot'].unsqueeze(dim=0) | |
self.max_t = int(self.args['T']*self.args['sampling_temp']) | |
# PARTIAL DIFFUSION MODE, NO INPUT TRB | |
elif self.args['trb'] != None: | |
print('Running in partial diffusion mode . . .') | |
self.features['trb_d'] = np.load(self.args['trb'], allow_pickle=True) | |
self.features['mask_str'] = torch.from_numpy(self.features['trb_d']['inpaint_str'])[None,:] | |
self.features['mask_seq'] = torch.from_numpy(self.features['trb_d']['inpaint_seq'])[None,:] | |
self.features['blank_mask'] = torch.ones(self.features['mask_str'].size()[-1])[None,:].bool() | |
self.features['seq'] = torch.from_numpy(self.features['parsed_pdb']['seq']) | |
self.features['xyz_t'] = torch.from_numpy(self.features['parsed_pdb']['xyz'][:,:,:])[None,None,...] | |
if self.features['mask_seq'].shape[1] == 0: | |
self.features['mask_seq'] = torch.zeros(self.features['seq'].shape[0])[None].bool() | |
if self.features['mask_str'].shape[1] == 0: | |
self.features['mask_str'] = torch.zeros(self.features['xyz_t'].shape[2])[None].bool() | |
idx_pdb = [] | |
chains_used = [self.features['parsed_pdb']['pdb_idx'][0][0]] | |
idx_jump = 0 | |
for i,x in enumerate(self.features['parsed_pdb']['pdb_idx']): | |
if x[0] not in chains_used: | |
chains_used.append(x[0]) | |
idx_jump += 200 | |
idx_pdb.append(idx_jump+i) | |
self.features['idx_pdb'] = torch.tensor(idx_pdb)[None,:] | |
conf_1d = torch.ones_like(self.features['seq']) | |
conf_1d[~self.features['mask_str'][0]] = 0 | |
self.features['seq_hot'], self.features['msa'], \ | |
self.features['msa_hot'], self.features['msa_extra_hot'], _ = MSAFeaturize_fixbb(self.features['seq'][None,:],self.loader_params) | |
self.features['t1d'] = TemplFeaturizeFixbb(self.features['seq'], conf_1d=conf_1d)[None,None,:] | |
self.features['seq_hot'] = self.features['seq_hot'].unsqueeze(dim=0) | |
self.features['msa'] = self.features['msa'].unsqueeze(dim=0) | |
self.features['msa_hot'] = self.features['msa_hot'].unsqueeze(dim=0) | |
self.features['msa_extra_hot'] = self.features['msa_extra_hot'].unsqueeze(dim=0) | |
self.max_t = int(self.args['T']*self.args['sampling_temp']) | |
else: | |
print('running in partial diffusion mode, with no trb input, diffusing whole input') | |
self.features['seq'] = torch.from_numpy(self.features['parsed_pdb']['seq']) | |
self.features['xyz_t'] = torch.from_numpy(self.features['parsed_pdb']['xyz'][:,:,:])[None,None,...] | |
if self.args['contigs'] in [('0'),(0),['0'],[0]]: | |
print('no contigs given partially diffusing everything') | |
self.features['mask_str'] = torch.zeros(self.features['xyz_t'].shape[2]).long()[None,:].bool() | |
self.features['mask_seq'] = torch.zeros(self.features['seq'].shape[0]).long()[None,:].bool() | |
self.features['blank_mask'] = torch.ones(self.features['mask_str'].size()[-1])[None,:].bool() | |
else: | |
print('found contigs setting up masking for partial diffusion') | |
self.features['mask_str'] = torch.from_numpy(self.features['rm'].inpaint_str)[None,:] | |
self.features['mask_seq'] = torch.from_numpy(self.features['rm'].inpaint_seq)[None,:] | |
self.features['blank_mask'] = torch.ones(self.features['mask_str'].size()[-1])[None,:].bool() | |
idx_pdb = [] | |
chains_used = [self.features['parsed_pdb']['pdb_idx'][0][0]] | |
idx_jump = 0 | |
for i,x in enumerate(self.features['parsed_pdb']['pdb_idx']): | |
if x[0] not in chains_used: | |
chains_used.append(x[0]) | |
idx_jump += 200 | |
idx_pdb.append(idx_jump+i) | |
self.features['idx_pdb'] = torch.tensor(idx_pdb)[none,:] | |
conf_1d = torch.ones_like(self.features['seq']) | |
conf_1d[~self.features['mask_str'][0]] = 0 | |
self.features['seq_hot'], self.features['msa'], \ | |
self.features['msa_hot'], self.features['msa_extra_hot'], _ = msafeaturize_fixbb(self.features['seq'][none,:],self.loader_params) | |
self.features['t1d'] = templfeaturizefixbb(self.features['seq'], conf_1d=conf_1d)[none,none,:] | |
self.features['seq_hot'] = self.features['seq_hot'].unsqueeze(dim=0) | |
self.features['msa'] = self.features['msa'].unsqueeze(dim=0) | |
self.features['msa_hot'] = self.features['msa_hot'].unsqueeze(dim=0) | |
self.features['msa_extra_hot'] = self.features['msa_extra_hot'].unsqueeze(dim=0) | |
self.max_t = int(self.args['t']*self.args['sampling_temp']) | |
# set L | |
self.features['L'] = self.features['seq'].shape[0] | |
def potential_init(self): | |
''' | |
initialize potential functions being used and return list of potentails | |
''' | |
potentials = self.args['potentials'].split(',') | |
potential_scale = [float(x) for x in self.args['potential_scale'].split(',')] | |
assert len(potentials) == len(potential_scale), \ | |
f'Please make sure number of potentials matches potential scales specified' | |
self.potential_list = [] | |
for p,s in zip(potentials, potential_scale): | |
assert p in POTENTIALS.keys(), \ | |
f'The potential specified: {p} , does not match into POTENTIALS dictionary in potentials.py' | |
print(f'Using potential: {p}') | |
self.potential_list.append(POTENTIALS[p](self.args, self.features, s, self.DEVICE)) | |
self.use_potentials = True | |
def setup(self, init_model=True): | |
''' | |
run init model and init features to get everything prepped to go into model | |
''' | |
# initialize features | |
self.feature_init() | |
# initialize potential | |
if self.args['potentials'] not in ['', None]: | |
self.potential_init() | |
# make hostspot features | |
self.make_hotspot_features() | |
# make dssp features | |
self.make_dssp_features() | |
# diffuse sequence and mask features | |
self.features['seq'], self.features['msa_masked'], \ | |
self.features['msa_full'], self.features['xyz_t'], self.features['t1d'], \ | |
self.features['seq_diffused'] = diff_utils.mask_inputs(self.features['seq_hot'], | |
self.features['msa_hot'], | |
self.features['msa_extra_hot'], | |
self.features['xyz_t'], | |
self.features['t1d'], | |
input_seq_mask=self.features['mask_seq'], | |
input_str_mask=self.features['mask_str'], | |
input_t1dconf_mask=self.features['blank_mask'], | |
diffuser=self.diffuser, | |
t=self.max_t, | |
MODEL_PARAM=self.MODEL_PARAM, | |
hotspots=self.features['hotspot_feat'], | |
dssp=self.features['dssp_feat'], | |
v2_mode=self.v2_mode) | |
# move features to device | |
self.features['idx_pdb'] = self.features['idx_pdb'].long().to(self.DEVICE, non_blocking=True) # (B, L) | |
self.features['mask_str'] = self.features['mask_str'][None].to(self.DEVICE, non_blocking=True) # (B, L) | |
self.features['xyz_t'] = self.features['xyz_t'][None].to(self.DEVICE, non_blocking=True) | |
self.features['t1d'] = self.features['t1d'][None].to(self.DEVICE, non_blocking=True) | |
self.features['seq'] = self.features['seq'][None].type(torch.float32).to(self.DEVICE, non_blocking=True) | |
self.features['msa'] = self.features['msa'].type(torch.float32).to(self.DEVICE, non_blocking=True) | |
self.features['msa_masked'] = self.features['msa_masked'][None].type(torch.float32).to(self.DEVICE, non_blocking=True) | |
self.features['msa_full'] = self.features['msa_full'][None].type(torch.float32).to(self.DEVICE, non_blocking=True) | |
self.ti_dev = torsion_indices.to(self.DEVICE, non_blocking=True) | |
self.ti_flip = torsion_can_flip.to(self.DEVICE, non_blocking=True) | |
self.ang_ref = reference_angles.to(self.DEVICE, non_blocking=True) | |
self.features['xyz_prev'] = torch.clone(self.features['xyz_t'][0]) | |
self.features['seq_diffused'] = self.features['seq_diffused'][None].to(self.DEVICE, non_blocking=True) | |
self.features['B'], _, self.features['N'], self.features['L'] = self.features['msa'].shape | |
self.features['t2d'] = xyz_to_t2d(self.features['xyz_t']) | |
# get alphas | |
self.features['alpha'], self.features['alpha_t'] = diff_utils.get_alphas(self.features['t1d'], self.features['xyz_t'], | |
self.features['B'], self.features['L'], | |
self.ti_dev, self.ti_flip, self.ang_ref) | |
# processing template coordinates | |
self.features['xyz_t'] = get_init_xyz(self.features['xyz_t']) | |
self.features['xyz_prev'] = get_init_xyz(self.features['xyz_prev'][:,None]).reshape(self.features['B'], self.features['L'], 27, 3) | |
# initialize extra features to none | |
self.features['xyz'] = None | |
self.features['pred_lddt'] = None | |
self.features['logit_s'] = None | |
self.features['logit_aa_s'] = None | |
self.features['best_plddt'] = 0 | |
self.features['best_pred_lddt'] = torch.zeros_like(self.features['mask_str'])[0].float() | |
self.features['msa_prev'] = None | |
self.features['pair_prev'] = None | |
self.features['state_prev'] = None | |
def symmetrize_seq(self, x): | |
''' | |
symmetrize x according sym in features | |
''' | |
assert (self.features['L']-self.features['cap']*2) % self.features['sym'] == 0, f'symmetry does not match for input length' | |
assert x.shape[0] == self.features['L'], f'make sure that dimension 0 of input matches to L' | |
n_cap = torch.clone(x[:self.features['cap']]) | |
c_cap = torch.clone(x[-self.features['cap']+1:]) | |
sym_x = torch.clone(x[self.features['cap']:self.features['sym']]).repeat(self.features['sym'],1) | |
return torch.cat([n_cap,sym_x,c_cap], dim=0) | |
def predict_x(self): | |
''' | |
take step using X_t-1 features to predict Xo | |
''' | |
self.features['seq'], \ | |
self.features['xyz'], \ | |
self.features['pred_lddt'], \ | |
self.features['logit_s'], \ | |
self.features['logit_aa_s'], \ | |
self.features['alpha'], \ | |
self.features['msa_prev'], \ | |
self.features['pair_prev'], \ | |
self.features['state_prev'] \ | |
= diff_utils.take_step_nostate(self.model, | |
self.features['msa_masked'], | |
self.features['msa_full'], | |
self.features['seq'], | |
self.features['t1d'], | |
self.features['t2d'], | |
self.features['idx_pdb'], | |
self.args['n_cycle'], | |
self.features['xyz_prev'], | |
self.features['alpha'], | |
self.features['xyz_t'], | |
self.features['alpha_t'], | |
self.features['seq_diffused'], | |
self.features['msa_prev'], | |
self.features['pair_prev'], | |
self.features['state_prev']) | |
def self_condition_seq(self): | |
''' | |
get previous logits and set at t1d template | |
''' | |
self.features['t1d'][:,:,:,:21] = self.features['logit_aa_s'][0,:21,:].permute(1,0) | |
def self_condition_str_scheduled(self): | |
''' | |
unmask random fraction of residues according to timestep | |
''' | |
print('self_conditioning on strcuture') | |
xyz_prev_template = torch.clone(self.features['xyz'])[None] | |
self_conditioning_mask = torch.rand(self.features['L']) < self.diffuser.alphas_cumprod[t] | |
xyz_prev_template[:,:,~self_conditioning_mask] = float('nan') | |
xyz_prev_template[:,:,self.features['mask_str'][0][0]] = float('nan') | |
xyz_prev_template[:,:,:,3:] = float('nan') | |
t2d_sc = xyz_to_t2d(xyz_prev_template) | |
xyz_t_sc = torch.zeros_like(self.features['xyz_t'][:,:1]) | |
xyz_t_sc[:,:,:,:3] = xyz_prev_template[:,:,:,:3] | |
xyz_t_sc[:,:,:,3:] = float('nan') | |
t1d_sc = torch.clone(self.features['t1d'][:,:1]) | |
t1d_sc[:,:,~self_conditioning_mask] = 0 | |
t1d_sc[:,:,mask_str[0][0]] = 0 | |
self.features['t1d'] = torch.cat([self.features['t1d'][:,:1],t1d_sc], dim=1) | |
self.features['t2d'] = torch.cat([self.features['t2d'][:,:1],t2d_sc], dim=1) | |
self.features['xyz_t'] = torch.cat([self.features['xyz_t'][:,:1],xyz_t_sc], dim=1) | |
self.features['alpha'], self.features['alpha_t'] = diff_utils.get_alphas(self.features['t1d'], self.features['xyz_t'], | |
self.features['B'], self.features['L'], | |
self.ti_dev, self.ti_flip, self.ang_ref) | |
self.features['xyz_t'] = get_init_xyz(self.features['xyz_t']) | |
def self_condition_str(self): | |
''' | |
conditioining on strucutre in NAR way | |
''' | |
print("conditioning on structure for NAR structure noising") | |
xyz_t_str_sc = torch.zeros_like(self.features['xyz_t'][:,:1]) | |
xyz_t_str_sc[:,:,:,:3] = torch.clone(self.features['xyz'])[None] | |
xyz_t_str_sc[:,:,:,3:] = float('nan') | |
t2d_str_sc = xyz_to_t2d(self.features['xyz_t']) | |
t1d_str_sc = torch.clone(self.features['t1d']) | |
self.features['xyz_t'] = torch.cat([self.features['xyz_t'],xyz_t_str_sc], dim=1) | |
self.features['t2d'] = torch.cat([self.features['t2d'],t2d_str_sc], dim=1) | |
self.features['t1d'] = torch.cat([self.features['t1d'],t1d_str_sc], dim=1) | |
def save_step(self): | |
''' | |
add step to trajectory dictionary | |
''' | |
self.trajectory[f'step{self.t}'] = (self.features['xyz'].squeeze().detach().cpu(), | |
self.features['logit_aa_s'][0,:21,:].permute(1,0).detach().cpu(), | |
self.features['seq_diffused'][0,:,:21].detach().cpu()) | |
def noise_x(self): | |
''' | |
get X_t-1 from predicted Xo | |
''' | |
# sample x_t-1 | |
self.features['post_mean'] = self.diffuser.q_sample(self.features['seq_out'], self.t, DEVICE=self.DEVICE) | |
if self.features['sym'] > 1: | |
self.features['post_mean'] = self.symmetrize_seq(self.features['post_mean']) | |
# update seq and masks | |
self.features['seq_diffused'][0,~self.features['mask_seq'][0],:21] = self.features['post_mean'][~self.features['mask_seq'][0],...] | |
self.features['seq_diffused'][0,:,21] = 0.0 | |
# did not know we were clamping seq | |
self.features['seq_diffused'] = torch.clamp(self.features['seq_diffused'], min=-3, max=3) | |
# match other features to seq diffused | |
self.features['seq'] = torch.argmax(self.features['seq_diffused'], dim=-1)[None] | |
self.features['msa_masked'][:,:,:,:,:22] = self.features['seq_diffused'] | |
self.features['msa_masked'][:,:,:,:,22:44] = self.features['seq_diffused'] | |
self.features['msa_full'][:,:,:,:,:22] = self.features['seq_diffused'] | |
self.features['t1d'][:1,:,:,22] = 1-int(self.t)/self.args['T'] | |
def apply_potentials(self): | |
''' | |
apply potentials | |
''' | |
grads = torch.zeros_like(self.features['seq_out']) | |
for p in self.potential_list: | |
grads += p.get_gradients(self.features['seq_out']) | |
self.features['seq_out'] += (grads/len(self.potential_list)) | |
def generate_sample(self): | |
''' | |
sample from the model | |
this function runs the full sampling loop | |
''' | |
# setup example | |
self.setup() | |
# start time | |
self.start_time = time.time() | |
# set up dictionary to save at each step in trajectory | |
self.trajectory = {} | |
# set out prefix | |
self.out_prefix = self.args['out']+f'_{self.design_num:06}' | |
print(f'Generating sample {self.design_num:06} ...') | |
# main sampling loop | |
for j in range(self.max_t): | |
self.t = torch.tensor(self.max_t-j-1).to(self.DEVICE) | |
# run features through the model to get X_o prediction | |
self.predict_x() | |
# save step | |
if self.args['save_all_steps']: | |
self.save_step() | |
# get seq out | |
self.features['seq_out'] = torch.permute(self.features['logit_aa_s'][0], (1,0)) | |
# save best seq | |
if self.features['pred_lddt'][~self.features['mask_seq']].mean().item() > self.features['best_plddt']: | |
self.features['best_seq'] = torch.argmax(torch.clone(self.features['seq_out']), dim=-1) | |
self.features['best_pred_lddt'] = torch.clone(self.features['pred_lddt']) | |
self.features['best_xyz'] = torch.clone(self.features['xyz']) | |
self.features['best_plddt'] = self.features['pred_lddt'][~self.features['mask_seq']].mean().item() | |
# self condition on sequence | |
self.self_condition_seq() | |
# self condition on structure | |
if self.args['scheduled_str_cond']: | |
self.self_condition_str_scheduled() | |
if self.args['struc_cond_sc']: | |
self.self_condition_str() | |
# sequence alterations | |
if self.args['softmax_seqout']: | |
self.features['seq_out'] = torch.softmax(self.features['seq_out'],dim=-1)*2-1 | |
if self.args['clamp_seqout']: | |
self.features['seq_out'] = torch.clamp(self.features['seq_out'], | |
min=-((1/self.diffuser.alphas_cumprod[t])*0.25+5), | |
max=((1/self.diffuser.alphas_cumprod[t])*0.25+5)) | |
# apply potentials | |
if self.use_potentials: | |
self.apply_potentials() | |
# noise to X_t-1 | |
if self.t != 0: | |
self.noise_x() | |
print(''.join([self.conversion[i] for i in torch.argmax(self.features['seq_out'],dim=-1)])) | |
print (" TIMESTEP [%02d/%02d] | current PLDDT: %.4f << >> best PLDDT: %.4f"%( | |
self.t+1, self.args['T'], self.features['pred_lddt'][~self.features['mask_seq']].mean().item(), | |
self.features['best_pred_lddt'][~self.features['mask_seq']].mean().item())) | |
# record time | |
self.delta_time = time.time() - self.start_time | |
# save outputs | |
self.save_outputs() | |
# increment design num | |
self.design_num += 1 | |
print(f'Finished design {self.out_prefix} in {self.delta_time/60:.2f} minutes.') | |
def save_outputs(self): | |
''' | |
save the outputs from the model | |
''' | |
# save trajectory | |
if self.args['save_all_steps']: | |
fname = f'{self.out_prefix}_trajectory.pt' | |
torch.save(self.trajecotry, fname) | |
# get items from best plddt step | |
if self.args['save_best_plddt']: | |
self.features['seq'] = torch.clone(self.features['best_seq']) | |
self.features['pred_lddt'] = torch.clone(self.features['best_pred_lddt']) | |
self.features['xyz'] = torch.clone(self.features['best_xyz']) | |
# get chain IDs | |
if (self.args['sampling_temp'] == 1.0 and self.args['trb'] == None) or (self.args['sequence'] not in ['',None]): | |
chain_ids = [i[0] for i in self.features['pdb_idx']] | |
elif self.args['dump_pdb']: | |
chain_ids = [i[0] for i in self.features['parsed_pdb']['pdb_idx']] | |
# write output pdb | |
fname = self.out_prefix + '.pdb' | |
if len(self.features['seq'].shape) == 2: | |
self.features['seq'] = self.features['seq'].squeeze() | |
write_pdb(fname, | |
self.features['seq'].type(torch.int64), | |
self.features['xyz'].squeeze(), | |
Bfacts=self.features['pred_lddt'].squeeze(), | |
chains=chain_ids) | |
if self.args['dump_trb']: | |
self.save_trb() | |
if self.args['save_args']: | |
self.save_args() | |
def save_trb(self): | |
''' | |
save trb file | |
''' | |
lddt = self.features['pred_lddt'].squeeze().cpu().numpy() | |
strmasktemp = self.features['mask_str'].squeeze().cpu().numpy() | |
partial_lddt = [lddt[i] for i in range(np.shape(strmasktemp)[0]) if strmasktemp[i] == 0] | |
trb = {} | |
trb['lddt'] = lddt | |
trb['inpaint_lddt'] = partial_lddt | |
trb['contigs'] = self.args['contigs'] | |
trb['device'] = self.DEVICE | |
trb['time'] = self.delta_time | |
trb['args'] = self.args | |
if self.args['sequence'] != None: | |
for key, value in self.features['trb_d'].items(): | |
trb[key] = value | |
else: | |
for key, value in self.features['mappings'].items(): | |
if key in self.features['trb_d'].keys(): | |
trb[key] = self.features['trb_d'][key] | |
else: | |
if len(value) > 0: | |
if type(value) == list and type(value[0]) != tuple: | |
value=np.array(value) | |
trb[key] = value | |
with open(f'{self.out_prefix}.trb','wb') as f_out: | |
pickle.dump(trb, f_out) | |
def save_args(self): | |
''' | |
save args | |
''' | |
with open(f'{self.out_prefix}_args.json','w') as f_out: | |
json.dump(self.args, f_out) | |
##################################################################### | |
###################### science is cool ############################## | |
##################################################################### | |
# making a custom sampler class for HuggingFace app | |
class HuggingFace_sampler(SEQDIFF_sampler): | |
def model_init(self): | |
''' | |
get model set up and choose checkpoint | |
''' | |
if self.args['checkpoint'] == None: | |
self.args['checkpoint'] = DEFAULT_CKPT | |
self.MODEL_PARAM['d_t1d'] = self.args['d_t1d'] | |
# check to make sure checkpoint chosen exists | |
if not os.path.exists(self.args['checkpoint']): | |
print('WARNING: couldn\'t find checkpoint') | |
self.ckpt = torch.load(self.args['checkpoint'], map_location=self.DEVICE) | |
# check to see if [loader_param, model_param, loss_param] is in checkpoint | |
# if so then you are using v2 of inference with t2d bug fixed | |
self.v2_mode = False | |
if 'model_param' in self.ckpt.keys(): | |
print('You are running a new v2 model switching into v2 inference mode') | |
self.v2_mode = True | |
for k in self.MODEL_PARAM.keys(): | |
if k in self.ckpt['model_param'].keys(): | |
self.MODEL_PARAM[k] = self.ckpt['model_param'][k] | |
else: | |
print(f'no match for {k} in loaded model params') | |
# make model and load checkpoint | |
print('Loading model checkpoint...') | |
self.model = RoseTTAFoldModule(**self.MODEL_PARAM).to(self.DEVICE) | |
model_state = self.ckpt['model_state_dict'] | |
self.model.load_state_dict(model_state, strict=False) | |
self.model.eval() | |
print('Successfully loaded model checkpoint') | |
def generate_sample(self): | |
''' | |
sample from the model | |
this function runs the full sampling loop | |
''' | |
# setup example | |
self.setup() | |
# start time | |
self.start_time = time.time() | |
# set up dictionary to save at each step in trajectory | |
self.trajectory = {} | |
# set out prefix | |
print(f'Generating sample {self.out_prefix} ...') | |
# main sampling loop | |
for j in range(self.max_t): | |
self.t = torch.tensor(self.max_t-j-1).to(self.DEVICE) | |
# run features through the model to get X_o prediction | |
self.predict_x() | |
# save step | |
if self.args['save_all_steps']: | |
self.save_step() | |
# get seq out | |
self.features['seq_out'] = torch.permute(self.features['logit_aa_s'][0], (1,0)) | |
# save best seq | |
if self.features['pred_lddt'].mean().item() > self.features['best_plddt']: | |
self.features['best_seq'] = torch.argmax(torch.clone(self.features['seq_out']), dim=-1) | |
self.features['best_pred_lddt'] = torch.clone(self.features['pred_lddt']) | |
self.features['best_xyz'] = torch.clone(self.features['xyz']) | |
self.features['best_plddt'] = self.features['pred_lddt'][~self.features['mask_seq']].mean().item() | |
# self condition on sequence | |
self.self_condition_seq() | |
# self condition on structure | |
if self.args['scheduled_str_cond']: | |
self.self_condition_str_scheduled() | |
if self.args['struc_cond_sc']: | |
self.self_condition_str() | |
# sequence alterations | |
if self.args['softmax_seqout']: | |
self.features['seq_out'] = torch.softmax(self.features['seq_out'],dim=-1)*2-1 | |
if self.args['clamp_seqout']: | |
self.features['seq_out'] = torch.clamp(self.features['seq_out'], | |
min=-((1/self.diffuser.alphas_cumprod[t])*0.25+5), | |
max=((1/self.diffuser.alphas_cumprod[t])*0.25+5)) | |
# apply potentials | |
if self.use_potentials: | |
self.apply_potentials() | |
# noise to X_t-1 | |
if self.t != 0: | |
self.noise_x() | |
print(''.join([self.conversion[i] for i in torch.argmax(self.features['seq_out'],dim=-1)])) | |
print (" TIMESTEP [%02d/%02d] | current PLDDT: %.4f << >> best PLDDT: %.4f"%( | |
self.t+1, self.args['T'], self.features['pred_lddt'][~self.features['mask_seq']].mean().item(), | |
self.features['best_pred_lddt'][~self.features['mask_seq']].mean().item())) | |
# record time | |
self.delta_time = time.time() - self.start_time | |
# save outputs | |
self.save_outputs() | |
# increment design num | |
self.design_num += 1 | |
print(f'Finished design {self.out_prefix} in {self.delta_time/60:.2f} minutes.') | |
def take_step_get_outputs(self, j): | |
self.t = torch.tensor(self.max_t-j-1).to(self.DEVICE) | |
# run features through the model to get X_o prediction | |
self.predict_x() | |
# save step | |
if self.args['save_all_steps']: | |
self.save_step() | |
# get seq out | |
self.features['seq_out'] = torch.permute(self.features['logit_aa_s'][0], (1,0)) | |
# save best seq | |
if self.features['pred_lddt'].mean().item() > self.features['best_plddt']: | |
self.features['best_seq'] = torch.argmax(torch.clone(self.features['seq_out']), dim=-1) | |
self.features['best_pred_lddt'] = torch.clone(self.features['pred_lddt']) | |
self.features['best_xyz'] = torch.clone(self.features['xyz']) | |
self.features['best_plddt'] = self.features['pred_lddt'].mean().item() | |
# WRITE OUTPUT TO GET TEMPORARY PDB TO DISPLAY | |
if self.t != 0: | |
self.features['seq'] = torch.argmax(torch.clone(self.features['seq_out']), dim=-1) | |
else: | |
# prepare final output | |
if self.args['save_args']: | |
self.save_args() | |
# get items from best plddt step | |
if self.args['save_best_plddt']: | |
self.features['seq'] = torch.clone(self.features['best_seq']) | |
self.features['pred_lddt'] = torch.clone(self.features['best_pred_lddt']) | |
self.features['xyz'] = torch.clone(self.features['best_xyz']) | |
# get chain IDs | |
if (self.args['sampling_temp'] == 1.0 and self.args['trb'] == None) or (self.args['sequence'] not in ['',None]): | |
chain_ids = [i[0] for i in self.features['pdb_idx']] | |
elif self.args['dump_pdb']: | |
chain_ids = [i[0] for i in self.features['parsed_pdb']['pdb_idx']] | |
# write output pdb | |
if len(self.features['seq'].shape) == 2: | |
self.features['seq'] = self.features['seq'].squeeze() | |
fname = f'{self.out_prefix}.pdb' | |
write_pdb(fname, self.features['seq'].type(torch.int64), | |
self.features['xyz'].squeeze(), | |
Bfacts=self.features['pred_lddt'].squeeze(), | |
chains=chain_ids) | |
aa_seq = ''.join([self.conversion[x] for x in self.features['seq'].tolist()]) | |
# self condition on sequence | |
self.self_condition_seq() | |
# self condition on structure | |
if self.args['scheduled_str_cond']: | |
self.self_condition_str_scheduled() | |
if self.args['struc_cond_sc']: | |
self.self_condition_str() | |
# sequence alterations | |
if self.args['softmax_seqout']: | |
self.features['seq_out'] = torch.softmax(self.features['seq_out'],dim=-1)*2-1 | |
if self.args['clamp_seqout']: | |
self.features['seq_out'] = torch.clamp(self.features['seq_out'], | |
min=-((1/self.diffuser.alphas_cumprod[t])*0.25+5), | |
max=((1/self.diffuser.alphas_cumprod[t])*0.25+5)) | |
# apply potentials | |
if self.use_potentials: | |
self.apply_potentials() | |
# noise to X_t-1 | |
if self.t != 0: | |
self.noise_x() | |
print(''.join([self.conversion[i] for i in torch.argmax(self.features['seq_out'],dim=-1)])) | |
print (" TIMESTEP [%02d/%02d] | current PLDDT: %.4f << >> best PLDDT: %.4f"%( | |
self.t+1, self.args['T'], self.features['pred_lddt'][~self.features['mask_seq']].mean().item(), | |
self.features['best_pred_lddt'][~self.features['mask_seq']].mean().item())) | |
return aa_seq, fname, self.features['pred_lddt'].mean().item() | |
def get_outputs(self): | |
aa_seq = ''.join([self.conversion[x] for x in self.features['seq'].tolist()]) | |
path_to_pdb = self.out_prefix+'.pdb' | |
return aa_seq, path_to_pdb, self.features['pred_lddt'].mean().item() | |