-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsubmit.py
118 lines (109 loc) · 3.99 KB
/
submit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Dict, Iterator, List, Optional, Union, Literal, Tuple
from tqdm import tqdm
import json
import numpy as np
from alms.args import SubmitArgs
from alms.database import *
from alms.aimstools.rdkit.smiles import *
from alms.ml.chemprop.features.features_generators import get_features
from alms.ml.chemprop.args import PredictArgs
from alms.ml.chemprop.train import make_predictions
def mol_filter(smiles: str,
smarts_bad: Dict[str, str],
heavy_atom: Tuple[int, int] = None):
mol = Chem.MolFromSmiles(smiles)
# Return None for heavy atoms out of range.
if heavy_atom is not None:
if not heavy_atom[0] <= mol.GetNumAtoms() <= heavy_atom[1]:
return None
# Return None for wrong smiles.
if mol is None:
print('Ignore invalid SMILES: %s.' % smiles)
return None
# return None for molecules contain bad smarts.
for name, smarts in smarts_bad.items():
matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts),
useChirality=True,
maxMatches=1)
if len(matches) != 0:
return None
# return canonical smiles.
return get_rdkit_smiles(smiles)
def submit(args: SubmitArgs):
smiles = args.smiles_list
# Filter smiles using smarts patterns.
smarts_bad = {
# Not covered by TEAM FF
'radicalC': '[#6;v0,v1,v2,v3]',
'*=*=*': '*=*=*',
'*#*~*#*': '*#*~*#*',
'[F,Cl,Br]~[!#6]': '[F,Cl,Br]~[!#6]',
'*#*~[!#6]': '*#*~[!#6]',
'[NX2,NX4]': '[NX2,NX4]',
'O~N(~[!$([OX1])])~[!$([OX1])]': 'O~N(~[!$([OX1])])~[!$([OX1])]',
'peroxide': 'O~O',
'N~N': 'N~N',
'[O,N]*[O,N;H1,H2]': '[O,N]*[O,N;H1,H2]',
'C=C~[O,N;H1,H2]': 'C=C~[O,N;H1,H2]',
'beta-dicarbonyl': 'O=C~*~C=O',
'a=*': 'a=*',
'o': 'o',
'[n;r5]': '[n;r5]',
'pyridine-N-oxide': '[nX3;r6]',
'triazine(zole)': '[$(nnn),$(nnan),$(nanan)]',
'[R3]': '[R3]',
'[r3,r4;R2]': '[r3,r4;R2]',
'[r3,r4;#6X3]': '[r3,r4;#6X3]',
'[r3,r4]~[!#6]': '[r3,r4]~[!#6]',
'nitrate': 'O[NX3](~[OX1])~[OX1]',
'amide': 'O=C[NX3]',
'acyl-halide': 'O=C[F,Cl,Br]',
'polybenzene': 'c1ccc2c(c1)cccc2',
# Covered by TEAM FF but the results are not good
'[r5;#6X3]': '[r5;#6X3]',
'[r5]~[!#6]': '[r5]~[!#6]',
'cyclo-ester': '[C;R](=O)O',
'C=C~[O,N;H0]': 'C=C~[O,N;H0]',
'C=C-X': 'C=C[F,Cl,Br]',
'[F,Cl,Br][#6][F,Cl,Br]': '[F,Cl,Br][#6][F,Cl,Br]',
'alkyne': '[CX2]#[CX2]',
'acid': 'C(=O)[OH]',
'nitrile': '[NX1]#[CX2][C,c]',
'nitro': '[C,c][NX3](~[OX1])~[OX1]',
'N-aromatic': 'n',
'halogen': '[F,Cl,Br]',
}
smiles_valid = []
for s in smiles:
can_s = mol_filter(s, smarts_bad, args.heavy_atoms)
if can_s is not None:
smiles_valid.append(can_s)
# Create molecules in database.
for s in tqdm(np.unique(smiles_valid), total=np.unique(smiles_valid).size):
mol = Molecule(smiles=s,
features=json.dumps(get_features(s, args.features_generator)))
add_or_query(mol, ['smiles'])
session.commit()
def predict(target: str):
mols = session.query(Molecule)
smiles = []
features = []
for mol in mols:
smiles.append([mol.smiles])
features.append(json.loads(mol.features)['rdkit_2d_normalized'])
features = np.asarray(features)
args = PredictArgs()
args.checkpoint_dir = 'ml-models/%s' % target
args.no_features_scaling = True
args.process_args()
preds = make_predictions(args, smiles, features, save_prediction=False)
for i, mol in enumerate(mols):
mol.update_dict('property_ml', {target: preds[i][0]})
session.commit()
if __name__ == '__main__':
submit(args=SubmitArgs().parse_args())
predict('tt')
predict('tb')
predict('tc')