36
36
AutoModelForCausalLM ,
37
37
AutoTokenizer ,
38
38
BitsAndBytesConfig ,
39
- GenerationConfig ,
40
39
GPTQConfig ,
41
40
PretrainedConfig ,
42
41
)
@@ -656,7 +655,7 @@ def greedy_until_multi_turn( # noqa: C901
656
655
]
657
656
)
658
657
659
- generation_config = GenerationConfig . from_dict ( self .generation_config_dict )
658
+ generation_config = self .generation_config_dict . copy ( )
660
659
generation_config .update (
661
660
{
662
661
"max_new_tokens" : max_generated_tokens ,
@@ -669,7 +668,7 @@ def greedy_until_multi_turn( # noqa: C901
669
668
)
670
669
671
670
model_outputs : GenerateOutput = self .model .generate (
672
- ** model_inputs , stopping_criteria = stopping_criteria , generation_config = generation_config
671
+ ** model_inputs , stopping_criteria = stopping_criteria , ** generation_config
673
672
)
674
673
model_outputs = model_outputs .sequences [0 , model_inputs ["input_ids" ].size (1 ) :]
675
674
model_generations = [model_outputs ]
@@ -699,7 +698,7 @@ def greedy_until_multi_turn( # noqa: C901
699
698
]
700
699
)
701
700
702
- generation_config = GenerationConfig . from_dict ( self .generation_config_dict )
701
+ generation_config = self .generation_config_dict . copy ( )
703
702
generation_config .update (
704
703
{
705
704
"max_new_tokens" : max_generated_tokens ,
@@ -715,7 +714,7 @@ def greedy_until_multi_turn( # noqa: C901
715
714
input_ids = model_inputs ["input_ids" ],
716
715
attention_mask = model_inputs ["attention_mask" ],
717
716
stopping_criteria = stopping_criteria ,
718
- generation_config = generation_config ,
717
+ ** generation_config ,
719
718
)
720
719
model_outputs = model_outputs .sequences [0 , model_inputs ["input_ids" ].size (1 ) :]
721
720
model_generations .append (model_outputs )
@@ -896,7 +895,7 @@ def _generate(
896
895
stopping_criteria = stop_sequences_criteria (self .tokenizer , stop_sequences = stop_tokens , batch = batch )
897
896
batch_size , _ = batch .input_ids .shape
898
897
899
- generation_config = GenerationConfig . from_dict ( self .generation_config_dict )
898
+ generation_config = self .generation_config_dict . copy ( )
900
899
generation_config .update (
901
900
max_new_tokens = max_new_tokens ,
902
901
pad_token_id = self .tokenizer .pad_token_id if self .tokenizer .pad_token_id else self .tokenizer .eos_token_id ,
@@ -912,7 +911,7 @@ def _generate(
912
911
input_ids = batch .input_ids ,
913
912
attention_mask = batch .input_mask ,
914
913
stopping_criteria = stopping_criteria ,
915
- generation_config = generation_config ,
914
+ ** generation_config ,
916
915
)
917
916
generations = outputs .sequences [:, batch .input_ids .size (1 ) :]
918
917
generations = torch .reshape (generations , (batch_size , num_samples , - 1 ))
0 commit comments