Skip to content

Commit

Permalink
Finalize removing icecream as a dependency.
Browse files Browse the repository at this point in the history
Icecream was removed from the installation dependencies, but was still being imported (but not used).
  • Loading branch information
roccomoretti authored and joewatchwell committed May 30, 2023
1 parent e783762 commit 06ae1cb
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 24 deletions.
45 changes: 22 additions & 23 deletions rfdiffusion/Embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from rfdiffusion.util_module import Dropout, create_custom_forward, rbf, init_lecun_normal
from rfdiffusion.Attention_module import Attention, FeedForwardLayer, AttentionWithBias
from rfdiffusion.Track_module import PairStr2Pair
from icecream import ic
import math

# Module contains classes and functions to generate initial embeddings
Expand Down Expand Up @@ -43,11 +42,11 @@ def __init__(self, d_msa=256, d_pair=128, d_state=32, d_init=22+22+2+2,
self.emb_state = nn.Embedding(22, d_state)
self.drop = nn.Dropout(p_drop)
self.pos = PositionalEncoding2D(d_pair, minpos=minpos, maxpos=maxpos, p_drop=p_drop)

self.input_seq_onehot=input_seq_onehot

self.reset_parameter()

def reset_parameter(self):
self.emb = init_lecun_normal(self.emb)
self.emb_q = init_lecun_normal(self.emb_q)
Expand All @@ -67,17 +66,17 @@ def forward(self, msa, seq, idx):
# - pair: Initial Pair embedding (B, L, L, d_pair)

N = msa.shape[1] # number of sequenes in MSA

# msa embedding
msa = self.emb(msa) # (B, N, L, d_model) # MSA embedding

# Sergey's one hot trick
tmp = (seq @ self.emb_q.weight).unsqueeze(1) # (B, 1, L, d_model) -- query embedding

msa = msa + tmp.expand(-1, N, -1, -1) # adding query embedding to MSA
msa = self.drop(msa)

# pair embedding
# pair embedding
# Sergey's one hot trick
left = (seq @ self.emb_left.weight)[:,None] # (B, 1, L, d_pair)
right = (seq @ self.emb_right.weight)[:,:,None] # (B, L, 1, d_pair)
Expand All @@ -99,9 +98,9 @@ def __init__(self, d_msa=256, d_init=22+1+2, p_drop=0.1, input_seq_onehot=False)
self.drop = nn.Dropout(p_drop)

self.input_seq_onehot=input_seq_onehot

self.reset_parameter()

def reset_parameter(self):
self.emb = init_lecun_normal(self.emb)
nn.init.zeros_(self.emb.bias)
Expand All @@ -115,7 +114,7 @@ def forward(self, msa, seq, idx):
# - msa: Initial MSA embedding (B, N, L, d_msa)
N = msa.shape[1] # number of sequenes in MSA
msa = self.emb(msa) # (B, N, L, d_model) # MSA embedding

# Sergey's one hot trick
seq = (seq @ self.emb_q.weight).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
msa = msa + seq.expand(-1, N, -1, -1) # adding query embedding to MSA
Expand Down Expand Up @@ -163,7 +162,7 @@ def forward(self, tors, pair, rbf_feat, use_checkpoint=False):
pair = pair.reshape(B*T, L, L, -1)
pair = torch.cat((pair, rbf_feat), dim=-1)
pair = self.proj_pair(pair)

for i_block in range(self.n_block):
if use_checkpoint:
tors = tors + checkpoint.checkpoint(create_custom_forward(self.row_attn[i_block]), tors, pair)
Expand All @@ -183,19 +182,19 @@ class Templ_emb(nn.Module):
# - confidence (1)
# - contacting or note (1). NB this is added for diffusion model. Used only in complex training examples - 1 signifies that a residue in the non-diffused chain\
# i.e. the context, is in contact with the diffused chain.
#
#
#Added extra t1d dimension for contacting or not
def __init__(self, d_t1d=21+1+1, d_t2d=43+1, d_tor=30, d_pair=128, d_state=32,
def __init__(self, d_t1d=21+1+1, d_t2d=43+1, d_tor=30, d_pair=128, d_state=32,
n_block=2, d_templ=64,
n_head=4, d_hidden=16, p_drop=0.25):
super(Templ_emb, self).__init__()
# process 2D features
self.emb = nn.Linear(d_t1d*2+d_t2d, d_templ)
self.templ_stack = TemplatePairStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
d_hidden=d_hidden, p_drop=p_drop)

self.attn = Attention(d_pair, d_templ, n_head, d_hidden, d_pair)

# process torsion angles
self.emb_t1d = nn.Linear(d_t1d+d_tor, d_templ)
self.proj_t1d = nn.Linear(d_templ, d_templ)
Expand All @@ -204,14 +203,14 @@ def __init__(self, d_t1d=21+1+1, d_t2d=43+1, d_tor=30, d_pair=128, d_state=32,
self.attn_tor = Attention(d_state, d_templ, n_head, d_hidden, d_state)

self.reset_parameter()

def reset_parameter(self):
self.emb = init_lecun_normal(self.emb)
nn.init.zeros_(self.emb.bias)

nn.init.kaiming_normal_(self.emb_t1d.weight, nonlinearity='relu')
nn.init.zeros_(self.emb_t1d.bias)

self.proj_t1d = init_lecun_normal(self.proj_t1d)
nn.init.zeros_(self.proj_t1d.bias)

Expand All @@ -237,7 +236,7 @@ def forward(self, t1d, t2d, alpha_t, xyz_t, pair, state, use_checkpoint=False):

# process each template features
t1d = self.proj_t1d(F.relu_(self.emb_t1d(t1d)))

# mixing query state features to template state features
state = state.reshape(B*L, 1, -1)
t1d = t1d.permute(0,2,1,3).reshape(B*L, T, -1)
Expand Down Expand Up @@ -270,9 +269,9 @@ def __init__(self, d_msa=256, d_pair=128, d_state=32):
self.norm_state = nn.LayerNorm(d_state)
self.norm_pair = nn.LayerNorm(d_pair)
self.norm_msa = nn.LayerNorm(d_msa)

self.reset_parameter()

def reset_parameter(self):
self.proj_dist = init_lecun_normal(self.proj_dist)
nn.init.zeros_(self.proj_dist.bias)
Expand All @@ -283,7 +282,7 @@ def forward(self, seq, msa, pair, xyz, state):
#
left = state.unsqueeze(2).expand(-1,-1,L,-1)
right = state.unsqueeze(1).expand(-1,L,-1,-1)

# three anchor atoms
N = xyz[:,:,0]
Ca = xyz[:,:,1]
Expand All @@ -293,8 +292,8 @@ def forward(self, seq, msa, pair, xyz, state):
b = Ca - N
c = C - Ca
a = torch.cross(b, c, dim=-1)
Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca

dist = rbf(torch.cdist(Cb, Cb))
dist = torch.cat((dist, left, right), dim=-1)
dist = self.proj_dist(dist)
Expand Down
1 change: 0 additions & 1 deletion rfdiffusion/contigs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import sys
import numpy as np
import random
from icecream import ic


class ContigMap:
Expand Down

0 comments on commit 06ae1cb

Please sign in to comment.