Skip to content

Commit 29695d7

Browse files
author
Xiangci Li
committed
Fix a bug due to Pytorch version update
1 parent c4bb72a commit 29695d7

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pipeline/paragraph_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def forward(self, encoded_dict, transformation_indices, N_tokens, discourse_labe
158158

159159

160160
discourse_pred = torch.argmax(discourse_out.cpu(), dim=-1) # (Batch_size, N_sep)
161-
discourse_out = [discourse_pred_paragraph[mask].detach().numpy().tolist() for discourse_pred_paragraph, mask in zip(discourse_pred, sentence_mask.bool())]
161+
discourse_out = [discourse_pred_paragraph[mask].detach().numpy().tolist() for discourse_pred_paragraph, mask in zip(discourse_pred, sentence_mask.bool().cpu())]
162162
citation_pred = torch.argmax(citation_out.cpu(), dim=-1) # (Batch_size, N_sep)
163163
citation_out = [citation_pred_paragraph[:n_token].detach().numpy().tolist() for citation_pred_paragraph, n_token in zip(citation_pred, N_tokens)]
164164
span_pred = torch.argmax(span_out.cpu(), dim=-1) # (Batch_size, N_sep)

0 commit comments

Comments
 (0)