Skip to content

Commit

Permalink
Code cleaning suggestions from arogozhnikov
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Watson authored and joewatchwell committed May 5, 2023
1 parent 5c6f2f1 commit 0d629aa
Show file tree
Hide file tree
Showing 18 changed files with 35 additions and 681 deletions.
2 changes: 0 additions & 2 deletions config/inference/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ model:
l1_in_features: 3
l1_out_features: 2
num_edge_features: 64
d_time_emb: null
d_time_emb_proj: null
freeze_track_motif: False
use_motif_timestep: False

Expand Down
3 changes: 1 addition & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,11 @@ RUN apt-get -q update \
decorator==5.1.0 \
hydra-core==1.3.2 \
pyrsistent==0.19.3 \
icecream==2.1.3 \
/app/RFdiffusion/env/SE3Transformer \
&& pip install --no-cache-dir /app/RFdiffusion --no-deps

WORKDIR /app/RFdiffusion

ENV DGLBACKEND="pytorch"

ENTRYPOINT ["python3.9", "scripts/run_inference.py"]
ENTRYPOINT ["python3.9", "scripts/run_inference.py"]
1 change: 0 additions & 1 deletion env/SE3nv.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ dependencies:
- torchvision
- cudatoolkit=11.1
- dgl-cuda11.1
- icecream
- pip
- pip:
- hydra-core
Expand Down
2 changes: 1 addition & 1 deletion rfdiffusion/Attention_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def forward(self, src):

class Attention(nn.Module):
# calculate multi-head attention
def __init__(self, d_query, d_key, n_head, d_hidden, d_out, p_drop=0.1):
def __init__(self, d_query, d_key, n_head, d_hidden, d_out):
super(Attention, self).__init__()
self.h = n_head
self.dim = d_hidden
Expand Down
2 changes: 1 addition & 1 deletion rfdiffusion/AuxiliaryPredictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def forward(self, x):
return logits_dist, logits_omega, logits_theta, logits_phi

class MaskedTokenNetwork(nn.Module):
def __init__(self, n_feat, p_drop=0.1):
def __init__(self, n_feat):
super(MaskedTokenNetwork, self).__init__()
self.proj = nn.Linear(n_feat, 21)

Expand Down
88 changes: 2 additions & 86 deletions rfdiffusion/Embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,82 +12,6 @@

# Module contains classes and functions to generate initial embeddings

def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
# Code from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(max_positions) / (half_dim - 1)

emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode='constant')
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb

class Timestep_emb(nn.Module):

def __init__(
self,
input_size,
output_size,
T,
use_motif_timestep=True
):
super(Timestep_emb, self).__init__()

self.input_size = input_size
self.output_size = output_size
self.T = T

# get source for timestep embeddings at all t AND zero (for the motif)
self.source_embeddings = get_timestep_embedding(torch.arange(self.T+1), self.input_size)
self.source_embeddings.requires_grad = False

# Layers to use for projection
self.node_embedder = nn.Sequential(
nn.Linear(input_size, output_size, bias=False),
nn.ReLU(),
nn.Linear(output_size, output_size, bias=True),
nn.LayerNorm(output_size),
)


def get_init_emb(self, t, L, motif_mask):
"""
Calculates and stacks a timestep embedding to project
Parameters:
t (int, required): Current timestep
L (int, required): Length of protein
motif_mask (torch.tensor, required): Boolean mask where True denotes a fixed motif position
"""
assert t > 0, 't should be 1-indexed and cant have t=0'

t_emb = torch.clone(self.source_embeddings[t.squeeze()]).to(motif_mask.device)
zero_emb = torch.clone(self.source_embeddings[0]).to(motif_mask.device)

# timestep embedding for all residues
timestep_embedding = torch.stack([t_emb]*L)

# slice in motif zero timestep features
timestep_embedding[motif_mask] = zero_emb

return timestep_embedding


def forward(self, L, t, motif_mask):
"""
Constructs and projects a timestep embedding
"""
emb_in = self.get_init_emb(t,L,motif_mask)
emb_out = self.node_embedder(emb_in)
return emb_out

class PositionalEncoding2D(nn.Module):
# Add relative positional encoding to pair features
def __init__(self, d_model, minpos=-32, maxpos=32, p_drop=0.1):
Expand Down Expand Up @@ -194,14 +118,6 @@ def forward(self, msa, seq, idx):

# Sergey's one hot trick
seq = (seq @ self.emb_q.weight).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
"""
#TODO delete this once verified
if self.input_seq_onehot:
# Sergey's one hot trick
seq = (seq @ self.emb_q.weight).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
else:
seq = self.emb_q(seq).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
"""
msa = msa + seq.expand(-1, N, -1, -1) # adding query embedding to MSA
return self.drop(msa)

Expand Down Expand Up @@ -278,14 +194,14 @@ def __init__(self, d_t1d=21+1+1, d_t2d=43+1, d_tor=30, d_pair=128, d_state=32,
self.templ_stack = TemplatePairStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
d_hidden=d_hidden, p_drop=p_drop)

self.attn = Attention(d_pair, d_templ, n_head, d_hidden, d_pair, p_drop=p_drop)
self.attn = Attention(d_pair, d_templ, n_head, d_hidden, d_pair)

# process torsion angles
self.emb_t1d = nn.Linear(d_t1d+d_tor, d_templ)
self.proj_t1d = nn.Linear(d_templ, d_templ)
#self.tor_stack = TemplateTorsionStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
# d_hidden=d_hidden, p_drop=p_drop)
self.attn_tor = Attention(d_state, d_templ, n_head, d_hidden, d_state, p_drop=p_drop)
self.attn_tor = Attention(d_state, d_templ, n_head, d_hidden, d_state)

self.reset_parameter()

Expand Down
16 changes: 2 additions & 14 deletions rfdiffusion/RoseTTAFoldModel.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import torch
import torch.nn as nn
from rfdiffusion.Embeddings import MSA_emb, Extra_emb, Templ_emb, Recycling, Timestep_emb
from rfdiffusion.Embeddings import MSA_emb, Extra_emb, Templ_emb, Recycling
from rfdiffusion.Track_module import IterativeSimulator
from rfdiffusion.AuxiliaryPredictor import DistanceNetwork, MaskedTokenNetwork, ExpResolvedNetwork, LDDTNetwork
from opt_einsum import contract as einsum


class RoseTTAFoldModule(nn.Module):
def __init__(self,
n_extra_block,
Expand All @@ -23,8 +22,6 @@ def __init__(self,
p_drop,
d_t1d,
d_t2d,
d_time_emb, # total dims for input timestep emb
d_time_emb_proj, # size of projected timestep emb
T, # total timesteps (used in timestep emb
use_motif_timestep, # Whether to have a distinct emb for motif
freeze_track_motif, # Whether to freeze updates to motif in track
Expand All @@ -47,15 +44,6 @@ def __init__(self,
n_head=n_head_templ,
d_hidden=d_hidden_templ, p_drop=0.25, d_t1d=d_t1d, d_t2d=d_t2d)

# timestep embedder
if d_time_emb:
print('NOTE: Using sinusoidal timestep embeddings of dim ',d_time_emb, ' projected to dim ',d_time_emb_proj)
assert d_t1d >= 22 + d_time_emb_proj, 'timestep projection size doesn\'t fit into RF t1d projection layers'
self.timestep_embedder = Timestep_emb(input_size=d_time_emb,
output_size=d_time_emb_proj,
T=T,
use_motif_timestep=use_motif_timestep)


# Update inputs with outputs from previous round
self.recycle = Recycling(d_msa=d_msa, d_pair=d_pair, d_state=d_state)
Expand All @@ -72,7 +60,7 @@ def __init__(self,
p_drop=p_drop)
##
self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop)
self.aa_pred = MaskedTokenNetwork(d_msa, p_drop=p_drop)
self.aa_pred = MaskedTokenNetwork(d_msa)
self.lddt_pred = LDDTNetwork(d_state)

self.exp_pred = ExpResolvedNetwork(d_msa, d_state)
Expand Down
23 changes: 3 additions & 20 deletions rfdiffusion/coords6d.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,7 @@
import numpy as np
import scipy
import scipy.spatial

# calculate dihedral angles defined by 4 sets of points
def get_dihedrals(a, b, c, d):

b0 = -1.0*(b - a)
b1 = c - b
b2 = d - c

b1 /= np.linalg.norm(b1, axis=-1)[:,None]

v = b0 - np.sum(b0*b1, axis=-1)[:,None]*b1
w = b2 - np.sum(b2*b1, axis=-1)[:,None]*b1

x = np.sum(v*w, axis=-1)
y = np.sum(np.cross(b1, v)*w, axis=-1)

return np.arctan2(y, x)
from rfdiffusion.kinematics import get_dih

# calculate planar angles defined by 3 sets of points
def get_angles(a, b, c):
Expand Down Expand Up @@ -65,11 +49,10 @@ def get_coords6d(xyz, dmax):

# matrix of Ca-Cb-Cb-Ca dihedrals
omega6d = np.zeros((nres, nres), dtype=np.float32)
omega6d[idx0,idx1] = get_dihedrals(Ca[idx0], Cb[idx0], Cb[idx1], Ca[idx1])

omega6d[idx0,idx1] = get_dih(Ca[idx0], Cb[idx0], Cb[idx1], Ca[idx1])
# matrix of polar coord theta
theta6d = np.zeros((nres, nres), dtype=np.float32)
theta6d[idx0,idx1] = get_dihedrals(N[idx0], Ca[idx0], Cb[idx0], Cb[idx1])
theta6d[idx0,idx1] = get_dih(N[idx0], Ca[idx0], Cb[idx0], Cb[idx1])

# matrix of polar coord phi
phi6d = np.zeros((nres, nres), dtype=np.float32)
Expand Down
Loading

0 comments on commit 0d629aa

Please sign in to comment.