in video 103 why we used enumerate to loop through test_dataloader but didn't use enumerate to loop through test_dataloader #584
-
I have added 2 comments preceding with many question marks in the code below. We used enumerate to loop through test_dataloader but didn't use enumerate to loop through test_dataloader. why so? (The code is from 16:10:30 timeframe/video 103. Training and testing loops for batch data) from tqdm.auto import tqdm
# set the seed and start the timer
torch.manual_seed(42)
train_time_start_on_cpu = timer()
#set the number of epochs
epochs = 3
# create training and test loop
for epoch in tqdm(range(epochs)):
print(f"Epoch: {epoch}\n---------")
### Training
train_loss = 0
# add a loop to loop through the training batches
for batch, (X, y) in enumerate(train_dataloader): #??????????????????????? why we use enumerate here?
model_0.train()
# 1. forward pass
y_pred = model_0(X)
# 2. calculate the loss(per batch)
loss = loss_fn(y_pred, y)
train_loss += loss
# 3. optimizer zero grad
optimizer.zero_grad()
# 4. loss backward
loss.backward()
# 5. optimizer step
optimizer.step()
# Ptrint out what's happening
if batch % 400 == 0:
print(f"Looked at {batch * len(X)}/{len(train_dataloader.dataset)} samples")
# Divide total train loss by length of train dataloader
train_loss /= len(train_dataloader)
### Testing
test_loss, test_acc = 0, 0
model_0.eval()
with torch.inference_mode():
for X_test, y_test in test_dataloader: #????????????????????????????????? but not here?
# 1. forward pass
test_pred = model_0(X_test)
# 2. Calculate the loss
test_loss += loss_fn(test_pred, y_test)
# 3. Calculate the accuracy
test_acc += accuracy_fn(y_true=y_test, y_pred=test_pred.argmax(dim=1))
# calculate the avg test loss per batch
test_loss /= len(test_dataloader)
# calculate the avg test accuracy per batch
test_acc /= len(test_dataloader)
#print out what's happening
print(f"\nTrain Loss: {train_loss:.4f} | Test Loss: {test_loss:.4f}, Test acc: {test_acc:.4f}")
# calculate the training time
train_time_end_on_cpu = timer()
total_train_time_model_0 = print_train_time(train_time_start_on_cpu, train_time_end_on_cpu, device=str(next(model_0.parameters()).device)) |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
Enumerate (have a look at the documentation) returns an iterator that are pairs of elements and their indexes. In the first loop, the index is assigned to |
Beta Was this translation helpful? Give feedback.
-
Hello, First of all, thank you for the very useful course. I have a follow-up question on this topic. When we iterate through the batches, do we provide the full batch (32 images) at a time to the neural network? How does it work, since the input size of the network is 28*28, how can it accept a full batch of 32 images? Thank you very much. |
Beta Was this translation helpful? Give feedback.
Enumerate (have a look at the documentation) returns an iterator that are pairs of elements and their indexes. In the first loop, the index is assigned to
batch
, while the values are put into X and y. It looks like this is done so that ever 400 batches, it gives an update:if batch % 400 == 0:
...There's no need to have access to the batch number because there's no updates in the second loop, so no need for
enumerate
.