Skip to content

Commit e7073d6

Browse files
committed
First commit
0 parents  commit e7073d6

File tree

68 files changed

+23095
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+23095
-0
lines changed

.vscode/launch.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
// Use IntelliSense to learn about possible attributes.
3+
// Hover to view descriptions of existing attributes.
4+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5+
"version": "0.2.0",
6+
"configurations": [
7+
{
8+
"name": "Python: Current File",
9+
"type": "python",
10+
"request": "launch",
11+
"program": "${file}",
12+
"console": "integratedTerminal",
13+
"justMyCode": true
14+
}
15+
]
16+
}

README.md

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
![](model_architecture.png)
2+
3+
# Scalable and Robust Physics-Informed Graph Neural Networks for Water Distribution Systems.
4+
5+
Official Code for the paper "Scalable and Robust Physics-Informed Graph Neural Networks for Water Distribution Systems" (under submission, preprint available at arXiv: ). \
6+
All system and package requirements are listed in the document 'environment.yml'. A corresponding conda environment can be setup via `conda env create -f environment.yml`.
7+
8+
## Simulating scenarios to generate data
9+
10+
WDS datasets can be generated (Vrachimis et al. https://github.com/KIOS-Research/BattLeDIM) using:
11+
12+
``` python
13+
python utils/dataset_generator.py
14+
```
15+
16+
A number of arguments can be passed to dataset generation parameters:
17+
18+
``` python
19+
'--wds' "Specify the WDS for which you want to simulate the scenarios; default is anytown. Choices are ['anytown', 'hanoi', 'pescara', 'area_c', 'zhijiang', 'modena', 'pa1', 'balerma', 'area_a', 'l_town', 'kl']."
20+
'--sim_start_time' "Specify the start time of the simulation; default is 2018-01-01 00:00, the simulation will be done every 30 minutes starting from this time."
21+
'--sim_end_time' "Specify the end time of the simulation; default is 2018-01-01 02:30."
22+
'--start_scenario' "Specify the start scenario name, must be an integer; default is 1000"
23+
'--end_scenario' "Specify the end scenario name, must be an integer; default is 1999"
24+
```
25+
26+
The simulation will produce an xlsx file for each scenario in the folder 'results' in the respective directories. These xlsx files will be used for training the models.
27+
28+
## Training and Evaluation
29+
30+
31+
Models can be trained using
32+
```python
33+
python run.py
34+
```
35+
A number of arguments can be passed to specify model types and hyperparameters:
36+
37+
``` python
38+
'--wds' "Specify the WDS for which you want to simulate the scenarios; default is anytown. Choices are ['anytown', 'hanoi', 'pescara', 'area_c', 'zhijiang', 'modena', 'pa1', 'balerma', 'area_a', 'l_town', 'kl']."
39+
'--mode' "train_test i.e. train and test a new model, or evaluate i.e. evaluate on an already trained model; default is train_test. "
40+
'--warm_start' "Specify True if you want to further train a partially trained model. model_path must also be specified; default is False."
41+
'--model_path' "Specify model path in case of re-training or evaluation; default is None."
42+
'--model' "Choose the model between PI_GNN and SPI_GNN; default is SPI_GNN."
43+
'--start_scenario' "Specify the start scenario name, must be an integer; default is 1"
44+
'--end_scenario' "Specify the end scenario name, must be an integer; default is 20"
45+
'--n_samples' "Specify the number of samples for each scenario to be used for training; default is 6."
46+
'--batch_size' "Specify the mini-batch size; default is 96."
47+
'--n_epochs' "Specify the number of epochs of training; default is 1500."
48+
'--lr' "Specify the learning rate; default is 1e-4."
49+
'--decay_step' "Specify the step size of the lr scheduler; default is 150."
50+
'--decay_rate' "Specify the decay rate of the lr scheduler; default is 0.75."
51+
'--I' "Specify the number of GCN layers; default is 5."
52+
'--n_iter' "Specify the minimum number of iterations; default is 5."
53+
'--r_iter' "Specify the maximum number of additional (random) iterations; default is 5."
54+
'--n_mlp' "Specify the number of layers in the MLP; default is 1."
55+
'--M_l' "Specify the latent dimension; default is 128."
56+
'--wandb' "Specify True if you want to use Weights and Biases during training; default is False."
57+
58+
```
59+
60+
Trained models can be used for evaluation using run.py by specifying the 'evaluate' mode and 'model_path'.
61+
62+
## Robustness Evaluation
63+
64+
Trained models can be evaluated for robustness using
65+
66+
```python
67+
python robustness_eval.py
68+
```
69+
A number of arguments can be passed:
70+
71+
``` python
72+
'--wds' "Specify the WDS for which you want to simulate the scenarios; default is anytown. Choices are ['anytown', 'hanoi', 'pescara', 'area_c', 'zhijiang', 'modena', 'pa1', 'balerma', 'area_a', 'l_town', 'kl']."
73+
'--mode' "'demands' or 'diameters', evaluate robustness by changing demands or diameters; default is demands. "
74+
'--model_path' "Specify the trained model path; default is the trained model for Anytown."
75+
'--model' "Choose the model between PI_GNN and SPI_GNN; default is SPI_GNN."
76+
'--batch_size' "Specify the mini-batch size; default is 1000."
77+
'--I' "Specify the number of GCN layers; default is 5."
78+
'--n_iter' "Specify the minimum number of iterations; default is 5."
79+
'--r_iter' "Specify the maximum number of additional iterations; default is 5."
80+
'--n_mlp' "Specify the number of layers in the MLP; default is 1."
81+
'--M_l' "Specify the latent dimension; default is 128."
82+
83+
```
84+
The robustness evaluation results can further be analyzed using the notebooks in 'results' directory.
85+
86+
## Important Information
87+
88+
Every WDS is specified by an '.inp' file. We have included those files for all WDSs. Moreover, we also include trained models for all WDSs.
89+
90+
## Citation
91+
### Preprint:
92+
```
93+
@misc{ashraf2025spignn_wds,
94+
author = {Ashraf, Inaam and Artelt, Andr{\'{e}} and Hammer, Barbara},
95+
title = {Scalable and Robust Physics-Informed Graph Neural Networks for Water Distribution Systems.},
96+
year = {2025},
97+
month = feb,
98+
archiveprefix = {arXiv},
99+
eprint = {},
100+
copyright = {Creative Commons Attribution Share Alike 4.0 International}
101+
}
102+
```
103+
### Repository:
104+
```
105+
@misc{SPIGNNs_for_WDSs,
106+
author = {Ashraf, Inaam and Artelt, Andr{\'{e}} and Hammer, Barbara},
107+
title = {{SPIGNNs_for_WDSs}},
108+
year = {2025},
109+
publisher = {GitHub}
110+
journal = {GitHub repository},
111+
organization = {CITEC, Bielefeld University, Germany},
112+
howpublished = {\url{https://github.com/inaamashraf/SPIGNNs_for_WDSs}},
113+
}
114+
```
115+
116+
117+
## Acknowledgments
118+
We gratefully acknowledge funding from the European
119+
Research Council (ERC) under the ERC Synergy Grant Water-Futures (Grant
120+
agreement No. 951424).
4.78 KB
Binary file not shown.

environment.yml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
name: waterfutures
2+
channels:
3+
- pyg
4+
- pytorch
5+
- nvidia
6+
- conda-forge
7+
- defaults
8+
dependencies:
9+
- cudatoolkit=11.8.0=h6a678d5_0
10+
- numpy=1.26.4=py39h5f9d8c6_0
11+
- pandas=2.2.2=py39h6a678d5_0
12+
- pygad=3.3.1=pyhd8ed1ab_0
13+
- python=3.9.18=h955ad1f_0
14+
- pytorch=2.2.1=py3.9_cuda12.1_cudnn8.9.2_0
15+
- pytorch-scatter=2.1.2=py39_torch_2.2.0_cu121
16+
- pyyaml=6.0.1=py39h5eee18b_0
17+
- tabulate=0.9.0=py39h06a4308_0
18+
- torchaudio=2.2.1=py39_cu121
19+
- torchmetrics=1.3.1=pyhd8ed1ab_0
20+
- torchtriton=2.2.0=py39
21+
- torchvision=0.17.1=py39_cu121
22+
- tqdm=4.66.2=pyhd8ed1ab_0
23+
- xlsxwriter=3.1.1=py39h06a4308_0
24+
- yaml=0.2.5=h7b6447c_0
25+
- pip:
26+
- matplotlib==3.9.1
27+
- networkx==3.2.1
28+
- optuna==3.6.1
29+
- python-calamine==0.2.0
30+
- ray==2.34.0
31+
- scikit-learn==1.4.1.post1
32+
- scipy==1.13.1
33+
- seaborn==0.13.2
34+
- torch-geometric==2.5.0
35+
- wandb==0.16.4
36+
- watchfiles==0.22.0
37+
- wntr==1.2.0

model_architecture.png

464 KB
Loading
3.26 KB
Binary file not shown.
6.34 KB
Binary file not shown.

models/layers.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import torch
2+
from typing import List
3+
from torch import Tensor
4+
from torch.nn import Dropout, Sequential
5+
from torch_scatter import scatter
6+
from torch_geometric.nn.conv import MessagePassing
7+
from torch_geometric.nn.dense.linear import Linear
8+
9+
"""
10+
Using partial code from torch_geometric.nn.conv.GENconv
11+
https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/gen_conv.html#GENConv
12+
"""
13+
14+
class MLP(Sequential):
15+
"""
16+
Defining an MLP object with multiple linear layers,
17+
activation functions and dropout.
18+
"""
19+
def __init__(self, dims: List[int], bias: bool = True, dropout: float = 0., activ=None):
20+
m = []
21+
for i in range(1, len(dims)):
22+
m.append(Linear(dims[i - 1], dims[i], bias=bias))
23+
24+
if i < len(dims) - 1:
25+
if activ is not None:
26+
m.append(activ)
27+
m.append(Dropout(dropout))
28+
29+
super().__init__(*m)
30+
31+
32+
33+
class GNN_Layer(MessagePassing):
34+
"""
35+
Graph Neural Network layer using message passing.
36+
The layer applies both node and edge feature update.
37+
"""
38+
def __init__(self, in_dim, out_dim, edge_dim, aggr='max', num_layers=2, bias=False, **kwargs):
39+
40+
kwargs.setdefault('aggr', None)
41+
super().__init__(**kwargs)
42+
43+
self.in_dim = in_dim
44+
self.out_dim = out_dim
45+
self.edge_dim = edge_dim
46+
self.aggr = aggr
47+
48+
""" Defining dimensions for layers in MLP gamma. """
49+
edge_dims = [2*in_dim + edge_dim]
50+
for _ in range(num_layers - 1):
51+
edge_dims.append(edge_dim)
52+
edge_dims.append(edge_dim)
53+
54+
""" Defining dimensions for layers in MLP eta. """
55+
node_dims = [edge_dim]
56+
for _ in range(num_layers - 1):
57+
node_dims.append(in_dim)
58+
node_dims.append(out_dim)
59+
60+
"""
61+
MLPs gamma and eta used for node and edge feature
62+
updates respectively.
63+
"""
64+
self.mlp_edges = MLP(edge_dims, bias=bias)
65+
self.mlp_nodes = MLP(node_dims, bias=bias)
66+
67+
def forward(self, g, edge_index, z) -> Tensor:
68+
"""
69+
Creating edge messages that are also the updated edge features
70+
using node and edge features and the MLP gamma.
71+
"""
72+
sndr_node_attr = g[edge_index[0, :], :]
73+
rcvr_node_attr = g[edge_index[1, :], :]
74+
m_e = self.mlp_edges(torch.selu(torch.cat((sndr_node_attr, rcvr_node_attr, z), dim=-1)))
75+
76+
""" Aggregating edge messages using max aggregation. """
77+
m_e_aggr = scatter(m_e, dim=0, index=edge_index[1:2, :].T, reduce='max', out=torch.zeros_like(g))
78+
79+
""" Updating node features using the MLP eta. """
80+
g = self.mlp_nodes(m_e_aggr)
81+
82+
return g, m_e
83+
84+
85+
86+
class SGNN_Layer(MessagePassing):
87+
"""
88+
Graph Neural Network layer using message passing.
89+
The layer applies both node and edge feature update.
90+
"""
91+
def __init__(self, in_dim, out_dim, edge_dim, aggr='max', num_layers=2, bias=False, **kwargs):
92+
93+
kwargs.setdefault('aggr', None)
94+
super().__init__(**kwargs)
95+
96+
self.in_dim = in_dim
97+
self.out_dim = out_dim
98+
self.edge_dim = edge_dim
99+
self.aggr = aggr
100+
101+
""" Defining dimensions for layers in MLP gamma. """
102+
edge_dims = [2*in_dim + edge_dim]
103+
for _ in range(num_layers - 1):
104+
edge_dims.append(edge_dim)
105+
edge_dims.append(edge_dim)
106+
107+
""" Defining dimensions for layers in MLP eta. """
108+
node_dims = [edge_dim]
109+
for _ in range(num_layers - 1):
110+
node_dims.append(in_dim)
111+
node_dims.append(out_dim)
112+
113+
"""
114+
MLPs gamma and eta used for node and edge feature
115+
updates respectively.
116+
"""
117+
self.mlp_edges = MLP(edge_dims, bias=bias)
118+
self.mlp_nodes = MLP(node_dims, bias=bias)
119+
120+
def forward(self, g, edge_index, z, edge_mask=None) -> Tensor:
121+
"""
122+
Creating edge messages that are also the updated edge features
123+
using node and edge features and the MLP gamma.
124+
"""
125+
sndr_node_attr = g[edge_index[0, :], :]
126+
rcvr_node_attr = g[edge_index[1, :], :]
127+
m_e = torch.relu(self.mlp_edges(torch.cat((sndr_node_attr, rcvr_node_attr, z), dim=-1)))
128+
129+
""" Aggregating edge messages using max aggregation. """
130+
m_e_aggr = scatter(m_e, dim=0, index=edge_index[1:2, :].T, reduce=self.aggr, out=torch.zeros_like(g))
131+
132+
""" Updating node features using the MLP eta. """
133+
g = self.mlp_nodes(m_e_aggr)
134+
135+
return g, m_e
136+
137+

0 commit comments

Comments
 (0)