|
1 | 1 | from pathlib import Path
|
2 | 2 | from functools import wraps, partialmethod
|
3 | 3 | from typing import Tuple, List, Optional
|
4 |
| - |
5 | 4 | import haiku
|
6 |
| - |
7 | 5 | from alphafold.model import model, config, data
|
8 | 6 | from alphafold.model.modules import AlphaFold
|
9 | 7 | from alphafold.model.modules_multimer import AlphaFold as AlphaFoldMultimer
|
|
12 | 10 | def load_models_and_params(
|
13 | 11 | num_models: int,
|
14 | 12 | use_templates: bool,
|
15 |
| - num_recycle: int = 3, |
| 13 | + num_recycles: Optional[int] = None, |
| 14 | + recycle_early_stop_tolerance: Optional[float] = None, |
16 | 15 | num_ensemble: int = 1,
|
17 | 16 | model_order: Optional[List[int]] = None,
|
18 | 17 | model_suffix: str = "_ptm",
|
19 | 18 | data_dir: Path = Path("."),
|
20 |
| - recompile_all_models: bool = False, |
21 | 19 | stop_at_score: float = 100,
|
22 |
| - rank_by: str = "plddt", |
23 |
| - return_representations: bool = False, |
24 |
| - training: bool = False, |
25 |
| - max_msa: str = None, |
| 20 | + rank_by: str = "auto", |
| 21 | + max_seq: Optional[int] = None, |
| 22 | + max_extra_seq: Optional[int] = None, |
| 23 | + use_cluster_profile: Optional[bool] = None, |
| 24 | + use_fuse: bool = True, |
| 25 | + use_bfloat16: bool = True, |
| 26 | + use_dropout: bool = False, |
| 27 | + |
26 | 28 | ) -> List[Tuple[str, model.RunModel, haiku.Params]]:
|
27 | 29 | """We use only two actual models and swap the parameters to avoid recompiling.
|
28 | 30 |
|
29 | 31 | Note that models 1 and 2 have a different number of parameters compared to models 3, 4 and 5,
|
30 | 32 | so we load model 1 and model 3.
|
31 | 33 | """
|
32 | 34 |
|
33 |
| - if return_representations: |
34 |
| - # this forces the AlphaFold to always return representations |
35 |
| - AlphaFold.__call__ = partialmethod( |
36 |
| - AlphaFold.__call__, return_representations=True |
37 |
| - ) |
38 |
| - |
39 |
| - AlphaFoldMultimer.__call__ = partialmethod( |
40 |
| - AlphaFoldMultimer.__call__, return_representations=True |
41 |
| - ) |
42 |
| - |
43 |
| - if not model_order: |
44 |
| - model_order = [3, 4, 5, 1, 2] |
45 |
| - |
46 | 35 | # Use only two model and later swap params to avoid recompiling
|
47 | 36 | model_runner_and_params: [Tuple[str, model.RunModel, haiku.Params]] = []
|
48 | 37 |
|
49 |
| - if recompile_all_models: |
50 |
| - for n, model_number in enumerate(model_order): |
51 |
| - if n == num_models: |
52 |
| - break |
53 |
| - model_name = f"model_{model_number}" |
54 |
| - params = data.get_model_haiku_params( |
55 |
| - model_name=model_name + model_suffix, data_dir=str(data_dir) |
56 |
| - ) |
57 |
| - model_config = config.model_config(model_name + model_suffix) |
58 |
| - model_config.model.stop_at_score = float(stop_at_score) |
59 |
| - model_config.model.stop_at_score_ranker = rank_by |
60 |
| - if max_msa != None: |
61 |
| - max_msa_clusters, max_extra_msa = [int(x) for x in max_msa.split(":")] |
62 |
| - model_config.data.eval.max_msa_clusters = max_msa_clusters |
63 |
| - model_config.data.common.max_extra_msa = max_extra_msa |
64 |
| - if model_suffix == "_ptm": |
65 |
| - model_config.data.common.num_recycle = num_recycle |
66 |
| - model_config.model.num_recycle = num_recycle |
67 |
| - model_config.data.eval.num_ensemble = num_ensemble |
68 |
| - elif model_suffix.startswith("_multimer"): |
69 |
| - model_config.model.num_recycle = num_recycle |
70 |
| - if training: |
71 |
| - model_config.model.num_ensemble_train = num_ensemble |
72 |
| - else: |
73 |
| - model_config.model.num_ensemble_eval = num_ensemble |
74 |
| - model_runner_and_params.append( |
75 |
| - ( |
76 |
| - model_name, |
77 |
| - model.RunModel(model_config, params, is_training=training), |
78 |
| - params, |
79 |
| - ) |
80 |
| - ) |
| 38 | + if model_order is None: model_order = [1, 2, 3, 4, 5] |
| 39 | + |
| 40 | + model_build_order = [3, 4, 5, 1, 2] |
| 41 | + if "multimer" in model_suffix: |
| 42 | + models_need_compilation = [3] |
81 | 43 | else:
|
| 44 | + # only models 1,2 use templates |
82 | 45 | models_need_compilation = [1, 3] if use_templates else [3]
|
83 |
| - model_build_order = [3, 4, 5, 1, 2] |
84 |
| - model_runner_and_params_build_order: [ |
85 |
| - Tuple[str, model.RunModel, haiku.Params] |
86 |
| - ] = [] |
87 |
| - model_runner = None |
88 |
| - for model_number in model_build_order: |
89 |
| - if model_number in models_need_compilation: |
90 |
| - model_config = config.model_config( |
91 |
| - "model_" + str(model_number) + model_suffix |
92 |
| - ) |
93 |
| - model_config.model.stop_at_score = float(stop_at_score) |
94 |
| - model_config.model.stop_at_score_ranker = rank_by |
95 |
| - if max_msa != None: |
96 |
| - max_msa_clusters, max_extra_msa = [ |
97 |
| - int(x) for x in max_msa.split(":") |
98 |
| - ] |
99 |
| - model_config.data.eval.max_msa_clusters = max_msa_clusters |
100 |
| - model_config.data.common.max_extra_msa = max_extra_msa |
101 |
| - if model_suffix == "_ptm": |
102 |
| - model_config.data.common.num_recycle = num_recycle |
103 |
| - model_config.model.num_recycle = num_recycle |
104 |
| - model_config.data.eval.num_ensemble = num_ensemble |
105 |
| - elif model_suffix.startswith("_multimer"): |
106 |
| - model_config.model.num_recycle = num_recycle |
107 |
| - if training: |
108 |
| - model_config.model.num_ensemble_train = num_ensemble |
109 |
| - else: |
110 |
| - model_config.model.num_ensemble_eval = num_ensemble |
111 |
| - model_runner = model.RunModel( |
112 |
| - model_config, |
113 |
| - data.get_model_haiku_params( |
114 |
| - model_name="model_" + str(model_number) + model_suffix, |
115 |
| - data_dir=str(data_dir), |
116 |
| - ), |
117 |
| - is_training=training, |
118 |
| - ) |
119 |
| - model_name = f"model_{model_number}" |
| 46 | + |
| 47 | + model_runner_and_params_build_order: [Tuple[str, model.RunModel, haiku.Params]] = [] |
| 48 | + model_runner = None |
| 49 | + for model_number in model_build_order: |
| 50 | + if model_number in models_need_compilation: |
| 51 | + |
| 52 | + # get configurations |
| 53 | + model_config = config.model_config("model_" + str(model_number) + model_suffix) |
| 54 | + model_config.model.stop_at_score = float(stop_at_score) |
| 55 | + model_config.model.rank_by = rank_by |
| 56 | + |
| 57 | + # set dropouts |
| 58 | + model_config.model.global_config.eval_dropout = use_dropout |
| 59 | + |
| 60 | + # set bfloat options |
| 61 | + model_config.model.global_config.bfloat16 = use_bfloat16 |
| 62 | + |
| 63 | + # set fuse options |
| 64 | + model_config.model.embeddings_and_evoformer.evoformer.triangle_multiplication_incoming.fuse_projection_weights = use_fuse |
| 65 | + model_config.model.embeddings_and_evoformer.evoformer.triangle_multiplication_outgoing.fuse_projection_weights = use_fuse |
| 66 | + if "multimer" in model_suffix or model_number in [1,2]: |
| 67 | + model_config.model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_incoming.fuse_projection_weights = use_fuse |
| 68 | + model_config.model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_outgoing.fuse_projection_weights = use_fuse |
| 69 | + |
| 70 | + # set number of sequences options |
| 71 | + if max_seq is not None: |
| 72 | + if "multimer" in model_suffix: |
| 73 | + model_config.model.embeddings_and_evoformer.num_msa = max_seq |
| 74 | + else: |
| 75 | + model_config.data.eval.max_msa_clusters = max_seq |
| 76 | + |
| 77 | + if max_extra_seq is not None: |
| 78 | + if "multimer" in model_suffix: |
| 79 | + model_config.model.embeddings_and_evoformer.num_extra_msa = max_extra_seq |
| 80 | + else: |
| 81 | + model_config.data.common.max_extra_msa = max_extra_seq |
| 82 | + |
| 83 | + # set number of recycles and ensembles |
| 84 | + if "multimer" in model_suffix: |
| 85 | + if num_recycles is not None: |
| 86 | + model_config.model.num_recycle = num_recycles |
| 87 | + if use_cluster_profile is not None: |
| 88 | + model_config.model.embeddings_and_evoformer.use_cluster_profile = use_cluster_profile |
| 89 | + model_config.model.num_ensemble_eval = num_ensemble |
| 90 | + else: |
| 91 | + if num_recycles is not None: |
| 92 | + model_config.data.common.num_recycle = num_recycles |
| 93 | + model_config.model.num_recycle = num_recycles |
| 94 | + model_config.data.eval.num_ensemble = num_ensemble |
| 95 | + |
| 96 | + |
| 97 | + if recycle_early_stop_tolerance is not None: |
| 98 | + model_config.model.recycle_early_stop_tolerance = recycle_early_stop_tolerance |
| 99 | + |
| 100 | + # get model runner |
120 | 101 | params = data.get_model_haiku_params(
|
121 |
| - model_name=model_name + model_suffix, data_dir=str(data_dir) |
| 102 | + model_name="model_" + str(model_number) + model_suffix, |
| 103 | + data_dir=str(data_dir), fuse=use_fuse) |
| 104 | + model_runner = model.RunModel( |
| 105 | + model_config, |
| 106 | + params, |
122 | 107 | )
|
123 |
| - # keep only parameters of compiled model |
124 |
| - params_subset = {} |
125 |
| - for k in model_runner.params.keys(): |
126 |
| - params_subset[k] = params[k] |
| 108 | + |
| 109 | + model_name = f"model_{model_number}" |
| 110 | + params = data.get_model_haiku_params( |
| 111 | + model_name=model_name + model_suffix, data_dir=str(data_dir), fuse=use_fuse, |
| 112 | + ) |
| 113 | + # keep only parameters of compiled model |
| 114 | + params_subset = {} |
| 115 | + for k in model_runner.params.keys(): |
| 116 | + params_subset[k] = params[k] |
127 | 117 |
|
128 |
| - model_runner_and_params_build_order.append( |
129 |
| - (model_name, model_runner, params_subset) |
130 |
| - ) |
131 |
| - # reorder model |
132 |
| - for n, model_number in enumerate(model_order): |
133 |
| - if n == num_models: |
| 118 | + model_runner_and_params_build_order.append( |
| 119 | + (model_name, model_runner, params_subset) |
| 120 | + ) |
| 121 | + # reorder model |
| 122 | + for n, model_number in enumerate(model_order): |
| 123 | + if n == num_models: |
| 124 | + break |
| 125 | + model_name = f"model_{model_number}" |
| 126 | + for m in model_runner_and_params_build_order: |
| 127 | + if model_name == m[0]: |
| 128 | + model_runner_and_params.append(m) |
134 | 129 | break
|
135 |
| - model_name = f"model_{model_number}" |
136 |
| - for m in model_runner_and_params_build_order: |
137 |
| - if model_name == m[0]: |
138 |
| - model_runner_and_params.append(m) |
139 |
| - break |
140 | 130 | return model_runner_and_params
|
0 commit comments