-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathrun_multiprocess_infer_multimer.py
108 lines (91 loc) · 4.22 KB
/
run_multiprocess_infer_multimer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import subprocess
import os
import time
import multiprocess_functions as mpf
from datetime import datetime
d = datetime.now()
timestamp = "inference_multimer_%04d%02d%02d%02d%02d" % (d.year, d.month, d.day, d.hour, d.minute)
from absl import app
from absl import flags
from absl import logging
flags.DEFINE_string('root_condaenv', None, 'conda environment directory path')
flags.DEFINE_string('root_home', None, 'home directory')
flags.DEFINE_string('input_dir', None, 'root directory holding all .fa files')
flags.DEFINE_string('output_dir', None, 'Path to a directory that will store the results.')
flags.DEFINE_string('model_names', None, 'Names of models to use')
flags.DEFINE_integer('AF2_BF16', 1, 'Set to 0 for FP32 precision run.')
flags.DEFINE_integer('random_seed', 123, 'The random seed for the data '
'pipeline. By default, this is randomly generated. Note '
'that even if this is set, Alphafold may still not be '
'deterministic, because processes like GPU inference are '
'nondeterministic.')
flags.DEFINE_integer('num_multimer_predictions_per_model', 1, 'How many '
'predictions (each with a different random seed) will be '
'generated per model. E.g. if this is 2 and there are 5 '
'models then there will be 10 predictions per input. '
'Note: this FLAG only applies in multimer mode')
FLAGS = flags.FLAGS
script = "python run_modelinfer_pytorch_jit_multimer.py"
base_fold_cmd = "/usr/bin/time -v {} \
--fasta_paths={} \
--output_dir={} \
--model_names={} \
--root_params={} \
--random_seed={} \
--num_multimer_predictions_per_model={} \
"
def bash_subprocess(file_path, mem, core_list):
"""Starts a new bash subprocess and puts it on the specified cores."""
out_dir = FLAGS.output_dir
root_params = FLAGS.root_home + "/weights/extracted/"
log_dir = FLAGS.root_home + "/logs/" + str(timestamp) + "/"
os.makedirs(log_dir, exist_ok=True)
model_names=FLAGS.model_names
random_seed = FLAGS.random_seed
num_multimer_predictions_per_model = FLAGS.num_multimer_predictions_per_model
command = base_fold_cmd.format(script, file_path, out_dir, model_names, root_params, random_seed, num_multimer_predictions_per_model)
numactl_args = ["numactl", "-m", mem, "-C", "-".join([str(core_list[0]), str(core_list[-1])]), command]
print(" ".join(numactl_args))
with open(log_dir + 'inference_log_' + os.path.basename(file_path) + '.txt', 'w') as f:
try:
process = subprocess.call(" ".join(numactl_args), shell=True, universal_newlines=True, stdout=f, stderr=f)
except Exception as e:
print('exception for', os.path.basename(file_path), e)
return (process, file_path, mem, core_list)
def main(argv):
t1 = time.time()
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
# root_condaenv=FLAGS.root_condaenv
input_dir = FLAGS.input_dir
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "1"
os.environ["MALLOC_CONF"] = "oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1"
os.environ["USE_OPENMP"] = "1"
os.environ["USE_AVX512"] = "1"
os.environ["IPEX_ONEDNN_LAYOUT"] = "1"
os.environ["PYTORCH_TENSOREXPR"] = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ["AF2_BF16"] = str(FLAGS.AF2_BF16)
"""The main function."""
directory = input_dir
# Get the list of files in the directory.
files = os.listdir(directory)
for i, file in enumerate(files):
files[i] = os.path.join(directory, file)
MIN_MEM_PER_PROCESS=32*1024 # 32 GB
MIN_CORES_PER_PROCESS=8
LOAD_BALANCE_FACTOR=4
max_processes_list = mpf.create_process_list(files, MIN_MEM_PER_PROCESS, MIN_CORES_PER_PROCESS, LOAD_BALANCE_FACTOR)
files = mpf.start_process_list(files, max_processes_list, bash_subprocess)
print("Following protein files couldn't be processed")
print(files)
t2 = time.time()
print('### Total inference time: %d sec' % (t2-t1))
if __name__ == "__main__":
flags.mark_flags_as_required([
'root_home',
'input_dir',
'output_dir',
'model_names'
])
app.run(main)