diff --git a/config/inference/base.yaml b/config/inference/base.yaml index f2147c2..1c8cae1 100644 --- a/config/inference/base.yaml +++ b/config/inference/base.yaml @@ -64,8 +64,6 @@ model: l1_in_features: 3 l1_out_features: 2 num_edge_features: 64 - d_time_emb: null - d_time_emb_proj: null freeze_track_motif: False use_motif_timestep: False diff --git a/docker/Dockerfile b/docker/Dockerfile index 6e74dcb..8364cb1 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -40,7 +40,6 @@ RUN apt-get -q update \ decorator==5.1.0 \ hydra-core==1.3.2 \ pyrsistent==0.19.3 \ - icecream==2.1.3 \ /app/RFdiffusion/env/SE3Transformer \ && pip install --no-cache-dir /app/RFdiffusion --no-deps @@ -48,4 +47,4 @@ WORKDIR /app/RFdiffusion ENV DGLBACKEND="pytorch" -ENTRYPOINT ["python3.9", "scripts/run_inference.py"] \ No newline at end of file +ENTRYPOINT ["python3.9", "scripts/run_inference.py"] diff --git a/env/SE3nv.yml b/env/SE3nv.yml index cd02194..a51bcce 100644 --- a/env/SE3nv.yml +++ b/env/SE3nv.yml @@ -12,7 +12,6 @@ dependencies: - torchvision - cudatoolkit=11.1 - dgl-cuda11.1 - - icecream - pip - pip: - hydra-core diff --git a/rfdiffusion/Attention_module.py b/rfdiffusion/Attention_module.py index 8a30573..f8868fc 100644 --- a/rfdiffusion/Attention_module.py +++ b/rfdiffusion/Attention_module.py @@ -31,7 +31,7 @@ def forward(self, src): class Attention(nn.Module): # calculate multi-head attention - def __init__(self, d_query, d_key, n_head, d_hidden, d_out, p_drop=0.1): + def __init__(self, d_query, d_key, n_head, d_hidden, d_out): super(Attention, self).__init__() self.h = n_head self.dim = d_hidden diff --git a/rfdiffusion/AuxiliaryPredictor.py b/rfdiffusion/AuxiliaryPredictor.py index af392f9..dd246e1 100644 --- a/rfdiffusion/AuxiliaryPredictor.py +++ b/rfdiffusion/AuxiliaryPredictor.py @@ -34,7 +34,7 @@ def forward(self, x): return logits_dist, logits_omega, logits_theta, logits_phi class MaskedTokenNetwork(nn.Module): - def __init__(self, n_feat, p_drop=0.1): + def __init__(self, n_feat): super(MaskedTokenNetwork, self).__init__() self.proj = nn.Linear(n_feat, 21) diff --git a/rfdiffusion/Embeddings.py b/rfdiffusion/Embeddings.py index 45f709c..fac242f 100644 --- a/rfdiffusion/Embeddings.py +++ b/rfdiffusion/Embeddings.py @@ -12,82 +12,6 @@ # Module contains classes and functions to generate initial embeddings -def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): - # Code from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py - assert len(timesteps.shape) == 1 - half_dim = embedding_dim // 2 - emb = math.log(max_positions) / (half_dim - 1) - - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = F.pad(emb, (0, 1), mode='constant') - assert emb.shape == (timesteps.shape[0], embedding_dim) - return emb - -class Timestep_emb(nn.Module): - - def __init__( - self, - input_size, - output_size, - T, - use_motif_timestep=True - ): - super(Timestep_emb, self).__init__() - - self.input_size = input_size - self.output_size = output_size - self.T = T - - # get source for timestep embeddings at all t AND zero (for the motif) - self.source_embeddings = get_timestep_embedding(torch.arange(self.T+1), self.input_size) - self.source_embeddings.requires_grad = False - - # Layers to use for projection - self.node_embedder = nn.Sequential( - nn.Linear(input_size, output_size, bias=False), - nn.ReLU(), - nn.Linear(output_size, output_size, bias=True), - nn.LayerNorm(output_size), - ) - - - def get_init_emb(self, t, L, motif_mask): - """ - Calculates and stacks a timestep embedding to project - - Parameters: - - t (int, required): Current timestep - - L (int, required): Length of protein - - motif_mask (torch.tensor, required): Boolean mask where True denotes a fixed motif position - """ - assert t > 0, 't should be 1-indexed and cant have t=0' - - t_emb = torch.clone(self.source_embeddings[t.squeeze()]).to(motif_mask.device) - zero_emb = torch.clone(self.source_embeddings[0]).to(motif_mask.device) - - # timestep embedding for all residues - timestep_embedding = torch.stack([t_emb]*L) - - # slice in motif zero timestep features - timestep_embedding[motif_mask] = zero_emb - - return timestep_embedding - - - def forward(self, L, t, motif_mask): - """ - Constructs and projects a timestep embedding - """ - emb_in = self.get_init_emb(t,L,motif_mask) - emb_out = self.node_embedder(emb_in) - return emb_out - class PositionalEncoding2D(nn.Module): # Add relative positional encoding to pair features def __init__(self, d_model, minpos=-32, maxpos=32, p_drop=0.1): @@ -194,14 +118,6 @@ def forward(self, msa, seq, idx): # Sergey's one hot trick seq = (seq @ self.emb_q.weight).unsqueeze(1) # (B, 1, L, d_model) -- query embedding - """ - #TODO delete this once verified - if self.input_seq_onehot: - # Sergey's one hot trick - seq = (seq @ self.emb_q.weight).unsqueeze(1) # (B, 1, L, d_model) -- query embedding - else: - seq = self.emb_q(seq).unsqueeze(1) # (B, 1, L, d_model) -- query embedding - """ msa = msa + seq.expand(-1, N, -1, -1) # adding query embedding to MSA return self.drop(msa) @@ -278,14 +194,14 @@ def __init__(self, d_t1d=21+1+1, d_t2d=43+1, d_tor=30, d_pair=128, d_state=32, self.templ_stack = TemplatePairStack(n_block=n_block, d_templ=d_templ, n_head=n_head, d_hidden=d_hidden, p_drop=p_drop) - self.attn = Attention(d_pair, d_templ, n_head, d_hidden, d_pair, p_drop=p_drop) + self.attn = Attention(d_pair, d_templ, n_head, d_hidden, d_pair) # process torsion angles self.emb_t1d = nn.Linear(d_t1d+d_tor, d_templ) self.proj_t1d = nn.Linear(d_templ, d_templ) #self.tor_stack = TemplateTorsionStack(n_block=n_block, d_templ=d_templ, n_head=n_head, # d_hidden=d_hidden, p_drop=p_drop) - self.attn_tor = Attention(d_state, d_templ, n_head, d_hidden, d_state, p_drop=p_drop) + self.attn_tor = Attention(d_state, d_templ, n_head, d_hidden, d_state) self.reset_parameter() diff --git a/rfdiffusion/RoseTTAFoldModel.py b/rfdiffusion/RoseTTAFoldModel.py index 3e9bbb3..84fbac4 100644 --- a/rfdiffusion/RoseTTAFoldModel.py +++ b/rfdiffusion/RoseTTAFoldModel.py @@ -1,11 +1,10 @@ import torch import torch.nn as nn -from rfdiffusion.Embeddings import MSA_emb, Extra_emb, Templ_emb, Recycling, Timestep_emb +from rfdiffusion.Embeddings import MSA_emb, Extra_emb, Templ_emb, Recycling from rfdiffusion.Track_module import IterativeSimulator from rfdiffusion.AuxiliaryPredictor import DistanceNetwork, MaskedTokenNetwork, ExpResolvedNetwork, LDDTNetwork from opt_einsum import contract as einsum - class RoseTTAFoldModule(nn.Module): def __init__(self, n_extra_block, @@ -23,8 +22,6 @@ def __init__(self, p_drop, d_t1d, d_t2d, - d_time_emb, # total dims for input timestep emb - d_time_emb_proj, # size of projected timestep emb T, # total timesteps (used in timestep emb use_motif_timestep, # Whether to have a distinct emb for motif freeze_track_motif, # Whether to freeze updates to motif in track @@ -47,15 +44,6 @@ def __init__(self, n_head=n_head_templ, d_hidden=d_hidden_templ, p_drop=0.25, d_t1d=d_t1d, d_t2d=d_t2d) - # timestep embedder - if d_time_emb: - print('NOTE: Using sinusoidal timestep embeddings of dim ',d_time_emb, ' projected to dim ',d_time_emb_proj) - assert d_t1d >= 22 + d_time_emb_proj, 'timestep projection size doesn\'t fit into RF t1d projection layers' - self.timestep_embedder = Timestep_emb(input_size=d_time_emb, - output_size=d_time_emb_proj, - T=T, - use_motif_timestep=use_motif_timestep) - # Update inputs with outputs from previous round self.recycle = Recycling(d_msa=d_msa, d_pair=d_pair, d_state=d_state) @@ -72,7 +60,7 @@ def __init__(self, p_drop=p_drop) ## self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop) - self.aa_pred = MaskedTokenNetwork(d_msa, p_drop=p_drop) + self.aa_pred = MaskedTokenNetwork(d_msa) self.lddt_pred = LDDTNetwork(d_state) self.exp_pred = ExpResolvedNetwork(d_msa, d_state) diff --git a/rfdiffusion/coords6d.py b/rfdiffusion/coords6d.py index 9f10d59..d322454 100644 --- a/rfdiffusion/coords6d.py +++ b/rfdiffusion/coords6d.py @@ -1,23 +1,7 @@ import numpy as np import scipy import scipy.spatial - -# calculate dihedral angles defined by 4 sets of points -def get_dihedrals(a, b, c, d): - - b0 = -1.0*(b - a) - b1 = c - b - b2 = d - c - - b1 /= np.linalg.norm(b1, axis=-1)[:,None] - - v = b0 - np.sum(b0*b1, axis=-1)[:,None]*b1 - w = b2 - np.sum(b2*b1, axis=-1)[:,None]*b1 - - x = np.sum(v*w, axis=-1) - y = np.sum(np.cross(b1, v)*w, axis=-1) - - return np.arctan2(y, x) +from rfdiffusion.kinematics import get_dih # calculate planar angles defined by 3 sets of points def get_angles(a, b, c): @@ -65,11 +49,10 @@ def get_coords6d(xyz, dmax): # matrix of Ca-Cb-Cb-Ca dihedrals omega6d = np.zeros((nres, nres), dtype=np.float32) - omega6d[idx0,idx1] = get_dihedrals(Ca[idx0], Cb[idx0], Cb[idx1], Ca[idx1]) - + omega6d[idx0,idx1] = get_dih(Ca[idx0], Cb[idx0], Cb[idx1], Ca[idx1]) # matrix of polar coord theta theta6d = np.zeros((nres, nres), dtype=np.float32) - theta6d[idx0,idx1] = get_dihedrals(N[idx0], Ca[idx0], Cb[idx0], Cb[idx1]) + theta6d[idx0,idx1] = get_dih(N[idx0], Ca[idx0], Cb[idx0], Cb[idx1]) # matrix of polar coord phi phi6d = np.zeros((nres, nres), dtype=np.float32) diff --git a/rfdiffusion/diff_util.py b/rfdiffusion/diff_util.py deleted file mode 100644 index 4a40964..0000000 --- a/rfdiffusion/diff_util.py +++ /dev/null @@ -1,255 +0,0 @@ -import torch -import numpy as np -import random - -from rfdiffusion.chemical import INIT_CRDS -from icecream import ic - - -def th_min_angle(start, end, radians=False): - """ - Finds the angle you would add to in order to get to - on the shortest path. - """ - a,b,c = (np.pi, 2*np.pi, 3*np.pi) if radians else (180, 360, 540) - shortest_angle = ((((end - start) % b) + c) % b) - a - return shortest_angle - - -def th_interpolate_angles(start, end, T, n_diffuse,mindiff=None, radians=True): - """ - - """ - # find the minimum angle to add to get from start to end - angle_diffs = th_min_angle(start, end, radians=radians) - if mindiff is not None: - assert torch.sum(mindiff.flatten()-angle_diffs) == 0. - if n_diffuse is None: - # default is to diffuse for max steps - n_diffuse = torch.full((len(angle_diffs)), T) - - - interps = [] - for i,diff in enumerate(angle_diffs): - N = int(n_diffuse[i]) - actual_interp = torch.linspace(start[i], start[i]+diff, N) - whole_interp = torch.full((T,), float(start[i]+diff)) - temp=torch.clone(whole_interp) - whole_interp[:N] = actual_interp - - interps.append(whole_interp) - - return torch.stack(interps, dim=0) - - -def th_interpolate_angle_single(start, end, step, T, mindiff=None, radians=True): - """ - - """ - # find the minimum angle to add to get from start to end - angle_diffs = th_min_angle(start, end, radians=radians) - if mindiff is not None: - assert torch.sum(mindiff.flatten()-angle_diffs) == 0. - - # linearly interpolate between x = [0, T-1], y = [start, start + diff] - x_range = T-1 - interps = step / x_range * angle_diffs + start - - return interps - - -def get_aa_schedule(T, L, nsteps=100): - """ - Returns the steps t when each amino acid should be decoded, - as well as how many steps that amino acids chi angles will be diffused - - Parameters: - T (int, required): Total number of steps we are decoding the sequence over - - L (int, required): Length of protein sequence - - nsteps (int, optional): Number of steps over the course of which to decode the amino acids - - Returns: three items - decode_times (list): List of times t when the positions in should be decoded - - decode_order (list): List of lists, each element containing which positions are going to be decoded at - the corresponding time in - - idx2diffusion_steps (np.array): Array mapping the index of the residue to how many diffusion steps it will require - - """ - # nsteps can't be more than T or more than length of protein - if (nsteps > T) or (nsteps > L): - nsteps = min([T,L]) - - - decode_order = [[a] for a in range(L)] - random.shuffle(decode_order) - - while len(decode_order) > nsteps: - # pop an element and then add those positions randomly to some other step - tmp_seqpos = decode_order.pop() - decode_order[random.randint(0,len(decode_order)-1)] += tmp_seqpos - random.shuffle(decode_order) - - decode_times = np.arange(nsteps)+1 - - # now given decode times, calculate number of diffusion steps each position gets - aa_masks = np.full((200,L), False) - - idx2diffusion_steps = np.full((L,),float(np.nan)) - for i,t in enumerate(decode_times): - decode_pos = decode_order[i] # positions to be decoded at this step - - for j,pos in enumerate(decode_pos): - # calculate number of diffusion steps this residue gets - idx2diffusion_steps[pos] = int(t) - aa_masks[t,pos] = True - - aa_masks = np.cumsum(aa_masks, axis=0) - - return decode_times, decode_order, idx2diffusion_steps, ~(aa_masks.astype(bool)) - -#################### -### for SecStruc ### -#################### - -def ss_to_tensor(ss_dict): - """ - Function to convert ss files to indexed tensors - 0 = Helix - 1 = Strand - 2 = Loop - 3 = Mask/unknown - 4 = idx for pdb - """ - ss_conv = {'H':0,'E':1,'L':2} - ss_int = np.array([int(ss_conv[i]) for i in ss_dict['ss']]) - return ss_int - -def mask_ss(ss, min_mask = 0, max_mask = 0.75): - """ - Function to take ss array, find the junctions, and randomly mask these until a random proportion (up to 75%) is masked - Input: numpy array of ss (H=0,E=1,L=2,mask=3) - output: tensor with some proportion of junctions masked - """ - mask_prop = random.uniform(min_mask, max_mask) - transitions = np.where(ss[:-1] - ss[1:] != 0)[0] #gets last index of each block of ss - counter = 0 - #TODO think about masking whole ss elements - while len(ss[ss == 3])/len(ss) < mask_prop and counter < 100: #very hacky - do better - try: - width = random.randint(1,9) - start = random.choice(transitions) - offset = random.randint(-8,1) - ss[start+offset:start+offset+width] = 3 - counter += 1 - except: - counter += 1 - ss = torch.tensor(ss) - mask = torch.where(ss == 3, True, False) - ss = torch.nn.functional.one_hot(ss, num_classes=4) - return ss, mask - -def construct_block_adj_matrix( sstruct, xyz, nan_mask, cutoff=6, include_loops=False ): - ''' - Given a sstruct specification and backbone coordinates, build a block adjacency matrix. - - Input: - - sstruct (torch.FloatTensor): (L) length tensor with numeric encoding of sstruct at each position - - xyz (torch.FloatTensor): (L,3,3) tensor of Cartesian coordinates of backbone N,Ca,C atoms - - cutoff (float): The Cb distance cutoff under which residue pairs are considered adjacent - By eye, Nate thinks 6A is a good Cb distance cutoff - - Output: - - block_adj (torch.FloatTensor): (L,L) boolean matrix where adjacent secondary structure contacts are 1 - ''' - - # Remove nans at this stage, as ss doesn't consider nans - xyz_nonan = xyz[nan_mask] - L = xyz_nonan.shape[0] - assert L == sstruct.shape[0] - # three anchor atoms - N = xyz_nonan[:,0] - Ca = xyz_nonan[:,1] - C = xyz_nonan[:,2] - - # recreate Cb given N,Ca,C - Cb = generate_Cbeta(N,Ca,C) - - dist = get_pair_dist(Cb,Cb) # [L,L] - dist[torch.isnan(dist)] = 999.9 - assert torch.sum(torch.isnan(dist)) == 0 - dist += 999.9*torch.eye(L,device=xyz.device) - - # Now we have dist matrix and sstruct specification, turn this into a block adjacency matrix - - # First: Construct a list of segments and the index at which they begin and end - in_segment = True - segments = [] - - begin = -1 - end = -1 - # need to expand ss out to size L - - - for i in range(sstruct.shape[0]): - # Starting edge case - if i == 0: - begin = 0 - continue - - if not sstruct[i] == sstruct[i-1]: - end = i - segments.append( (sstruct[i-1], begin, end) ) - - begin = i - - # Ending edge case: last segment is length one - if not end == sstruct.shape[0]: - segments.append( (sstruct[-1], begin, sstruct.shape[0]) ) - - # Second: Using segments and dgram, determine adjacent blocks - block_adj = torch.zeros_like(dist) - for i in range(len(segments)): - curr_segment = segments[i] - - if curr_segment[0] == 2 and not include_loops: continue - - begin_i = curr_segment[1] - end_i = curr_segment[2] - for j in range(i+1, len(segments)): - j_segment = segments[j] - - if j_segment[0] == 2 and not include_loops: continue - - begin_j = j_segment[1] - end_j = j_segment[2] - - if torch.any( dist[begin_i:end_i, begin_j:end_j] < cutoff ): - # Matrix is symmetic - block_adj[begin_i:end_i, begin_j:end_j] = torch.ones(end_i - begin_i, end_j - begin_j) - block_adj[begin_j:end_j, begin_i:end_i] = torch.ones(end_j - begin_j, end_i - begin_i) - - return block_adj - -def get_pair_dist(a, b): - """calculate pair distances between two sets of points - - Parameters - ---------- - a,b : pytorch tensors of shape [batch,nres,3] - store Cartesian coordinates of two sets of atoms - Returns - ------- - dist : pytorch tensor of shape [batch,nres,nres] - stores paitwise distances between atoms in a and b - """ - - dist = torch.cdist(a, b, p=2) - return dist diff --git a/rfdiffusion/diffusion.py b/rfdiffusion/diffusion.py index 5db5d9d..a67e579 100644 --- a/rfdiffusion/diffusion.py +++ b/rfdiffusion/diffusion.py @@ -237,8 +237,7 @@ def _calc_igso3_vals(self, L=2000): num_sigma=self.num_sigma, min_sigma=self.min_sigma, max_sigma=self.max_sigma, - num_omega=self.num_omega, - L=L, + num_omega=self.num_omega ) write_pkl(cache_fname, igso3_vals) diff --git a/rfdiffusion/igso3.py b/rfdiffusion/igso3.py index 94b626f..6d90bdb 100644 --- a/rfdiffusion/igso3.py +++ b/rfdiffusion/igso3.py @@ -71,7 +71,7 @@ def igso3_score(R, t, L=L_default): unit_vector = np.einsum('Nij,Njk->Nik', R, log(R))/omega[:, None, None] return unit_vector * d_logf_d_omega(omega, t, L)[:, None, None] -def calculate_igso3(*, num_sigma, num_omega, min_sigma, max_sigma, L=L_default): +def calculate_igso3(*, num_sigma, num_omega, min_sigma, max_sigma): """calculate_igso3 pre-computes numerical approximations to the IGSO3 cdfs and score norms and expected squared score norms. diff --git a/rfdiffusion/inference/model_runners.py b/rfdiffusion/inference/model_runners.py index 8c55cfb..26dd1b3 100644 --- a/rfdiffusion/inference/model_runners.py +++ b/rfdiffusion/inference/model_runners.py @@ -247,7 +247,6 @@ def construct_denoiser(self, L, visible): 'L': L, 'diffuser': self.diffuser, 'potential_manager': self.potential_manager, - 'visible': visible }) return iu.Denoise(**denoise_kwargs) diff --git a/rfdiffusion/inference/utils.py b/rfdiffusion/inference/utils.py index 59d1a74..43e7e99 100644 --- a/rfdiffusion/inference/utils.py +++ b/rfdiffusion/inference/utils.py @@ -223,8 +223,6 @@ def __init__( T, L, diffuser, - visible, - seq_diffuser=None, b_0=0.001, b_T=0.1, min_b=1.0, @@ -256,7 +254,6 @@ def __init__( self.T = T self.L = L self.diffuser = diffuser - self.seq_diffuser = seq_diffuser self.b_0 = b_0 self.b_T = b_T self.noise_level = noise_level @@ -301,8 +298,6 @@ def align_to_xt_motif(self, px0, xT, diffusion_mask, eps=1e-6): Third, centre at origin """ - # if True: - # return px0 def rmsd(V, W, eps=0): # First sum down atoms, then sum down xyz N = V.shape[-2] @@ -358,17 +353,12 @@ def rmsd(V, W, eps=0): px0[~atom_mask] = 0 # convert nans to 0 px0 = px0.reshape(-1, 3) - px0_motif_mean px0_ = px0 @ R - # xT_motif_out = xT_motif.reshape(-1,3) - # xT_motif_out = (xT_motif_out @ R ) + px0_motif_mean - # ic(xT_motif_out.shape) - # xT_motif_out = xT_motif_out.reshape((diffusion_mask.sum(),3,3)) # 3 put in same global position as xT px0_ = px0_ + xT_motif_mean px0_ = px0_.reshape([L, n_atom, 3]) px0_[~atom_mask] = float("nan") return torch.Tensor(px0_) - # return torch.tensor(xT_motif_out) def get_potential_gradients(self, xyz, diffusion_mask): """ diff --git a/rfdiffusion/kinematics.py b/rfdiffusion/kinematics.py index 32c7959..8d54839 100644 --- a/rfdiffusion/kinematics.py +++ b/rfdiffusion/kinematics.py @@ -1,6 +1,7 @@ import numpy as np import torch from rfdiffusion.chemical import INIT_CRDS +from rfdiffusion.util import generate_Cbeta PARAMS = { "DMIN" : 2.0, @@ -55,13 +56,18 @@ def get_dih(a, b, c, d): Parameters ---------- - a,b,c,d : pytorch tensors of shape [batch,nres,3] + a,b,c,d : pytorch tensors or numpy array of shape [batch,nres,3] store Cartesian coordinates of four sets of atoms Returns ------- - dih : pytorch tensor of shape [batch,nres] + dih : pytorch tensor or numpy array of shape [batch,nres] stores resulting dihedrals """ + convert_to_torch = lambda *arrays: [torch.from_numpy(arr) for arr in arrays] + output_np=False + if isinstance(a, np.ndarray): + output_np=True + a,b,c,d = convert_to_torch(a,b,c,d) b0 = a - b b1 = c - b b2 = d - c @@ -73,18 +79,10 @@ def get_dih(a, b, c, d): x = torch.sum(v*w, dim=-1) y = torch.sum(torch.cross(b1,v,dim=-1)*w, dim=-1) - - return torch.atan2(y, x) - -def get_Cb(xyz): - '''recreate Cb given N,Ca,C''' - N = xyz[...,0,:] - Ca = xyz[...,1,:] - C = xyz[...,2,:] - b = Ca - N - c = C - Ca - a = torch.cross(b, c, dim=-1) - return -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca + output = torch.atan2(y, x) + if output_np: + return output.numpy() + return output # ============================================================ def xyz_to_c6d(xyz, params=PARAMS): @@ -108,7 +106,7 @@ def xyz_to_c6d(xyz, params=PARAMS): N = xyz[:,:,0] Ca = xyz[:,:,1] C = xyz[:,:,2] - Cb = get_Cb(xyz) + Cb = generate_Cbeta(N, Ca, C) # 6d coordinates order: (dist,omega,theta,phi) c6d = torch.zeros([batch,nres,nres,4],dtype=xyz.dtype,device=xyz.device) diff --git a/rfdiffusion/potentials/manager.py b/rfdiffusion/potentials/manager.py index 1be86ff..98b0d5e 100644 --- a/rfdiffusion/potentials/manager.py +++ b/rfdiffusion/potentials/manager.py @@ -113,9 +113,6 @@ def __init__(self, if setting['type'] in potentials.require_binderlen: setting.update(binderlen_update) - if setting['type'] in potentials.require_hotspot_res: - setting.update(hotspot_res_update) - self.potentials_to_apply = self.initialize_all_potentials(setting_list) self.T = diffuser_config.T @@ -199,7 +196,7 @@ def get_guide_scale(self, t): # Linear interpolation with y2: 0, y1: guide_scale, x2: 0, x1: T, x: t 'linear' : lambda t: t/self.T * self.guide_scale, 'quadratic' : lambda t: t**2/self.T**2 * self.guide_scale, - 'cubic' : lambda t: t**3/self.T**3 + 'cubic' : lambda t: t**3/self.T**3 * self.guide_scale } if self.guide_decay not in implemented_decay_types: diff --git a/rfdiffusion/potentials/potentials.py b/rfdiffusion/potentials/potentials.py index 1eaafa5..b43a2a6 100644 --- a/rfdiffusion/potentials/potentials.py +++ b/rfdiffusion/potentials/potentials.py @@ -146,52 +146,6 @@ def compute(self, xyz): #Potential value is the average of both radii of gyration (is avg. the best way to do this?) return self.weight * binder_ncontacts.sum() - -class dimer_ncontacts(Potential): - - ''' - Differentiable way to maximise number of contacts for two individual monomers in a dimer - - Motivation is given here: https://www.plumed.org/doc-v2.7/user-doc/html/_c_o_o_r_d_i_n_a_t_i_o_n.html - - Author: PV - ''' - - - def __init__(self, binderlen, weight=1, r_0=8, d_0=4): - - self.binderlen = binderlen - self.r_0 = r_0 - self.weight = weight - self.d_0 = d_0 - - def compute(self, xyz): - - # Only look at binder Ca residues - Ca = xyz[:self.binderlen,1] # [Lb,3] - #cdist needs a batch dimension - NRB - dgram = torch.cdist(Ca[None,...].contiguous(), Ca[None,...].contiguous(), p=2) # [1,Lb,Lb] - divide_by_r_0 = (dgram - self.d_0) / self.r_0 - numerator = torch.pow(divide_by_r_0,6) - denominator = torch.pow(divide_by_r_0,12) - binder_ncontacts = (1 - numerator) / (1 - denominator) - #Potential is the sum of values in the tensor - binder_ncontacts = binder_ncontacts.sum() - - # Only look at target Ca residues - Ca = xyz[self.binderlen:,1] # [Lb,3] - dgram = torch.cdist(Ca[None,...].contiguous(), Ca[None,...].contiguous(), p=2) # [1,Lb,Lb] - divide_by_r_0 = (dgram - self.d_0) / self.r_0 - numerator = torch.pow(divide_by_r_0,6) - denominator = torch.pow(divide_by_r_0,12) - target_ncontacts = (1 - numerator) / (1 - denominator) - #Potential is the sum of values in the tensor - target_ncontacts = target_ncontacts.sum() - - print("DIMER NCONTACTS:", (binder_ncontacts+target_ncontacts)/2) - #Returns average of n contacts withiin monomer 1 and monomer 2 - return self.weight * (binder_ncontacts+target_ncontacts)/2 - class interface_ncontacts(Potential): ''' @@ -266,42 +220,6 @@ def compute(self, xyz): return self.weight * ncontacts.sum() -def make_contact_matrix(nchain, contact_string=None): - """ - Calculate a matrix of inter/intra chain contact indicators - - Parameters: - nchain (int, required): How many chains are in this design - - contact_str (str, optional): String denoting how to define contacts, comma delimited between pairs of chains - '!' denotes repulsive, '&' denotes attractive - """ - alphabet = [a for a in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'] - letter2num = {a:i for i,a in enumerate(alphabet)} - - contacts = np.zeros((nchain,nchain)) - written = np.zeros((nchain,nchain)) - - contact_list = contact_string.split(',') - for c in contact_list: - if not len(c) == 3: - raise SyntaxError('Invalid contact(s) specification') - - i,j = letter2num[c[0]],letter2num[c[2]] - symbol = c[1] - - # denote contacting/repulsive - assert symbol in ['!','&'] - if symbol == '!': - contacts[i,j] = -1 - contacts[j,i] = -1 - else: - contacts[i,j] = 1 - contacts[j,i] = 1 - - return contacts - - class olig_contacts(Potential): """ Applies PV's num contacts potential within/between chains in symmetric oligomers @@ -343,17 +261,6 @@ def __init__(self, self.nchain=shape[0] - # self._compute_chain_indices() - - # def _compute_chain_indices(self): - # # make list of shape [i,N] for indices of each chain in total length - # indices = [] - # start = 0 - # for l in self.chain_lengths: - # indices.append(torch.arange(start,start+l)) - # start += l - # self.indices = indices - def _get_idx(self,i,L): """ Returns the zero-indexed indices of the residues in chain i @@ -398,51 +305,6 @@ def compute(self, xyz): return all_contacts - -class olig_intra_contacts(Potential): - """ - Applies PV's num contacts potential for each chain individually in an oligomer design - - Author: DJ - """ - - def __init__(self, chain_lengths, weight=1): - """ - Parameters: - - chain_lengths (list, required): Ordered list of chain lengths - - weight (int/float, optional): Scaling/weighting factor - """ - self.chain_lengths = chain_lengths - self.weight = weight - - - def compute(self, xyz): - """ - Computes intra-chain num contacts potential - """ - assert sum(self.chain_lengths) == xyz.shape[0], 'given chain lengths do not match total protein length' - - all_contacts = 0 - start = 0 - for Lc in self.chain_lengths: - Ca = xyz[start:start+Lc] # slice out crds for this chain - dgram = torch.cdist(Ca[None,...].contiguous(), Ca[None,...].contiguous(), p=2) # [1,Lb,Lb] - divide_by_r_0 = (dgram - self.d_0) / self.r_0 - numerator = torch.pow(divide_by_r_0,6) - denominator = torch.pow(divide_by_r_0,12) - ncontacts = (1 - numerator) / (1 - denominator) - - # add contacts for this chain to all contacts - all_contacts += ncontacts.sum() - - # increment the start to be at the next chain - start += Lc - - - return self.weight * all_contacts - def get_damped_lj(r_min, r_lin,p1=6,p2=12): y_at_r_lin = lj(r_lin, r_min, p1, p2) @@ -592,131 +454,15 @@ def _grab_motif_residues(self, xyz) -> None: self.motif_frame = xyz[rand_idx[0],:4] self.motif_mapping = [(rand_idx, i) for i in range(4)] -class binder_distance_ReLU(Potential): - ''' - Given the current coordinates of the diffusion trajectory, calculate a potential that is the distance between each residue - and the closest target residue. - - This potential is meant to encourage the binder to interact with a certain subset of residues on the target that - define the binding site. - - Author: NRB - ''' - - def __init__(self, binderlen, hotspot_res, weight=1, min_dist=15, use_Cb=False): - - self.binderlen = binderlen - self.hotspot_res = [res + binderlen for res in hotspot_res] - self.weight = weight - self.min_dist = min_dist - self.use_Cb = use_Cb - - def compute(self, xyz): - binder = xyz[:self.binderlen,:,:] # (Lb,27,3) - target = xyz[self.hotspot_res,:,:] # (N,27,3) - - if self.use_Cb: - N = binder[:,0] - Ca = binder[:,1] - C = binder[:,2] - - Cb = generate_Cbeta(N,Ca,C) # (Lb,3) - - N_t = target[:,0] - Ca_t = target[:,1] - C_t = target[:,2] - - Cb_t = generate_Cbeta(N_t,Ca_t,C_t) # (N,3) - - dgram = torch.cdist(Cb[None,...], Cb_t[None,...], p=2) # (1,Lb,N) - - else: - # Use Ca dist for potential - - Ca = binder[:,1] # (Lb,3) - - Ca_t = target[:,1] # (N,3) - - dgram = torch.cdist(Ca[None,...], Ca_t[None,...], p=2) # (1,Lb,N) - - closest_dist = torch.min(dgram.squeeze(0), dim=1)[0] # (Lb) - - # Cap the distance at a minimum value - min_distance = self.min_dist * torch.ones_like(closest_dist) # (Lb) - potential = torch.maximum(min_distance, closest_dist) # (Lb) - - # torch.Tensor.backward() requires the potential to be a single value - potential = torch.sum(potential, dim=-1) - - return -1 * self.weight * potential - -class binder_any_ReLU(Potential): - ''' - Given the current coordinates of the diffusion trajectory, calculate a potential that is the minimum distance between - ANY residue and the closest target residue. - - In contrast to binder_distance_ReLU this potential will only penalize a pose if all of the binder residues are outside - of a certain distance from the target residues. - - Author: NRB - ''' - - def __init__(self, binderlen, hotspot_res, weight=1, min_dist=15, use_Cb=False): - - self.binderlen = binderlen - self.hotspot_res = [res + binderlen for res in hotspot_res] - self.weight = weight - self.min_dist = min_dist - self.use_Cb = use_Cb - - def compute(self, xyz): - binder = xyz[:self.binderlen,:,:] # (Lb,27,3) - target = xyz[self.hotspot_res,:,:] # (N,27,3) - - if use_Cb: - N = binder[:,0] - Ca = binder[:,1] - C = binder[:,2] - - Cb = generate_Cbeta(N,Ca,C) # (Lb,3) - - N_t = target[:,0] - Ca_t = target[:,1] - C_t = target[:,2] - - Cb_t = generate_Cbeta(N_t,Ca_t,C_t) # (N,3) - - dgram = torch.cdist(Cb[None,...], Cb_t[None,...], p=2) # (1,Lb,N) - - else: - # Use Ca dist for potential - - Ca = binder[:,1] # (Lb,3) - - Ca_t = target[:,1] # (N,3) - - dgram = torch.cdist(Ca[None,...], Ca_t[None,...], p=2) # (1,Lb,N) - - - closest_dist = torch.min(dgram.squeeze(0)) # (1) - - potential = torch.maximum(min_dist, closest_dist) # (1) - - return -1 * self.weight * potential - # Dictionary of types of potentials indexed by name of potential. Used by PotentialManager. # If you implement a new potential you must add it to this dictionary for it to be used by # the PotentialManager implemented_potentials = { 'monomer_ROG': monomer_ROG, 'binder_ROG': binder_ROG, - 'binder_distance_ReLU': binder_distance_ReLU, - 'binder_any_ReLU': binder_any_ReLU, 'dimer_ROG': dimer_ROG, 'binder_ncontacts': binder_ncontacts, - 'dimer_ncontacts': dimer_ncontacts, 'interface_ncontacts': interface_ncontacts, 'monomer_contacts': monomer_contacts, - 'olig_intra_contacts': olig_intra_contacts, 'olig_contacts': olig_contacts, 'substrate_contacts': substrate_contacts} @@ -725,9 +471,5 @@ def compute(self, xyz): 'binder_any_ReLU', 'dimer_ROG', 'binder_ncontacts', - 'dimer_ncontacts', 'interface_ncontacts'} -require_hotspot_res = { 'binder_distance_ReLU', - 'binder_any_ReLU' } - diff --git a/rfdiffusion/util.py b/rfdiffusion/util.py index ca3ee89..19c30f5 100644 --- a/rfdiffusion/util.py +++ b/rfdiffusion/util.py @@ -8,9 +8,10 @@ def generate_Cbeta(N, Ca, C): b = Ca - N c = C - Ca a = torch.cross(b, c, dim=-1) - # Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca + # These are the values used during training + Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca # fd: below matches sidechain generator (=Rosetta params) - Cb = -0.57910144 * a + 0.5689693 * b - 0.5441217 * c + Ca + # Cb = -0.57910144 * a + 0.5689693 * b - 0.5441217 * c + Ca return Cb diff --git a/rfdiffusion/util_module.py b/rfdiffusion/util_module.py index 38380b2..20ba2dc 100644 --- a/rfdiffusion/util_module.py +++ b/rfdiffusion/util_module.py @@ -7,7 +7,7 @@ import dgl from rfdiffusion.util import base_indices, RTs_by_torsion, xyzs_in_base_frame, rigid_from_3_points -def init_lecun_normal(module, scale=1.0): +def init_lecun_normal(module): def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2): normal = torch.distributions.normal.Normal(0, 1) @@ -23,14 +23,14 @@ def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2): return x - def sample_truncated_normal(shape, scale=1.0): - stddev = np.sqrt(scale/shape[-1])/.87962566103423978 # shape[-1] = fan_in + def sample_truncated_normal(shape): + stddev = np.sqrt(1.0/shape[-1])/.87962566103423978 # shape[-1] = fan_in return stddev * truncated_normal(torch.rand(shape)) module.weight = torch.nn.Parameter( (sample_truncated_normal(module.weight.shape)) ) return module -def init_lecun_normal_param(weight, scale=1.0): +def init_lecun_normal_param(weight): def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2): normal = torch.distributions.normal.Normal(0, 1) @@ -46,8 +46,8 @@ def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2): return x - def sample_truncated_normal(shape, scale=1.0): - stddev = np.sqrt(scale/shape[-1])/.87962566103423978 # shape[-1] = fan_in + def sample_truncated_normal(shape): + stddev = np.sqrt(1.0/shape[-1])/.87962566103423978 # shape[-1] = fan_in return stddev * truncated_normal(torch.rand(shape)) weight = torch.nn.Parameter( (sample_truncated_normal(weight.shape)) )