Skip to content

Add missing logging keys for GNN #729

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion graph_neural_network/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ or
```bash
git clone https://github.com/mlcommons/training.git
```
once `GNN node classification` is merged into `mlcommons/training`.


#### 2. Build the docker image:

Expand Down
20 changes: 17 additions & 3 deletions graph_neural_network/dist_train_rgnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,

training_start = time.time()
for epoch in range(epochs):
if rank == 0:
mllogger.start(
key=mllog_constants.EPOCH_START,
metadata={mllog_constants.EPOCH_NUM: epoch},
)
model.train()
total_loss = 0
train_acc = 0
Expand Down Expand Up @@ -282,6 +287,12 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
if with_gpu:
torch.cuda.synchronize()
torch.distributed.barrier()

if rank == 0:
mllogger.end(
key=mllog_constants.EPOCH_STOP,
metadata={mllog_constants.EPOCH_NUM: epoch},
)

#checkpoint at the end of epoch
if checkpoint_on_epoch_end:
Expand Down Expand Up @@ -412,10 +423,9 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
if args.split_training_sampling:
assert(not args.num_training_procs > torch.cuda.device_count() // 2)

world_size = args.num_nodes * args.num_training_procs
if args.node_rank == 0:
world_size = args.num_nodes * args.num_training_procs
submission_info(mllogger, 'GNN', 'reference_implementation')

submission_info(mllogger, mllog_constants.GNN, 'reference_implementation')
mllogger.event(key=mllog_constants.GLOBAL_BATCH_SIZE, value=world_size*args.train_batch_size)
mllogger.event(key=mllog_constants.GRADIENT_ACCUMULATION_STEPS, value=1)
mllogger.event(key=mllog_constants.OPT_NAME, value='adam')
Expand Down Expand Up @@ -443,6 +453,10 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
)
train_idx.share_memory_()
val_idx.share_memory_()

if args.node_rank == 0:
mllogger.event(key=mllog_constants.TRAIN_SAMPLES, value=train_idx.size(0) * world_size)
mllogger.event(key=mllog_constants.EVAL_SAMPLES, value=val_idx.size(0) * world_size)

print('--- Launching training processes ...\n')
torch.multiprocessing.spawn(
Expand Down
16 changes: 15 additions & 1 deletion graph_neural_network/train_rgnn_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ def run_training_proc(rank, world_size,

training_start = time.time()
for epoch in tqdm.tqdm(range(epochs)):
if rank == 0:
mllogger.start(
key=mllog_constants.EPOCH_START,
metadata={mllog_constants.EPOCH_NUM: epoch},
)
model.train()
total_loss = 0
train_acc = 0
Expand Down Expand Up @@ -208,6 +213,12 @@ def run_training_proc(rank, world_size,
torch.cuda.synchronize()
dist.barrier()

if rank == 0:
mllogger.end(
key=mllog_constants.EPOCH_STOP,
metadata={mllog_constants.EPOCH_NUM: epoch},
)

#checkpoint at the end of epoch
if checkpoint_on_epoch_end:
if rank == 0:
Expand Down Expand Up @@ -311,7 +322,7 @@ def run_training_proc(rank, world_size,

glt.utils.common.seed_everything(args.random_seed)
world_size = torch.cuda.device_count()
submission_info(mllogger, 'GNN', 'reference_implementation')
submission_info(mllogger, mllog_constants.GNN, 'reference_implementation')
mllogger.event(key=mllog_constants.GLOBAL_BATCH_SIZE, value=world_size*args.train_batch_size)
mllogger.event(key=mllog_constants.GRADIENT_ACCUMULATION_STEPS, value=1)
mllogger.event(key=mllog_constants.OPT_NAME, value='adam')
Expand Down Expand Up @@ -342,6 +353,9 @@ def run_training_proc(rank, world_size,
train_idx = igbh_dataset.train_idx.clone().share_memory_()
val_idx = igbh_dataset.val_idx.clone().share_memory_()

mllogger.event(key=mllog_constants.TRAIN_SAMPLES, value=train_idx.size(0))
mllogger.event(key=mllog_constants.EVAL_SAMPLES, value=val_idx.size(0))

print('--- Launching training processes ...\n')
torch.multiprocessing.spawn(
run_training_proc,
Expand Down