diff --git a/lagent/llms/huggingface.py b/lagent/llms/huggingface.py index 6d61a8d0..57ea17a8 100644 --- a/lagent/llms/huggingface.py +++ b/lagent/llms/huggingface.py @@ -145,6 +145,11 @@ def stream_generate( new_gen_params = self.update_gen_params(**kwargs) generation_config.update(**new_gen_params) generation_config.update(**kwargs) + temperature = generation_config.temperature + if isinstance(temperature, (float, int)): + if int(temperature) == 0: + temperature = 1e-10 + generation_config.update(temperature=temperature) model_kwargs = generation_config.to_dict() model_kwargs['attention_mask'] = attention_mask _, eos_token_id = ( # noqa: F841 # pylint: disable=W0612