Skip to content

Add Support for Custom Embeddings #829

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ Every array will produce the combinations of flat configurations when the method

## Description of embedding models config

`embedding_model` is an array containing the configuration for the embedding models to use. Embedding model `type` must be `azure` for Azure OpenAI models and `sentence-transformer` for HuggingFace sentence transformer models.
`embedding_model` is an array containing the configuration for the embedding models to use. Embedding model `type` must be `azure` for Azure OpenAI models, `sentence-transformer` for HuggingFace sentence transformer models and `custom-embedding` for custom embeddings deployed as Azure Online Endpoints.

### Azure OpenAI embedding model config

Expand Down Expand Up @@ -408,6 +408,18 @@ When using the [newer embeddings models (v3)](https://openai.com/blog/new-embedd
}
```

### Custom embedding model

```json
{
"type": "custom-embedding",
"model_name": "the name of the Azure deployment of the custom embedding model",
"dimension": "the dimension of the custom embedding model. This field is not required"
}
```

The variables `azure_model_api_key` and `azure_model_api_endpoint` should also be set in the environment variables (.env file).

## Query Expansion

Giving an example of an hypothetical answer for the question in query, an hypothetical passage which holds an answer to the query, or generate few alternative related question might improve retrieval and thus get more accurate chunks of docs to pass into LLM context.
Expand Down
2 changes: 2 additions & 0 deletions rag_experiment_accelerator/config/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class Environment:
azure_document_intelligence_endpoint: Optional[str]
azure_document_intelligence_admin_key: Optional[str]
azure_key_vault_endpoint: Optional[str]
azure_model_api_key: Optional[str]
azure_model_api_endpoint: Optional[str]

@classmethod
def _field_names(cls) -> list[str]:
Expand Down
4 changes: 3 additions & 1 deletion rag_experiment_accelerator/config/tests/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def test_to_keyvault(mock_init_keyvault):
azure_language_service_key=None,
azure_key_vault_endpoint="test_endpoint",
azure_search_use_semantic_search="True",
azure_model_api_key="mock_key",
azure_model_api_endpoint="mock_endpoint",
)
environment.to_keyvault()

assert mock_keyvault.set_secret.call_count == 17
assert mock_keyvault.set_secret.call_count == 19
124 changes: 124 additions & 0 deletions rag_experiment_accelerator/embedding/custom_embedding_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import urllib.request
import json
import os
import ssl
from typing import Union

from rag_experiment_accelerator.config.environment import Environment
from rag_experiment_accelerator.embedding.embedding_model import EmbeddingModel
from rag_experiment_accelerator.utils.logging import get_logger

logger = get_logger(__name__)


class CustomEmbeddingModel(EmbeddingModel):
"""
A class representing a Custom Embedding Model deployed as an AzureML online endpoint.

Args:
model_name (str): The name of the deployment.
environment (Environment): The initialized environment.
dimension (int, optional): The dimension of the embedding. Defaults to 1536.
**kwargs: Additional keyword arguments.
"""

def __init__(
self, model_name: str, environment: Environment, dimension: int = 1536, **kwargs
):
super().__init__(name=model_name, dimension=dimension, **kwargs)
self.environment = environment
pass

def prepare_request(self, body: Union[dict, list]) -> Union[dict, bytes]:
"""
Prepares the request to be sent to the AzureML online endpoint.

Args:
body (Union[dict, list]): The input data.

Returns:
Union[dict, bytes]: The prepared request body.

"""
# replace the format based the model input
data_format = {
"input": body,
}

body = str.encode(json.dumps(data_format))

headers = {
"Content-Type": "application/json",
"Authorization": ("Bearer " + self.environment.azure_model_api_key),
"azureml-model-deployment": self.name,
}

return headers, body

def make_request(self, body: bytes, headers: dict) -> list[float]:
"""
Makes a request to the AzureML online endpoint.

Args:
body (bytes): The request body.
headers (dict): The request headers.

Returns:
list[float]: The response from the AzureML online endpoint.

"""
try:
logger.info("Calling Custom Embedding Model API")
req = urllib.request.Request(
self.environment.azure_model_api_endpoint, body, headers
)
response = urllib.request.urlopen(req)
logger.info("Custom Embedding Model response received")
data = json.loads(response.read())
logger.info("Custom Embedding Model response parsed")

return data

except urllib.error.HTTPError as error:
logger.exception("The request failed with status code: " + str(error.code))
raise

def allowSelfSignedHttps(self, allowed: bool) -> None:
"""
Allows self-signed HTTPS requests.

Args:
allowed (bool): Whether to allow self-signed HTTPS requests.

"""

# bypass the server certificate verification on client side
if (
allowed
and not os.environ.get("PYTHONHTTPSVERIFY", "")
and getattr(ssl, "_create_unverified_context", None)
):
ssl._create_default_https_context = ssl._create_unverified_context
else:
ssl._create_default_https_context = ssl.create_default_context

def generate_embedding(self, chunk: str) -> list[float]:
"""
Generates the embedding for a given chunk of text.

Args:
chunk (str): The input text.

Returns:
list[float]: The generated embedding.

"""
self.allowSelfSignedHttps(
True
) # this line is needed if you use self-signed certificate in your scoring service.

headers, body = self.prepare_request(chunk)

result = self.make_request(body, headers)

return result
5 changes: 5 additions & 0 deletions rag_experiment_accelerator/embedding/factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from rag_experiment_accelerator.embedding.aoai_embedding_model import AOAIEmbeddingModel
from rag_experiment_accelerator.embedding.st_embedding_model import STEmbeddingModel
from rag_experiment_accelerator.embedding.custom_embedding_model import (
CustomEmbeddingModel,
)


def create_embedding_model(model_type: str, **kwargs):
Expand All @@ -8,6 +11,8 @@ def create_embedding_model(model_type: str, **kwargs):
return AOAIEmbeddingModel(**kwargs)
case "sentence-transformer":
return STEmbeddingModel(**kwargs)
case "custom_embedding":
return CustomEmbeddingModel(**kwargs)
case _:
raise ValueError(
f"Invalid embedding type: {model_type}. Must be one of ['azure', 'sentence-transformer']"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from unittest.mock import patch, MagicMock
import json
import urllib
from rag_experiment_accelerator.embedding.custom_embedding_model import (
CustomEmbeddingModel,
)
import ssl


def test_can_set_embedding_dimension():
environment = MagicMock()
model = CustomEmbeddingModel("custom-embedding-model", environment, 123)
assert model.dimension == 123


def test_prepare_request_success():
environment = MagicMock()
environment.azure_model_api_key = "api_key"
model = CustomEmbeddingModel("custom-embedding-deployment", environment)

body = {"text": "Hello world"}
headers, prepared_body = model.prepare_request(body)

expected_headers = {
"Content-Type": "application/json",
"Authorization": "Bearer api_key",
"azureml-model-deployment": "custom-embedding-deployment",
}
expected_body = str.encode(json.dumps({"input": body}))

assert headers == expected_headers
assert prepared_body == expected_body


@patch("urllib.request.urlopen")
def test_make_request_success(mock_urlopen):
environment = MagicMock()
environment.azure_model_api_endpoint = "http://fake-endpoint"
model = CustomEmbeddingModel("custom-embedding-model", environment)

mock_response = MagicMock()
mock_response.read.return_value = json.dumps([0.1, 0.2, 0.3]).encode("utf-8")
mock_urlopen.return_value = mock_response

headers = {"Content-Type": "application/json"}
body = b'{"input": {"text": "Hello world"}}'

result = model.make_request(body, headers)
assert result == [0.1, 0.2, 0.3]


@patch("urllib.request.urlopen")
def test_make_request_http_error(mock_urlopen):
environment = MagicMock()
environment.azure_model_api_endpoint = "http://fake-endpoint"
model = CustomEmbeddingModel("custom-embedding-model", environment)

mock_urlopen.side_effect = urllib.error.HTTPError(
url=None, code=500, msg="Internal Server Error", hdrs=None, fp=None
)

headers = {"Content-Type": "application/json"}
body = b'{"input": {"text": "Hello world"}}'

try:
model.make_request(body, headers)
except urllib.error.HTTPError as e:
assert e.code == 500


@patch(
"rag_experiment_accelerator.embedding.custom_embedding_model.CustomEmbeddingModel.make_request"
)
def test_generate_embedding_success(mock_make_request):
environment = MagicMock()
model = CustomEmbeddingModel("custom-embedding-model", environment)

mock_make_request.return_value = [0.1, 0.2, 0.3]

result = model.generate_embedding("Hello world")
assert result == [0.1, 0.2, 0.3]


def test_allow_self_signed_http_true():
environment = MagicMock()
model = CustomEmbeddingModel("custom-embedding-model", environment)

model.allowSelfSignedHttps(True)
assert ssl._create_default_https_context == ssl._create_unverified_context


def test_allow_self_signed_http_false():
environment = MagicMock()
model = CustomEmbeddingModel("custom-embedding-model", environment)

model.allowSelfSignedHttps(False)
assert ssl._create_default_https_context == ssl.create_default_context
Loading