@@ -704,6 +704,143 @@ class OllamaProvider(BaseProvider, Ollama):
704
704
TextField (key = "base_url" , label = "Base API URL (optional)" , format = "text" ),
705
705
]
706
706
707
+
708
+ class JsonContentHandler (LLMContentHandler ):
709
+ content_type = "application/json"
710
+ accepts = "application/json"
711
+
712
+ def __init__ (self , request_schema , response_path ):
713
+ self .request_schema = json .loads (request_schema )
714
+ self .response_path = response_path
715
+ self .response_parser = parse (response_path )
716
+
717
+ def replace_values (self , old_val , new_val , d : Dict [str , Any ]):
718
+ """Replaces values of a dictionary recursively."""
719
+ for key , val in d .items ():
720
+ if val == old_val :
721
+ d [key ] = new_val
722
+ if isinstance (val , dict ):
723
+ self .replace_values (old_val , new_val , val )
724
+
725
+ return d
726
+
727
+ def transform_input (self , prompt : str , model_kwargs : Dict ) -> bytes :
728
+ request_obj = copy .deepcopy (self .request_schema )
729
+ self .replace_values ("<prompt>" , prompt , request_obj )
730
+ request = json .dumps (request_obj ).encode ("utf-8" )
731
+ return request
732
+
733
+ def transform_output (self , output : bytes ) -> str :
734
+ response_json = json .loads (output .read ().decode ("utf-8" ))
735
+ matches = self .response_parser .find (response_json )
736
+ return matches [0 ].value
737
+
738
+
739
+ class SmEndpointProvider (BaseProvider , SagemakerEndpoint ):
740
+ id = "sagemaker-endpoint"
741
+ name = "SageMaker endpoint"
742
+ models = ["*" ]
743
+ model_id_key = "endpoint_name"
744
+ model_id_label = "Endpoint name"
745
+ # This all needs to be on one line of markdown, for use in a table
746
+ help = (
747
+ "See [https://www.ollama.com/library](https://www.ollama.com/library) for a list of models. "
748
+ "Pass a model's name; for example, `deepseek-coder-v2`."
749
+ )
750
+ models = ["*" ]
751
+ registry = True
752
+ fields = [
753
+ TextField (key = "base_url" , label = "Base API URL (optional)" , format = "text" ),
754
+ ]
755
+
756
+ def __init__ (self , * args , ** kwargs ):
757
+ request_schema = kwargs .pop ("request_schema" )
758
+ response_path = kwargs .pop ("response_path" )
759
+ content_handler = JsonContentHandler (
760
+ request_schema = request_schema , response_path = response_path
761
+ )
762
+
763
+ super ().__init__ (* args , ** kwargs , content_handler = content_handler )
764
+
765
+ async def _acall (self , * args , ** kwargs ) -> Coroutine [Any , Any , str ]:
766
+ return await self ._call_in_executor (* args , ** kwargs )
767
+
768
+
769
+ # See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
770
+ class BedrockProvider (BaseProvider , Bedrock ):
771
+ id = "bedrock"
772
+ name = "Amazon Bedrock"
773
+ models = [
774
+ "amazon.titan-text-express-v1" ,
775
+ "amazon.titan-text-lite-v1" ,
776
+ "ai21.j2-ultra-v1" ,
777
+ "ai21.j2-mid-v1" ,
778
+ "cohere.command-light-text-v14" ,
779
+ "cohere.command-text-v14" ,
780
+ "cohere.command-r-v1:0" ,
781
+ "cohere.command-r-plus-v1:0" ,
782
+ "meta.llama2-13b-chat-v1" ,
783
+ "meta.llama2-70b-chat-v1" ,
784
+ "meta.llama3-8b-instruct-v1:0" ,
785
+ "meta.llama3-70b-instruct-v1:0" ,
786
+ "meta.llama3-1-8b-instruct-v1:0" ,
787
+ "meta.llama3-1-70b-instruct-v1:0" ,
788
+ "mistral.mistral-7b-instruct-v0:2" ,
789
+ "mistral.mixtral-8x7b-instruct-v0:1" ,
790
+ "mistral.mistral-large-2402-v1:0" ,
791
+ ]
792
+ model_id_key = "model_id"
793
+ pypi_package_deps = ["boto3" ]
794
+ auth_strategy = AwsAuthStrategy ()
795
+ fields = [
796
+ TextField (
797
+ key = "credentials_profile_name" ,
798
+ label = "AWS profile (optional)" ,
799
+ format = "text" ,
800
+ ),
801
+ TextField (key = "region_name" , label = "Region name (optional)" , format = "text" ),
802
+ ]
803
+
804
+ async def _acall (self , * args , ** kwargs ) -> Coroutine [Any , Any , str ]:
805
+ return await self ._call_in_executor (* args , ** kwargs )
806
+
807
+
808
+ # See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
809
+ class BedrockChatProvider (BaseProvider , BedrockChat ):
810
+ id = "bedrock-chat"
811
+ name = "Amazon Bedrock Chat"
812
+ models = [
813
+ "anthropic.claude-v2" ,
814
+ "anthropic.claude-v2:1" ,
815
+ "anthropic.claude-instant-v1" ,
816
+ "anthropic.claude-3-sonnet-20240229-v1:0" ,
817
+ "anthropic.claude-3-haiku-20240307-v1:0" ,
818
+ "anthropic.claude-3-opus-20240229-v1:0" ,
819
+ "anthropic.claude-3-5-sonnet-20240620-v1:0" ,
820
+ ]
821
+ model_id_key = "model_id"
822
+ pypi_package_deps = ["boto3" ]
823
+ auth_strategy = AwsAuthStrategy ()
824
+ fields = [
825
+ TextField (
826
+ key = "credentials_profile_name" ,
827
+ label = "AWS profile (optional)" ,
828
+ format = "text" ,
829
+ ),
830
+ TextField (key = "region_name" , label = "Region name (optional)" , format = "text" ),
831
+ ]
832
+
833
+ async def _acall (self , * args , ** kwargs ) -> Coroutine [Any , Any , str ]:
834
+ return await self ._call_in_executor (* args , ** kwargs )
835
+
836
+ async def _agenerate (self , * args , ** kwargs ) -> Coroutine [Any , Any , LLMResult ]:
837
+ return await self ._generate_in_executor (* args , ** kwargs )
838
+
839
+ @property
840
+ def allows_concurrency (self ):
841
+ return not "anthropic" in self .model_id
842
+
843
+
707
844
class TogetherAIProvider (BaseProvider , Together ):
708
845
id = "togetherai"
709
846
name = "Together AI"
0 commit comments