diff --git a/src/agentlab/llm/base_api.py b/src/agentlab/llm/base_api.py index 9c1ebf5f..b6d1a7be 100644 --- a/src/agentlab/llm/base_api.py +++ b/src/agentlab/llm/base_api.py @@ -21,6 +21,7 @@ class BaseModelArgs(ABC): max_new_tokens: int = None temperature: float = 0.1 vision_support: bool = False + log_probs: bool = False @abstractmethod def make_model(self) -> AbstractChatModel: diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index 7392e666..bf3380b2 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -87,6 +87,7 @@ def make_model(self): model_name=self.model_name, temperature=self.temperature, max_tokens=self.max_new_tokens, + log_probs=self.log_probs, ) @@ -100,6 +101,7 @@ def make_model(self): model_name=self.model_name, temperature=self.temperature, max_tokens=self.max_new_tokens, + log_probs=self.log_probs, ) @@ -115,6 +117,7 @@ def make_model(self): temperature=self.temperature, max_tokens=self.max_new_tokens, deployment_name=self.deployment_name, + log_probs=self.log_probs, ) @@ -225,6 +228,7 @@ def __init__( client_class=OpenAI, client_args=None, pricing_func=None, + log_probs=False, ): assert max_retry > 0, "max_retry should be greater than 0" @@ -233,6 +237,7 @@ def __init__( self.max_tokens = max_tokens self.max_retry = max_retry self.min_retry_wait_time = min_retry_wait_time + self.logprobs = log_probs # Get the API key from the environment variable if not provided if api_key_env_var: @@ -279,6 +284,7 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float n=n_samples, temperature=temperature, max_tokens=self.max_tokens, + logprobs=self.logprobs, ) if completion.usage is None: @@ -308,7 +314,10 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float tracking.TRACKER.instance(input_tokens, output_tokens, cost) if n_samples == 1: - return AIMessage(completion.choices[0].message.content) + res = AIMessage(completion.choices[0].message.content) + if self.logprobs: + res["logprobs"] = completion.choices[0].logprobs + return res else: return [AIMessage(c.message.content) for c in completion.choices] @@ -328,6 +337,7 @@ def __init__( max_tokens=100, max_retry=4, min_retry_wait_time=60, + log_probs=False, ): super().__init__( model_name=model_name, @@ -339,6 +349,7 @@ def __init__( api_key_env_var="OPENAI_API_KEY", client_class=OpenAI, pricing_func=tracking.get_pricing_openai, + log_probs=log_probs, ) @@ -351,6 +362,7 @@ def __init__( max_tokens=100, max_retry=4, min_retry_wait_time=60, + log_probs=False, ): client_args = { "base_url": "https://openrouter.ai/api/v1", @@ -366,6 +378,7 @@ def __init__( client_class=OpenAI, client_args=client_args, pricing_func=tracking.get_pricing_openrouter, + log_probs=log_probs, ) @@ -379,6 +392,7 @@ def __init__( max_tokens=100, max_retry=4, min_retry_wait_time=60, + log_probs=False, ): api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY") endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") @@ -399,6 +413,7 @@ def __init__( client_class=AzureOpenAI, client_args=client_args, pricing_func=tracking.get_pricing_openai, + log_probs=log_probs, ) @@ -412,6 +427,7 @@ def __init__( temperature: Optional[int] = 1e-1, max_new_tokens: Optional[int] = 512, n_retry_server: Optional[int] = 4, + log_probs: Optional[bool] = False, ): super().__init__(model_name, base_model_name, n_retry_server) if temperature < 1e-3: @@ -422,4 +438,4 @@ def __init__( token = os.environ["TGI_TOKEN"] client = InferenceClient(model=model_url, token=token) - self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens) + self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens, details=log_probs) diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index 2bbf219d..d6f9e822 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -382,9 +382,10 @@ def image_to_jpg_base64_url(image: np.ndarray | Image.Image): class BaseMessage(dict): - def __init__(self, role: str, content: Union[str, list[dict]]): + def __init__(self, role: str, content: Union[str, list[dict]], **kwargs): self["role"] = role self["content"] = deepcopy(content) + self.update(kwargs) def __str__(self, warn_if_image=False) -> str: if isinstance(self["content"], str):