Skip to content

Commit a9af2e3

Browse files
committed
No MPI for cifar10
1 parent edbb7a3 commit a9af2e3

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

examples/cifar10/main.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ using Wandb # Logging to Weights and
3030
using Zygote # Our AD Engine
3131

3232
# Distributed Training
33-
FluxMPI.Init(; verbose=true)
33+
# FluxMPI.Init(; verbose=true)
3434
CUDA.allowscalar(false)
3535

3636
# Training Options
@@ -216,7 +216,7 @@ function train_one_epoch(train_loader, model, ps, st, optimiser_state, epoch, lo
216216
)
217217
forward_pass_time(time() - _t, B)
218218
_t = time()
219-
gs = back((one(loss) / total_workers(), nothing, nothing))[1]
219+
gs = back((one(loss), nothing, nothing))[1]
220220
backward_pass_time(time() - _t, B)
221221
st = Lux.update_state(st, :update_mask, Val(true))
222222
if is_distributed()
@@ -341,9 +341,9 @@ function main(args)
341341

342342
logging_header = ["Epoch", "Train/Batch Time", "Train/Data Time", "Train/Forward Pass Time", "Train/Backward Pass Time", "Train/Cross Entropy Loss", "Train/Skip Loss", "Train/Net Loss", "Train/NFE", "Train/Accuracy", "Train/Residual", "Test/Batch Time", "Test/Data Time", "Test/Cross Entropy Loss", "Test/Skip Loss", "Test/Net Loss", "Test/NFE", "Test/Accuracy", "Test/Residual"]
343343
csv_logger = CSVLogger(log_path, logging_header)
344-
wandb_logger = WandbLoggerMPI(project="deep_equilibrium_models",
345-
name=store_in,
346-
config=loggable_config)
344+
wandb_logger = WandbLogger(project="deep_equilibrium_models",
345+
name=store_in,
346+
config=loggable_config)
347347

348348
values_to_loggable_dict(args...) = Dict(zip(logging_header, args))
349349

0 commit comments

Comments
 (0)