diff --git a/pyatlan/cache/abstract_asset_cache.py b/pyatlan/cache/abstract_asset_cache.py new file mode 100644 index 00000000..021d5b76 --- /dev/null +++ b/pyatlan/cache/abstract_asset_cache.py @@ -0,0 +1,176 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2024 Atlan Pte. Ltd. +from __future__ import annotations + +import threading +from abc import ABC, abstractmethod +from typing import Any + +from pyatlan.errors import ErrorCode +from pyatlan.model.assets import Asset +from pyatlan.model.enums import AtlanConnectorType + + +class AbstractAssetCache(ABC): + """ + Base class for reusable components that are common + to all caches, where a cache is populated entry-by-entry. + """ + + def __init__(self, client): + self.client = client + self.lock = threading.Lock() + self.name_to_guid = dict() + self.guid_to_asset = dict() + self.qualified_name_to_guid = dict() + + @classmethod + @abstractmethod + def get_cache(cls): + """Abstract method to retreive cache.""" + + @abstractmethod + def lookup_by_guid(self, guid: str): + """Abstract method to lookup asset by guid.""" + + @abstractmethod + def lookup_by_qualified_name(self, qualified_name: str): + """Abstract method to lookup asset by qualified name.""" + + @abstractmethod + def lookup_by_name(self, name: Any): + """Abstract method to lookup asset by name.""" + + @abstractmethod + def get_name(self, asset: Asset): + """Abstract method to get name from asset.""" + + def is_guid_known(self, guid: str) -> bool: + """ + Checks whether the provided Atlan-internal UUID is known. + NOTE: will not refresh the cache itself to determine this. + + :param guid: Atlan-internal UUID of the object + :returns: `True` if the object is known, `False` otherwise + """ + return guid in self.guid_to_asset + + def is_qualified_name_known(self, qualified_name: str): + """ + Checks whether the provided Atlan-internal ID string is known. + NOTE: will not refresh the cache itself to determine this. + + :param qualified_name: Atlan-internal ID string of the object + :returns: `True` if the object is known, `False` otherwise + """ + return qualified_name in self.qualified_name_to_guid + + def is_name_known(self, name: str): + """ + Checks whether the provided Atlan-internal ID string is known. + NOTE: will not refresh the cache itself to determine this. + + :param name: human-constructable name of the object + :returns: `True` if the object is known, `False` otherwise + """ + return name in self.name_to_guid + + def cache(self, asset: Asset): + """ + Add an entry to the cache. + + :param asset: to be cached + """ + name = asset and self.get_name(asset) + if not all([name, asset.guid, asset.qualified_name]): + return + self.name_to_guid[name] = asset.guid + self.guid_to_asset[asset.guid] = asset + self.qualified_name_to_guid[asset.qualified_name] = asset.guid + + def _get_by_guid(self, guid: str, allow_refresh: bool = True): + """ + Retrieve an asset from the cache by its UUID. + If the asset is not found, it will be looked up and added to the cache. + + :param guid: UUID of the asset in Atlan + :returns: the asset (if found) + :raises AtlanError: on any API communication problem if the cache needs to be refreshed + :raises NotFoundError: if the asset cannot be found (does not exist) in Atlan + :raises InvalidRequestError: if no UUID was provided for the asset to retrieve + """ + if not guid: + raise ErrorCode.MISSING_ID.exception_with_parameters() + asset = self.guid_to_asset.get(guid) + if not asset and allow_refresh: + self.lookup_by_guid(guid) + asset = self.guid_to_asset.get(guid) + if not asset: + raise ErrorCode.ASSET_NOT_FOUND_BY_GUID.exception_with_parameters(guid) + return asset + + def _get_by_qualified_name(self, qualified_name: str, allow_refresh: bool = True): + """ + Retrieve an asset from the cache by its unique Atlan-internal name. + + :param qualified_name: unique Atlan-internal name of the asset + :param allow_refresh: whether to allow a refresh of the cache (`True`) or not (`False`) + :returns: the asset (if found) + :raises AtlanError: on any API communication problem if the cache needs to be refreshed + :raises NotFoundError: if the object cannot be found (does not exist) in Atlan + :raises InvalidRequestError: if no qualified_name was provided for the object to retrieve + """ + if not qualified_name: + raise ErrorCode.MISSING_ID.exception_with_parameters() + guid = self.qualified_name_to_guid.get(qualified_name) + if not guid and allow_refresh: + self.lookup_by_qualified_name(qualified_name) + guid = self.qualified_name_to_guid.get(qualified_name) + if not guid: + raise ErrorCode.ASSET_NOT_FOUND_BY_QN.exception_with_parameters( + qualified_name, + AtlanConnectorType._get_connector_type_from_qualified_name( + qualified_name + ), + ) + return self._get_by_guid(guid=guid, allow_refresh=False) + + def _get_by_name(self, name: AbstractAssetName, allow_refresh: bool = True): + """ + Retrieve an asset from the cache by its uniquely identifiable name. + + :param name: uniquely identifiable name of the asset in Atlan + :param allow_refresh: whether to allow a refresh of the cache (`True`) or not (`False`) + :returns: the asset (if found) + :raises AtlanError: on any API communication problem if the cache needs to be refreshed + :raises NotFoundError: if the object cannot be found (does not exist) in Atlan + :raises InvalidRequestError: if no name was provided for the object to retrieve + """ + if not isinstance(name, AbstractAssetName): + raise ErrorCode.MISSING_NAME.exception_with_parameters() + guid = self.name_to_guid.get(str(name)) + if not guid and allow_refresh: + self.lookup_by_name(name) + guid = self.name_to_guid.get(str(name)) + if not guid: + raise ErrorCode.ASSET_NOT_FOUND_BY_NAME.exception_with_parameters( + name._TYPE_NAME, name + ) + return self._get_by_guid(guid=guid, allow_refresh=False) + + +class AbstractAssetName(ABC): + """ + Base class for reusable components common to all asset names + used by the cache's find methods, such as AssetCache.get_by_name(). + """ + + _TYPE_NAME = str() + + @abstractmethod + def __init__(self): + pass + + @abstractmethod + def __str__(self): + pass diff --git a/pyatlan/cache/connection_cache.py b/pyatlan/cache/connection_cache.py new file mode 100644 index 00000000..4784fa2e --- /dev/null +++ b/pyatlan/cache/connection_cache.py @@ -0,0 +1,204 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2024 Atlan Pte. Ltd. +from __future__ import annotations + +import logging +import threading +from typing import Dict, Optional, Union + +from pyatlan.cache.abstract_asset_cache import AbstractAssetCache, AbstractAssetName +from pyatlan.client.atlan import AtlanClient +from pyatlan.model.assets import Asset, Connection +from pyatlan.model.enums import AtlanConnectorType +from pyatlan.model.fluent_search import FluentSearch +from pyatlan.model.search import Term + +LOGGER = logging.getLogger(__name__) + +lock = threading.Lock() + + +class ConnectionCache(AbstractAssetCache): + """ + Lazily-loaded cache for translating between + a connection's simplified name its details. + + - guid = UUID of the connection + for eg: 9c677e77-e01d-40e0-85b7-8ba4cd7d0ea9 + - qualified_name = Atlan-internal name of the connection (with epoch) + for eg: default/snowflake/1234567890 + - name = simple name of the form {{connectorType}}/{{connectorName}}, + for eg: snowflake/development + """ + + _SEARCH_FIELDS = [ + Connection.NAME, + Connection.STATUS, + Connection.CONNECTOR_NAME, + ] + SEARCH_ATTRIBUTES = [field.atlan_field_name for field in _SEARCH_FIELDS] + caches: Dict[int, ConnectionCache] = dict() + + def __init__(self, client: AtlanClient): + super().__init__(client) + + @classmethod + def get_cache(cls) -> ConnectionCache: + from pyatlan.client.atlan import AtlanClient + + with lock: + default_client = AtlanClient.get_default_client() + cache_key = default_client.cache_key + if cache_key not in cls.caches: + cls.caches[cache_key] = ConnectionCache(client=default_client) + return cls.caches[cache_key] + + @classmethod + def get_by_guid(cls, guid: str, allow_refresh: bool = True) -> Connection: + """ + Retrieve a connection from the cache by its UUID. + If the asset is not found, it will be looked up and added to the cache. + + :param guid: UUID of the connection in Atlan + for eg: 9c677e77-e01d-40e0-85b7-8ba4cd7d0ea9 + :returns: connection (if found) + :raises AtlanError: on any API communication problem if the cache needs to be refreshed + :raises NotFoundError: if the connection cannot be found (does not exist) in Atlan + :raises InvalidRequestError: if no UUID was provided for the connection to retrieve + """ + return cls.get_cache()._get_by_guid(guid=guid, allow_refresh=allow_refresh) + + @classmethod + def get_by_qualified_name( + cls, qualified_name: str, allow_refresh: bool = True + ) -> Connection: + """ + Retrieve a connection from the cache by its unique Atlan-internal name. + + :param qualified_name: unique Atlan-internal name of the connection + for eg: default/snowflake/1234567890 + :param allow_refresh: whether to allow a refresh of the cache (`True`) or not (`False`) + :param qualified_name: unique Atlan-internal name of the connection + :returns: connection (if found) + :raises AtlanError: on any API communication problem if the cache needs to be refreshed + :raises NotFoundError: if the connection cannot be found (does not exist) in Atlan + :raises InvalidRequestError: if no qualified_name was provided for the connection to retrieve + """ + return cls.get_cache()._get_by_qualified_name( + qualified_name=qualified_name, allow_refresh=allow_refresh + ) + + @classmethod + def get_by_name( + cls, name: ConnectionName, allow_refresh: bool = True + ) -> Connection: + """ + Retrieve an connection from the cache by its uniquely identifiable name. + + :param name: uniquely identifiable name of the connection in Atlan + In the form of {{connectorType}}/{{connectorName}} + for eg: snowflake/development + :param allow_refresh: whether to allow a refresh of the cache (`True`) or not (`False`) + :returns: connection (if found) + :raises AtlanError: on any API communication problem if the cache needs to be refreshed + :raises NotFoundError: if the connection cannot be found (does not exist) in Atlan + :raises InvalidRequestError: if no name was provided for the connection to retrieve + """ + return cls.get_cache()._get_by_name(name=name, allow_refresh=allow_refresh) + + def lookup_by_guid(self, guid: str) -> None: + if not guid: + return + with self.lock: + response = ( + FluentSearch(_includes_on_results=self.SEARCH_ATTRIBUTES) + .where(Term.with_state("ACTIVE")) + .where(Term.with_super_type_names("Asset")) + .where(Connection.GUID.eq(guid)) + .execute(self.client) + ) + candidate = (response.current_page() and response.current_page()[0]) or None + if candidate and isinstance(candidate, Connection): + self.cache(candidate) + + def lookup_by_qualified_name(self, connection_qn: str) -> None: + if not connection_qn: + return + with self.lock: + response = ( + FluentSearch(_includes_on_results=self.SEARCH_ATTRIBUTES) + .where(Term.with_state("ACTIVE")) + .where(Term.with_super_type_names("Asset")) + .where(Connection.QUALIFIED_NAME.eq(connection_qn)) + .execute(self.client) + ) + candidate = (response.current_page() and response.current_page()[0]) or None + if candidate and isinstance(candidate, Connection): + self.cache(candidate) + + def lookup_by_name(self, name: ConnectionName) -> None: + if not isinstance(name, ConnectionName): + return + results = self.client.asset.find_connections_by_name( + name=name.name, + connector_type=name.type, + attributes=self.SEARCH_ATTRIBUTES, + ) + if not results: + return + if len(results) > 1: + LOGGER.warning( + ( + "Found multiple connections of the same type " + "with the same name, caching only the first: %s" + ), + name, + ) + self.cache(results[0]) + + def get_name(self, asset: Asset): + if not isinstance(asset, Connection): + return + return str(ConnectionName(asset)) + + +class ConnectionName(AbstractAssetName): + """ + Unique identity for a connection, + in the form: {{type}}/{{name}} + + - For eg: snowflake/development + """ + + _TYPE_NAME = "Connection" + + def __init__( + self, + connection: Union[ + str, + Optional[Connection], + ] = None, + ): + self.name = None + self.type = None + + if isinstance(connection, Connection): + self.name = connection.name + self.type = connection.connector_name + + elif isinstance(connection, str): + tokens = connection.split("/") + if len(tokens) > 1: + self.type = AtlanConnectorType(tokens[0]) # type: ignore[call-arg] + self.name = connection[len(tokens[0]) + 1 :] # noqa + + def __hash__(self): + return hash((self.name, self.type)) + + def __str__(self): + return f"{self.type}/{self.name}" + + def __eq__(self, other): + if isinstance(other, ConnectionName): + return self.name == other.name and self.type == other.type + return False diff --git a/pyatlan/cache/source_tag_cache.py b/pyatlan/cache/source_tag_cache.py new file mode 100644 index 00000000..98d58a43 --- /dev/null +++ b/pyatlan/cache/source_tag_cache.py @@ -0,0 +1,211 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2024 Atlan Pte. Ltd. +from __future__ import annotations + +import logging +import threading +from typing import Dict, Union + +from pyatlan.cache.abstract_asset_cache import AbstractAssetCache, AbstractAssetName +from pyatlan.cache.connection_cache import ConnectionCache, ConnectionName +from pyatlan.client.atlan import AtlanClient +from pyatlan.errors import AtlanError +from pyatlan.model.assets import Asset, Tag +from pyatlan.model.fluent_search import FluentSearch +from pyatlan.model.search import Term + +LOGGER = logging.getLogger(__name__) + +lock = threading.Lock() + + +class SourceTagCache(AbstractAssetCache): + """ + Lazily-loaded cache for translating between + source-synced tags and the qualifiedName of such. + + - guid = UUID of the source tag + for eg: 9c677e77-e01d-40e0-85b7-8ba4cd7d0ea9 + - qualified_name = of the source tag (with epoch) + for eg: default/snowflake/1234567890/DB/SCHEMA/TAG_NAME + - name = simple name of the form {{connectorType}}/{{connectorName}}@@DB/SCHEMA/TAG_NAME + for eg: snowflake/development@@DB/SCHEMA/TAG_NAME + """ + + _SEARCH_FIELDS = [Asset.NAME] + SEARCH_ATTRIBUTES = [field.atlan_field_name for field in _SEARCH_FIELDS] + caches: Dict[int, SourceTagCache] = dict() + + def __init__(self, client: AtlanClient): + super().__init__(client) + + @classmethod + def get_cache(cls) -> SourceTagCache: + from pyatlan.client.atlan import AtlanClient + + with lock: + default_client = AtlanClient.get_default_client() + cache_key = default_client.cache_key + if cache_key not in cls.caches: + cls.caches[cache_key] = SourceTagCache(client=default_client) + return cls.caches[cache_key] + + @classmethod + def get_by_guid(cls, guid: str, allow_refresh: bool = True) -> Tag: + """ + Retrieve a source tag from the cache by its UUID. + If the asset is not found, it will be looked up and added to the cache. + + :param guid: UUID of the source tag in Atlan + for eg: 9c677e77-e01d-40e0-85b7-8ba4cd7d0ea9 + :returns: source tag (if found) + :raises AtlanError: on any API communication problem if the cache needs to be refreshed + :raises NotFoundError: if the source tag cannot be found (does not exist) in Atlan + :raises InvalidRequestError: if no UUID was provided for the source tag to retrieve + """ + return cls.get_cache()._get_by_guid(guid=guid, allow_refresh=allow_refresh) + + @classmethod + def get_by_qualified_name( + cls, qualified_name: str, allow_refresh: bool = True + ) -> Tag: + """ + Retrieve a source tag from the cache by its unique Atlan-internal name. + + :param qualified_name: unique Atlan-internal name of the source tag + for eg: default/snowflake/1234567890/DB/SCHEMA/TAG_NAME + :param allow_refresh: whether to allow a refresh of the cache (`True`) or not (`False`) + :param qualified_name: unique Atlan-internal name of the source tag + :returns: source tag (if found) + :raises AtlanError: on any API communication problem if the cache needs to be refreshed + :raises NotFoundError: if the source tag cannot be found (does not exist) in Atlan + :raises InvalidRequestError: if no qualified_name was provided for the source tag to retrieve + """ + return cls.get_cache()._get_by_qualified_name( + qualified_name=qualified_name, allow_refresh=allow_refresh + ) + + @classmethod + def get_by_name(cls, name: SourceTagName, allow_refresh: bool = True) -> Tag: + """ + Retrieve an connection from the cache by its uniquely identifiable name. + + :param name: uniquely identifiable name of the connection in Atlan. + In the form of {{connectorType}}/{{connectorName}}@@DB/SCHEMA/TAG_NAME + for eg: snowflake/development@@DB/SCHEMA/TAG_NAME + :param allow_refresh: whether to allow a refresh of the cache (`True`) or not (`False`) + :returns: the connection (if found) + :raises AtlanError: on any API communication problem if the cache needs to be refreshed + :raises NotFoundError: if the object cannot be found (does not exist) in Atlan + :raises InvalidRequestError: if no name was provided for the object to retrieve + """ + return cls.get_cache()._get_by_name(name=name, allow_refresh=allow_refresh) + + def lookup_by_guid(self, guid: str) -> None: + if not guid: + return + with self.lock: + response = ( + FluentSearch(_includes_on_results=self.SEARCH_ATTRIBUTES) + .where(Term.with_state("ACTIVE")) + .where(Asset.SUPER_TYPE_NAMES.eq(Tag.__name__)) + .where(Asset.GUID.eq(guid)) + .execute(self.client) + ) + candidate = (response.current_page() and response.current_page()[0]) or None + # NOTE: Checking if the first result is an "Asset" since in pyatlan, + # "DbtTag" extends "Dbt" (unlike other tags like "SnowflakeTag" that extend the "Tag" model), + # preventing Dbt tags from being excluded from caching: + if candidate and isinstance(candidate, Asset): + self.cache(candidate) + + def lookup_by_qualified_name(self, source_tag_qn: str) -> None: + if not source_tag_qn: + return + with self.lock: + response = ( + FluentSearch(_includes_on_results=self.SEARCH_ATTRIBUTES) + .where(Term.with_state("ACTIVE")) + .where(Asset.SUPER_TYPE_NAMES.eq(Tag.__name__)) + .where(Asset.QUALIFIED_NAME.eq(source_tag_qn)) + .execute(self.client) + ) + candidate = (response.current_page() and response.current_page()[0]) or None + # NOTE: Checking if the first result is an "Asset" since in pyatlan, + # "DbtTag" extends "Dbt" (unlike other tags like "SnowflakeTag" that extend the "Tag" model), + # preventing Dbt tags from being excluded from caching: + if candidate and isinstance(candidate, Asset): + self.cache(candidate) + + def lookup_by_name(self, stn: SourceTagName) -> None: + if not isinstance(stn, SourceTagName): + return + connection_name = stn.connection + connection_qn = ConnectionCache.get_by_name(connection_name).qualified_name # type: ignore[arg-type] + source_tag_qn = f"{connection_qn}/{stn.partial_tag_name}" + + with self.lock: + response = ( + FluentSearch(_includes_on_results=self.SEARCH_ATTRIBUTES) + .where(Term.with_state("ACTIVE")) + .where(Asset.SUPER_TYPE_NAMES.eq(Tag.__name__)) + .where(Asset.QUALIFIED_NAME.eq(source_tag_qn)) + .execute(self.client) + ) + candidate = (response.current_page() and response.current_page()[0]) or None + # NOTE: Checking if the first result is an "Asset" since in pyatlan, + # "DbtTag" extends "Dbt" (unlike other tags like "SnowflakeTag" that extend the "Tag" model), + # preventing Dbt tags from being excluded from caching: + if candidate and isinstance(candidate, Asset): + self.cache(candidate) + + def get_name(self, asset: Asset): + # NOTE: Checking if the first result is an "Asset" since in pyatlan, + # "DbtTag" extends "Dbt" (unlike other tags like "SnowflakeTag" that extend the "Tag" model), + # preventing Dbt tags from being excluded from caching: + if not isinstance(asset, Asset): + return + try: + source_tag_name = str(SourceTagName(asset)) + except AtlanError as e: + LOGGER.error( + "Unable to construct a source tag name for: %s", asset.qualified_name + ) + LOGGER.debug("Details: %s", e) + return source_tag_name + + +class SourceTagName(AbstractAssetName): + """ + Unique identity for a source tag, + in the form: {{connectorType}}/{{connectorName}}@@DB/SCHEMA/TAG_NAME + + - For eg: snowflake/development + """ + + _TYPE_NAME = "SourceTagAttachment" + _CONNECTION_DELIMITER = "@@" + + def __init__(self, tag: Union[str, Asset]): + self.connection = None + self.partial_tag_name = None + + # NOTE: Checking if the first result is an "Asset" since in pyatlan, + # "DbtTag" extends "Dbt" (unlike other tags like "SnowflakeTag" that extend the "Tag" model), + # preventing Dbt tags from being excluded from caching: + if isinstance(tag, Asset): + source_tag_qn = tag.qualified_name or "" + tokens = source_tag_qn.split("/") + connection_qn = "/".join(tokens[:3]) if len(tokens) >= 3 else "" + conn = ConnectionCache.get_by_qualified_name(connection_qn) + self.connection = ConnectionName(conn) + self.partial_tag_name = source_tag_qn[len(connection_qn) + 1 :] # noqa + + elif isinstance(tag, str): + tokens = tag.split(self._CONNECTION_DELIMITER) + if len(tokens) == 2: + self.connection = ConnectionName(tokens[0]) + self.partial_tag_name = tokens[1] + + def __str__(self): + return f"{self.connection}{self._CONNECTION_DELIMITER}{self.partial_tag_name}" diff --git a/pyatlan/errors.py b/pyatlan/errors.py index 093435a1..1ff799b5 100644 --- a/pyatlan/errors.py +++ b/pyatlan/errors.py @@ -555,6 +555,20 @@ class ErrorCode(Enum): + "configure OpenLineage for this connector before you can send events for it.", InvalidRequestError, ) + MISSING_ID = ( + 400, + "ATLAN-PYTHON-400-065", + "No ID was provided when attempting to retrieve or update the object.", + "You must provide an ID when attempting to retrieve or update an object.", + InvalidRequestError, + ) + MISSING_NAME = ( + 400, + "ATLAN-PYTHON-400-065", + "No name instance was provided when attempting to retrieve an object.", + "You must provide the name of the object when attempting to retrieve one.", + InvalidRequestError, + ) AUTHENTICATION_PASSTHROUGH = ( 401, "ATLAN-PYTHON-401-000", diff --git a/pyatlan/model/enums.py b/pyatlan/model/enums.py index f3ef5c05..6c047f8f 100644 --- a/pyatlan/model/enums.py +++ b/pyatlan/model/enums.py @@ -146,11 +146,18 @@ def _get_connector_type_from_qualified_name( cls, qualified_name: str ) -> "AtlanConnectorType": tokens = qualified_name.split("/") - if len(tokens) > 1: - return AtlanConnectorType[tokens[1].upper()] - raise ValueError( - f"Could not determine AtlanConnectorType from {qualified_name}" - ) + if len(tokens) < 2: + raise ValueError( + f"Qualified name '{qualified_name}' does not contain enough segments." + ) + connector_type_key = tokens[1].upper() + # Check if the connector_type_key exists in AtlanConnectorType + if connector_type_key not in AtlanConnectorType.__members__: + raise ValueError( + f"Could not determine AtlanConnectorType from '{qualified_name}'; " + f"'{connector_type_key}' is not a valid connector type." + ) + return AtlanConnectorType[connector_type_key] def __new__( cls, value: str, category: AtlanConnectionCategory diff --git a/tests/unit/test_connection_cache.py b/tests/unit/test_connection_cache.py new file mode 100644 index 00000000..8c51f4b9 --- /dev/null +++ b/tests/unit/test_connection_cache.py @@ -0,0 +1,259 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2024 Atlan Pte. Ltd. +from unittest.mock import Mock, patch + +import pytest + +from pyatlan.cache.connection_cache import ConnectionCache, ConnectionName +from pyatlan.client.atlan import AtlanClient +from pyatlan.errors import ErrorCode, InvalidRequestError, NotFoundError +from pyatlan.model.assets import Connection + + +@pytest.fixture(autouse=True) +def set_env(monkeypatch): + monkeypatch.setenv("ATLAN_BASE_URL", "https://test.atlan.com") + monkeypatch.setenv("ATLAN_API_KEY", "test-api-key") + + +def test_get_by_guid_with_not_found_error(monkeypatch): + with pytest.raises(InvalidRequestError, match=ErrorCode.MISSING_ID.error_message): + ConnectionCache.get_by_guid("") + + +@patch.object(ConnectionCache, "lookup_by_guid") +@patch.object( + ConnectionCache, "get_cache", return_value=ConnectionCache(client=AtlanClient()) +) +def test_get_by_guid_with_no_invalid_request_error(mock_get_cache, mock_lookup_by_guid): + test_guid = "test-guid-123" + with pytest.raises( + NotFoundError, + match=ErrorCode.ASSET_NOT_FOUND_BY_GUID.error_message.format(test_guid), + ): + ConnectionCache.get_by_guid(test_guid) + mock_get_cache.assert_called_once() + + +def test_get_by_qualified_name_with_not_found_error(monkeypatch): + with pytest.raises(InvalidRequestError, match=ErrorCode.MISSING_ID.error_message): + ConnectionCache.get_by_qualified_name("") + + +@patch.object(ConnectionCache, "lookup_by_qualified_name") +@patch.object( + ConnectionCache, "get_cache", return_value=ConnectionCache(client=AtlanClient()) +) +def test_get_by_qualified_name_with_no_invalid_request_error( + mock_get_cache, mock_lookup_by_qualified_name +): + test_qn = "default/snowflake/123456789" + test_connector = "snowflake" + with pytest.raises( + NotFoundError, + match=ErrorCode.ASSET_NOT_FOUND_BY_QN.error_message.format( + test_qn, test_connector + ), + ): + ConnectionCache.get_by_qualified_name(test_qn) + mock_get_cache.assert_called_once() + + +def test_get_by_name_with_not_found_error(monkeypatch): + with pytest.raises(InvalidRequestError, match=ErrorCode.MISSING_NAME.error_message): + ConnectionCache.get_by_name("") + + +@patch.object(ConnectionCache, "lookup_by_name") +@patch.object( + ConnectionCache, "get_cache", return_value=ConnectionCache(client=AtlanClient()) +) +def test_get_by_name_with_no_invalid_request_error(mock_get_cache, mock_lookup_by_name): + test_name = ConnectionName("snowflake/test") + with pytest.raises( + NotFoundError, + match=ErrorCode.ASSET_NOT_FOUND_BY_NAME.error_message.format( + ConnectionName._TYPE_NAME, + test_name, + ), + ): + ConnectionCache.get_by_name(test_name) + mock_get_cache.assert_called_once() + + +@patch.object(ConnectionCache, "lookup_by_guid") +@patch.object( + ConnectionCache, "get_cache", return_value=ConnectionCache(client=AtlanClient()) +) +def test_get_by_guid(mock_get_cache, mock_lookup_by_guid): + test_guid = "test-guid-123" + test_qn = "test-qualified-name" + conn = Connection() + conn.guid = test_guid + conn.qualified_name = test_qn + test_asset = conn + + mock_guid_to_asset = Mock() + mock_name_to_guid = Mock() + mock_qualified_name_to_guid = Mock() + + # 1 - Not found in the cache, triggers a lookup call + # 2, 3, 4 - Uses the cached entry from the map + mock_guid_to_asset.get.side_effect = [ + None, + test_asset, + test_asset, + test_asset, + ] + mock_name_to_guid.get.side_effect = [test_guid, test_guid, test_guid, test_guid] + mock_qualified_name_to_guid.get.side_effect = [ + test_guid, + test_guid, + test_guid, + test_guid, + ] + + # Assign mock caches to the return value of get_cache + mock_get_cache.return_value.guid_to_asset = mock_guid_to_asset + mock_get_cache.return_value.name_to_guid = mock_name_to_guid + mock_get_cache.return_value.qualified_name_to_guid = mock_qualified_name_to_guid + + connection = ConnectionCache.get_by_guid(test_guid) + + # Multiple calls with the same GUID result in no additional API lookups + # as the object is already cached + connection = ConnectionCache.get_by_guid(test_guid) + connection = ConnectionCache.get_by_guid(test_guid) + + assert test_guid == connection.guid + assert test_qn == connection.qualified_name + + # The method is called three times, but the lookup is triggered only once + assert mock_get_cache.call_count == 3 + mock_lookup_by_guid.assert_called_once() + + +@patch.object(ConnectionCache, "lookup_by_guid") +@patch.object(ConnectionCache, "lookup_by_qualified_name") +@patch.object( + ConnectionCache, "get_cache", return_value=ConnectionCache(client=AtlanClient()) +) +def test_get_by_qualified_name(mock_get_cache, mock_lookup_by_qn, mock_lookup_by_guid): + test_guid = "test-guid-123" + test_qn = "test-qualified-name" + conn = Connection() + conn.guid = test_guid + conn.qualified_name = test_qn + test_asset = conn + + mock_guid_to_asset = Mock() + mock_name_to_guid = Mock() + mock_qualified_name_to_guid = Mock() + + # 1 - Not found in the cache, triggers a lookup call + # 2, 3, 4 - Uses the cached entry from the map + mock_qualified_name_to_guid.get.side_effect = [ + None, + test_guid, + test_guid, + test_guid, + ] + + # Other caches will be populated once + # the lookup call for get_by_qualified_name is made + mock_guid_to_asset.get.side_effect = [ + test_asset, + test_asset, + test_asset, + test_asset, + ] + mock_name_to_guid.get.side_effect = [test_guid, test_guid, test_guid, test_guid] + + mock_get_cache.return_value.guid_to_asset = mock_guid_to_asset + mock_get_cache.return_value.name_to_guid = mock_name_to_guid + mock_get_cache.return_value.qualified_name_to_guid = mock_qualified_name_to_guid + + connection = ConnectionCache.get_by_qualified_name(test_qn) + + # Multiple calls with the same + # qualified name result in no additional API lookups + # as the object is already cached + connection = ConnectionCache.get_by_qualified_name(test_qn) + connection = ConnectionCache.get_by_qualified_name(test_qn) + + assert test_guid == connection.guid + assert test_qn == connection.qualified_name + + # The method is called three times + # but the lookup is triggered only once + assert mock_get_cache.call_count == 3 + mock_lookup_by_qn.assert_called_once() + + # No call to guid lookup since the object is already in the cache + assert mock_lookup_by_guid.call_count == 0 + + +@patch.object(ConnectionCache, "lookup_by_guid") +@patch.object(ConnectionCache, "lookup_by_name") +@patch.object( + ConnectionCache, "get_cache", return_value=ConnectionCache(client=AtlanClient()) +) +def test_get_by_name(mock_get_cache, mock_lookup_by_name, mock_lookup_by_guid): + test_name = ConnectionName("snowflake/test") + test_guid = "test-guid-123" + test_qn = "test-qualified-name" + conn = Connection() + conn.guid = test_guid + conn.qualified_name = test_qn + test_asset = conn + + mock_guid_to_asset = Mock() + mock_name_to_guid = Mock() + mock_qualified_name_to_guid = Mock() + + # 1 - Not found in the cache, triggers a lookup call + # 2, 3, 4 - Uses the cached entry from the map + mock_name_to_guid.get.side_effect = [ + None, + test_guid, + test_guid, + test_guid, + ] + + # Other caches will be populated once + # the lookup call for get_by_qualified_name is made + mock_guid_to_asset.get.side_effect = [ + test_asset, + test_asset, + test_asset, + test_asset, + ] + mock_qualified_name_to_guid.get.side_effect = [ + test_guid, + test_guid, + test_guid, + test_guid, + ] + + mock_get_cache.return_value.guid_to_asset = mock_guid_to_asset + mock_get_cache.return_value.name_to_guid = mock_name_to_guid + mock_get_cache.return_value.qualified_name_to_guid = mock_qualified_name_to_guid + + connection = ConnectionCache.get_by_name(test_name) + + # Multiple calls with the same + # qualified name result in no additional API lookups + # as the object is already cached + connection = ConnectionCache.get_by_name(test_name) + connection = ConnectionCache.get_by_name(test_name) + + assert test_guid == connection.guid + assert test_qn == connection.qualified_name + + # The method is called three times + # but the lookup is triggered only once + assert mock_get_cache.call_count == 3 + mock_lookup_by_name.assert_called_once() + + # No call to guid lookup since the object is already in the cache + assert mock_lookup_by_guid.call_count == 0 diff --git a/tests/unit/test_source_cache.py b/tests/unit/test_source_cache.py new file mode 100644 index 00000000..c46b8958 --- /dev/null +++ b/tests/unit/test_source_cache.py @@ -0,0 +1,259 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2024 Atlan Pte. Ltd. +from unittest.mock import Mock, patch + +import pytest + +from pyatlan.cache.source_tag_cache import SourceTagCache, SourceTagName +from pyatlan.client.atlan import AtlanClient +from pyatlan.errors import ErrorCode, InvalidRequestError, NotFoundError +from pyatlan.model.assets import Connection + + +@pytest.fixture(autouse=True) +def set_env(monkeypatch): + monkeypatch.setenv("ATLAN_BASE_URL", "https://test.atlan.com") + monkeypatch.setenv("ATLAN_API_KEY", "test-api-key") + + +def test_get_by_guid_with_not_found_error(monkeypatch): + with pytest.raises(InvalidRequestError, match=ErrorCode.MISSING_ID.error_message): + SourceTagCache.get_by_guid("") + + +@patch.object(SourceTagCache, "lookup_by_guid") +@patch.object( + SourceTagCache, "get_cache", return_value=SourceTagCache(client=AtlanClient()) +) +def test_get_by_guid_with_no_invalid_request_error(mock_get_cache, mock_lookup_by_guid): + test_guid = "test-guid-123" + with pytest.raises( + NotFoundError, + match=ErrorCode.ASSET_NOT_FOUND_BY_GUID.error_message.format(test_guid), + ): + SourceTagCache.get_by_guid(test_guid) + mock_get_cache.assert_called_once() + + +def test_get_by_qualified_name_with_not_found_error(monkeypatch): + with pytest.raises(InvalidRequestError, match=ErrorCode.MISSING_ID.error_message): + SourceTagCache.get_by_qualified_name("") + + +@patch.object(SourceTagCache, "lookup_by_qualified_name") +@patch.object( + SourceTagCache, "get_cache", return_value=SourceTagCache(client=AtlanClient()) +) +def test_get_by_qualified_name_with_no_invalid_request_error( + mock_get_cache, mock_lookup_by_qualified_name +): + test_qn = "default/snowflake/123456789" + test_connector = "snowflake" + with pytest.raises( + NotFoundError, + match=ErrorCode.ASSET_NOT_FOUND_BY_QN.error_message.format( + test_qn, test_connector + ), + ): + SourceTagCache.get_by_qualified_name(test_qn) + mock_get_cache.assert_called_once() + + +def test_get_by_name_with_not_found_error(monkeypatch): + with pytest.raises(InvalidRequestError, match=ErrorCode.MISSING_NAME.error_message): + SourceTagCache.get_by_name("") + + +@patch.object(SourceTagCache, "lookup_by_name") +@patch.object( + SourceTagCache, "get_cache", return_value=SourceTagCache(client=AtlanClient()) +) +def test_get_by_name_with_no_invalid_request_error(mock_get_cache, mock_lookup_by_name): + test_name = SourceTagName("snowflake/test@@DB/SCHEMA/TEST_TAG") + with pytest.raises( + NotFoundError, + match=ErrorCode.ASSET_NOT_FOUND_BY_NAME.error_message.format( + SourceTagName._TYPE_NAME, + test_name, + ), + ): + SourceTagCache.get_by_name(test_name) + mock_get_cache.assert_called_once() + + +@patch.object(SourceTagCache, "lookup_by_guid") +@patch.object( + SourceTagCache, "get_cache", return_value=SourceTagCache(client=AtlanClient()) +) +def test_get_by_guid(mock_get_cache, mock_lookup_by_guid): + test_guid = "test-guid-123" + test_qn = "test-qualified-name" + conn = Connection() + conn.guid = test_guid + conn.qualified_name = test_qn + test_asset = conn + + mock_guid_to_asset = Mock() + mock_name_to_guid = Mock() + mock_qualified_name_to_guid = Mock() + + # 1 - Not found in the cache, triggers a lookup call + # 2, 3, 4 - Uses the cached entry from the map + mock_guid_to_asset.get.side_effect = [ + None, + test_asset, + test_asset, + test_asset, + ] + mock_name_to_guid.get.side_effect = [test_guid, test_guid, test_guid, test_guid] + mock_qualified_name_to_guid.get.side_effect = [ + test_guid, + test_guid, + test_guid, + test_guid, + ] + + # Assign mock caches to the return value of get_cache + mock_get_cache.return_value.guid_to_asset = mock_guid_to_asset + mock_get_cache.return_value.name_to_guid = mock_name_to_guid + mock_get_cache.return_value.qualified_name_to_guid = mock_qualified_name_to_guid + + connection = SourceTagCache.get_by_guid(test_guid) + + # Multiple calls with the same GUID result in no additional API lookups + # as the object is already cached + connection = SourceTagCache.get_by_guid(test_guid) + connection = SourceTagCache.get_by_guid(test_guid) + + assert test_guid == connection.guid + assert test_qn == connection.qualified_name + + # The method is called three times, but the lookup is triggered only once + assert mock_get_cache.call_count == 3 + mock_lookup_by_guid.assert_called_once() + + +@patch.object(SourceTagCache, "lookup_by_guid") +@patch.object(SourceTagCache, "lookup_by_qualified_name") +@patch.object( + SourceTagCache, "get_cache", return_value=SourceTagCache(client=AtlanClient()) +) +def test_get_by_qualified_name(mock_get_cache, mock_lookup_by_qn, mock_lookup_by_guid): + test_guid = "test-guid-123" + test_qn = "test-qualified-name" + conn = Connection() + conn.guid = test_guid + conn.qualified_name = test_qn + test_asset = conn + + mock_guid_to_asset = Mock() + mock_name_to_guid = Mock() + mock_qualified_name_to_guid = Mock() + + # 1 - Not found in the cache, triggers a lookup call + # 2, 3, 4 - Uses the cached entry from the map + mock_qualified_name_to_guid.get.side_effect = [ + None, + test_guid, + test_guid, + test_guid, + ] + + # Other caches will be populated once + # the lookup call for get_by_qualified_name is made + mock_guid_to_asset.get.side_effect = [ + test_asset, + test_asset, + test_asset, + test_asset, + ] + mock_name_to_guid.get.side_effect = [test_guid, test_guid, test_guid, test_guid] + + mock_get_cache.return_value.guid_to_asset = mock_guid_to_asset + mock_get_cache.return_value.name_to_guid = mock_name_to_guid + mock_get_cache.return_value.qualified_name_to_guid = mock_qualified_name_to_guid + + connection = SourceTagCache.get_by_qualified_name(test_qn) + + # Multiple calls with the same + # qualified name result in no additional API lookups + # as the object is already cached + connection = SourceTagCache.get_by_qualified_name(test_qn) + connection = SourceTagCache.get_by_qualified_name(test_qn) + + assert test_guid == connection.guid + assert test_qn == connection.qualified_name + + # The method is called three times + # but the lookup is triggered only once + assert mock_get_cache.call_count == 3 + mock_lookup_by_qn.assert_called_once() + + # No call to guid lookup since the object is already in the cache + assert mock_lookup_by_guid.call_count == 0 + + +@patch.object(SourceTagCache, "lookup_by_guid") +@patch.object(SourceTagCache, "lookup_by_name") +@patch.object( + SourceTagCache, "get_cache", return_value=SourceTagCache(client=AtlanClient()) +) +def test_get_by_name(mock_get_cache, mock_lookup_by_name, mock_lookup_by_guid): + test_name = SourceTagName("snowflake/test@@DB/SCHEMA/TEST_TAG") + test_guid = "test-guid-123" + test_qn = "test-qualified-name" + conn = Connection() + conn.guid = test_guid + conn.qualified_name = test_qn + test_asset = conn + + mock_guid_to_asset = Mock() + mock_name_to_guid = Mock() + mock_qualified_name_to_guid = Mock() + + # 1 - Not found in the cache, triggers a lookup call + # 2, 3, 4 - Uses the cached entry from the map + mock_name_to_guid.get.side_effect = [ + None, + test_guid, + test_guid, + test_guid, + ] + + # Other caches will be populated once + # the lookup call for get_by_qualified_name is made + mock_guid_to_asset.get.side_effect = [ + test_asset, + test_asset, + test_asset, + test_asset, + ] + mock_qualified_name_to_guid.get.side_effect = [ + test_guid, + test_guid, + test_guid, + test_guid, + ] + + mock_get_cache.return_value.guid_to_asset = mock_guid_to_asset + mock_get_cache.return_value.name_to_guid = mock_name_to_guid + mock_get_cache.return_value.qualified_name_to_guid = mock_qualified_name_to_guid + + connection = SourceTagCache.get_by_name(test_name) + + # Multiple calls with the same + # qualified name result in no additional API lookups + # as the object is already cached + connection = SourceTagCache.get_by_name(test_name) + connection = SourceTagCache.get_by_name(test_name) + + assert test_guid == connection.guid + assert test_qn == connection.qualified_name + + # The method is called three times + # but the lookup is triggered only once + assert mock_get_cache.call_count == 3 + mock_lookup_by_name.assert_called_once() + + # No call to guid lookup since the object is already in the cache + assert mock_lookup_by_guid.call_count == 0