Skip to content

[Codex] Add handling for Conversational RAG to Validator API #84

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions src/cleanlab_codex/internal/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

from typing import TYPE_CHECKING, Any, Optional, Sequence, cast

from cleanlab_tlm.tlm import TLMResponse
from cleanlab_tlm.utils.rag import Eval, TrustworthyRAGScore, get_default_evals

from cleanlab_codex.types.validator import ThresholdedTrustworthyRAGScore

if TYPE_CHECKING:
from cleanlab_tlm import TLM

from cleanlab_codex.validator import BadResponseThresholds


Expand All @@ -21,6 +24,17 @@
"context_sufficiency": "is_not_context_sufficient",
}

REWRITE_QUERY = """Given a conversational Message History and a Query, rewrite the Query to be a self-contained question. If the query is already self contained, return it as-is.
Query: {query}
--
Message History: \n{messages}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're already using triple-quotes.

Suggested change
Message History: \n{messages}
Message History:
{messages}

--
Remember, return the Query as-is except in cases where the Query is missing key words or has content that should be additionally clarified."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Remember, return the Query as-is except in cases where the Query is missing key words or has content that should be additionally clarified."""
Remember, return the Query as-is, except in cases where the Query is missing key words or has content that should be additionally clarified."""

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So when the Query is missing key words or has content that should be additionally clarified, what should it do then?

REWRITE_QUERY_TRUSTWORTHINESS_THRESHOLD = 0.5


def get_default_evaluations() -> list[Eval]:
"""Get the default evaluations for the TrustworthyRAG.
Expand All @@ -40,6 +54,16 @@ def get_default_trustworthyrag_config() -> dict[str, Any]:
}


def get_default_tlm_config() -> dict[str, Any]:
"""Get the default configuration for the TLM."""

return {
"quality_preset": "medium",
Copy link
Member

@elisno elisno Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason for using this default preset here? How were these config options chosen?

AFAICT, the quality preset and verbostiy flag are the same by default, but we use gpt-4.1-mini by default instead of gpt-4.1-nano?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd expect us to pick defaults that favor lower latency, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default TLM trustworthiness score in Validator must remain identical to default TLM at all times. Unless there is a spec explicitly written to change it.

This whole config should not be hardcoded here I think. Instead, it can use these helper methods from the cleanlab_tlm library:

https://github.com/cleanlab/cleanlab-tlm/blob/main/src/cleanlab_tlm/utils/config.py

"verbose": False,
"options": {"model": "gpt-4.1-nano"},
}


def update_scores_based_on_thresholds(
scores: TrustworthyRAGScore | Sequence[TrustworthyRAGScore], thresholds: BadResponseThresholds
) -> ThresholdedTrustworthyRAGScore:
Expand Down Expand Up @@ -108,3 +132,38 @@ def is_bad(metric: str) -> bool:
if is_bad("trustworthiness"):
return "hallucination"
return "other_issues"


def validate_messages(messages: Optional[list[dict[str, Any]]] = None) -> None:
Copy link
Member

@elisno elisno Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this name validate_messages should be more carefully chosen when the entire validator module reserves the name method validate in Validator for looking at the trustworthiness & Eval scores.

I'd bet we wouldn't change the Validator.validate api, but we could find a different name for validate_messages since it behaves quite differently.

Copy link
Member

@elisno elisno Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider having validate_messages take messages as a required (positional argument):

Suggested change
def validate_messages(messages: Optional[list[dict[str, Any]]] = None) -> None:
def validate_messages(messages: list[dict[str, Any]]) -> None:

Everywhere it's being called, it takes in a messages argument.
The caller already sets a default value for that argument, so I'd advise against setting default values in two function signatures.

"""Validate the format of messages based on OpenAI."""

if messages is None:
return
if not isinstance(messages, list):
raise TypeError("Messages must be a list of dictionaries.") # noqa: TRY003
for message in messages:
if not isinstance(message, dict) or "role" not in message or "content" not in message:
raise TypeError("Each message must be a dictionary containing 'role' and 'content' keys.") # noqa: TRY003
if not isinstance(message["content"], str):
raise TypeError("Message content must be a string.") # noqa: TRY003


def prompt_tlm_for_rewrite_query(query: str, messages: list[dict[str, Any]], tlm: TLM) -> TLMResponse:
"""Given the query and message history, prompt the TLM for a response that could possibly be self contained.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did we decide this should be TLM in the end? I thought we were thinking this can just be regular openAI call.

If sticking w TLM here, you must ensure:

  • this is a different instance of TLM than the one being use for trustworthiness scoring of the response
  • this instance of TLM is minimal latency (gpt-4.1-nano model, quality_preset='base')

If the tlm call fails, then the original query is returned.
"""

messages_str = ""
for message in messages:
messages_str += f"{message['role']}: {message['content']}\n"
Comment on lines +156 to +158
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit. Just use the list comprehension to join the different substrings instead of leaving an extra whitespace for the last entry.

Suggested change
messages_str = ""
for message in messages:
messages_str += f"{message['role']}: {message['content']}\n"
messages_str = "\n".join([f"{m['role']}: {m['content']}" for m in messages])


response = tlm.prompt(
REWRITE_QUERY.format(
query=query,
messages=messages_str,
)
)

if response is None:
return TLMResponse(response=query, trustworthiness_score=1.0)
return cast(TLMResponse, response)
49 changes: 45 additions & 4 deletions src/cleanlab_codex/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,21 @@
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Callable, Optional, cast

from cleanlab_tlm import TrustworthyRAG
from cleanlab_tlm import TLM, TrustworthyRAG
from pydantic import BaseModel, Field, field_validator

from cleanlab_codex.internal.validator import (
REWRITE_QUERY_TRUSTWORTHINESS_THRESHOLD,
get_default_evaluations,
get_default_tlm_config,
get_default_trustworthyrag_config,
)
from cleanlab_codex.internal.validator import (
process_score_metadata as _process_score_metadata,
)
from cleanlab_codex.internal.validator import (
update_scores_based_on_thresholds as _update_scores_based_on_thresholds,
)
from cleanlab_codex.internal.validator import prompt_tlm_for_rewrite_query as _prompt_tlm_for_rewrite_query
from cleanlab_codex.internal.validator import update_scores_based_on_thresholds as _update_scores_based_on_thresholds
from cleanlab_codex.internal.validator import validate_messages as _validate_messages
from cleanlab_codex.project import Project

if TYPE_CHECKING:
Expand All @@ -32,6 +34,7 @@ def __init__(
codex_access_key: str,
tlm_api_key: Optional[str] = None,
trustworthy_rag_config: Optional[dict[str, Any]] = None,
tlm_config: Optional[dict[str, Any]] = None,
bad_response_thresholds: Optional[dict[str, float]] = None,
):
"""Real-time detection and remediation of bad responses in RAG applications, powered by Cleanlab's TrustworthyRAG and Codex.
Expand Down Expand Up @@ -74,11 +77,20 @@ def __init__(
ValueError: If any threshold value is not between 0 and 1.
"""
trustworthy_rag_config = trustworthy_rag_config or get_default_trustworthyrag_config()
tlm_config = tlm_config or get_default_tlm_config()
self._tlm: Optional[TLM] = None

if tlm_api_key is not None and "api_key" in trustworthy_rag_config:
error_msg = "Cannot specify both tlm_api_key and api_key in trustworthy_rag_config"
raise ValueError(error_msg)
if tlm_api_key is not None:
trustworthy_rag_config["api_key"] = tlm_api_key
if "api_key" not in tlm_config:
tlm_config["api_key"] = tlm_api_key
else:
error_msg = "Cannot specify both tlm_api_key and api_key in tlm_config"
raise ValueError(error_msg)
self._tlm = TLM(**tlm_config)

self._project: Project = Project.from_access_key(access_key=codex_access_key)

Expand Down Expand Up @@ -108,6 +120,7 @@ def validate(
query: str,
context: str,
response: str,
messages: Optional[list[dict[str, Any]]] = None,
prompt: Optional[str] = None,
form_prompt: Optional[Callable[[str, str], str]] = None,
metadata: Optional[dict[str, Any]] = None,
Expand All @@ -129,6 +142,10 @@ def validate(
- 'is_bad_response': True if the response is flagged as potentially bad, False otherwise. When True, a Codex lookup is performed, which logs this query into the Codex Project for SMEs to answer.
- Additional keys from a [`ThresholdedTrustworthyRAGScore`](/codex/api/python/types.validator/#class-thresholdedtrustworthyragscore) dictionary: each corresponds to a [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag) evaluation metric, and points to the score for this evaluation as well as a boolean `is_bad` flagging whether the score falls below the corresponding threshold.
"""
_validate_messages(messages)
if messages is not None:
query = self._maybe_rewrite_query(query=query, messages=messages)

scores, is_bad_response = self.detect(
query=query, context=context, response=response, prompt=prompt, form_prompt=form_prompt
)
Expand All @@ -154,6 +171,7 @@ async def validate_async(
query: str,
context: str,
response: str,
messages: Optional[list[dict[str, Any]]] = None,
prompt: Optional[str] = None,
form_prompt: Optional[Callable[[str, str], str]] = None,
metadata: Optional[dict[str, Any]] = None,
Expand All @@ -175,6 +193,10 @@ async def validate_async(
- 'is_bad_response': True if the response is flagged as potentially bad, False otherwise. When True, a Codex lookup is performed, which logs this query into the Codex Project for SMEs to answer.
- Additional keys from a [`ThresholdedTrustworthyRAGScore`](/codex/api/python/types.validator/#class-thresholdedtrustworthyragscore) dictionary: each corresponds to a [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag) evaluation metric, and points to the score for this evaluation as well as a boolean `is_bad` flagging whether the score falls below the corresponding threshold.
"""
_validate_messages(messages)
if messages is not None:
query = self._maybe_rewrite_query(query=query, messages=messages)

scores, is_bad_response = await self.detect_async(query, context, response, prompt, form_prompt)
final_metadata = metadata.copy() if metadata else {}
if log_results:
Expand Down Expand Up @@ -296,6 +318,25 @@ def _remediate(self, *, query: str, metadata: dict[str, Any] | None = None) -> s
codex_answer, _ = self._project.query(question=query, metadata=metadata)
return codex_answer

def _maybe_rewrite_query(self, *, query: str, messages: list[dict[str, Any]]) -> str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This _maybe... prefix implies that we might get something different from the method, other than a string. Should the check for self._tlm be done by the caller?

"""Rewrite the query based on the message history if the query is not a self-contained question but could be improved to be one with context from the messages.
Args:
query (str): The original query.
messages (list[dict[str, Any]]): The message history to use for rewriting the query.
Returns:
final_query (str): Either the original query and a rewritten self-contained version of the original query.
"""
if self._tlm is not None:
response = _prompt_tlm_for_rewrite_query(query=query, messages=messages, tlm=self._tlm)
if (
response["trustworthiness_score"] is not None
and response["trustworthiness_score"] >= REWRITE_QUERY_TRUSTWORTHINESS_THRESHOLD
):
return str(response["response"])
return query # If the trustworthiness score is below the threshold or we don't have access to the TLM, omit the rewrite


class BadResponseThresholds(BaseModel):
"""Config for determining if a response is bad.
Expand Down
41 changes: 41 additions & 0 deletions tests/internal/test_validator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from typing import cast
from unittest.mock import MagicMock

import pytest
from cleanlab_tlm.utils.rag import TrustworthyRAGScore

from cleanlab_codex.internal.validator import (
get_default_evaluations,
process_score_metadata,
prompt_tlm_for_rewrite_query,
update_scores_based_on_thresholds,
validate_messages,
)
from cleanlab_codex.types.validator import ThresholdedTrustworthyRAGScore
from cleanlab_codex.validator import BadResponseThresholds
Expand Down Expand Up @@ -108,3 +112,40 @@ def test_update_scores_based_on_thresholds() -> None:
for metric, expected in expected_is_bad.items():
assert scores[metric]["is_bad"] is expected
assert all(scores[k]["score"] == raw_scores[k]["score"] for k in raw_scores)


def test_prompt_tlm_with_message_history() -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add test to confirm there is no query rewriting happening, whenever this is the first user message

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add test to confirm that the primary TrustworthyRAG.score(prompt, response) call happens with prompt reflecting the full chat history, not with prompt reflecting the rewritten query.

Confirm you are using this TLM utils method:
cleanlab/cleanlab-tlm@a479e32

to turn the chat history into a prompt string.

messages = [
{"role": "user", "content": "What is the capital of France?"},
{"role": "assistant", "content": "The capital of France is Paris."},
]

dummy_tlm = MagicMock()
dummy_tlm.prompt.return_value = {
"response": "What is the capital of France?",
"trustworthiness_score": 0.99,
}

mocked_response = prompt_tlm_for_rewrite_query(query="What is the capital?", messages=messages, tlm=dummy_tlm)
dummy_tlm.prompt.assert_called_once()

assert mocked_response["response"] == "What is the capital of France?"
assert mocked_response["trustworthiness_score"] == 0.99


def test_validate_messages() -> None:
# Valid messages
valid_messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
validate_messages(valid_messages) # Should not raise
validate_messages(None)
validate_messages()

# Invalid messages
with pytest.raises(TypeError, match="Each message must be a dictionary containing 'role' and 'content' keys."):
validate_messages([{"role": "assistant"}]) # Missing 'content'

with pytest.raises(TypeError, match="Message content must be a string."):
validate_messages([{"role": "user", "content": 123}]) # content not string
62 changes: 62 additions & 0 deletions tests/test_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ def mock_trustworthy_rag() -> Generator[Mock, None, None]:
yield mock_class


@pytest.fixture
def mock_tlm() -> Generator[Mock, None, None]:
mock = Mock()
mock.prompt.return_value = {
"response": "rewritten test query",
"trustworthiness_score": 0.99,
}
with patch("cleanlab_codex.validator.TLM") as mock_class:
mock_class.return_value = mock
yield mock_class


def assert_threshold_equal(validator: Validator, eval_name: str, threshold: float) -> None:
assert validator._bad_response_thresholds.get_threshold(eval_name) == threshold # noqa: SLF001

Expand Down Expand Up @@ -109,6 +121,56 @@ def test_validate(self, mock_project: Mock, mock_trustworthy_rag: Mock) -> None:
assert "score" in result[metric]
assert "is_bad" in result[metric]

@pytest.mark.usefixtures("mock_project")
def test_validate_with_message_history(self, mock_trustworthy_rag: Mock, mock_tlm: Mock) -> None:
validator = Validator(codex_access_key="test")

result = validator.validate(
query="test query",
context="test context",
response="test response",
messages=[
{"role": "user", "content": "previous message"},
{"role": "assistant", "content": "previous response"},
],
)

# Verify TrustworthyRAG.score was called
mock_trustworthy_rag.return_value.score.assert_called_once_with(
response="test response",
query="test query",
context="test context",
prompt=None,
form_prompt=None,
)

# Verify TLM was called for rewriting query
mock_tlm.return_value.prompt.assert_called_once()

# Verify expected result structure
assert result["is_bad_response"] is False
assert result["expert_answer"] is None

eval_metrics = ["trustworthiness", "response_helpfulness"]
for metric in eval_metrics:
assert metric in result
assert "score" in result[metric]
assert "is_bad" in result[metric]

def test_maybe_rewrite_query(self, mock_project: Mock, mock_trustworthy_rag: Mock, mock_tlm: Mock) -> None: # noqa: ARG002
validator = Validator(codex_access_key="test")

result = validator._maybe_rewrite_query( # noqa: SLF001 -- intentional test of private method
query="confusing test query",
messages=[
{"role": "user", "content": "previous message"},
{"role": "assistant", "content": "previous response"},
],
)

mock_tlm.assert_called_once()
assert result == "rewritten test query"

def test_validate_expert_answer(self, mock_project: Mock, mock_trustworthy_rag: Mock) -> None: # noqa: ARG002
# Setup mock project query response
mock_project.from_access_key.return_value.query.return_value = ("expert answer", None)
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def function_without_annotations(a) -> None: # type: ignore # noqa: ARG001

fn_schema = pydantic_model_from_function("test_function", function_without_annotations)
assert fn_schema.model_json_schema()["title"] == "test_function"
assert fn_schema.model_fields["a"].annotation is Any # type: ignore[comparison-overlap]
assert fn_schema.model_fields["a"].annotation is Any
assert fn_schema.model_fields["a"].is_required()
assert fn_schema.model_fields["a"].description is None

Expand Down
Loading