Skip to content

Commit

Permalink
Python: provide methods to register single native function to the ker…
Browse files Browse the repository at this point in the history
…nel (#2390)

### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->
solve #2321 

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [ ] The code builds clean without any errors or warnings
- [ ] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [ ] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄

---------

Co-authored-by: Abby Harrison <[email protected]>
Co-authored-by: Abby Harrison <[email protected]>
Co-authored-by: Lee Miller <[email protected]>
  • Loading branch information
4 people authored Aug 23, 2023
1 parent 0c15862 commit 4bc5ff7
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 0 deletions.
33 changes: 33 additions & 0 deletions python/semantic_kernel/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,39 @@ def register_semantic_function(

return function

def register_native_function(
self,
skill_name: Optional[str],
sk_function: Callable,
) -> SKFunctionBase:
if not hasattr(sk_function, "__sk_function__"):
raise KernelException(
KernelException.ErrorCodes.InvalidFunctionType,
"sk_function argument must be decorated with @sk_function",
)
function_name = sk_function.__sk_function_name__

if skill_name is None or skill_name == "":
skill_name = SkillCollection.GLOBAL_SKILL
assert skill_name is not None # for type checker

validate_skill_name(skill_name)
validate_function_name(function_name)

function = SKFunction.from_native_method(sk_function, skill_name, self.logger)

if self.skills.has_function(skill_name, function_name):
raise KernelException(
KernelException.ErrorCodes.FunctionOverloadNotSupported,
"Overloaded functions are not supported, "
"please differentiate function names.",
)

function.set_default_skill_collection(self.skills)
self._skill_collection.add_native_function(function)

return function

async def run_stream_async(
self,
*functions: Any,
Expand Down
61 changes: 61 additions & 0 deletions python/tests/unit/kernel_extensions/test_register_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Microsoft. All rights reserved.


import pytest

from semantic_kernel import Kernel
from semantic_kernel.kernel_exception import KernelException
from semantic_kernel.orchestration.sk_function_base import SKFunctionBase
from semantic_kernel.skill_definition.sk_function_decorator import sk_function
from semantic_kernel.skill_definition.skill_collection import SkillCollection


def not_decorated_native_function(arg1: str) -> str:
return "test"


@sk_function(name="getLightStatus")
def decorated_native_function(arg1: str) -> str:
return "test"


def test_register_valid_native_function():
kernel = Kernel()

registered_func = kernel.register_native_function(
"TestSkill", decorated_native_function
)

assert isinstance(registered_func, SKFunctionBase)
assert (
kernel.skills.get_native_function("TestSkill", "getLightStatus")
== registered_func
)
assert registered_func.invoke("testtest").result == "test"


def test_register_undecorated_native_function():
kernel = Kernel()

with pytest.raises(KernelException):
kernel.register_native_function("TestSkill", not_decorated_native_function)


def test_register_with_none_skill_name():
kernel = Kernel()

registered_func = kernel.register_native_function(None, decorated_native_function)
assert registered_func.skill_name == SkillCollection.GLOBAL_SKILL


def test_register_overloaded_native_function():
kernel = Kernel()

kernel.register_native_function("TestSkill", decorated_native_function)

with pytest.raises(KernelException):
kernel.register_native_function("TestSkill", decorated_native_function)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 4bc5ff7

Please sign in to comment.