diff --git a/osbenchmark/benchmark.py b/osbenchmark/benchmark.py index 0dd6de593..ab0b83d2f 100644 --- a/osbenchmark/benchmark.py +++ b/osbenchmark/benchmark.py @@ -43,6 +43,7 @@ from osbenchmark.workload_generator import workload_generator from osbenchmark.utils import io, convert, process, console, net, opts, versions from osbenchmark import aggregator +from osbenchmark.database.registry import DatabaseType def create_arg_parser(): def positive_number(v): @@ -598,6 +599,13 @@ def add_workload_source(subparser): help=f"Define a comma-separated list of client options to use. The options will be passed to the OpenSearch " f"Python client (default: {opts.ClientOptions.DEFAULT_CLIENT_OPTIONS}).", default=opts.ClientOptions.DEFAULT_CLIENT_OPTIONS) + test_run_parser.add_argument( + "--database-type", + help="Target database backend. Selects the DatabaseClient adapter used to run " + "the workload (default: opensearch). Choices are populated from the " + "registered DatabaseType enum.", + choices=[d.value for d in DatabaseType], + default=DatabaseType.OPENSEARCH.value) test_run_parser.add_argument("--on-error", choices=["continue", "abort"], help="Controls how OSB behaves on response errors (default: continue).", @@ -1079,6 +1087,11 @@ def configure_connection_params(arg_parser, args, cfg): # Configure gRPC target hosts grpc_target_hosts = opts.TargetHosts(args.grpc_target_hosts) if hasattr(args, "grpc_target_hosts") and args.grpc_target_hosts else None cfg.add(config.Scope.applicationOverride, "client", "grpc_hosts", grpc_target_hosts) + + # Configure database backend; worker_coordinator reads cfg.opts("database", "type") + # to pick the DatabaseClient factory via the database/ registry. + database_type = getattr(args, "database_type", "opensearch") + cfg.add(config.Scope.applicationOverride, "database", "type", database_type) if "timeout" not in client_options.default: console.info("You did not provide an explicit timeout in the client options. Assuming default of 10 seconds.") if list(target_hosts.all_hosts) != list(client_options.all_client_options): diff --git a/osbenchmark/test_run_orchestrator.py b/osbenchmark/test_run_orchestrator.py index d4ec442ab..9e323370b 100644 --- a/osbenchmark/test_run_orchestrator.py +++ b/osbenchmark/test_run_orchestrator.py @@ -184,7 +184,13 @@ def setup(self, sources=False): # but there are rare cases (external pipeline and user did not specify the distribution version) where we need # to derive it ourselves. For source builds we always assume "master" oss_distribution_version = "2.11.0" - if not sources and not self.cfg.exists("builder", "distribution.version"): + # Distribution-version auto-detection probes the target with a raw opensearchpy + # client and the legacy wait_for_rest_layer. For non-OpenSearch backends the + # OS-flavored health/info routes either don't exist or differ enough to make + # this probe meaningless. Skip the whole branch for non-OS databases; the + # OS-min-version check below is also OS-specific. + database_type = self.cfg.opts("database", "type", default_value="opensearch", mandatory=False) + if not sources and not self.cfg.exists("builder", "distribution.version") and database_type.lower() == "opensearch": distribution_version = builder.cluster_distribution_version(self.cfg) if distribution_version == 'oss': self.logger.info("Automatically derived serverless collection, setting distribution version to 2.11.0") diff --git a/osbenchmark/worker_coordinator/worker_coordinator.py b/osbenchmark/worker_coordinator/worker_coordinator.py index b650fe89c..4f843d953 100644 --- a/osbenchmark/worker_coordinator/worker_coordinator.py +++ b/osbenchmark/worker_coordinator/worker_coordinator.py @@ -972,16 +972,33 @@ def __init__(self, target, config, os_client_factory_class=client.OsClientFactor self.complete_current_task_sent = False self.telemetry = None + # Caches the DatabaseClientFactory per cluster for non-OpenSearch backends + # so the same instance can be reused by wait_for_rest_api without a second + # construction. + self._database_factories = {} def create_os_clients(self): all_hosts = self.config.opts("client", "hosts").all_hosts + database_type = self.config.opts("database", "type", default_value="opensearch", mandatory=False) opensearch = {} for cluster_name, cluster_hosts in all_hosts.items(): all_client_options = self.config.opts("client", "options").all_client_options cluster_client_options = dict(all_client_options[cluster_name]) # Use retries to avoid aborts on long living connections for telemetry devices cluster_client_options["retry-on-timeout"] = True - opensearch[cluster_name] = self.os_client_factory(cluster_hosts, cluster_client_options).create() + if database_type.lower() == "opensearch": + opensearch[cluster_name] = self.os_client_factory(cluster_hosts, cluster_client_options).create() + else: + # Non-OpenSearch backends route through the registry so the right + # url_prefix / transport configuration is applied. This path + # mirrors the async create path at WorkerCoordinator.os_clients. + # Cache the factory so wait_for_rest_api can reuse it for the + # readiness probe rather than constructing a second instance. + db_factory = DatabaseClientFactory.create_client_factory( + database_type, cluster_hosts, cluster_client_options, + ) + self._database_factories[cluster_name] = db_factory + opensearch[cluster_name] = db_factory.create() return opensearch def prepare_telemetry(self, opensearch, enable): @@ -1021,6 +1038,19 @@ def prepare_telemetry(self, opensearch, enable): def wait_for_rest_api(self, opensearch): os_default = opensearch["default"] + database_type = self.config.opts("database", "type", default_value="opensearch", mandatory=False) + # The legacy wait_for_rest_layer probes /_cluster/health then falls back + # to /_cat/indices — both OS-API-shape-specific. For non-OS backends, + # delegate to a factory-provided probe. + if database_type.lower() != "opensearch": + db_factory = self._database_factories.get("default") + self.logger.info("Checking if non-OS REST layer is available (database_type=%s).", database_type) + if db_factory is not None and hasattr(db_factory, "wait_for_rest_layer") \ + and db_factory.wait_for_rest_layer(max_attempts=40): + self.logger.info("REST layer is available.") + return + self.logger.error("Non-OS REST layer is not yet available. Stopping benchmark.") + raise exceptions.SystemSetupError(f"{database_type} REST layer is not available.") self.logger.info("Checking if REST API is available.") if client.wait_for_rest_layer(os_default, max_attempts=40): self.logger.info("REST API is available.")