@@ -403,6 +403,7 @@ def predict_structure(
403
403
if "multimer" in model_type :
404
404
# TODO: add multimer padding
405
405
input_features = processed_feature_dict
406
+ input_features ["asym_id" ] = input_features ["asym_id" ] - input_features ["asym_id" ][...,0 ]
406
407
else :
407
408
# TODO: move asym_id processing to "process_features"
408
409
r = processed_feature_dict ["aatype" ].shape [0 ]
@@ -427,28 +428,24 @@ def callback(prediction_result, recycles):
427
428
print_line += f" { y } ={ prediction_result [x ]:.3g} "
428
429
logger .info (f"{ tag } recycle={ recycles } { print_line } " )
429
430
430
- if save_recycles or save_all :
431
- prediction_result = _jnp_to_np (prediction_result )
432
- prediction_result ["representations" ] = prediction_result .pop ("prev" )
433
-
434
431
if save_recycles :
435
- final_atom_mask = prediction_result ["structure_module" ]["final_atom_mask" ]
436
- b_factors = prediction_result ["plddt" ][:, None ] * final_atom_mask
432
+ result = _jnp_to_np (prediction_result )
433
+ final_atom_mask = result ["structure_module" ]["final_atom_mask" ]
434
+ b_factors = result ["plddt" ][:, None ] * final_atom_mask
437
435
unrelaxed_protein = protein .from_prediction (features = input_features ,
438
- result = prediction_result , b_factors = b_factors ,
436
+ result = result , b_factors = b_factors ,
439
437
remove_leading_feature_dimension = ("ptm" in model_type ))
440
438
441
439
unrelaxed_pdb_lines = protein .to_pdb (class_to_np (unrelaxed_protein ))
442
440
files .get ("unrelaxed" ,f"r{ recycles } .pdb" ).write_text (unrelaxed_pdb_lines )
443
441
444
- if save_all :
445
- with files .get ("all" ,f"r{ recycles } .pickle" ).open ("wb" ) as handle :
446
- pickle .dump (prediction_result , handle )
442
+ if save_all :
443
+ with files .get ("all" ,f"r{ recycles } .pickle" ).open ("wb" ) as handle :
444
+ pickle .dump (result , handle )
447
445
448
446
prediction_result , recycles = \
449
447
model_runner .predict (input_features , random_seed = seed , prediction_callback = callback )
450
448
prediction_result = _jnp_to_np (prediction_result )
451
- prediction_result ["representations" ] = prediction_result .pop ("prev" )
452
449
prediction_times .append (time .time () - start )
453
450
454
451
########################
@@ -482,19 +479,23 @@ def callback(prediction_result, recycles):
482
479
483
480
#########################
484
481
# save results
485
- #########################
482
+ #########################
483
+
486
484
# save pdb
487
485
protein_lines = protein .to_pdb (unrelaxed_protein )
488
486
files .get ("unrelaxed" ,"pdb" ).write_text (protein_lines )
489
487
unrelaxed_pdb_lines .append (protein_lines )
490
488
491
489
# save raw outputs
492
- if save_single_representations or save_pair_representations :
493
- rep = prediction_result ["representations" ]
494
- if save_single_representations :
495
- np .save (files .get ("single_repr" ,"npy" ), rep ["prev_msa_first_row" ])
496
- if save_pair_representations :
497
- np .save (files .get ("pair_repr" ,"npy" ), rep ["prev_pair" ])
490
+ if save_all :
491
+ with files .get ("all" ,"pickle" ).open ("wb" ) as handle :
492
+ pickle .dump (prediction_result , handle )
493
+ if save_single_representations :
494
+ np .save (files .get ("single_repr" ,"npy" ),
495
+ prediction_result ["prev" ]["prev_msa_first_row" ])
496
+ if save_pair_representations :
497
+ np .save (files .get ("pair_repr" ,"npy" ),
498
+ prediction_result ["prev" ]["prev_pair" ])
498
499
499
500
# write an easy-to-use format (pAE and pLDDT)
500
501
with files .get ("scores" ,"json" ).open ("w" ) as handle :
@@ -1186,6 +1187,7 @@ def run(
1186
1187
dpi : int = 200 ,
1187
1188
max_seq : Optional [int ] = None ,
1188
1189
max_extra_seq : Optional [int ] = None ,
1190
+ use_cluster_profile : bool = True ,
1189
1191
feature_dict_callback : Callable [[Any ], Any ] = None ,
1190
1192
** kwargs
1191
1193
):
@@ -1234,7 +1236,6 @@ def run(
1234
1236
pair_mode = old_names .get (pair_mode ,pair_mode )
1235
1237
feature_dict_callback = kwargs .pop ("input_features_callback" , feature_dict_callback )
1236
1238
use_dropout = kwargs .pop ("training" , use_dropout )
1237
- use_cluster_profile = kwargs .pop ("use_cluster_profile" , None )
1238
1239
use_fuse = kwargs .pop ("use_fuse" , True )
1239
1240
use_bfloat16 = kwargs .pop ("use_bfloat16" , True )
1240
1241
max_msa = kwargs .pop ("max_msa" ,None )
@@ -1659,7 +1660,7 @@ def main():
1659
1660
help = "rank models by auto, plddt or ptmscore" ,
1660
1661
type = str ,
1661
1662
default = "auto" ,
1662
- choices = ["auto" , "plddt" , "ptmscore " , "multimer" ],
1663
+ choices = ["auto" , "plddt" , "ptm" , "iptm " , "multimer" ],
1663
1664
)
1664
1665
parser .add_argument (
1665
1666
"--pair-mode" ,
@@ -1711,6 +1712,12 @@ def main():
1711
1712
type = str ,
1712
1713
default = None ,
1713
1714
)
1715
+ parser .add_argument (
1716
+ "--disable-cluster-profile" ,
1717
+ default = False ,
1718
+ action = "store_true" ,
1719
+ help = "EXPERIMENTAL: for multimer models, disable cluster profiles" ,
1720
+ )
1714
1721
parser .add_argument (
1715
1722
"--zip" ,
1716
1723
default = False ,
@@ -1798,6 +1805,7 @@ def main():
1798
1805
max_seq = args .max_seq ,
1799
1806
max_extra_seq = args .max_extra_seq ,
1800
1807
max_msa = args .max_msa ,
1808
+ use_cluster_profile = not args .disable_cluster_profile ,
1801
1809
use_gpu_relax = args .use_gpu_relax ,
1802
1810
save_all = args .save_all ,
1803
1811
save_recycles = args .save_recycles ,
0 commit comments