Skip to content

Commit 40d32b7

Browse files
committed
Merge commit 'public/main@{v2.28.3}' into embedding_model
2 parents 77906cb + 28d5009 commit 40d32b7

File tree

1 file changed

+137
-0
lines changed

1 file changed

+137
-0
lines changed

packages/jupyter-ai-magics/jupyter_ai_magics/providers.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,143 @@ class OllamaProvider(BaseProvider, Ollama):
704704
TextField(key="base_url", label="Base API URL (optional)", format="text"),
705705
]
706706

707+
708+
class JsonContentHandler(LLMContentHandler):
709+
content_type = "application/json"
710+
accepts = "application/json"
711+
712+
def __init__(self, request_schema, response_path):
713+
self.request_schema = json.loads(request_schema)
714+
self.response_path = response_path
715+
self.response_parser = parse(response_path)
716+
717+
def replace_values(self, old_val, new_val, d: Dict[str, Any]):
718+
"""Replaces values of a dictionary recursively."""
719+
for key, val in d.items():
720+
if val == old_val:
721+
d[key] = new_val
722+
if isinstance(val, dict):
723+
self.replace_values(old_val, new_val, val)
724+
725+
return d
726+
727+
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
728+
request_obj = copy.deepcopy(self.request_schema)
729+
self.replace_values("<prompt>", prompt, request_obj)
730+
request = json.dumps(request_obj).encode("utf-8")
731+
return request
732+
733+
def transform_output(self, output: bytes) -> str:
734+
response_json = json.loads(output.read().decode("utf-8"))
735+
matches = self.response_parser.find(response_json)
736+
return matches[0].value
737+
738+
739+
class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
740+
id = "sagemaker-endpoint"
741+
name = "SageMaker endpoint"
742+
models = ["*"]
743+
model_id_key = "endpoint_name"
744+
model_id_label = "Endpoint name"
745+
# This all needs to be on one line of markdown, for use in a table
746+
help = (
747+
"See [https://www.ollama.com/library](https://www.ollama.com/library) for a list of models. "
748+
"Pass a model's name; for example, `deepseek-coder-v2`."
749+
)
750+
models = ["*"]
751+
registry = True
752+
fields = [
753+
TextField(key="base_url", label="Base API URL (optional)", format="text"),
754+
]
755+
756+
def __init__(self, *args, **kwargs):
757+
request_schema = kwargs.pop("request_schema")
758+
response_path = kwargs.pop("response_path")
759+
content_handler = JsonContentHandler(
760+
request_schema=request_schema, response_path=response_path
761+
)
762+
763+
super().__init__(*args, **kwargs, content_handler=content_handler)
764+
765+
async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
766+
return await self._call_in_executor(*args, **kwargs)
767+
768+
769+
# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
770+
class BedrockProvider(BaseProvider, Bedrock):
771+
id = "bedrock"
772+
name = "Amazon Bedrock"
773+
models = [
774+
"amazon.titan-text-express-v1",
775+
"amazon.titan-text-lite-v1",
776+
"ai21.j2-ultra-v1",
777+
"ai21.j2-mid-v1",
778+
"cohere.command-light-text-v14",
779+
"cohere.command-text-v14",
780+
"cohere.command-r-v1:0",
781+
"cohere.command-r-plus-v1:0",
782+
"meta.llama2-13b-chat-v1",
783+
"meta.llama2-70b-chat-v1",
784+
"meta.llama3-8b-instruct-v1:0",
785+
"meta.llama3-70b-instruct-v1:0",
786+
"meta.llama3-1-8b-instruct-v1:0",
787+
"meta.llama3-1-70b-instruct-v1:0",
788+
"mistral.mistral-7b-instruct-v0:2",
789+
"mistral.mixtral-8x7b-instruct-v0:1",
790+
"mistral.mistral-large-2402-v1:0",
791+
]
792+
model_id_key = "model_id"
793+
pypi_package_deps = ["boto3"]
794+
auth_strategy = AwsAuthStrategy()
795+
fields = [
796+
TextField(
797+
key="credentials_profile_name",
798+
label="AWS profile (optional)",
799+
format="text",
800+
),
801+
TextField(key="region_name", label="Region name (optional)", format="text"),
802+
]
803+
804+
async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
805+
return await self._call_in_executor(*args, **kwargs)
806+
807+
808+
# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
809+
class BedrockChatProvider(BaseProvider, BedrockChat):
810+
id = "bedrock-chat"
811+
name = "Amazon Bedrock Chat"
812+
models = [
813+
"anthropic.claude-v2",
814+
"anthropic.claude-v2:1",
815+
"anthropic.claude-instant-v1",
816+
"anthropic.claude-3-sonnet-20240229-v1:0",
817+
"anthropic.claude-3-haiku-20240307-v1:0",
818+
"anthropic.claude-3-opus-20240229-v1:0",
819+
"anthropic.claude-3-5-sonnet-20240620-v1:0",
820+
]
821+
model_id_key = "model_id"
822+
pypi_package_deps = ["boto3"]
823+
auth_strategy = AwsAuthStrategy()
824+
fields = [
825+
TextField(
826+
key="credentials_profile_name",
827+
label="AWS profile (optional)",
828+
format="text",
829+
),
830+
TextField(key="region_name", label="Region name (optional)", format="text"),
831+
]
832+
833+
async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
834+
return await self._call_in_executor(*args, **kwargs)
835+
836+
async def _agenerate(self, *args, **kwargs) -> Coroutine[Any, Any, LLMResult]:
837+
return await self._generate_in_executor(*args, **kwargs)
838+
839+
@property
840+
def allows_concurrency(self):
841+
return not "anthropic" in self.model_id
842+
843+
707844
class TogetherAIProvider(BaseProvider, Together):
708845
id = "togetherai"
709846
name = "Together AI"

0 commit comments

Comments
 (0)