diff --git a/python/.env.example b/python/.env.example index ba598ada40a8..853022f34fc3 100644 --- a/python/.env.example +++ b/python/.env.example @@ -18,3 +18,7 @@ AZCOSMOS_API = "" // should be mongo-vcore for now, as CosmosDB only supports ve AZCOSMOS_CONNSTR = "" AZCOSMOS_DATABASE_NAME = "" AZCOSMOS_CONTAINER_NAME = "" +ASTRADB_APP_TOKEN="" // Starts with AstraCS: +ASTRADB_ID="" +ASTRADB_REGION="" +ASTRADB_KEYSPACE="" \ No newline at end of file diff --git a/python/semantic_kernel/__init__.py b/python/semantic_kernel/__init__.py index 6dc36a0b3bd5..3772a68b7744 100644 --- a/python/semantic_kernel/__init__.py +++ b/python/semantic_kernel/__init__.py @@ -16,6 +16,7 @@ from semantic_kernel.utils.logging import setup_logging from semantic_kernel.utils.null_logger import NullLogger from semantic_kernel.utils.settings import ( + astradb_settings_from_dot_env, azure_aisearch_settings_from_dot_env, azure_aisearch_settings_from_dot_env_as_dict, azure_cosmos_db_settings_from_dot_env, @@ -39,6 +40,7 @@ "azure_aisearch_settings_from_dot_env_as_dict", "postgres_settings_from_dot_env", "pinecone_settings_from_dot_env", + "astradb_settings_from_dot_env", "bing_search_settings_from_dot_env", "mongodb_atlas_settings_from_dot_env", "google_palm_settings_from_dot_env", diff --git a/python/semantic_kernel/connectors/memory/astradb/__init__.py b/python/semantic_kernel/connectors/memory/astradb/__init__.py new file mode 100644 index 000000000000..b8907d83882b --- /dev/null +++ b/python/semantic_kernel/connectors/memory/astradb/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft. All rights reserved. + +from semantic_kernel.connectors.memory.astradb.astradb_memory_store import ( + AstraDBMemoryStore, +) + +__all__ = ["AstraDBMemoryStore"] diff --git a/python/semantic_kernel/connectors/memory/astradb/astra_client.py b/python/semantic_kernel/connectors/memory/astradb/astra_client.py new file mode 100644 index 000000000000..d0aa30870b92 --- /dev/null +++ b/python/semantic_kernel/connectors/memory/astradb/astra_client.py @@ -0,0 +1,157 @@ +import json +from typing import Dict, List, Optional + +import aiohttp + +from semantic_kernel.connectors.memory.astradb.utils import AsyncSession + + +class AstraClient: + def __init__( + self, + astra_id: str, + astra_region: str, + astra_application_token: str, + keyspace_name: str, + embedding_dim: int, + similarity_function: str, + session: Optional[aiohttp.ClientSession] = None, + ): + self.astra_id = astra_id + self.astra_application_token = astra_application_token + self.astra_region = astra_region + self.keyspace_name = keyspace_name + self.embedding_dim = embedding_dim + self.similarity_function = similarity_function + + self.request_base_url = ( + f"https://{self.astra_id}-{self.astra_region}.apps.astra.datastax.com/api/json/v1/{self.keyspace_name}" + ) + self.request_header = { + "x-cassandra-token": self.astra_application_token, + "Content-Type": "application/json", + } + self._session = session + + async def _run_query(self, request_url: str, query: Dict): + async with AsyncSession(self._session) as session: + async with session.post(request_url, data=json.dumps(query), headers=self.request_header) as response: + if response.status == 200: + response_dict = await response.json() + if "errors" in response_dict: + raise Exception(f"Astra DB request error - {response_dict['errors']}") + else: + return response_dict + else: + raise Exception(f"Astra DB not available. Status : {response}") + + async def find_collections(self, include_detail: bool = True): + query = {"findCollections": {"options": {"explain": include_detail}}} + result = await self._run_query(self.request_base_url, query) + return result["status"]["collections"] + + async def find_collection(self, collection_name: str): + collections = await self.find_collections(False) + found = False + for collection in collections: + if collection == collection_name: + found = True + break + return found + + async def create_collection( + self, + collection_name: str, + embedding_dim: Optional[int] = None, + similarity_function: Optional[str] = None, + ): + query = { + "createCollection": { + "name": collection_name, + "options": { + "vector": { + "dimension": embedding_dim if embedding_dim is not None else self.embedding_dim, + "metric": similarity_function if similarity_function is not None else self.similarity_function, + } + }, + } + } + result = await self._run_query(self.request_base_url, query) + return True if result["status"]["ok"] == 1 else False + + async def delete_collection(self, collection_name: str): + query = {"deleteCollection": {"name": collection_name}} + result = await self._run_query(self.request_base_url, query) + return True if result["status"]["ok"] == 1 else False + + def _build_request_collection_url(self, collection_name: str): + return f"{self.request_base_url}/{collection_name}" + + async def find_documents( + self, + collection_name: str, + filter: Optional[Dict] = None, + vector: Optional[List[float]] = None, + limit: Optional[int] = None, + include_vector: Optional[bool] = None, + include_similarity: Optional[bool] = None, + ) -> List[Dict]: + find_query = {} + + if filter is not None: + find_query["filter"] = filter + + if vector is not None: + find_query["sort"] = {"$vector": vector} + + if include_vector is not None and include_vector is False: + find_query["projection"] = {"$vector": 0} + + if limit is not None: + find_query["options"] = {"limit": limit} + + if include_similarity is not None: + if "options" in find_query: + find_query["options"]["includeSimilarity"] = int(include_similarity) + else: + find_query["options"] = {"includeSimilarity": int(include_similarity)} + + query = {"find": find_query} + result = await self._run_query(self._build_request_collection_url(collection_name), query) + return result["data"]["documents"] + + async def insert_document(self, collection_name: str, document: Dict) -> str: + query = {"insertOne": {"document": document}} + result = await self._run_query(self._build_request_collection_url(collection_name), query) + return result["status"]["insertedIds"][0] + + async def insert_documents(self, collection_name: str, documents: List[Dict]) -> List[str]: + query = {"insertMany": {"documents": documents}} + result = await self._run_query(self._build_request_collection_url(collection_name), query) + return result["status"]["insertedIds"] + + async def update_document(self, collection_name: str, filter: Dict, update: Dict, upsert: bool = True) -> Dict: + query = { + "findOneAndUpdate": { + "filter": filter, + "update": update, + "options": {"returnDocument": "after", "upsert": upsert}, + } + } + result = await self._run_query(self._build_request_collection_url(collection_name), query) + return result["status"] + + async def update_documents(self, collection_name: str, filter: Dict, update: Dict): + query = { + "updateMany": { + "filter": filter, + "update": update, + } + } + result = await self._run_query(self._build_request_collection_url(collection_name), query) + return result["status"] + + async def delete_documents(self, collection_name: str, filter: Dict) -> int: + query = {"deleteMany": {"filter": filter}} + result = await self._run_query(self._build_request_collection_url(collection_name), query) + return result["status"]["deletedCount"] diff --git a/python/semantic_kernel/connectors/memory/astradb/astradb_memory_store.py b/python/semantic_kernel/connectors/memory/astradb/astradb_memory_store.py new file mode 100644 index 000000000000..68db6789b2dc --- /dev/null +++ b/python/semantic_kernel/connectors/memory/astradb/astradb_memory_store.py @@ -0,0 +1,303 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import logging +from typing import List, Optional, Tuple + +import aiohttp +from numpy import ndarray + +from semantic_kernel.connectors.memory.astradb.astra_client import AstraClient +from semantic_kernel.connectors.memory.astradb.utils import ( + build_payload, + parse_payload, +) +from semantic_kernel.memory.memory_record import MemoryRecord +from semantic_kernel.memory.memory_store_base import MemoryStoreBase + +MAX_DIMENSIONALITY = 20000 +MAX_UPSERT_BATCH_SIZE = 100 +MAX_QUERY_WITHOUT_METADATA_BATCH_SIZE = 10000 +MAX_QUERY_WITH_METADATA_BATCH_SIZE = 1000 +MAX_FETCH_BATCH_SIZE = 1000 +MAX_DELETE_BATCH_SIZE = 1000 + +logger: logging.Logger = logging.getLogger(__name__) + + +class AstraDBMemoryStore(MemoryStoreBase): + """A memory store that uses Astra database as the backend.""" + + def __init__( + self, + astra_application_token: str, + astra_id: str, + astra_region: str, + keyspace_name: str, + embedding_dim: int, + similarity: str, + session: Optional[aiohttp.ClientSession] = None, + ) -> None: + """Initializes a new instance of the AstraDBMemoryStore class. + + Arguments: + astra_application_token {str} -- The Astra application token. + astra_id {str} -- The Astra id of database. + astra_region {str} -- The Astra region + keyspace_name {str} -- The Astra keyspace + embedding_dim {int} -- The dimensionality to use for new collections. + similarity {str} -- TODO + session -- Optional session parameter + """ + self._embedding_dim = embedding_dim + self._similarity = similarity + self._session = session + + if self._embedding_dim > MAX_DIMENSIONALITY: + raise ValueError( + f"Dimensionality of {self._embedding_dim} exceeds " + + f"the maximum allowed value of {MAX_DIMENSIONALITY}." + ) + + self._client = AstraClient( + astra_id=astra_id, + astra_region=astra_region, + astra_application_token=astra_application_token, + keyspace_name=keyspace_name, + embedding_dim=embedding_dim, + similarity_function=similarity, + session=self._session, + ) + + async def get_collections_async(self) -> List[str]: + """Gets the list of collections. + + Returns: + List[str] -- The list of collections. + """ + return await self._client.find_collections(False) + + async def create_collection_async( + self, + collection_name: str, + dimension_num: Optional[int] = None, + distance_type: Optional[str] = "cosine", + ) -> None: + """Creates a new collection in Astra if it does not exist. + + Arguments: + collection_name {str} -- The name of the collection to create. + dimension_num {int} -- The dimension of the vectors to be stored in this collection. + distance_type {str} -- Specifies the similarity metric to be used when querying or comparing vectors within + this collection. The available options are dot_product, euclidean, and cosine. + Returns: + None + """ + dimension_num = dimension_num if dimension_num is not None else self._embedding_dim + distance_type = distance_type if distance_type is not None else self._similarity + + if dimension_num > MAX_DIMENSIONALITY: + raise ValueError( + f"Dimensionality of {dimension_num} exceeds " + f"the maximum allowed value of {MAX_DIMENSIONALITY}." + ) + + result = await self._client.create_collection(collection_name, dimension_num, distance_type) + if result is True: + logger.info(f"Collection {collection_name} created.") + + async def delete_collection_async(self, collection_name: str) -> None: + """Deletes a collection. + + Arguments: + collection_name {str} -- The name of the collection to delete. + + Returns: + None + """ + result = await self._client.delete_collection(collection_name) + logger.log( + logging.INFO if result is True else logging.WARNING, + f"Collection {collection_name} {'deleted.' if result is True else 'does not exist.'}", + ) + + async def does_collection_exist_async(self, collection_name: str) -> bool: + """Checks if a collection exists. + + Arguments: + collection_name {str} -- The name of the collection to check. + + Returns: + bool -- True if the collection exists; otherwise, False. + """ + return await self._client.find_collection(collection_name) + + async def upsert_async(self, collection_name: str, record: MemoryRecord) -> str: + """Upserts a memory record into the data store. Does not guarantee that the collection exists. + If the record already exists, it will be updated. + If the record does not exist, it will be created. + + Arguments: + collection_name {str} -- The name associated with a collection of embeddings. + record {MemoryRecord} -- The memory record to upsert. + + Returns: + str -- The unique identifier for the memory record. + """ + filter = {"_id": record._id} + update = {"$set": build_payload(record)} + status = await self._client.update_document(collection_name, filter, update, True) + + return status["upsertedId"] if "upsertedId" in status else record._id + + async def upsert_batch_async(self, collection_name: str, records: List[MemoryRecord]) -> List[str]: + """Upserts a batch of memory records into the data store. Does not guarantee that the collection exists. + If the record already exists, it will be updated. + If the record does not exist, it will be created. + + Arguments: + collection_name {str} -- The name associated with a collection of embeddings. + records {List[MemoryRecord]} -- The memory records to upsert. + + Returns: + List[str] -- The unique identifiers for the memory record. + """ + return await asyncio.gather(*[self.upsert_async(collection_name, record) for record in records]) + + async def get_async(self, collection_name: str, key: str, with_embedding: bool = False) -> MemoryRecord: + """Gets a record. Does not guarantee that the collection exists. + + Arguments: + collection_name {str} -- The name of the collection to get the record from. + key {str} -- The unique database key of the record. + with_embedding {bool} -- Whether to include the embedding in the result. (default: {False}) + + Returns: + MemoryRecord -- The record. + """ + filter = {"_id": key} + documents = await self._client.find_documents( + collection_name=collection_name, + filter=filter, + include_vector=with_embedding, + ) + + if len(documents) == 0: + raise KeyError(f"Record with key '{key}' does not exist") + + return parse_payload(documents[0]) + + async def get_batch_async( + self, collection_name: str, keys: List[str], with_embeddings: bool = False + ) -> List[MemoryRecord]: + """Gets a batch of records. Does not guarantee that the collection exists. + + Arguments: + collection_name {str} -- The name of the collection to get the records from. + keys {List[str]} -- The unique database keys of the records. + with_embeddings {bool} -- Whether to include the embeddings in the results. (default: {False}) + + Returns: + List[MemoryRecord] -- The records. + """ + + filter = {"_id": {"$in": keys}} + documents = await self._client.find_documents( + collection_name=collection_name, + filter=filter, + include_vector=with_embeddings, + ) + return [parse_payload(document) for document in documents] + + async def remove_async(self, collection_name: str, key: str) -> None: + """Removes a memory record from the data store. Does not guarantee that the collection exists. + + Arguments: + collection_name {str} -- The name of the collection to remove the record from. + key {str} -- The unique id associated with the memory record to remove. + + Returns: + None + """ + filter = {"_id": key} + await self._client.delete_documents(collection_name, filter) + + async def remove_batch_async(self, collection_name: str, keys: List[str]) -> None: + """Removes a batch of records. Does not guarantee that the collection exists. + + Arguments: + collection_name {str} -- The name of the collection to remove the records from. + keys {List[str]} -- The unique ids associated with the memory records to remove. + + Returns: + None + """ + filter = {"_id": {"$in": keys}} + await self._client.delete_documents(collection_name, filter) + + async def get_nearest_match_async( + self, + collection_name: str, + embedding: ndarray, + min_relevance_score: float = 0.0, + with_embedding: bool = False, + ) -> Tuple[MemoryRecord, float]: + """Gets the nearest match to an embedding using cosine similarity. + Arguments: + collection_name {str} -- The name of the collection to get the nearest matches from. + embedding {ndarray} -- The embedding to find the nearest matches to. + min_relevance_score {float} -- The minimum relevance score of the matches. (default: {0.0}) + with_embeddings {bool} -- Whether to include the embeddings in the results. (default: {False}) + + Returns: + Tuple[MemoryRecord, float] -- The record and the relevance score. + """ + matches = await self.get_nearest_matches_async( + collection_name=collection_name, + embedding=embedding, + limit=1, + min_relevance_score=min_relevance_score, + with_embeddings=with_embedding, + ) + return matches[0] + + async def get_nearest_matches_async( + self, + collection_name: str, + embedding: ndarray, + limit: int, + min_relevance_score: float = 0.0, + with_embeddings: bool = False, + ) -> List[Tuple[MemoryRecord, float]]: + """Gets the nearest matches to an embedding using cosine similarity. + Arguments: + collection_name {str} -- The name of the collection to get the nearest matches from. + embedding {ndarray} -- The embedding to find the nearest matches to. + limit {int} -- The maximum number of matches to return. + min_relevance_score {float} -- The minimum relevance score of the matches. (default: {0.0}) + with_embeddings {bool} -- Whether to include the embeddings in the results. (default: {False}) + + Returns: + List[Tuple[MemoryRecord, float]] -- The records and their relevance scores. + """ + matches = await self._client.find_documents( + collection_name=collection_name, + vector=embedding.tolist(), + limit=limit, + include_similarity=True, + include_vector=with_embeddings, + ) + + if min_relevance_score: + matches = [match for match in matches if match["$similarity"] >= min_relevance_score] + + return ( + [ + ( + parse_payload(match), + match["$similarity"], + ) + for match in matches + ] + if len(matches) > 0 + else [] + ) diff --git a/python/semantic_kernel/connectors/memory/astradb/utils.py b/python/semantic_kernel/connectors/memory/astradb/utils.py new file mode 100644 index 000000000000..a5a69a0595b4 --- /dev/null +++ b/python/semantic_kernel/connectors/memory/astradb/utils.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft. All rights reserved. +from typing import Any, Dict + +import aiohttp +import numpy + +from semantic_kernel.memory.memory_record import MemoryRecord + + +class AsyncSession: + def __init__(self, session: aiohttp.ClientSession = None): + self._session = session if session else aiohttp.ClientSession() + + async def __aenter__(self): + return await self._session.__aenter__() + + async def __aexit__(self, *args, **kwargs): + await self._session.close() + + +def build_payload(record: MemoryRecord) -> Dict[str, Any]: + """ + Builds a metadata payload to be sent to AstraDb from a MemoryRecord. + """ + payload: Dict[str, Any] = {} + payload["$vector"] = record.embedding.tolist() + if record._text: + payload["text"] = record._text + if record._description: + payload["description"] = record._description + if record._additional_metadata: + payload["additional_metadata"] = record._additional_metadata + return payload + + +def parse_payload(document: Dict[str, Any]) -> MemoryRecord: + """ + Parses a record from AstraDb into a MemoryRecord. + """ + text = document.get("text", None) + description = document["description"] if "description" in document else None + additional_metadata = document["additional_metadata"] if "additional_metadata" in document else None + + return MemoryRecord.local_record( + id=document["_id"], + description=description, + text=text, + additional_metadata=additional_metadata, + embedding=document["$vector"] if "$vector" in document else numpy.array([]), + ) diff --git a/python/semantic_kernel/utils/settings.py b/python/semantic_kernel/utils/settings.py index cab5c96d5ab7..3bba2168bc81 100644 --- a/python/semantic_kernel/utils/settings.py +++ b/python/semantic_kernel/utils/settings.py @@ -130,6 +130,46 @@ def pinecone_settings_from_dot_env() -> Tuple[str, Optional[str]]: return api_key, environment +def astradb_settings_from_dot_env() -> Tuple[str, Optional[str]]: + """ + Reads the Astradb API key and Environment from the .env file. + Returns: + Tuple[str, str]: The Astradb API key, the Astradb Environment + """ + + app_token, db_id, region, keyspace = None, None, None, None + with open(".env", "r") as f: + lines = f.readlines() + + for line in lines: + if line.startswith("ASTRADB_APP_TOKEN"): + parts = line.split("=")[1:] + app_token = "=".join(parts).strip().strip('"') + continue + + if line.startswith("ASTRADB_ID"): + parts = line.split("=")[1:] + db_id = "=".join(parts).strip().strip('"') + continue + + if line.startswith("ASTRADB_REGION"): + parts = line.split("=")[1:] + region = "=".join(parts).strip().strip('"') + continue + + if line.startswith("ASTRADB_KEYSPACE"): + parts = line.split("=")[1:] + keyspace = "=".join(parts).strip().strip('"') + continue + + assert app_token, "Astradb Application token not found in .env file" + assert db_id, "Astradb ID not found in .env file" + assert region, "Astradb Region not found in .env file" + assert keyspace, "Astradb Keyspace name not found in .env file" + + return app_token, db_id, region, keyspace + + def weaviate_settings_from_dot_env() -> Tuple[Optional[str], str]: """ Reads the Weaviate API key and URL from the .env file. diff --git a/python/tests/integration/connectors/memory/test_astradb.py b/python/tests/integration/connectors/memory/test_astradb.py new file mode 100644 index 000000000000..8816a0301c09 --- /dev/null +++ b/python/tests/integration/connectors/memory/test_astradb.py @@ -0,0 +1,266 @@ +# Copyright (c) Microsoft. All rights reserved. + +import os +import time + +import numpy as np +import pytest + +import semantic_kernel as sk +from semantic_kernel.connectors.memory.astradb import AstraDBMemoryStore +from semantic_kernel.memory.memory_record import MemoryRecord + +astradb_installed: bool +try: + if os.environ["ASTRADB_INTEGRATION_TEST"]: + astradb_installed = True +except KeyError: + astradb_installed = False + + +pytestmark = pytest.mark.skipif(not astradb_installed, reason="astradb is not installed") + + +async def retry(func, retries=1): + for i in range(retries): + try: + return await func() + except Exception as e: + print(e) + time.sleep(i * 2) + + +@pytest.fixture(autouse=True, scope="module") +def slow_down_tests(): + yield + time.sleep(3) + + +@pytest.fixture(scope="session") +def get_astradb_config(): + if "Python_Integration_Tests" in os.environ: + app_token = os.environ["ASTRADB_APP_TOKEN"] + db_id = os.environ["ASTRADB_ID"] + region = os.environ["ASTRADB_REGION"] + keyspace = os.environ["ASTRADB_KEYSPACE"] + else: + # Load credentials from .env file + app_token, db_id, region, keyspace = sk.astradb_settings_from_dot_env() + + return app_token, db_id, region, keyspace + + +@pytest.fixture +def memory_record1(): + return MemoryRecord( + id="test_id1", + text="sample text1", + is_reference=False, + embedding=np.array([0.5, 0.5]), + description="description", + additional_metadata="additional metadata", + external_source_name="external source", + timestamp="timestamp", + ) + + +@pytest.fixture +def memory_record2(): + return MemoryRecord( + id="test_id2", + text="sample text2", + is_reference=False, + embedding=np.array([0.25, 0.75]), + description="description", + additional_metadata="additional metadata", + external_source_name="external source", + timestamp="timestamp", + ) + + +@pytest.fixture +def memory_record3(): + return MemoryRecord( + id="test_id3", + text="sample text3", + is_reference=False, + embedding=np.array([0.25, 0.80]), + description="description", + additional_metadata="additional metadata", + external_source_name="external source", + timestamp="timestamp", + ) + + +@pytest.mark.asyncio +async def test_constructor(get_astradb_config): + app_token, db_id, region, keyspace = get_astradb_config + memory = AstraDBMemoryStore(app_token, db_id, region, keyspace, 2, "cosine") + result = await retry(lambda: memory.get_collections_async()) + + assert result is not None + + +@pytest.mark.asyncio +async def test_create_and_get_collection_async(get_astradb_config): + app_token, db_id, region, keyspace = get_astradb_config + memory = AstraDBMemoryStore(app_token, db_id, region, keyspace, 2, "cosine") + + await retry(lambda: memory.create_collection_async("test_collection")) + result = await retry(lambda: memory.does_collection_exist_async("test_collection")) + assert result is not None + assert result is True + + +@pytest.mark.asyncio +async def test_get_collections_async(get_astradb_config): + app_token, db_id, region, keyspace = get_astradb_config + memory = AstraDBMemoryStore(app_token, db_id, region, keyspace, 2, "cosine") + + await retry(lambda: memory.create_collection_async("test_collection")) + result = await retry(lambda: memory.get_collections_async()) + assert "test_collection" in result + + +@pytest.mark.asyncio +async def test_delete_collection_async(get_astradb_config): + app_token, db_id, region, keyspace = get_astradb_config + memory = AstraDBMemoryStore(app_token, db_id, region, keyspace, 2, "cosine") + + await retry(lambda: memory.create_collection_async("test_collection")) + await retry(lambda: memory.delete_collection_async("test_collection")) + result = await retry(lambda: memory.get_collections_async()) + assert "test_collection" not in result + + +@pytest.mark.asyncio +async def test_does_collection_exist_async(get_astradb_config): + app_token, db_id, region, keyspace = get_astradb_config + memory = AstraDBMemoryStore(app_token, db_id, region, keyspace, 2, "cosine") + + await retry(lambda: memory.create_collection_async("test_collection")) + result = await retry(lambda: memory.does_collection_exist_async("test_collection")) + assert result is True + + +@pytest.mark.asyncio +async def test_upsert_async_and_get_async(get_astradb_config, memory_record1): + app_token, db_id, region, keyspace = get_astradb_config + memory = AstraDBMemoryStore(app_token, db_id, region, keyspace, 2, "cosine") + + await retry(lambda: memory.create_collection_async("test_collection")) + await retry(lambda: memory.upsert_async("test_collection", memory_record1)) + + result = await retry( + lambda: memory.get_async( + "test_collection", + memory_record1._id, + with_embedding=True, + ) + ) + + assert result is not None + assert result._id == memory_record1._id + assert result._description == memory_record1._description + assert result._text == memory_record1._text + assert result.embedding is not None + + +@pytest.mark.asyncio +async def test_upsert_batch_async_and_get_batch_async(get_astradb_config, memory_record1, memory_record2): + app_token, db_id, region, keyspace = get_astradb_config + memory = AstraDBMemoryStore(app_token, db_id, region, keyspace, 2, "cosine") + + await retry(lambda: memory.create_collection_async("test_collection")) + await retry(lambda: memory.upsert_batch_async("test_collection", [memory_record1, memory_record2])) + + results = await retry( + lambda: memory.get_batch_async( + "test_collection", + [memory_record1._id, memory_record2._id], + with_embeddings=True, + ) + ) + + assert len(results) >= 2 + assert results[0]._id in [memory_record1._id, memory_record2._id] + assert results[1]._id in [memory_record1._id, memory_record2._id] + + +@pytest.mark.asyncio +async def test_remove_async(get_astradb_config, memory_record1): + app_token, db_id, region, keyspace = get_astradb_config + memory = AstraDBMemoryStore(app_token, db_id, region, keyspace, 2, "cosine") + + await retry(lambda: memory.create_collection_async("test_collection")) + await retry(lambda: memory.upsert_async("test_collection", memory_record1)) + await retry(lambda: memory.remove_async("test_collection", memory_record1._id)) + + with pytest.raises(KeyError): + _ = await memory.get_async("test_collection", memory_record1._id, with_embedding=True) + + +@pytest.mark.asyncio +async def test_remove_batch_async(get_astradb_config, memory_record1, memory_record2): + app_token, db_id, region, keyspace = get_astradb_config + memory = AstraDBMemoryStore(app_token, db_id, region, keyspace, 2, "cosine") + + await retry(lambda: memory.create_collection_async("test_collection")) + await retry(lambda: memory.upsert_batch_async("test_collection", [memory_record1, memory_record2])) + await retry(lambda: memory.remove_batch_async("test_collection", [memory_record1._id, memory_record2._id])) + + with pytest.raises(KeyError): + _ = await memory.get_async("test_collection", memory_record1._id, with_embedding=True) + + with pytest.raises(KeyError): + _ = await memory.get_async("test_collection", memory_record2._id, with_embedding=True) + + +@pytest.mark.asyncio +async def test_get_nearest_match_async(get_astradb_config, memory_record1, memory_record2): + app_token, db_id, region, keyspace = get_astradb_config + memory = AstraDBMemoryStore(app_token, db_id, region, keyspace, 2, "cosine") + + await retry(lambda: memory.create_collection_async("test_collection")) + await retry(lambda: memory.upsert_batch_async("test_collection", [memory_record1, memory_record2])) + + test_embedding = memory_record1.embedding + test_embedding[0] = test_embedding[0] + 0.01 + + result = await retry( + lambda: memory.get_nearest_match_async( + "test_collection", + test_embedding, + min_relevance_score=0.0, + with_embedding=True, + ) + ) + + assert result is not None + assert result[0]._id == memory_record1._id + + +@pytest.mark.asyncio +async def test_get_nearest_matches_async(get_astradb_config, memory_record1, memory_record2, memory_record3): + app_token, db_id, region, keyspace = get_astradb_config + memory = AstraDBMemoryStore(app_token, db_id, region, keyspace, 2, "cosine") + + await retry(lambda: memory.create_collection_async("test_collection")) + await retry(lambda: memory.upsert_batch_async("test_collection", [memory_record1, memory_record2, memory_record3])) + + test_embedding = memory_record2.embedding + test_embedding[0] = test_embedding[0] + 0.025 + + result = await retry( + lambda: memory.get_nearest_matches_async( + "test_collection", + test_embedding, + limit=2, + min_relevance_score=0.0, + with_embeddings=True, + ) + ) + + assert len(result) == 2 + assert result[0][0]._id in [memory_record3._id, memory_record2._id] + assert result[1][0]._id in [memory_record3._id, memory_record2._id]