use torch.gather instead of direct indexing #5
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Instead of this line:
log_liklihoods.append(output[:, target])
have this line:
log_liklihoods.append(torch.gather(output, dim=1, index=target.unsqueeze(-1)))
Why?
Assume our output is 100x4 which means batch size is 100 and we have 4 classes. Target is a (100,) vector of classes, by indexing output[:, target] we will create a 100x100 matrix, instead of gathering the loglikelihoods 100x1 that we desire.
The torch.gather function does this propoerly.