-
Notifications
You must be signed in to change notification settings - Fork 3
add async query to improve latency #62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
aditya1503
wants to merge
38
commits into
main
Choose a base branch
from
async_query
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 24 commits
Commits
Show all changes
38 commits
Select commit
Hold shift + click to select a range
1bc7370
add cleanlab-tlm as a dependency in pyproject.toml
elisno 2529ae6
Add response validation functionality using TrustworthyRAG
elisno 722d287
alt_answer -> expert_answer
elisno 6f64a12
address comments
elisno a2c0ea5
have is_bad_response function take the BadResponseThreshold object in…
elisno b8a1e97
Enhance Validator with flexible thresholds and improved error handling
elisno db5fe24
move BadResponseThresholds
elisno 29e231a
add prompt and form_prompt
elisno a741e15
fix formatting and type hints
elisno 380b1ef
update docstrings
elisno 4f40e3d
Add unit tests for Validator and BadResponseThresholds
elisno 02b16e0
include type hints and fix formatting
elisno 873f552
set "expert_answer" as first key
elisno b471371
clean up imports, type hints and docs
elisno be4745c
Update pyproject.toml
elisno 54e866b
Update response_validation.py docstring to indicate module deprecatio…
elisno 0a21649
add async query to improve latency
aditya1503 c632625
make remediate method private
elisno d422bcf
update docstrings
elisno d7bc592
revert and wait outside
aditya1503 2407b88
add event lopping
aditya1503 0ac8e5d
add thread correctly
aditya1503 94c626a
add try catch
aditya1503 ae49baf
Merge branch 'validator' into async_query
aditya1503 86707d9
Update validator.py
aditya1503 d57e2c9
merge main
aditya1503 0f1b838
docstring
aditya1503 2556833
add tab to docstring
aditya1503 cee4f13
add bool run_async
aditya1503 84cc0f7
linting
aditya1503 640a194
typing
aditya1503 158e1b2
entry fix
aditya1503 c4330fd
format fix
aditya1503 c9e1357
add docstring
aditya1503 63d2614
simpler cod
aditya1503 bc45c23
noqa
aditya1503 acb3beb
linting
aditya1503 573426d
Merge branch 'main' into async_query
aditya1503 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Any, Optional, Sequence, cast | ||
|
||
from cleanlab_tlm.utils.rag import Eval, TrustworthyRAGScore, get_default_evals | ||
|
||
from cleanlab_codex.types.validator import ThresholdedTrustworthyRAGScore | ||
|
||
if TYPE_CHECKING: | ||
from cleanlab_codex.validator import BadResponseThresholds | ||
|
||
|
||
"""Evaluation metrics (excluding trustworthiness) that are used to determine if a response is bad.""" | ||
DEFAULT_EVAL_METRICS = ["response_helpfulness"] | ||
|
||
|
||
def get_default_evaluations() -> list[Eval]: | ||
"""Get the default evaluations for the TrustworthyRAG. | ||
|
||
Note: | ||
This excludes trustworthiness, which is automatically computed by TrustworthyRAG. | ||
""" | ||
return [evaluation for evaluation in get_default_evals() if evaluation.name in DEFAULT_EVAL_METRICS] | ||
|
||
|
||
def get_default_trustworthyrag_config() -> dict[str, Any]: | ||
"""Get the default configuration for the TrustworthyRAG.""" | ||
return { | ||
"options": { | ||
"log": ["explanation"], | ||
}, | ||
} | ||
|
||
|
||
def update_scores_based_on_thresholds( | ||
scores: TrustworthyRAGScore | Sequence[TrustworthyRAGScore], thresholds: BadResponseThresholds | ||
) -> ThresholdedTrustworthyRAGScore: | ||
"""Adds a `is_bad` flag to the scores dictionaries based on the thresholds.""" | ||
|
||
# Helper function to check if a score is bad | ||
def is_bad(score: Optional[float], threshold: float) -> bool: | ||
return score is not None and score < threshold | ||
|
||
if isinstance(scores, Sequence): | ||
raise NotImplementedError("Batching is not supported yet.") | ||
|
||
thresholded_scores = {} | ||
for eval_name, score_dict in scores.items(): | ||
thresholded_scores[eval_name] = { | ||
**score_dict, | ||
"is_bad": is_bad(score_dict["score"], thresholds.get_threshold(eval_name)), | ||
} | ||
return cast(ThresholdedTrustworthyRAGScore, thresholded_scores) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from cleanlab_tlm.utils.rag import EvalMetric | ||
|
||
|
||
class ThresholdedEvalMetric(EvalMetric): | ||
is_bad: bool | ||
|
||
|
||
ThresholdedEvalMetric.__doc__ = f""" | ||
{EvalMetric.__doc__} | ||
|
||
is_bad: bool | ||
Whether the score is a certain threshold. | ||
""" | ||
|
||
|
||
class ThresholdedTrustworthyRAGScore(dict[str, ThresholdedEvalMetric]): | ||
"""Object returned by `Validator.detect` containing evaluation scores from [TrustworthyRAGScore](/tlm/api/python/utils.rag/#class-trustworthyragscore) | ||
along with a boolean flag, `is_bad`, indicating whether the score is below the threshold. | ||
|
||
Example: | ||
```python | ||
{ | ||
"trustworthiness": { | ||
"score": 0.92, | ||
"log": {"explanation": "Did not find a reason to doubt trustworthiness."}, | ||
"is_bad": False | ||
}, | ||
"response_helpfulness": { | ||
"score": 0.35, | ||
"is_bad": True | ||
}, | ||
... | ||
} | ||
``` | ||
""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,254 @@ | ||
""" | ||
Leverage Cleanlab's Evals and Codex to detect and remediate bad responses in RAG applications. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Any, Callable, Optional, cast | ||
import asyncio | ||
|
||
from cleanlab_tlm import TrustworthyRAG | ||
from pydantic import BaseModel, Field, field_validator | ||
|
||
from cleanlab_codex.internal.validator import ( | ||
get_default_evaluations, | ||
get_default_trustworthyrag_config, | ||
) | ||
from cleanlab_codex.internal.validator import update_scores_based_on_thresholds as _update_scores_based_on_thresholds | ||
from cleanlab_codex.project import Project | ||
|
||
if TYPE_CHECKING: | ||
from cleanlab_codex.types.validator import ThresholdedTrustworthyRAGScore | ||
|
||
|
||
class BadResponseThresholds(BaseModel): | ||
"""Config for determining if a response is bad. | ||
Each key is an evaluation metric and the value is a threshold such that if the score is below the threshold, the response is bad. | ||
|
||
Default Thresholds: | ||
- trustworthiness: 0.5 | ||
- response_helpfulness: 0.5 | ||
- Any custom eval: 0.5 (if not explicitly specified in bad_response_thresholds) | ||
""" | ||
|
||
trustworthiness: float = Field( | ||
description="Threshold for trustworthiness.", | ||
default=0.5, | ||
ge=0.0, | ||
le=1.0, | ||
) | ||
response_helpfulness: float = Field( | ||
description="Threshold for response helpfulness.", | ||
default=0.5, | ||
ge=0.0, | ||
le=1.0, | ||
) | ||
|
||
@property | ||
def default_threshold(self) -> float: | ||
"""The default threshold to use when a specific evaluation metric's threshold is not set. This threshold is set to 0.5.""" | ||
return 0.5 | ||
|
||
def get_threshold(self, eval_name: str) -> float: | ||
"""Get threshold for an eval if it exists. | ||
|
||
For fields defined in the model, returns their value (which may be the field's default). | ||
For custom evals not defined in the model, returns the default threshold value (see `default_threshold`). | ||
""" | ||
|
||
# For fields defined in the model, use their value (which may be the field's default) | ||
if eval_name in self.model_fields: | ||
return cast(float, getattr(self, eval_name)) | ||
|
||
# For custom evals, use the default threshold | ||
return getattr(self, eval_name, self.default_threshold) | ||
|
||
@field_validator("*") | ||
@classmethod | ||
def validate_threshold(cls, v: Any) -> float: | ||
"""Validate that all fields (including dynamic ones) are floats between 0 and 1.""" | ||
if not isinstance(v, (int, float)): | ||
error_msg = f"Threshold must be a number, got {type(v)}" | ||
raise TypeError(error_msg) | ||
if not 0 <= float(v) <= 1: | ||
error_msg = f"Threshold must be between 0 and 1, got {v}" | ||
raise ValueError(error_msg) | ||
return float(v) | ||
|
||
model_config = { | ||
"extra": "allow" # Allow additional fields for custom eval thresholds | ||
} | ||
|
||
|
||
class Validator: | ||
def __init__( | ||
self, | ||
codex_access_key: str, | ||
tlm_api_key: Optional[str] = None, | ||
trustworthy_rag_config: Optional[dict[str, Any]] = None, | ||
bad_response_thresholds: Optional[dict[str, float]] = None, | ||
): | ||
"""Real-time detection and remediation of bad responses in RAG applications, powered by Cleanlab's TrustworthyRAG and Codex. | ||
|
||
This object combines Cleanlab's TrustworthyRAG evaluation scores with configurable thresholds to detect potentially bad responses | ||
in your RAG application. When a bad response is detected, it automatically attempts to remediate by retrieving an expert-provided | ||
answer from your Codex project. | ||
|
||
For most use cases, we recommend using the `validate()` method which provides a complete validation workflow including | ||
both detection and Codex remediation. The `detect()` method is available separately for testing and threshold tuning purposes | ||
without triggering a Codex lookup. | ||
|
||
By default, this uses the same default configurations as [`TrustworthyRAG`](/tlm/api/python/utils.rag/#class-trustworthyrag), except: | ||
- Explanations are returned in logs for better debugging | ||
- Only the `response_helpfulness` eval is run | ||
|
||
Args: | ||
codex_access_key (str): The [access key](/codex/web_tutorials/create_project/#access-keys) for a Codex project. Used to retrieve expert-provided answers | ||
when bad responses are detected. | ||
|
||
tlm_api_key (str, optional): API key for accessing [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag). If not provided, this must be specified | ||
in trustworthy_rag_config. | ||
|
||
trustworthy_rag_config (dict[str, Any], optional): Optional initialization arguments for [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag), | ||
which is used to detect response issues. If not provided, default configuration will be used. | ||
|
||
bad_response_thresholds (dict[str, float], optional): Detection score thresholds used to flag whether | ||
a response is considered bad. Each key corresponds to an Eval from TrustworthyRAG, and the value | ||
indicates a threshold (between 0 and 1) below which scores are considered detected issues. A response | ||
is flagged as bad if any issues are detected. If not provided, default thresholds will be used. See | ||
[`BadResponseThresholds`](/codex/api/python/validator/#class-badresponsethresholds) for more details. | ||
|
||
Raises: | ||
ValueError: If both tlm_api_key and api_key in trustworthy_rag_config are provided. | ||
ValueError: If bad_response_thresholds contains thresholds for non-existent evaluation metrics. | ||
TypeError: If any threshold value is not a number. | ||
ValueError: If any threshold value is not between 0 and 1. | ||
""" | ||
trustworthy_rag_config = trustworthy_rag_config or get_default_trustworthyrag_config() | ||
if tlm_api_key is not None and "api_key" in trustworthy_rag_config: | ||
error_msg = "Cannot specify both tlm_api_key and api_key in trustworthy_rag_config" | ||
raise ValueError(error_msg) | ||
if tlm_api_key is not None: | ||
trustworthy_rag_config["api_key"] = tlm_api_key | ||
|
||
self._project: Project = Project.from_access_key(access_key=codex_access_key) | ||
|
||
trustworthy_rag_config.setdefault("evals", get_default_evaluations()) | ||
self._tlm_rag = TrustworthyRAG(**trustworthy_rag_config) | ||
|
||
# Validate that all the necessary thresholds are present in the TrustworthyRAG. | ||
_evals = [e.name for e in self._tlm_rag.get_evals()] + ["trustworthiness"] | ||
|
||
self._bad_response_thresholds = BadResponseThresholds.model_validate(bad_response_thresholds or {}) | ||
|
||
_threshold_keys = self._bad_response_thresholds.model_dump().keys() | ||
|
||
# Check if there are any thresholds without corresponding evals (this is an error) | ||
_extra_thresholds = set(_threshold_keys) - set(_evals) | ||
if _extra_thresholds: | ||
error_msg = f"Found thresholds for non-existent evaluation metrics: {_extra_thresholds}" | ||
raise ValueError(error_msg) | ||
|
||
def validate( | ||
self, | ||
query: str, | ||
context: str, | ||
response: str, | ||
prompt: Optional[str] = None, | ||
form_prompt: Optional[Callable[[str, str], str]] = None, | ||
) -> dict[str, Any]: | ||
"""Evaluate whether the AI-generated response is bad, and if so, request an alternate expert response. | ||
|
||
Args: | ||
query (str): The user query that was used to generate the response. | ||
context (str): The context that was retrieved from the RAG Knowledge Base and used to generate the response. | ||
response (str): A reponse from your LLM/RAG system. | ||
|
||
Returns: | ||
dict[str, Any]: A dictionary containing: | ||
- 'expert_answer': Alternate SME-provided answer from Codex if the response was flagged as bad and an answer was found, or None otherwise. | ||
- 'is_bad_response': True if the response is flagged as potentially bad (when True, a lookup in Codex is performed), False otherwise. | ||
- Additional keys: Various keys from a [`ThresholdedTrustworthyRAGScore`](/cleanlab_codex/types/validator/#class-thresholdedtrustworthyragscore) dictionary, with raw scores from [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag) for each evaluation metric. `is_bad` indicating whether the score is below the threshold. | ||
""" | ||
try: | ||
loop = asyncio.get_running_loop() | ||
except RuntimeError: # No running loop | ||
loop = asyncio.new_event_loop() | ||
asyncio.set_event_loop(loop) | ||
expert_task = loop.create_task(self.remediate_async(query)) | ||
detect_task = loop.run_in_executor(None, self.detect, query, context, response, prompt, form_prompt) | ||
expert_answer, maybe_entry = loop.run_until_complete(expert_task) | ||
scores, is_bad_response = loop.run_until_complete(detect_task) | ||
loop.close() | ||
if is_bad_response: | ||
if expert_answer == None: | ||
self._project._sdk_client.projects.entries.add_question( | ||
self._project._id, question=query, | ||
).model_dump() | ||
else: | ||
expert_answer = None | ||
|
||
return { | ||
"expert_answer": expert_answer, | ||
"is_bad_response": is_bad_response, | ||
**scores, | ||
} | ||
|
||
def detect( | ||
self, | ||
query: str, | ||
context: str, | ||
response: str, | ||
prompt: Optional[str] = None, | ||
form_prompt: Optional[Callable[[str, str], str]] = None, | ||
) -> tuple[ThresholdedTrustworthyRAGScore, bool]: | ||
"""Score response quality using TrustworthyRAG and flag bad responses based on configured thresholds. | ||
|
||
Note: | ||
This method is primarily intended for testing and threshold tuning purposes. For production use cases, | ||
we recommend using the `validate()` method which provides a complete validation workflow including | ||
Codex remediation. | ||
|
||
Args: | ||
query (str): The user query that was used to generate the response. | ||
context (str): The context that was retrieved from the RAG Knowledge Base and used to generate the response. | ||
response (str): A reponse from your LLM/RAG system. | ||
|
||
Returns: | ||
tuple[ThresholdedTrustworthyRAGScore, bool]: A tuple containing: | ||
- ThresholdedTrustworthyRAGScore: Quality scores for different evaluation metrics like trustworthiness | ||
and response helpfulness. Each metric has a score between 0-1. It also has a boolean flag, `is_bad` indicating whether the score is below a given threshold. | ||
- bool: True if the response is determined to be bad based on the evaluation scores | ||
and configured thresholds, False otherwise. | ||
""" | ||
scores = self._tlm_rag.score( | ||
response=response, | ||
query=query, | ||
context=context, | ||
prompt=prompt, | ||
form_prompt=form_prompt, | ||
) | ||
|
||
thresholded_scores = _update_scores_based_on_thresholds( | ||
scores=scores, | ||
thresholds=self._bad_response_thresholds, | ||
) | ||
|
||
is_bad_response = any(score_dict["is_bad"] for score_dict in thresholded_scores.values()) | ||
return thresholded_scores, is_bad_response | ||
|
||
def _remediate(self, query: str) -> str | None: | ||
"""Request a SME-provided answer for this query, if one is available in Codex. | ||
|
||
Args: | ||
query (str): The user's original query to get SME-provided answer for. | ||
|
||
Returns: | ||
str | None: The SME-provided answer from Codex, or None if no answer could be found in the Codex Project. | ||
""" | ||
codex_answer, _ = self._project.query(question=query) | ||
return codex_answer | ||
|
||
async def remediate_async(self, query: str): | ||
codex_answer, entry = self._project.query(question=query, read_only=True) | ||
return codex_answer, entry |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from typing import cast | ||
|
||
from cleanlab_tlm.utils.rag import TrustworthyRAGScore | ||
|
||
from cleanlab_codex.internal.validator import get_default_evaluations | ||
from cleanlab_codex.validator import BadResponseThresholds | ||
|
||
|
||
def make_scores(trustworthiness: float, response_helpfulness: float) -> TrustworthyRAGScore: | ||
scores = { | ||
"trustworthiness": { | ||
"score": trustworthiness, | ||
}, | ||
"response_helpfulness": { | ||
"score": response_helpfulness, | ||
}, | ||
} | ||
return cast(TrustworthyRAGScore, scores) | ||
|
||
|
||
def make_is_bad_response_config(trustworthiness: float, response_helpfulness: float) -> BadResponseThresholds: | ||
return BadResponseThresholds( | ||
trustworthiness=trustworthiness, | ||
response_helpfulness=response_helpfulness, | ||
) | ||
|
||
|
||
def test_get_default_evaluations() -> None: | ||
assert {evaluation.name for evaluation in get_default_evaluations()} == {"response_helpfulness"} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If
is_bad_response
== True, andexpert_answer
= None, then there's extra work being done.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make
add_question
async as wellThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why did you mark this resolved? This seems like a critical consideration to think about.
If is_bad_response == True, and expert_answer = None, then there may be extra compute being run in these cases. Need to time this implementation vs. original implementation over a bunch of cases where is_bad_response == True, and expert_answer = None