Skip to content

Commit

Permalink
Merge commit 'public/main@{v2.28.3}' into embedding_model
Browse files Browse the repository at this point in the history
  • Loading branch information
cl-gavan committed Dec 17, 2024
2 parents 77906cb + 28d5009 commit 40d32b7
Showing 1 changed file with 137 additions and 0 deletions.
137 changes: 137 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,143 @@ class OllamaProvider(BaseProvider, Ollama):
TextField(key="base_url", label="Base API URL (optional)", format="text"),
]


class JsonContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"

def __init__(self, request_schema, response_path):
self.request_schema = json.loads(request_schema)
self.response_path = response_path
self.response_parser = parse(response_path)

def replace_values(self, old_val, new_val, d: Dict[str, Any]):
"""Replaces values of a dictionary recursively."""
for key, val in d.items():
if val == old_val:
d[key] = new_val
if isinstance(val, dict):
self.replace_values(old_val, new_val, val)

return d

def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
request_obj = copy.deepcopy(self.request_schema)
self.replace_values("<prompt>", prompt, request_obj)
request = json.dumps(request_obj).encode("utf-8")
return request

def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
matches = self.response_parser.find(response_json)
return matches[0].value


class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
id = "sagemaker-endpoint"
name = "SageMaker endpoint"
models = ["*"]
model_id_key = "endpoint_name"
model_id_label = "Endpoint name"
# This all needs to be on one line of markdown, for use in a table
help = (
"See [https://www.ollama.com/library](https://www.ollama.com/library) for a list of models. "
"Pass a model's name; for example, `deepseek-coder-v2`."
)
models = ["*"]
registry = True
fields = [
TextField(key="base_url", label="Base API URL (optional)", format="text"),
]

def __init__(self, *args, **kwargs):
request_schema = kwargs.pop("request_schema")
response_path = kwargs.pop("response_path")
content_handler = JsonContentHandler(
request_schema=request_schema, response_path=response_path
)

super().__init__(*args, **kwargs, content_handler=content_handler)

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)


# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
class BedrockProvider(BaseProvider, Bedrock):
id = "bedrock"
name = "Amazon Bedrock"
models = [
"amazon.titan-text-express-v1",
"amazon.titan-text-lite-v1",
"ai21.j2-ultra-v1",
"ai21.j2-mid-v1",
"cohere.command-light-text-v14",
"cohere.command-text-v14",
"cohere.command-r-v1:0",
"cohere.command-r-plus-v1:0",
"meta.llama2-13b-chat-v1",
"meta.llama2-70b-chat-v1",
"meta.llama3-8b-instruct-v1:0",
"meta.llama3-70b-instruct-v1:0",
"meta.llama3-1-8b-instruct-v1:0",
"meta.llama3-1-70b-instruct-v1:0",
"mistral.mistral-7b-instruct-v0:2",
"mistral.mixtral-8x7b-instruct-v0:1",
"mistral.mistral-large-2402-v1:0",
]
model_id_key = "model_id"
pypi_package_deps = ["boto3"]
auth_strategy = AwsAuthStrategy()
fields = [
TextField(
key="credentials_profile_name",
label="AWS profile (optional)",
format="text",
),
TextField(key="region_name", label="Region name (optional)", format="text"),
]

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)


# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
class BedrockChatProvider(BaseProvider, BedrockChat):
id = "bedrock-chat"
name = "Amazon Bedrock Chat"
models = [
"anthropic.claude-v2",
"anthropic.claude-v2:1",
"anthropic.claude-instant-v1",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-3-haiku-20240307-v1:0",
"anthropic.claude-3-opus-20240229-v1:0",
"anthropic.claude-3-5-sonnet-20240620-v1:0",
]
model_id_key = "model_id"
pypi_package_deps = ["boto3"]
auth_strategy = AwsAuthStrategy()
fields = [
TextField(
key="credentials_profile_name",
label="AWS profile (optional)",
format="text",
),
TextField(key="region_name", label="Region name (optional)", format="text"),
]

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)

async def _agenerate(self, *args, **kwargs) -> Coroutine[Any, Any, LLMResult]:
return await self._generate_in_executor(*args, **kwargs)

@property
def allows_concurrency(self):
return not "anthropic" in self.model_id


class TogetherAIProvider(BaseProvider, Together):
id = "togetherai"
name = "Together AI"
Expand Down

0 comments on commit 40d32b7

Please sign in to comment.