Skip to content

Commit 667c4a4

Browse files
authored
Merge pull request #1105 from guardrails-ai/fix_server_export_mismatch
fix missing exports for server
2 parents 575a3bc + 618ad21 commit 667c4a4

File tree

10 files changed

+272
-10
lines changed

10 files changed

+272
-10
lines changed

guardrails/async_guard.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,11 @@ async def _exec(
369369
output=llm_output,
370370
base_model=self._base_model,
371371
full_schema_reask=full_schema_reask,
372-
disable_tracer=(not self._allow_metrics_collection),
372+
disable_tracer=(
373+
not self._allow_metrics_collection
374+
if isinstance(self._allow_metrics_collection, bool)
375+
else None
376+
),
373377
exec_options=self._exec_opts,
374378
)
375379
# Here we have an async generator
@@ -391,7 +395,11 @@ async def _exec(
391395
output=llm_output,
392396
base_model=self._base_model,
393397
full_schema_reask=full_schema_reask,
394-
disable_tracer=(not self._allow_metrics_collection),
398+
disable_tracer=(
399+
not self._allow_metrics_collection
400+
if isinstance(self._allow_metrics_collection, bool)
401+
else None
402+
),
395403
exec_options=self._exec_opts,
396404
)
397405
# Why are we using a different method here instead of just overriding?

guardrails/guard.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -908,7 +908,11 @@ def _exec(
908908
output=llm_output,
909909
base_model=self._base_model,
910910
full_schema_reask=full_schema_reask,
911-
disable_tracer=(not self._allow_metrics_collection),
911+
disable_tracer=(
912+
not self._allow_metrics_collection
913+
if isinstance(self._allow_metrics_collection, bool)
914+
else None
915+
),
912916
exec_options=self._exec_opts,
913917
)
914918
return runner(call_log=call_log, prompt_params=prompt_params)
@@ -927,7 +931,11 @@ def _exec(
927931
output=llm_output,
928932
base_model=self._base_model,
929933
full_schema_reask=full_schema_reask,
930-
disable_tracer=(not self._allow_metrics_collection),
934+
disable_tracer=(
935+
not self._allow_metrics_collection
936+
if isinstance(self._allow_metrics_collection, bool)
937+
else None
938+
),
931939
exec_options=self._exec_opts,
932940
)
933941
call = runner(call_log=call_log, prompt_params=prompt_params)

guardrails/hub_telemetry/hub_tracing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def wrapper(*args, **kwargs):
253253
nonlocal origin
254254
origin = origin if origin is not None else name
255255
add_attributes(span, attrs, name, origin, *args, **kwargs)
256-
return _run_async_gen(fn, *args, **kwargs)
256+
return fn(*args, **kwargs)
257257
else:
258258
return fn(*args, **kwargs)
259259

guardrails/run/async_stream_runner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ async def async_step(
153153
validate_subschema=True,
154154
stream=True,
155155
)
156+
# TODO why? how does it happen in the other places we handle streams
157+
if validated_fragment is None:
158+
validated_fragment = ""
159+
156160
if isinstance(validated_fragment, SkeletonReAsk):
157161
raise ValueError(
158162
"Received fragment schema is an invalid sub-schema "
@@ -165,7 +169,7 @@ async def async_step(
165169
"Reasks are not yet supported with streaming. Please "
166170
"remove reasks from schema or disable streaming."
167171
)
168-
validation_response += cast(str, validated_fragment)
172+
validation_response += validated_fragment
169173
passed = call_log.status == pass_status
170174
yield ValidationOutcome(
171175
call_id=call_log.id, # type: ignore

guardrails/telemetry/guard_tracing.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212

1313
from opentelemetry import context, trace
14-
from opentelemetry.trace import StatusCode, Tracer, Span
14+
from opentelemetry.trace import StatusCode, Tracer, Span, Link, get_tracer
1515

1616
from guardrails.settings import settings
1717
from guardrails.classes.generic.stack import Stack
@@ -22,6 +22,10 @@
2222
from guardrails.telemetry.common import add_user_attributes
2323
from guardrails.version import GUARDRAILS_VERSION
2424

25+
import sys
26+
27+
if sys.version_info.minor < 10:
28+
from guardrails.utils.polyfills import anext
2529

2630
# from sentence_transformers import SentenceTransformer
2731
# import numpy as np
@@ -195,8 +199,18 @@ async def trace_async_stream_guard(
195199
while next_exists:
196200
try:
197201
res = await anext(result) # type: ignore
198-
add_guard_attributes(guard_span, history, res)
199-
add_user_attributes(guard_span)
202+
if not guard_span.is_recording():
203+
# Assuming you have a tracer instance
204+
tracer = get_tracer(__name__)
205+
# Create a new span and link it to the previous span
206+
with tracer.start_as_current_span(
207+
"new_guard_span", # type: ignore
208+
links=[Link(guard_span.get_span_context())],
209+
) as new_span:
210+
guard_span = new_span
211+
212+
add_guard_attributes(guard_span, history, res)
213+
add_user_attributes(guard_span)
200214
yield res
201215
except StopIteration:
202216
next_exists = False

guardrails/telemetry/runner_tracing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
from guardrails.utils.safe_get import safe_get
2222
from guardrails.version import GUARDRAILS_VERSION
2323

24+
import sys
25+
26+
if sys.version_info.minor < 10:
27+
from guardrails.utils.polyfills import anext
2428

2529
#########################################
2630
### START Runner.step Instrumentation ###

guardrails/utils/hub_telemetry_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def initialize_tracer(
5757
"""Initializes a tracer for Guardrails Hub."""
5858
if enabled is None:
5959
enabled = settings.rc.enable_metrics or False
60-
6160
self._enabled = enabled
6261
self._carrier = {}
6362
self._service_name = service_name

guardrails/utils/openai_utils/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
is_static_openai_chat_acreate_func,
77
is_static_openai_chat_create_func,
88
is_static_openai_create_func,
9+
get_static_openai_create_func,
10+
get_static_openai_chat_create_func,
11+
get_static_openai_acreate_func,
12+
get_static_openai_chat_acreate_func,
913
)
1014

1115
__all__ = [
@@ -16,4 +20,8 @@
1620
"is_static_openai_acreate_func",
1721
"is_static_openai_chat_acreate_func",
1822
"OpenAIServiceUnavailableError",
23+
"get_static_openai_create_func",
24+
"get_static_openai_chat_create_func",
25+
"get_static_openai_acreate_func",
26+
"get_static_openai_chat_acreate_func",
1927
]

guardrails/utils/openai_utils/v1.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import openai
44

5+
import warnings
56
from guardrails.classes.llm.llm_response import LLMResponse
67
from guardrails.utils.openai_utils.base import BaseOpenAIClient
78
from guardrails.utils.openai_utils.streaming_utils import (
@@ -12,6 +13,38 @@
1213
from guardrails.telemetry import trace_llm_call, trace_operation
1314

1415

16+
def get_static_openai_create_func():
17+
warnings.warn(
18+
"This function is deprecated. " " and will be removed in 0.6.0",
19+
DeprecationWarning,
20+
)
21+
return openai.completions.create
22+
23+
24+
def get_static_openai_chat_create_func():
25+
warnings.warn(
26+
"This function is deprecated and will be removed in 0.6.0",
27+
DeprecationWarning,
28+
)
29+
return openai.chat.completions.create
30+
31+
32+
def get_static_openai_acreate_func():
33+
warnings.warn(
34+
"This function is deprecated and will be removed in 0.6.0",
35+
DeprecationWarning,
36+
)
37+
return None
38+
39+
40+
def get_static_openai_chat_acreate_func():
41+
warnings.warn(
42+
"This function is deprecated and will be removed in 0.6.0",
43+
DeprecationWarning,
44+
)
45+
return None
46+
47+
1548
def is_static_openai_create_func(llm_api: Optional[Callable]) -> bool:
1649
try:
1750
return llm_api == openai.completions.create
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# 3 tests
2+
# 1. Test streaming with OpenAICallable (mock openai.Completion.create)
3+
# 2. Test streaming with OpenAIChatCallable (mock openai.ChatCompletion.create)
4+
# 3. Test string schema streaming
5+
# Using the LowerCase Validator, and a custom validator to show new streaming behavior
6+
from typing import Any, Callable, Dict, List, Optional, Union
7+
8+
import asyncio
9+
import pytest
10+
11+
import guardrails as gd
12+
from guardrails.utils.casting_utils import to_int
13+
from guardrails.validator_base import (
14+
ErrorSpan,
15+
FailResult,
16+
OnFailAction,
17+
PassResult,
18+
ValidationResult,
19+
Validator,
20+
register_validator,
21+
)
22+
from tests.integration_tests.test_assets.validators import LowerCase, MockDetectPII
23+
24+
25+
@register_validator(name="minsentencelength", data_type=["string", "list"])
26+
class MinSentenceLengthValidator(Validator):
27+
def __init__(
28+
self,
29+
min: Optional[int] = None,
30+
max: Optional[int] = None,
31+
on_fail: Optional[Callable] = None,
32+
):
33+
super().__init__(
34+
on_fail=on_fail,
35+
min=min,
36+
max=max,
37+
)
38+
self._min = to_int(min)
39+
self._max = to_int(max)
40+
41+
def sentence_split(self, value):
42+
return list(map(lambda x: x + ".", value.split(".")[:-1]))
43+
44+
def validate(self, value: Union[str, List], metadata: Dict) -> ValidationResult:
45+
sentences = self.sentence_split(value)
46+
error_spans = []
47+
index = 0
48+
for sentence in sentences:
49+
if len(sentence) < self._min:
50+
error_spans.append(
51+
ErrorSpan(
52+
start=index,
53+
end=index + len(sentence),
54+
reason=f"Sentence has length less than {self._min}. "
55+
f"Please return a longer output, "
56+
f"that is shorter than {self._max} characters.",
57+
)
58+
)
59+
if len(sentence) > self._max:
60+
error_spans.append(
61+
ErrorSpan(
62+
start=index,
63+
end=index + len(sentence),
64+
reason=f"Sentence has length greater than {self._max}. "
65+
f"Please return a shorter output, "
66+
f"that is shorter than {self._max} characters.",
67+
)
68+
)
69+
index = index + len(sentence)
70+
if len(error_spans) > 0:
71+
return FailResult(
72+
validated_chunk=value,
73+
error_spans=error_spans,
74+
error_message=f"Sentence has length less than {self._min}. "
75+
f"Please return a longer output, "
76+
f"that is shorter than {self._max} characters.",
77+
)
78+
return PassResult(validated_chunk=value)
79+
80+
def validate_stream(self, chunk: Any, metadata: Dict, **kwargs) -> ValidationResult:
81+
return super().validate_stream(chunk, metadata, **kwargs)
82+
83+
84+
class Delta:
85+
content: str
86+
87+
def __init__(self, content):
88+
self.content = content
89+
90+
91+
class Choice:
92+
text: str
93+
finish_reason: str
94+
index: int
95+
delta: Delta
96+
97+
def __init__(self, text, delta, finish_reason, index=0):
98+
self.index = index
99+
self.delta = delta
100+
self.text = text
101+
self.finish_reason = finish_reason
102+
103+
104+
class MockOpenAIV1ChunkResponse:
105+
choices: list
106+
model: str
107+
108+
def __init__(self, choices, model):
109+
self.choices = choices
110+
self.model = model
111+
112+
113+
class Response:
114+
def __init__(self, chunks):
115+
self.chunks = chunks
116+
117+
async def gen():
118+
for chunk in self.chunks:
119+
yield MockOpenAIV1ChunkResponse(
120+
choices=[
121+
Choice(
122+
delta=Delta(content=chunk),
123+
text=chunk,
124+
finish_reason=None,
125+
)
126+
],
127+
model="OpenAI model name",
128+
)
129+
await asyncio.sleep(0) # Yield control to the event loop
130+
131+
self.completion_stream = gen()
132+
133+
134+
POETRY_CHUNKS = [
135+
"John, under ",
136+
"GOLDEN bridges",
137+
", roams,\n",
138+
"SAN Francisco's ",
139+
"hills, his HOME.\n",
140+
"Dreams of",
141+
" FOG, and salty AIR,\n",
142+
"In his HEART",
143+
", he's always THERE.",
144+
]
145+
146+
147+
@pytest.mark.asyncio
148+
async def test_filter_behavior(mocker):
149+
mocker.patch(
150+
"litellm.acompletion",
151+
return_value=Response(POETRY_CHUNKS),
152+
)
153+
154+
guard = gd.AsyncGuard().use_many(
155+
MockDetectPII(
156+
on_fail=OnFailAction.FIX,
157+
pii_entities="pii",
158+
replace_map={"John": "<PERSON>", "SAN Francisco's": "<LOCATION>"},
159+
),
160+
LowerCase(on_fail=OnFailAction.FILTER),
161+
)
162+
prompt = """Write me a 4 line poem about John in San Francisco.
163+
Make every third word all caps."""
164+
gen = await guard(
165+
model="gpt-3.5-turbo",
166+
max_tokens=10,
167+
temperature=0,
168+
stream=True,
169+
prompt=prompt,
170+
)
171+
172+
text = ""
173+
final_res = None
174+
async for res in gen:
175+
final_res = res
176+
text += res.validated_output
177+
178+
assert final_res.raw_llm_output == ", he's always THERE."
179+
# TODO deep dive this
180+
assert text == (
181+
"John, under GOLDEN bridges, roams,\n"
182+
"SAN Francisco's Dreams of FOG, and salty AIR,\n"
183+
"In his HEART"
184+
)

0 commit comments

Comments
 (0)