Skip to content

Commit 9fc4413

Browse files
committed
Added data and models directories
1 parent 803a9c4 commit 9fc4413

8 files changed

+1094
-4
lines changed

.gitignore

+2-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,2 @@
1-
data
2-
logs
3-
models
4-
__pycache__
1+
.idea
2+
__pycache__

src/data/__init__.py

Whitespace-only changes.

src/data/abstract_dataset.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import numpy as np
2+
import torch
3+
import pytorch_lightning as pl
4+
5+
from torch_geometric.utils import scatter
6+
7+
from pdb import set_trace
8+
9+
10+
class AbstractDataModule(pl.LightningDataModule):
11+
def __init__(self, batch_size, num_workers, shuffle):
12+
super().__init__()
13+
self.dataloaders = None
14+
self.input_dims = None
15+
self.output_dims = None
16+
self.batch_size = batch_size
17+
self.num_workers = num_workers
18+
self.shuffle = shuffle
19+
20+
def train_dataloader(self):
21+
return self.dataloaders["train"]
22+
23+
def val_dataloader(self):
24+
return self.dataloaders["val"]
25+
26+
def test_dataloader(self):
27+
return self.dataloaders["test"]
28+
29+
def __getitem__(self, idx):
30+
return self.dataloaders['train'][idx]
31+
32+
def node_counts(self, max_nodes_possible=300):
33+
all_counts = torch.zeros(max_nodes_possible)
34+
for split in ['train', 'val', 'test']:
35+
for i, data in enumerate(self.dataloaders[split]):
36+
unique, counts = torch.unique(data.batch, return_counts=True)
37+
for count in counts:
38+
all_counts[count] += 1
39+
max_index = max(all_counts.nonzero())
40+
all_counts = all_counts[:max_index + 1]
41+
all_counts = all_counts / all_counts.sum()
42+
return all_counts
43+
44+
def node_types(self):
45+
num_classes = None
46+
for data in self.dataloaders['train']:
47+
num_classes = data.x.shape[1]
48+
break
49+
50+
counts = torch.zeros(num_classes)
51+
52+
for i, data in enumerate(self.dataloaders['train']):
53+
counts += data.x.sum(dim=0)
54+
55+
counts = counts / counts.sum()
56+
return counts
57+
58+
def edge_counts(self):
59+
num_classes = None
60+
for data in self.dataloaders['train']:
61+
num_classes = data.edge_attr.shape[1]
62+
break
63+
64+
d = torch.zeros(num_classes, dtype=torch.float)
65+
66+
for i, data in enumerate(self.dataloaders['train']):
67+
unique, counts = torch.unique(data.batch, return_counts=True)
68+
69+
all_pairs = 0
70+
for count in counts:
71+
all_pairs += count * (count - 1)
72+
73+
num_edges = data.edge_index.shape[1]
74+
num_non_edges = all_pairs - num_edges
75+
76+
edge_types = data.edge_attr.sum(dim=0)
77+
assert num_non_edges >= 0
78+
d[0] += num_non_edges
79+
d[1:] += edge_types[1:]
80+
81+
d = d / d.sum()
82+
return d
83+
84+
def dummy_atoms_counts(self, max_n_dummy_nodes):
85+
dummy_atoms = np.zeros(max_n_dummy_nodes + 1)
86+
for data in self.dataloaders['train']:
87+
batch_counts = scatter(data.p_x[:, -1], data.batch, reduce='sum')
88+
for cnt in batch_counts.long().detach().cpu().numpy():
89+
if cnt > max_n_dummy_nodes:
90+
continue
91+
dummy_atoms[cnt] += 1
92+
93+
return torch.tensor(dummy_atoms) / dummy_atoms.sum()
94+
95+
96+
class MolecularDataModule(AbstractDataModule):
97+
def valency_count(self, max_n_nodes):
98+
valencies = torch.zeros(3 * max_n_nodes - 2) # Max valency possible if everything is connected
99+
100+
# No bond, single bond, double bond, triple bond, aromatic bond
101+
multiplier = torch.tensor([0, 1, 2, 3, 1.5])
102+
103+
for split in ['train', 'val', 'test']:
104+
for i, data in enumerate(self.dataloaders[split]):
105+
n = data.x.shape[0]
106+
107+
for atom in range(n):
108+
edges = data.edge_attr[data.edge_index[0] == atom]
109+
edges_total = edges.sum(dim=0)
110+
valency = (edges_total * multiplier).sum()
111+
valencies[valency.long().item()] += 1
112+
valencies = valencies / valencies.sum()
113+
return valencies

0 commit comments

Comments
 (0)