Skip to content

Commit

Permalink
Made necessary changes for integration tests based on example submiss…
Browse files Browse the repository at this point in the history
…ions
  • Loading branch information
Joseph Watson authored and joewatchwell committed May 5, 2023
1 parent 49b9245 commit 5c6f2f1
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 60 deletions.
12 changes: 1 addition & 11 deletions examples/design_cyclic_oligos.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,4 @@
# We decay this potential with quadratic form, so that it is applied more strongly initially
# We specify a total length of 480aa, so each chain is 80 residues long

python ../scripts/run_inference.py \
--config-name=symmetry \
inference.symmetry="C6" \
inference.num_designs=10 \
inference.output_prefix="example_outputs/C6_oligo" \
'potentials.guiding_potentials=["type:olig_contacts,weight_intra:1,weight_inter:0.1"]' \
potentials.olig_intra_all=True \
potentials.olig_inter_all=True \
potentials.guide_scale=2.0 \
potentials.guide_decay="quadratic" \
'contigmap.contigs=[480-480]'
python ../scripts/run_inference.py --config-name=symmetry inference.symmetry="C6" inference.num_designs=10 inference.output_prefix="example_outputs/C6_oligo" 'potentials.guiding_potentials=["type:olig_contacts,weight_intra:1,weight_inter:0.1"]' potentials.olig_intra_all=True potentials.olig_inter_all=True potentials.guide_scale=2.0 potentials.guide_decay="quadratic" 'contigmap.contigs=[480-480]'
12 changes: 1 addition & 11 deletions examples/design_dihedral_oligos.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,4 @@
# We decay this potential with quadratic form, so that it is applied more strongly initially
# We specify a total length of 320aa, so each chain is 80 residues long

python ../scripts/run_inference.py \
--config-name=symmetry \
inference.symmetry="D2" \
inference.num_designs=10 \
inference.output_prefix="example_outputs/D2_oligo" \
'potentials.guiding_potentials=["type:olig_contacts,weight_intra:1,weight_inter:0.1"]' \
potentials.olig_intra_all=True \
potentials.olig_inter_all=True \
potentials.guide_scale=2.0 \
potentials.guide_decay="quadratic" \
'contigmap.contigs=[320-320]'
python ../scripts/run_inference.py --config-name=symmetry inference.symmetry="D2" inference.num_designs=10 inference.output_prefix="example_outputs/D2_oligo" 'potentials.guiding_potentials=["type:olig_contacts,weight_intra:1,weight_inter:0.1"]' potentials.olig_intra_all=True potentials.olig_inter_all=True potentials.guide_scale=2.0 potentials.guide_decay="quadratic" 'contigmap.contigs=[320-320]'
12 changes: 1 addition & 11 deletions examples/design_tetrahedral_oligos.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,4 @@
# We decay this potential with quadratic form, so that it is applied more strongly initially
# We specify a total length of 1200aa, so each chain is 100 residues long

python ../scripts/run_inference.py \
--config-name=symmetry \
inference.symmetry="tetrahedral" \
inference.num_designs=10 \
inference.output_prefix="example_outputs/tetrahedral_oligo" \
'potentials.guiding_potentials=["type:olig_contacts,weight_intra:1,weight_inter:0.1"]' \
potentials.olig_intra_all=True \
potentials.olig_inter_all=True \
potentials.guide_scale=2.0 \
potentials.guide_decay="quadratic" \
'contigmap.contigs=[1200-1200]'
python ../scripts/run_inference.py --config-name=symmetry inference.symmetry="tetrahedral" inference.num_designs=10 inference.output_prefix="example_outputs/tetrahedral_oligo" 'potentials.guiding_potentials=["type:olig_contacts,weight_intra:1,weight_inter:0.1"]' potentials.olig_intra_all=True potentials.olig_inter_all=True potentials.guide_scale=2.0 potentials.guide_decay="quadratic" 'contigmap.contigs=[1200-1200]'
2 changes: 1 addition & 1 deletion rfdiffusion/inference/model_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def __init__(self, conf: DictConfig):
"""
super().__init__(conf)
# initialize BlockAdjacency sampling class
self.blockadjacency = iu.BlockAdjacency(conf.scaffoldguided, conf.inference.num_designs)
self.blockadjacency = iu.BlockAdjacency(conf, conf.inference.num_designs)

#################################################
### Initialize target, if doing binder design ###
Expand Down
63 changes: 37 additions & 26 deletions rfdiffusion/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,8 @@ def sampler_selector(conf: DictConfig):

def parse_pdb(filename, **kwargs):
"""extract xyz coords for all heavy atoms"""
lines = open(filename, "r").readlines()
with open(filename,"r") as f:
lines=f.readlines()
return parse_pdb_lines(lines, **kwargs)


Expand Down Expand Up @@ -697,15 +698,16 @@ def __init__(self, conf, num_designs):
conf.scaffold_list as conf
conf.inference.num_designs for sanity checking
"""


self.conf=conf
# either list or path to .txt file with list of scaffolds
if conf.scaffold_list is not None:
if type(conf.scaffold_list) == list:
if self.conf.scaffoldguided.scaffold_list is not None:
if type(self.conf.scaffoldguided.scaffold_list) == list:
self.scaffold_list = scaffold_list
elif conf.scaffold_list[-4:] == ".txt":
elif self.conf.scaffoldguided.scaffold_list[-4:] == ".txt":
# txt file with list of ids
list_from_file = []
with open(conf.scaffold_list, "r") as f:
with open(self.conf.scaffoldguided.scaffold_list, "r") as f:
for line in f:
list_from_file.append(line.strip())
self.scaffold_list = list_from_file
Expand All @@ -714,43 +716,45 @@ def __init__(self, conf, num_designs):
else:
self.scaffold_list = [
os.path.split(i)[1][:-6]
for i in glob.glob(f"{conf.scaffold_dir}/*_ss.pt")
for i in glob.glob(f"{self.conf.scaffoldguided.scaffold_dir}/*_ss.pt")
]
self.scaffold_list.sort()

# path to directory with scaffolds, ss files and block_adjacency files
self.scaffold_dir = conf.scaffold_dir
self.scaffold_dir = self.conf.scaffoldguided.scaffold_dir

# maximum sampled insertion in each loop segment
if "-" in str(conf.sampled_insertion):
if "-" in str(self.conf.scaffoldguided.sampled_insertion):
self.sampled_insertion = [
int(str(conf.sampled_insertion).split("-")[0]),
int(str(conf.sampled_insertion).split("-")[1]),
int(str(self.conf.scaffoldguided.sampled_insertion).split("-")[0]),
int(str(self.conf.scaffoldguided.sampled_insertion).split("-")[1]),
]
else:
self.sampled_insertion = [0, int(conf.sampled_insertion)]
self.sampled_insertion = [0, int(self.conf.scaffoldguided.sampled_insertion)]

# maximum sampled insertion at N- and C-terminus
if "-" in str(conf.sampled_N):
if "-" in str(self.conf.scaffoldguided.sampled_N):
self.sampled_N = [
int(str(conf.sampled_N).split("-")[0]),
int(str(conf.sampled_N).split("-")[1]),
int(str(self.conf.scaffoldguided.sampled_N).split("-")[0]),
int(str(self.conf.scaffoldguided.sampled_N).split("-")[1]),
]
else:
self.sampled_N = [0, int(conf.sampled_N)]
if "-" in str(conf.sampled_C):
self.sampled_N = [0, int(self.conf.scaffoldguided.sampled_N)]
if "-" in str(self.conf.scaffoldguided.sampled_C):
self.sampled_C = [
int(str(conf.sampled_C).split("-")[0]),
int(str(conf.sampled_C).split("-")[1]),
int(str(self.conf.scaffoldguided.sampled_C).split("-")[0]),
int(str(self.conf.scaffoldguided.sampled_C).split("-")[1]),
]
else:
self.sampled_C = [0, int(conf.sampled_C)]
self.sampled_C = [0, int(self.conf.scaffoldguided.sampled_C)]

# number of residues to mask ss identity of in H/E regions (from junction)
# e.g. if ss_mask = 2, L,L,L,H,H,H,H,H,H,H,L,L,E,E,E,E,E,E,L,L,L,L,L,L would become\
# M,M,M,M,M,H,H,H,M,M,M,M,M,M,E,E,M,M,M,M,M,M,M,M where M is mask
self.ss_mask = conf.ss_mask
self.ss_mask = self.conf.scaffoldguided.ss_mask

# whether or not to work systematically through the list
self.systematic = conf.systematic
self.systematic = self.conf.scaffoldguided.systematic

self.num_designs = num_designs

Expand All @@ -765,10 +769,10 @@ def __init__(self, conf, num_designs):
self.item_n = 0

# whether to mask loops or not
if not conf.mask_loops:
assert conf.sampled_N == 0, "can't add length if not masking loops"
assert conf.sampled_C == 0, "can't add lemgth if not masking loops"
assert conf.sampled_insertion == 0, "can't add length if not masking loops"
if not self.conf.scaffoldguided.mask_loops:
assert self.conf.scaffoldguided.sampled_N == 0, "can't add length if not masking loops"
assert self.conf.scaffoldguided.sampled_C == 0, "can't add lemgth if not masking loops"
assert self.conf.scaffoldguided.sampled_insertion == 0, "can't add length if not masking loops"
self.mask_loops = False
else:
self.mask_loops = True
Expand Down Expand Up @@ -880,6 +884,13 @@ def get_scaffold(self):
"""
Wrapper method for pulling an item from the list, and preparing ss and block adj features
"""

# Handle determinism. Useful for integration tests
if self.conf.inference.deterministic:
torch.manual_seed(self.num_completed)
np.random.seed(self.num_completed)
random.seed(self.num_completed)

if self.systematic:
# reset if num designs > num_scaffolds
if self.item_n >= len(self.scaffold_list):
Expand Down
28 changes: 28 additions & 0 deletions rfdiffusion/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,3 +712,31 @@ def writepdb_multi(
ctr += 1

f.write("ENDMDL\n")

def calc_rmsd(xyz1, xyz2, eps=1e-6):
"""
Calculates RMSD between two sets of atoms (L, 3)
"""
# center to CA centroid
xyz1 = xyz1 - xyz1.mean(0)
xyz2 = xyz2 - xyz2.mean(0)

# Computation of the covariance matrix
C = xyz2.T @ xyz1

# Compute otimal rotation matrix using SVD
V, S, W = np.linalg.svd(C)

# get sign to ensure right-handedness
d = np.ones([3,3])
d[:,-1] = np.sign(np.linalg.det(V)*np.linalg.det(W))

# Rotation matrix U
U = (d*V) @ W

# Rotate xyz2
xyz2_ = xyz2 @ U
L = xyz2_.shape[0]
rmsd = np.sqrt(np.sum((xyz2_-xyz1)*(xyz2_-xyz1), axis=(0,1)) / L + eps)

return rmsd, U

0 comments on commit 5c6f2f1

Please sign in to comment.