-
Notifications
You must be signed in to change notification settings - Fork 3
[Codex] Add handling for Conversational RAG to Validator API #84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -2,11 +2,14 @@ | |||||||||
|
||||||||||
from typing import TYPE_CHECKING, Any, Optional, Sequence, cast | ||||||||||
|
||||||||||
from cleanlab_tlm.tlm import TLMResponse | ||||||||||
from cleanlab_tlm.utils.rag import Eval, TrustworthyRAGScore, get_default_evals | ||||||||||
|
||||||||||
from cleanlab_codex.types.validator import ThresholdedTrustworthyRAGScore | ||||||||||
|
||||||||||
if TYPE_CHECKING: | ||||||||||
from cleanlab_tlm import TLM | ||||||||||
|
||||||||||
from cleanlab_codex.validator import BadResponseThresholds | ||||||||||
|
||||||||||
|
||||||||||
|
@@ -21,6 +24,17 @@ | |||||||||
"context_sufficiency": "is_not_context_sufficient", | ||||||||||
} | ||||||||||
|
||||||||||
REWRITE_QUERY = """Given a conversational Message History and a Query, rewrite the Query to be a self-contained question. If the query is already self contained, return it as-is. | ||||||||||
Query: {query} | ||||||||||
-- | ||||||||||
Message History: \n{messages} | ||||||||||
-- | ||||||||||
Remember, return the Query as-is except in cases where the Query is missing key words or has content that should be additionally clarified.""" | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So when the Query is missing key words or has content that should be additionally clarified, what should it do then? |
||||||||||
REWRITE_QUERY_TRUSTWORTHINESS_THRESHOLD = 0.5 | ||||||||||
|
||||||||||
|
||||||||||
def get_default_evaluations() -> list[Eval]: | ||||||||||
"""Get the default evaluations for the TrustworthyRAG. | ||||||||||
|
@@ -40,6 +54,16 @@ def get_default_trustworthyrag_config() -> dict[str, Any]: | |||||||||
} | ||||||||||
|
||||||||||
|
||||||||||
def get_default_tlm_config() -> dict[str, Any]: | ||||||||||
"""Get the default configuration for the TLM.""" | ||||||||||
|
||||||||||
return { | ||||||||||
"quality_preset": "medium", | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason for using this default preset here? How were these config options chosen? AFAICT, the quality preset and verbostiy flag are the same by default, but we use gpt-4.1-mini by default instead of gpt-4.1-nano? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd expect us to pick defaults that favor lower latency, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. default TLM trustworthiness score in Validator must remain identical to default TLM at all times. Unless there is a spec explicitly written to change it. This whole config should not be hardcoded here I think. Instead, it can use these helper methods from the cleanlab_tlm library: https://github.com/cleanlab/cleanlab-tlm/blob/main/src/cleanlab_tlm/utils/config.py |
||||||||||
"verbose": False, | ||||||||||
"options": {"model": "gpt-4.1-nano"}, | ||||||||||
} | ||||||||||
|
||||||||||
|
||||||||||
def update_scores_based_on_thresholds( | ||||||||||
scores: TrustworthyRAGScore | Sequence[TrustworthyRAGScore], thresholds: BadResponseThresholds | ||||||||||
) -> ThresholdedTrustworthyRAGScore: | ||||||||||
|
@@ -108,3 +132,38 @@ def is_bad(metric: str) -> bool: | |||||||||
if is_bad("trustworthiness"): | ||||||||||
return "hallucination" | ||||||||||
return "other_issues" | ||||||||||
|
||||||||||
|
||||||||||
def validate_messages(messages: Optional[list[dict[str, Any]]] = None) -> None: | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this name I'd bet we wouldn't change the Validator.validate api, but we could find a different name for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider having
Suggested change
Everywhere it's being called, it takes in a |
||||||||||
"""Validate the format of messages based on OpenAI.""" | ||||||||||
|
||||||||||
if messages is None: | ||||||||||
return | ||||||||||
if not isinstance(messages, list): | ||||||||||
raise TypeError("Messages must be a list of dictionaries.") # noqa: TRY003 | ||||||||||
for message in messages: | ||||||||||
if not isinstance(message, dict) or "role" not in message or "content" not in message: | ||||||||||
raise TypeError("Each message must be a dictionary containing 'role' and 'content' keys.") # noqa: TRY003 | ||||||||||
if not isinstance(message["content"], str): | ||||||||||
raise TypeError("Message content must be a string.") # noqa: TRY003 | ||||||||||
|
||||||||||
|
||||||||||
def prompt_tlm_for_rewrite_query(query: str, messages: list[dict[str, Any]], tlm: TLM) -> TLMResponse: | ||||||||||
"""Given the query and message history, prompt the TLM for a response that could possibly be self contained. | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. did we decide this should be TLM in the end? I thought we were thinking this can just be regular openAI call. If sticking w TLM here, you must ensure:
|
||||||||||
If the tlm call fails, then the original query is returned. | ||||||||||
""" | ||||||||||
|
||||||||||
messages_str = "" | ||||||||||
for message in messages: | ||||||||||
messages_str += f"{message['role']}: {message['content']}\n" | ||||||||||
Comment on lines
+156
to
+158
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit. Just use the list comprehension to join the different substrings instead of leaving an extra whitespace for the last entry.
Suggested change
|
||||||||||
|
||||||||||
response = tlm.prompt( | ||||||||||
REWRITE_QUERY.format( | ||||||||||
query=query, | ||||||||||
messages=messages_str, | ||||||||||
) | ||||||||||
) | ||||||||||
|
||||||||||
if response is None: | ||||||||||
return TLMResponse(response=query, trustworthiness_score=1.0) | ||||||||||
return cast(TLMResponse, response) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,19 +7,21 @@ | |
from copy import deepcopy | ||
from typing import TYPE_CHECKING, Any, Callable, Optional, cast | ||
|
||
from cleanlab_tlm import TrustworthyRAG | ||
from cleanlab_tlm import TLM, TrustworthyRAG | ||
from pydantic import BaseModel, Field, field_validator | ||
|
||
from cleanlab_codex.internal.validator import ( | ||
REWRITE_QUERY_TRUSTWORTHINESS_THRESHOLD, | ||
get_default_evaluations, | ||
get_default_tlm_config, | ||
get_default_trustworthyrag_config, | ||
) | ||
from cleanlab_codex.internal.validator import ( | ||
process_score_metadata as _process_score_metadata, | ||
) | ||
from cleanlab_codex.internal.validator import ( | ||
update_scores_based_on_thresholds as _update_scores_based_on_thresholds, | ||
) | ||
from cleanlab_codex.internal.validator import prompt_tlm_for_rewrite_query as _prompt_tlm_for_rewrite_query | ||
from cleanlab_codex.internal.validator import update_scores_based_on_thresholds as _update_scores_based_on_thresholds | ||
from cleanlab_codex.internal.validator import validate_messages as _validate_messages | ||
from cleanlab_codex.project import Project | ||
|
||
if TYPE_CHECKING: | ||
|
@@ -32,6 +34,7 @@ def __init__( | |
codex_access_key: str, | ||
tlm_api_key: Optional[str] = None, | ||
trustworthy_rag_config: Optional[dict[str, Any]] = None, | ||
tlm_config: Optional[dict[str, Any]] = None, | ||
bad_response_thresholds: Optional[dict[str, float]] = None, | ||
): | ||
"""Real-time detection and remediation of bad responses in RAG applications, powered by Cleanlab's TrustworthyRAG and Codex. | ||
|
@@ -74,11 +77,20 @@ def __init__( | |
ValueError: If any threshold value is not between 0 and 1. | ||
""" | ||
trustworthy_rag_config = trustworthy_rag_config or get_default_trustworthyrag_config() | ||
tlm_config = tlm_config or get_default_tlm_config() | ||
self._tlm: Optional[TLM] = None | ||
|
||
if tlm_api_key is not None and "api_key" in trustworthy_rag_config: | ||
error_msg = "Cannot specify both tlm_api_key and api_key in trustworthy_rag_config" | ||
raise ValueError(error_msg) | ||
if tlm_api_key is not None: | ||
trustworthy_rag_config["api_key"] = tlm_api_key | ||
if "api_key" not in tlm_config: | ||
tlm_config["api_key"] = tlm_api_key | ||
else: | ||
error_msg = "Cannot specify both tlm_api_key and api_key in tlm_config" | ||
raise ValueError(error_msg) | ||
self._tlm = TLM(**tlm_config) | ||
|
||
self._project: Project = Project.from_access_key(access_key=codex_access_key) | ||
|
||
|
@@ -108,6 +120,7 @@ def validate( | |
query: str, | ||
context: str, | ||
response: str, | ||
messages: Optional[list[dict[str, Any]]] = None, | ||
prompt: Optional[str] = None, | ||
form_prompt: Optional[Callable[[str, str], str]] = None, | ||
metadata: Optional[dict[str, Any]] = None, | ||
|
@@ -129,6 +142,10 @@ def validate( | |
- 'is_bad_response': True if the response is flagged as potentially bad, False otherwise. When True, a Codex lookup is performed, which logs this query into the Codex Project for SMEs to answer. | ||
- Additional keys from a [`ThresholdedTrustworthyRAGScore`](/codex/api/python/types.validator/#class-thresholdedtrustworthyragscore) dictionary: each corresponds to a [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag) evaluation metric, and points to the score for this evaluation as well as a boolean `is_bad` flagging whether the score falls below the corresponding threshold. | ||
""" | ||
_validate_messages(messages) | ||
if messages is not None: | ||
query = self._maybe_rewrite_query(query=query, messages=messages) | ||
|
||
scores, is_bad_response = self.detect( | ||
query=query, context=context, response=response, prompt=prompt, form_prompt=form_prompt | ||
) | ||
|
@@ -154,6 +171,7 @@ async def validate_async( | |
query: str, | ||
context: str, | ||
response: str, | ||
messages: Optional[list[dict[str, Any]]] = None, | ||
prompt: Optional[str] = None, | ||
form_prompt: Optional[Callable[[str, str], str]] = None, | ||
metadata: Optional[dict[str, Any]] = None, | ||
|
@@ -175,6 +193,10 @@ async def validate_async( | |
- 'is_bad_response': True if the response is flagged as potentially bad, False otherwise. When True, a Codex lookup is performed, which logs this query into the Codex Project for SMEs to answer. | ||
- Additional keys from a [`ThresholdedTrustworthyRAGScore`](/codex/api/python/types.validator/#class-thresholdedtrustworthyragscore) dictionary: each corresponds to a [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag) evaluation metric, and points to the score for this evaluation as well as a boolean `is_bad` flagging whether the score falls below the corresponding threshold. | ||
""" | ||
_validate_messages(messages) | ||
if messages is not None: | ||
query = self._maybe_rewrite_query(query=query, messages=messages) | ||
|
||
scores, is_bad_response = await self.detect_async(query, context, response, prompt, form_prompt) | ||
final_metadata = metadata.copy() if metadata else {} | ||
if log_results: | ||
|
@@ -296,6 +318,25 @@ def _remediate(self, *, query: str, metadata: dict[str, Any] | None = None) -> s | |
codex_answer, _ = self._project.query(question=query, metadata=metadata) | ||
return codex_answer | ||
|
||
def _maybe_rewrite_query(self, *, query: str, messages: list[dict[str, Any]]) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This _maybe... prefix implies that we might get something different from the method, other than a string. Should the check for |
||
"""Rewrite the query based on the message history if the query is not a self-contained question but could be improved to be one with context from the messages. | ||
Args: | ||
query (str): The original query. | ||
messages (list[dict[str, Any]]): The message history to use for rewriting the query. | ||
Returns: | ||
final_query (str): Either the original query and a rewritten self-contained version of the original query. | ||
""" | ||
if self._tlm is not None: | ||
response = _prompt_tlm_for_rewrite_query(query=query, messages=messages, tlm=self._tlm) | ||
if ( | ||
response["trustworthiness_score"] is not None | ||
and response["trustworthiness_score"] >= REWRITE_QUERY_TRUSTWORTHINESS_THRESHOLD | ||
): | ||
return str(response["response"]) | ||
return query # If the trustworthiness score is below the threshold or we don't have access to the TLM, omit the rewrite | ||
|
||
|
||
class BadResponseThresholds(BaseModel): | ||
"""Config for determining if a response is bad. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,15 @@ | ||
from typing import cast | ||
from unittest.mock import MagicMock | ||
|
||
import pytest | ||
from cleanlab_tlm.utils.rag import TrustworthyRAGScore | ||
|
||
from cleanlab_codex.internal.validator import ( | ||
get_default_evaluations, | ||
process_score_metadata, | ||
prompt_tlm_for_rewrite_query, | ||
update_scores_based_on_thresholds, | ||
validate_messages, | ||
) | ||
from cleanlab_codex.types.validator import ThresholdedTrustworthyRAGScore | ||
from cleanlab_codex.validator import BadResponseThresholds | ||
|
@@ -108,3 +112,40 @@ def test_update_scores_based_on_thresholds() -> None: | |
for metric, expected in expected_is_bad.items(): | ||
assert scores[metric]["is_bad"] is expected | ||
assert all(scores[k]["score"] == raw_scores[k]["score"] for k in raw_scores) | ||
|
||
|
||
def test_prompt_tlm_with_message_history() -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add test to confirm there is no query rewriting happening, whenever this is the first user message There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add test to confirm that the primary Confirm you are using this TLM utils method: to turn the chat history into a prompt string. |
||
messages = [ | ||
{"role": "user", "content": "What is the capital of France?"}, | ||
{"role": "assistant", "content": "The capital of France is Paris."}, | ||
] | ||
|
||
dummy_tlm = MagicMock() | ||
dummy_tlm.prompt.return_value = { | ||
"response": "What is the capital of France?", | ||
"trustworthiness_score": 0.99, | ||
} | ||
|
||
mocked_response = prompt_tlm_for_rewrite_query(query="What is the capital?", messages=messages, tlm=dummy_tlm) | ||
dummy_tlm.prompt.assert_called_once() | ||
|
||
assert mocked_response["response"] == "What is the capital of France?" | ||
assert mocked_response["trustworthiness_score"] == 0.99 | ||
|
||
|
||
def test_validate_messages() -> None: | ||
# Valid messages | ||
valid_messages = [ | ||
{"role": "user", "content": "Hello"}, | ||
{"role": "assistant", "content": "Hi there!"}, | ||
] | ||
validate_messages(valid_messages) # Should not raise | ||
validate_messages(None) | ||
validate_messages() | ||
|
||
# Invalid messages | ||
with pytest.raises(TypeError, match="Each message must be a dictionary containing 'role' and 'content' keys."): | ||
validate_messages([{"role": "assistant"}]) # Missing 'content' | ||
|
||
with pytest.raises(TypeError, match="Message content must be a string."): | ||
validate_messages([{"role": "user", "content": 123}]) # content not string |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're already using triple-quotes.