diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index b4cd78cf..3505ebda 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -251,6 +251,7 @@ def read(self) -> Optional[OAuthToken]: self.telemetry_enabled = ( self.client_telemetry_enabled and self.server_telemetry_enabled ) + self.telemetry_batch_size = kwargs.get("telemetry_batch_size") user_agent_entry = kwargs.get("user_agent_entry") if user_agent_entry is None: @@ -312,6 +313,7 @@ def read(self) -> Optional[OAuthToken]: session_id_hex=self.get_session_id_hex(), auth_provider=auth_provider, host_url=self.host, + batch_size=self.telemetry_batch_size, ) self._telemetry_client = TelemetryClientFactory.get_telemetry_client( diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 5eb8c6ed..1846c0b2 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -141,10 +141,13 @@ def __init__( auth_provider, host_url, executor, + batch_size=None, ): logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled - self._batch_size = self.DEFAULT_BATCH_SIZE + self._batch_size = ( + batch_size if batch_size is not None else self.DEFAULT_BATCH_SIZE + ) self._session_id_hex = session_id_hex self._auth_provider = auth_provider self._user_agent = None @@ -311,7 +314,7 @@ def close(self): class TelemetryClientFactory: """ Static factory class for creating and managing telemetry clients. - It uses a thread pool to handle asynchronous operations. + It uses a thread pool to handle asynchronous operations and a single flush thread for all clients. """ _clients: Dict[ @@ -324,6 +327,11 @@ class TelemetryClientFactory: _original_excepthook = None _excepthook_installed = False + # Shared flush thread for all clients + _flush_thread = None + _flush_event = threading.Event() + _flush_interval_seconds = 90 + @classmethod def _initialize(cls): """Initialize the factory if not already initialized""" @@ -334,11 +342,39 @@ def _initialize(cls): max_workers=10 ) # Thread pool for async operations cls._install_exception_hook() + cls._start_flush_thread() cls._initialized = True logger.debug( "TelemetryClientFactory initialized with thread pool (max_workers=10)" ) + @classmethod + def _start_flush_thread(cls): + """Start the shared background thread for periodic flushing of all clients""" + cls._flush_event.clear() + cls._flush_thread = threading.Thread(target=cls._flush_worker, daemon=True) + cls._flush_thread.start() + + @classmethod + def _flush_worker(cls): + """Background worker thread for periodic flushing of all clients""" + while not cls._flush_event.wait(cls._flush_interval_seconds): + logger.debug("Performing periodic flush for all telemetry clients") + + with cls._lock: + clients_to_flush = list(cls._clients.values()) + + for client in clients_to_flush: + client._flush() + + @classmethod + def _stop_flush_thread(cls): + """Stop the shared background flush thread""" + if cls._flush_thread is not None: + cls._flush_event.set() + cls._flush_thread.join(timeout=1.0) + cls._flush_thread = None + @classmethod def _install_exception_hook(cls): """Install global exception handler for unhandled exceptions""" @@ -367,6 +403,7 @@ def initialize_telemetry_client( session_id_hex, auth_provider, host_url, + batch_size=None, ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: @@ -388,6 +425,7 @@ def initialize_telemetry_client( auth_provider=auth_provider, host_url=host_url, executor=TelemetryClientFactory._executor, + batch_size=batch_size, ) else: TelemetryClientFactory._clients[ @@ -426,6 +464,7 @@ def close(session_id_hex): "No more telemetry clients, shutting down thread pool executor" ) try: + TelemetryClientFactory._stop_flush_thread() TelemetryClientFactory._executor.shutdown(wait=True) except Exception as e: logger.debug("Failed to shutdown thread pool executor: %s", e)