@@ -704,6 +704,143 @@ class OllamaProvider(BaseProvider, Ollama):
704704 TextField (key = "base_url" , label = "Base API URL (optional)" , format = "text" ),
705705 ]
706706
707+
708+ class JsonContentHandler (LLMContentHandler ):
709+ content_type = "application/json"
710+ accepts = "application/json"
711+
712+ def __init__ (self , request_schema , response_path ):
713+ self .request_schema = json .loads (request_schema )
714+ self .response_path = response_path
715+ self .response_parser = parse (response_path )
716+
717+ def replace_values (self , old_val , new_val , d : Dict [str , Any ]):
718+ """Replaces values of a dictionary recursively."""
719+ for key , val in d .items ():
720+ if val == old_val :
721+ d [key ] = new_val
722+ if isinstance (val , dict ):
723+ self .replace_values (old_val , new_val , val )
724+
725+ return d
726+
727+ def transform_input (self , prompt : str , model_kwargs : Dict ) -> bytes :
728+ request_obj = copy .deepcopy (self .request_schema )
729+ self .replace_values ("<prompt>" , prompt , request_obj )
730+ request = json .dumps (request_obj ).encode ("utf-8" )
731+ return request
732+
733+ def transform_output (self , output : bytes ) -> str :
734+ response_json = json .loads (output .read ().decode ("utf-8" ))
735+ matches = self .response_parser .find (response_json )
736+ return matches [0 ].value
737+
738+
739+ class SmEndpointProvider (BaseProvider , SagemakerEndpoint ):
740+ id = "sagemaker-endpoint"
741+ name = "SageMaker endpoint"
742+ models = ["*" ]
743+ model_id_key = "endpoint_name"
744+ model_id_label = "Endpoint name"
745+ # This all needs to be on one line of markdown, for use in a table
746+ help = (
747+ "See [https://www.ollama.com/library](https://www.ollama.com/library) for a list of models. "
748+ "Pass a model's name; for example, `deepseek-coder-v2`."
749+ )
750+ models = ["*" ]
751+ registry = True
752+ fields = [
753+ TextField (key = "base_url" , label = "Base API URL (optional)" , format = "text" ),
754+ ]
755+
756+ def __init__ (self , * args , ** kwargs ):
757+ request_schema = kwargs .pop ("request_schema" )
758+ response_path = kwargs .pop ("response_path" )
759+ content_handler = JsonContentHandler (
760+ request_schema = request_schema , response_path = response_path
761+ )
762+
763+ super ().__init__ (* args , ** kwargs , content_handler = content_handler )
764+
765+ async def _acall (self , * args , ** kwargs ) -> Coroutine [Any , Any , str ]:
766+ return await self ._call_in_executor (* args , ** kwargs )
767+
768+
769+ # See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
770+ class BedrockProvider (BaseProvider , Bedrock ):
771+ id = "bedrock"
772+ name = "Amazon Bedrock"
773+ models = [
774+ "amazon.titan-text-express-v1" ,
775+ "amazon.titan-text-lite-v1" ,
776+ "ai21.j2-ultra-v1" ,
777+ "ai21.j2-mid-v1" ,
778+ "cohere.command-light-text-v14" ,
779+ "cohere.command-text-v14" ,
780+ "cohere.command-r-v1:0" ,
781+ "cohere.command-r-plus-v1:0" ,
782+ "meta.llama2-13b-chat-v1" ,
783+ "meta.llama2-70b-chat-v1" ,
784+ "meta.llama3-8b-instruct-v1:0" ,
785+ "meta.llama3-70b-instruct-v1:0" ,
786+ "meta.llama3-1-8b-instruct-v1:0" ,
787+ "meta.llama3-1-70b-instruct-v1:0" ,
788+ "mistral.mistral-7b-instruct-v0:2" ,
789+ "mistral.mixtral-8x7b-instruct-v0:1" ,
790+ "mistral.mistral-large-2402-v1:0" ,
791+ ]
792+ model_id_key = "model_id"
793+ pypi_package_deps = ["boto3" ]
794+ auth_strategy = AwsAuthStrategy ()
795+ fields = [
796+ TextField (
797+ key = "credentials_profile_name" ,
798+ label = "AWS profile (optional)" ,
799+ format = "text" ,
800+ ),
801+ TextField (key = "region_name" , label = "Region name (optional)" , format = "text" ),
802+ ]
803+
804+ async def _acall (self , * args , ** kwargs ) -> Coroutine [Any , Any , str ]:
805+ return await self ._call_in_executor (* args , ** kwargs )
806+
807+
808+ # See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
809+ class BedrockChatProvider (BaseProvider , BedrockChat ):
810+ id = "bedrock-chat"
811+ name = "Amazon Bedrock Chat"
812+ models = [
813+ "anthropic.claude-v2" ,
814+ "anthropic.claude-v2:1" ,
815+ "anthropic.claude-instant-v1" ,
816+ "anthropic.claude-3-sonnet-20240229-v1:0" ,
817+ "anthropic.claude-3-haiku-20240307-v1:0" ,
818+ "anthropic.claude-3-opus-20240229-v1:0" ,
819+ "anthropic.claude-3-5-sonnet-20240620-v1:0" ,
820+ ]
821+ model_id_key = "model_id"
822+ pypi_package_deps = ["boto3" ]
823+ auth_strategy = AwsAuthStrategy ()
824+ fields = [
825+ TextField (
826+ key = "credentials_profile_name" ,
827+ label = "AWS profile (optional)" ,
828+ format = "text" ,
829+ ),
830+ TextField (key = "region_name" , label = "Region name (optional)" , format = "text" ),
831+ ]
832+
833+ async def _acall (self , * args , ** kwargs ) -> Coroutine [Any , Any , str ]:
834+ return await self ._call_in_executor (* args , ** kwargs )
835+
836+ async def _agenerate (self , * args , ** kwargs ) -> Coroutine [Any , Any , LLMResult ]:
837+ return await self ._generate_in_executor (* args , ** kwargs )
838+
839+ @property
840+ def allows_concurrency (self ):
841+ return not "anthropic" in self .model_id
842+
843+
707844class TogetherAIProvider (BaseProvider , Together ):
708845 id = "togetherai"
709846 name = "Together AI"
0 commit comments