Skip to content

Commit

Permalink
fix: Added off and clear proeprty to Instructor base class (#1087)
Browse files Browse the repository at this point in the history
Adding support for `.off` and `.clear` in instructor
  • Loading branch information
ivanleomk authored Oct 17, 2024
1 parent 52982f3 commit 9ef3ebd
Show file tree
Hide file tree
Showing 3 changed files with 242 additions and 4 deletions.
32 changes: 32 additions & 0 deletions instructor/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,38 @@ def on(
) -> None:
self.hooks.on(hook_name, handler)

def off(
self,
hook_name: (
HookName
| Literal[
"completion:kwargs",
"completion:response",
"completion:error",
"completion:last_attempt",
"parse:error",
]
),
handler: Callable[[Any], None],
) -> None:
self.hooks.off(hook_name, handler)

def clear(
self,
hook_name: (
HookName
| Literal[
"completion:kwargs",
"completion:response",
"completion:error",
"completion:last_attempt",
"parse:error",
]
)
| None = None,
) -> None:
self.hooks.clear(hook_name)

@property
def chat(self) -> Self:
return self
Expand Down
36 changes: 32 additions & 4 deletions instructor/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,16 @@ def on(
>>> hooks.emit_completion_arguments(model="gpt-3.5-turbo", temperature=0.7)
Completion kwargs: (), {'model': 'gpt-3.5-turbo', 'temperature': 0.7}
"""
hook_name = self.get_hook_name(hook_name)
self._handlers[hook_name].append(handler)

def get_hook_name(self, hook_name: HookName | str) -> HookName:
if isinstance(hook_name, str):
try:
hook_name = HookName(hook_name)
return HookName(hook_name)
except ValueError as err:
raise ValueError(f"Invalid hook name: {hook_name}") from err
self._handlers[hook_name].append(handler)
return hook_name

def emit_completion_arguments(self, *args: Any, **kwargs: Any) -> None:
for handler in self._handlers[HookName.COMPLETION_KWARGS]:
Expand Down Expand Up @@ -128,27 +132,51 @@ def emit_parse_error(self, error: Exception) -> None:
f"Error in parse error handler:\n{error_traceback}", stacklevel=2
)

def off(self, hook_name: HookName, handler: Callable[[Any], None]) -> None:
def off(
self,
hook_name: HookName
| Literal[
"completion:kwargs",
"completion:response",
"completion:error",
"completion:last_attempt",
"parse:error",
],
handler: Callable[[Any], None],
) -> None:
"""
Removes a specific handler from an event.
Args:
hook_name (HookName): The name of the hook.
handler (Callable[[Any], None]): The handler to remove.
"""
hook_name = self.get_hook_name(hook_name)
if hook_name in self._handlers:
self._handlers[hook_name].remove(handler)
if not self._handlers[hook_name]:
del self._handlers[hook_name]

def clear(self, hook_name: HookName | None = None) -> None:
def clear(
self,
hook_name: HookName
| Literal[
"completion:kwargs",
"completion:response",
"completion:error",
"completion:last_attempt",
"parse:error",
]
| None = None,
) -> None:
"""
Clears handlers for a specific event or all events.
Args:
hook_name (HookName | None): The name of the event to clear handlers for. If None, all handlers are cleared.
"""
if hook_name is not None:
hook_name = self.get_hook_name(hook_name)
self._handlers.pop(hook_name, None)
else:
self._handlers.clear()
178 changes: 178 additions & 0 deletions tests/llm/test_openai/test_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import pytest
import instructor
from openai import OpenAI
import pprint


@pytest.fixture
def client():
return instructor.from_openai(OpenAI())


def log_kwargs(*args, **kwargs):
pprint.pprint({"args": args, "kwargs": kwargs})


def log_kwargs_1(*args, **kwargs):
pprint.pprint({"args": args, "kwargs": kwargs})


def log_kwargs_2(*args, **kwargs):
pprint.pprint({"args": args, "kwargs": kwargs})


hook_names = [item.value for item in instructor.hooks.HookName]
hook_enums = [instructor.hooks.HookName(hook_name) for hook_name in hook_names]
hook_functions = [log_kwargs, log_kwargs_1, log_kwargs_2]
hook_object = instructor.hooks.Hooks()


@pytest.mark.parametrize("hook_name", hook_names)
@pytest.mark.parametrize("num_functions", [1, 2, 3])
def test_on_method_str(
client: instructor.Instructor, hook_name: str, num_functions: int
):
functions_to_add = hook_functions[:num_functions]
hook_enum = hook_object.get_hook_name(hook_name)

assert hook_enum not in client.hooks._handlers

for func in functions_to_add:
client.on(hook_name, func)

assert hook_enum in client.hooks._handlers
assert len(client.hooks._handlers[hook_enum]) == num_functions

for func in functions_to_add:
assert func in client.hooks._handlers[hook_enum]


@pytest.mark.parametrize("hook_enum", hook_enums)
@pytest.mark.parametrize("num_functions", [1, 2, 3])
def test_on_method_enum(
client: instructor.Instructor,
hook_enum: instructor.hooks.HookName,
num_functions: int,
):
functions_to_add = hook_functions[:num_functions]
assert hook_enum not in client.hooks._handlers

for func in functions_to_add:
client.on(hook_enum, func)

assert hook_enum in client.hooks._handlers
assert len(client.hooks._handlers[hook_enum]) == num_functions

for func in functions_to_add:
assert func in client.hooks._handlers[hook_enum]


@pytest.mark.parametrize("hook_name", hook_names)
@pytest.mark.parametrize("num_functions", [1, 2, 3])
def test_off_method_str(
client: instructor.Instructor,
hook_name: str,
num_functions: int,
):
functions_to_add = hook_functions[:num_functions]
hook_enum = hook_object.get_hook_name(hook_name)
assert hook_enum not in client.hooks._handlers

for func in functions_to_add:
client.on(hook_name, func)

assert hook_enum in client.hooks._handlers
assert len(client.hooks._handlers[hook_enum]) == num_functions

for func in functions_to_add:
client.off(hook_name, func)
if client.hooks._handlers.get(hook_enum):
assert func not in client.hooks._handlers[hook_enum]
else:
assert hook_enum not in client.hooks._handlers

assert hook_enum not in client.hooks._handlers


@pytest.mark.parametrize("hook_enum", hook_enums)
@pytest.mark.parametrize("num_functions", [1, 2, 3])
def test_off_method_enum(
client: instructor.Instructor,
hook_enum: instructor.hooks.HookName,
num_functions: int,
):
functions_to_add = hook_functions[:num_functions]
assert hook_enum not in client.hooks._handlers
for func in functions_to_add:
client.on(hook_enum, func)

assert hook_enum in client.hooks._handlers
assert len(client.hooks._handlers[hook_enum]) == num_functions

for func in functions_to_add:
client.off(hook_enum, func)
if client.hooks._handlers.get(hook_enum):
assert func not in client.hooks._handlers[hook_enum]
else:
assert hook_enum not in client.hooks._handlers

assert hook_enum not in client.hooks._handlers


@pytest.mark.parametrize("hook_name", hook_names)
@pytest.mark.parametrize("num_functions", [1, 2, 3])
def test_clear_method_str(
client: instructor.Instructor,
hook_name: str,
num_functions: int,
):
functions_to_add = hook_functions[:num_functions]

for func in functions_to_add:
client.on(hook_name, func)

hook_enum = hook_object.get_hook_name(hook_name)

assert hook_enum in client.hooks._handlers
assert len(client.hooks._handlers[hook_enum]) == num_functions

client.clear(hook_name)
assert hook_enum not in client.hooks._handlers


@pytest.mark.parametrize("hook_enum", hook_enums)
@pytest.mark.parametrize("num_functions", [1, 2, 3])
def test_clear_method(
client: instructor.Instructor,
hook_enum: instructor.hooks.HookName,
num_functions: int,
):
functions_to_add = hook_functions[:num_functions]

for func in functions_to_add:
client.on(hook_enum, func)

assert hook_enum in client.hooks._handlers
assert len(client.hooks._handlers[hook_enum]) == num_functions

client.clear(hook_enum)
assert hook_enum not in client.hooks._handlers


@pytest.mark.parametrize("hook_enum", hook_enums)
@pytest.mark.parametrize("num_functions", [1, 2, 3])
def test_clear_no_args(
client: instructor.Instructor,
hook_enum: instructor.hooks.HookName,
num_functions: int,
):
functions_to_add = hook_functions[:num_functions]

for func in functions_to_add:
client.on(hook_enum, func)

assert hook_enum in client.hooks._handlers
assert len(client.hooks._handlers[hook_enum]) == num_functions

client.clear()
assert hook_enum not in client.hooks._handlers

0 comments on commit 9ef3ebd

Please sign in to comment.