Skip to content

Commit

Permalink
Merge pull request #1085 from guardrails-ai/feat/validation-summary
Browse files Browse the repository at this point in the history
Add Validation failure information to Validation Outcome
  • Loading branch information
dtam authored Oct 3, 2024
2 parents c7967db + 46dc05b commit fe2c0b7
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 29 deletions.
54 changes: 54 additions & 0 deletions guardrails/classes/validation/validation_summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# TODO Temp to update once generated class is in
from typing import Iterator, List

from guardrails.classes.generic.arbitrary_model import ArbitraryModel
from guardrails.classes.validation.validation_result import FailResult
from guardrails.classes.validation.validator_logs import ValidatorLogs
from guardrails_api_client import ValidationSummary as IValidationSummary


class ValidationSummary(IValidationSummary, ArbitraryModel):
@staticmethod
def _generate_summaries_from_validator_logs(
validator_logs: List[ValidatorLogs],
) -> Iterator["ValidationSummary"]:
"""
Generate a list of ValidationSummary objects from a list of
ValidatorLogs objects. Using an iterator to allow serializing
the summaries to other formats.
"""
for log in validator_logs:
validation_result = log.validation_result
is_fail_result = isinstance(validation_result, FailResult)
failure_reason = validation_result.error_message if is_fail_result else None
error_spans = validation_result.error_spans if is_fail_result else []
yield ValidationSummary(
validatorName=log.validator_name,
validatorStatus=log.validation_result.outcome, # type: ignore
propertyPath=log.property_path,
failureReason=failure_reason,
errorSpans=error_spans, # type: ignore
)

@staticmethod
def from_validator_logs(
validator_logs: List[ValidatorLogs],
) -> List["ValidationSummary"]:
summaries = []
for summary in ValidationSummary._generate_summaries_from_validator_logs(
validator_logs
):
summaries.append(summary)
return summaries

@staticmethod
def from_validator_logs_only_fails(
validator_logs: List[ValidatorLogs],
) -> List["ValidationSummary"]:
summaries = []
for summary in ValidationSummary._generate_summaries_from_validator_logs(
validator_logs
):
if summary.failure_reason:
summaries.append(summary)
return summaries
13 changes: 12 additions & 1 deletion guardrails/classes/validation_outcome.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Generic, Iterator, Optional, Tuple, Union, cast
from typing import Generic, Iterator, List, Optional, Tuple, Union, cast

from pydantic import Field
from rich.pretty import pretty_repr
Expand All @@ -11,6 +11,7 @@
from guardrails.classes.history import Call, Iteration
from guardrails.classes.output_type import OT
from guardrails.classes.generic.arbitrary_model import ArbitraryModel
from guardrails.classes.validation.validation_summary import ValidationSummary
from guardrails.constants import pass_status
from guardrails.utils.safe_get import safe_get

Expand All @@ -31,6 +32,11 @@ class ValidationOutcome(IValidationOutcome, ArbitraryModel, Generic[OT]):
error: If the validation failed, this field will contain the error message
"""

validation_summaries: Optional[List["ValidationSummary"]] = Field(
description="The summaries of the validation results.", default=[]
)
"""The summaries of the validation results."""

raw_llm_output: Optional[str] = Field(
description="The raw, unchanged output from the LLM call.", default=None
)
Expand Down Expand Up @@ -75,6 +81,10 @@ def from_guard_history(cls, call: Call):
list(last_iteration.reasks), 0
)
validation_passed = call.status == pass_status
validator_logs = last_iteration.validator_logs or []
validation_summaries = ValidationSummary.from_validator_logs_only_fails(
validator_logs
)
reask = last_output if isinstance(last_output, ReAsk) else None
error = call.error
output = cast(OT, call.guarded_output)
Expand All @@ -84,6 +94,7 @@ def from_guard_history(cls, call: Call):
validated_output=output,
reask=reask,
validation_passed=validation_passed,
validation_summaries=validation_summaries,
error=error,
)

Expand Down
9 changes: 9 additions & 0 deletions guardrails/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from guardrails.classes.output_type import OT
from guardrails.classes.rc import RC
from guardrails.classes.validation.validation_result import ErrorSpan
from guardrails.classes.validation.validation_summary import ValidationSummary
from guardrails.classes.validation_outcome import ValidationOutcome
from guardrails.classes.execution import GuardExecutionOptions
from guardrails.classes.generic import Stack
Expand Down Expand Up @@ -1217,6 +1218,13 @@ def _single_server_call(self, *, payload: Dict[str, Any]) -> ValidationOutcome[O
)
self.history.extend([Call.from_interface(call) for call in guard_history])

validation_summaries = []
if self.history.last and self.history.last.iterations.last:
validator_logs = self.history.last.iterations.last.validator_logs
validation_summaries = ValidationSummary.from_validator_logs_only_fails(
validator_logs
)

# TODO: See if the below statement is still true
# Our interfaces are too different for this to work right now.
# Once we move towards shared interfaces for both the open source
Expand All @@ -1232,6 +1240,7 @@ def _single_server_call(self, *, payload: Dict[str, Any]) -> ValidationOutcome[O
raw_llm_output=validation_output.raw_llm_output,
validated_output=validated_output,
validation_passed=(validation_output.validation_passed is True),
validation_summaries=validation_summaries,
)
else:
raise ValueError("Guard does not have an api client!")
Expand Down
40 changes: 13 additions & 27 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ opentelemetry-sdk = "^1.24.0"
opentelemetry-exporter-otlp-proto-grpc = "^1.24.0"
opentelemetry-exporter-otlp-proto-http = "^1.24.0"
guardrails-hub-types = "^0.0.4"
guardrails-api-client = ">=0.3.8"
guardrails-api-client = ">=0.3.13"
diff-match-patch = "^20230430"
guardrails-api = ">=0.0.1"
mlflow = {version = ">=2.0.1", optional = true}
Expand Down

0 comments on commit fe2c0b7

Please sign in to comment.