Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,7 +1146,8 @@ def _swift_encode(self, inputs: StdTemplateInputs):
if template_meta.auto_add_bos and sep_token:
res_context_list.append(sep_token)
res_context_types.append(ContextType.SUFFIX)
res_context_list, loss_scale_list = self.loss_scale(res_context_list, res_context_types, inputs.messages)
res_context_list, loss_scale_list = self.loss_scale(res_context_list, res_context_types, inputs.messages,
**inputs.extra_kwargs)
Comment on lines +1149 to +1150
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

这个修改是正确的,它将extra_kwargs传递给了self.loss_scale,也就是LossScale.__call__方法。

然而,为了使这个功能完全生效,还需要修改swift/plugin/loss_scale/loss_scale.py中的LossScale.__call__方法,以确保它能将接收到的**kwargs向下传递给get_loss_scale方法。

swift/plugin/loss_scale/loss_scale.pyLossScale.__call__实现中,调用get_loss_scale时没有传递这些额外的参数。如果不进行此项修改,您在base.py中传递的extra_kwargs将不会被实际的loss_scale计算逻辑所使用。

考虑到您在PR描述中提到已经成功测试了此功能,请确认是否遗漏了对loss_scale.py文件的提交,或者是否有其他机制来处理这些参数。

if self.is_training:
answer_len = len(extra_context_list) + bool(response is not None)
else:
Expand Down Expand Up @@ -1673,7 +1674,7 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
seq_len = max(seq_lens) if padding_to is None else padding_to
res['attention_mask'] = torch.tril(torch.ones(
(len(seq_lens), seq_len, seq_len), dtype=torch.bool)).view(len(seq_lens), 1, seq_len, seq_len)
assert res['attention_mask'].dtype is torch.bool, f'attention_mask.dtype: {res["attention_mask"].dtype}'
assert res['attention_mask'].dtype is torch.bool, f'attention_mask.dtype: {res['attention_mask'].dtype}'
for i, seq_len in enumerate(seq_lens):
res['attention_mask'][i, :, seq_len:] = 0

Expand Down
Loading