Skip to content

Commit 77906cb

Browse files
committed
Add logic to read from COPILOT_EMBEDDING_CONFIG_DIR, if it exists.
1 parent bc7fffc commit 77906cb

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

packages/cloudera-ai-inference-package/cloudera_ai_inference_package/cloudera_ai_embedding_provider.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@ class ClouderaAIInferenceEmbeddingModelProvider(BaseEmbeddingsProvider, Embeddin
2222
os.getenv("COPILOT_CONFIG_DIR", ""), model_type="embedding"
2323
)
2424

25+
# Read from both config files, as embedding models could still be in the old config file for an older CML version.
26+
embedding_ai_inference_models, embedding_models = getCopilotModels(
27+
os.getenv("COPILOT_EMBEDDING_CONFIG_DIR", ""), model_type="embedding"
28+
)
29+
30+
# Merge lists, removing duplicates.
31+
ai_inference_models = ai_inference_models + embedding_ai_inference_models
32+
models = models + embedding_models
33+
2534
def __init__(self, **kwargs):
2635
super().__init__(**kwargs)
2736
self.model_endpoint = self._get_inference_endpoint()

packages/jupyter-ai/jupyter_ai/handlers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,13 @@ def getConfiguredThirdPartyModels(self):
455455
if copilot_config and "thirdPartyModels" in copilot_config and copilot_config["thirdPartyModels"]:
456456
third_party_models = copilot_config["thirdPartyModels"]
457457

458+
copilot_embedding_config_dir = os.getenv("COPILOT_EMBEDDING_CONFIG_DIR")
459+
if copilot_embedding_config_dir and os.path.exists(copilot_embedding_config_dir):
460+
f = open(copilot_embedding_config_dir)
461+
copilot_embedding_config = json.load(f)
462+
if copilot_embedding_config and "thirdPartyModels" in copilot_embedding_config and copilot_embedding_config["thirdPartyModels"]:
463+
third_party_models += copilot_embedding_config["thirdPartyModels"]
464+
458465
# Fill in provider_id if it is missing.
459466
for third_party_model in third_party_models:
460467
if len(third_party_model["provider_id"]) == 0:

0 commit comments

Comments
 (0)