diff --git a/src/cleanlab_codex/internal/validator.py b/src/cleanlab_codex/internal/validator.py index 8921673..e5aa518 100644 --- a/src/cleanlab_codex/internal/validator.py +++ b/src/cleanlab_codex/internal/validator.py @@ -2,11 +2,14 @@ from typing import TYPE_CHECKING, Any, Optional, Sequence, cast +from cleanlab_tlm.tlm import TLMResponse from cleanlab_tlm.utils.rag import Eval, TrustworthyRAGScore, get_default_evals from cleanlab_codex.types.validator import ThresholdedTrustworthyRAGScore if TYPE_CHECKING: + from cleanlab_tlm import TLM + from cleanlab_codex.validator import BadResponseThresholds @@ -21,6 +24,17 @@ "context_sufficiency": "is_not_context_sufficient", } +REWRITE_QUERY = """Given a conversational Message History and a Query, rewrite the Query to be a self-contained question. If the query is already self contained, return it as-is. + +Query: {query} + +-- +Message History: \n{messages} + +-- +Remember, return the Query as-is except in cases where the Query is missing key words or has content that should be additionally clarified.""" +REWRITE_QUERY_TRUSTWORTHINESS_THRESHOLD = 0.5 + def get_default_evaluations() -> list[Eval]: """Get the default evaluations for the TrustworthyRAG. @@ -40,6 +54,16 @@ def get_default_trustworthyrag_config() -> dict[str, Any]: } +def get_default_tlm_config() -> dict[str, Any]: + """Get the default configuration for the TLM.""" + + return { + "quality_preset": "medium", + "verbose": False, + "options": {"model": "gpt-4.1-nano"}, + } + + def update_scores_based_on_thresholds( scores: TrustworthyRAGScore | Sequence[TrustworthyRAGScore], thresholds: BadResponseThresholds ) -> ThresholdedTrustworthyRAGScore: @@ -108,3 +132,38 @@ def is_bad(metric: str) -> bool: if is_bad("trustworthiness"): return "hallucination" return "other_issues" + + +def validate_messages(messages: Optional[list[dict[str, Any]]] = None) -> None: + """Validate the format of messages based on OpenAI.""" + + if messages is None: + return + if not isinstance(messages, list): + raise TypeError("Messages must be a list of dictionaries.") # noqa: TRY003 + for message in messages: + if not isinstance(message, dict) or "role" not in message or "content" not in message: + raise TypeError("Each message must be a dictionary containing 'role' and 'content' keys.") # noqa: TRY003 + if not isinstance(message["content"], str): + raise TypeError("Message content must be a string.") # noqa: TRY003 + + +def prompt_tlm_for_rewrite_query(query: str, messages: list[dict[str, Any]], tlm: TLM) -> TLMResponse: + """Given the query and message history, prompt the TLM for a response that could possibly be self contained. + If the tlm call fails, then the original query is returned. + """ + + messages_str = "" + for message in messages: + messages_str += f"{message['role']}: {message['content']}\n" + + response = tlm.prompt( + REWRITE_QUERY.format( + query=query, + messages=messages_str, + ) + ) + + if response is None: + return TLMResponse(response=query, trustworthiness_score=1.0) + return cast(TLMResponse, response) diff --git a/src/cleanlab_codex/validator.py b/src/cleanlab_codex/validator.py index c18922f..c4c03a4 100644 --- a/src/cleanlab_codex/validator.py +++ b/src/cleanlab_codex/validator.py @@ -7,19 +7,21 @@ from copy import deepcopy from typing import TYPE_CHECKING, Any, Callable, Optional, cast -from cleanlab_tlm import TrustworthyRAG +from cleanlab_tlm import TLM, TrustworthyRAG from pydantic import BaseModel, Field, field_validator from cleanlab_codex.internal.validator import ( + REWRITE_QUERY_TRUSTWORTHINESS_THRESHOLD, get_default_evaluations, + get_default_tlm_config, get_default_trustworthyrag_config, ) from cleanlab_codex.internal.validator import ( process_score_metadata as _process_score_metadata, ) -from cleanlab_codex.internal.validator import ( - update_scores_based_on_thresholds as _update_scores_based_on_thresholds, -) +from cleanlab_codex.internal.validator import prompt_tlm_for_rewrite_query as _prompt_tlm_for_rewrite_query +from cleanlab_codex.internal.validator import update_scores_based_on_thresholds as _update_scores_based_on_thresholds +from cleanlab_codex.internal.validator import validate_messages as _validate_messages from cleanlab_codex.project import Project if TYPE_CHECKING: @@ -32,6 +34,7 @@ def __init__( codex_access_key: str, tlm_api_key: Optional[str] = None, trustworthy_rag_config: Optional[dict[str, Any]] = None, + tlm_config: Optional[dict[str, Any]] = None, bad_response_thresholds: Optional[dict[str, float]] = None, ): """Real-time detection and remediation of bad responses in RAG applications, powered by Cleanlab's TrustworthyRAG and Codex. @@ -74,11 +77,20 @@ def __init__( ValueError: If any threshold value is not between 0 and 1. """ trustworthy_rag_config = trustworthy_rag_config or get_default_trustworthyrag_config() + tlm_config = tlm_config or get_default_tlm_config() + self._tlm: Optional[TLM] = None + if tlm_api_key is not None and "api_key" in trustworthy_rag_config: error_msg = "Cannot specify both tlm_api_key and api_key in trustworthy_rag_config" raise ValueError(error_msg) if tlm_api_key is not None: trustworthy_rag_config["api_key"] = tlm_api_key + if "api_key" not in tlm_config: + tlm_config["api_key"] = tlm_api_key + else: + error_msg = "Cannot specify both tlm_api_key and api_key in tlm_config" + raise ValueError(error_msg) + self._tlm = TLM(**tlm_config) self._project: Project = Project.from_access_key(access_key=codex_access_key) @@ -108,6 +120,7 @@ def validate( query: str, context: str, response: str, + messages: Optional[list[dict[str, Any]]] = None, prompt: Optional[str] = None, form_prompt: Optional[Callable[[str, str], str]] = None, metadata: Optional[dict[str, Any]] = None, @@ -129,6 +142,10 @@ def validate( - 'is_bad_response': True if the response is flagged as potentially bad, False otherwise. When True, a Codex lookup is performed, which logs this query into the Codex Project for SMEs to answer. - Additional keys from a [`ThresholdedTrustworthyRAGScore`](/codex/api/python/types.validator/#class-thresholdedtrustworthyragscore) dictionary: each corresponds to a [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag) evaluation metric, and points to the score for this evaluation as well as a boolean `is_bad` flagging whether the score falls below the corresponding threshold. """ + _validate_messages(messages) + if messages is not None: + query = self._maybe_rewrite_query(query=query, messages=messages) + scores, is_bad_response = self.detect( query=query, context=context, response=response, prompt=prompt, form_prompt=form_prompt ) @@ -154,6 +171,7 @@ async def validate_async( query: str, context: str, response: str, + messages: Optional[list[dict[str, Any]]] = None, prompt: Optional[str] = None, form_prompt: Optional[Callable[[str, str], str]] = None, metadata: Optional[dict[str, Any]] = None, @@ -175,6 +193,10 @@ async def validate_async( - 'is_bad_response': True if the response is flagged as potentially bad, False otherwise. When True, a Codex lookup is performed, which logs this query into the Codex Project for SMEs to answer. - Additional keys from a [`ThresholdedTrustworthyRAGScore`](/codex/api/python/types.validator/#class-thresholdedtrustworthyragscore) dictionary: each corresponds to a [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag) evaluation metric, and points to the score for this evaluation as well as a boolean `is_bad` flagging whether the score falls below the corresponding threshold. """ + _validate_messages(messages) + if messages is not None: + query = self._maybe_rewrite_query(query=query, messages=messages) + scores, is_bad_response = await self.detect_async(query, context, response, prompt, form_prompt) final_metadata = metadata.copy() if metadata else {} if log_results: @@ -296,6 +318,25 @@ def _remediate(self, *, query: str, metadata: dict[str, Any] | None = None) -> s codex_answer, _ = self._project.query(question=query, metadata=metadata) return codex_answer + def _maybe_rewrite_query(self, *, query: str, messages: list[dict[str, Any]]) -> str: + """Rewrite the query based on the message history if the query is not a self-contained question but could be improved to be one with context from the messages. + + Args: + query (str): The original query. + messages (list[dict[str, Any]]): The message history to use for rewriting the query. + + Returns: + final_query (str): Either the original query and a rewritten self-contained version of the original query. + """ + if self._tlm is not None: + response = _prompt_tlm_for_rewrite_query(query=query, messages=messages, tlm=self._tlm) + if ( + response["trustworthiness_score"] is not None + and response["trustworthiness_score"] >= REWRITE_QUERY_TRUSTWORTHINESS_THRESHOLD + ): + return str(response["response"]) + return query # If the trustworthiness score is below the threshold or we don't have access to the TLM, omit the rewrite + class BadResponseThresholds(BaseModel): """Config for determining if a response is bad. diff --git a/tests/internal/test_validator.py b/tests/internal/test_validator.py index 48e6503..13e5986 100644 --- a/tests/internal/test_validator.py +++ b/tests/internal/test_validator.py @@ -1,11 +1,15 @@ from typing import cast +from unittest.mock import MagicMock +import pytest from cleanlab_tlm.utils.rag import TrustworthyRAGScore from cleanlab_codex.internal.validator import ( get_default_evaluations, process_score_metadata, + prompt_tlm_for_rewrite_query, update_scores_based_on_thresholds, + validate_messages, ) from cleanlab_codex.types.validator import ThresholdedTrustworthyRAGScore from cleanlab_codex.validator import BadResponseThresholds @@ -108,3 +112,40 @@ def test_update_scores_based_on_thresholds() -> None: for metric, expected in expected_is_bad.items(): assert scores[metric]["is_bad"] is expected assert all(scores[k]["score"] == raw_scores[k]["score"] for k in raw_scores) + + +def test_prompt_tlm_with_message_history() -> None: + messages = [ + {"role": "user", "content": "What is the capital of France?"}, + {"role": "assistant", "content": "The capital of France is Paris."}, + ] + + dummy_tlm = MagicMock() + dummy_tlm.prompt.return_value = { + "response": "What is the capital of France?", + "trustworthiness_score": 0.99, + } + + mocked_response = prompt_tlm_for_rewrite_query(query="What is the capital?", messages=messages, tlm=dummy_tlm) + dummy_tlm.prompt.assert_called_once() + + assert mocked_response["response"] == "What is the capital of France?" + assert mocked_response["trustworthiness_score"] == 0.99 + + +def test_validate_messages() -> None: + # Valid messages + valid_messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + validate_messages(valid_messages) # Should not raise + validate_messages(None) + validate_messages() + + # Invalid messages + with pytest.raises(TypeError, match="Each message must be a dictionary containing 'role' and 'content' keys."): + validate_messages([{"role": "assistant"}]) # Missing 'content' + + with pytest.raises(TypeError, match="Message content must be a string."): + validate_messages([{"role": "user", "content": 123}]) # content not string diff --git a/tests/test_validator.py b/tests/test_validator.py index 574ed31..f9f09d0 100644 --- a/tests/test_validator.py +++ b/tests/test_validator.py @@ -64,6 +64,18 @@ def mock_trustworthy_rag() -> Generator[Mock, None, None]: yield mock_class +@pytest.fixture +def mock_tlm() -> Generator[Mock, None, None]: + mock = Mock() + mock.prompt.return_value = { + "response": "rewritten test query", + "trustworthiness_score": 0.99, + } + with patch("cleanlab_codex.validator.TLM") as mock_class: + mock_class.return_value = mock + yield mock_class + + def assert_threshold_equal(validator: Validator, eval_name: str, threshold: float) -> None: assert validator._bad_response_thresholds.get_threshold(eval_name) == threshold # noqa: SLF001 @@ -109,6 +121,56 @@ def test_validate(self, mock_project: Mock, mock_trustworthy_rag: Mock) -> None: assert "score" in result[metric] assert "is_bad" in result[metric] + @pytest.mark.usefixtures("mock_project") + def test_validate_with_message_history(self, mock_trustworthy_rag: Mock, mock_tlm: Mock) -> None: + validator = Validator(codex_access_key="test") + + result = validator.validate( + query="test query", + context="test context", + response="test response", + messages=[ + {"role": "user", "content": "previous message"}, + {"role": "assistant", "content": "previous response"}, + ], + ) + + # Verify TrustworthyRAG.score was called + mock_trustworthy_rag.return_value.score.assert_called_once_with( + response="test response", + query="test query", + context="test context", + prompt=None, + form_prompt=None, + ) + + # Verify TLM was called for rewriting query + mock_tlm.return_value.prompt.assert_called_once() + + # Verify expected result structure + assert result["is_bad_response"] is False + assert result["expert_answer"] is None + + eval_metrics = ["trustworthiness", "response_helpfulness"] + for metric in eval_metrics: + assert metric in result + assert "score" in result[metric] + assert "is_bad" in result[metric] + + def test_maybe_rewrite_query(self, mock_project: Mock, mock_trustworthy_rag: Mock, mock_tlm: Mock) -> None: # noqa: ARG002 + validator = Validator(codex_access_key="test") + + result = validator._maybe_rewrite_query( # noqa: SLF001 -- intentional test of private method + query="confusing test query", + messages=[ + {"role": "user", "content": "previous message"}, + {"role": "assistant", "content": "previous response"}, + ], + ) + + mock_tlm.assert_called_once() + assert result == "rewritten test query" + def test_validate_expert_answer(self, mock_project: Mock, mock_trustworthy_rag: Mock) -> None: # noqa: ARG002 # Setup mock project query response mock_project.from_access_key.return_value.query.return_value = ("expert answer", None) diff --git a/tests/utils/test_function.py b/tests/utils/test_function.py index dc4706c..916d697 100644 --- a/tests/utils/test_function.py +++ b/tests/utils/test_function.py @@ -23,7 +23,7 @@ def function_without_annotations(a) -> None: # type: ignore # noqa: ARG001 fn_schema = pydantic_model_from_function("test_function", function_without_annotations) assert fn_schema.model_json_schema()["title"] == "test_function" - assert fn_schema.model_fields["a"].annotation is Any # type: ignore[comparison-overlap] + assert fn_schema.model_fields["a"].annotation is Any assert fn_schema.model_fields["a"].is_required() assert fn_schema.model_fields["a"].description is None