diff --git a/src/cli/src/database.rs b/src/cli/src/database.rs index 7152aac59270..24c4514fbc4c 100644 --- a/src/cli/src/database.rs +++ b/src/cli/src/database.rs @@ -17,6 +17,7 @@ use std::time::Duration; use base64::engine::general_purpose; use base64::Engine; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; +use common_error::ext::BoxedError; use humantime::format_duration; use serde_json::Value; use servers::http::header::constants::GREPTIME_DB_HEADER_TIMEOUT; @@ -24,7 +25,9 @@ use servers::http::result::greptime_result_v1::GreptimedbV1Response; use servers::http::GreptimeQueryOutput; use snafu::ResultExt; -use crate::error::{HttpQuerySqlSnafu, Result, SerdeJsonSnafu}; +use crate::error::{ + BuildClientSnafu, HttpQuerySqlSnafu, ParseProxyOptsSnafu, Result, SerdeJsonSnafu, +}; #[derive(Debug, Clone)] pub struct DatabaseClient { @@ -32,6 +35,23 @@ pub struct DatabaseClient { catalog: String, auth_header: Option, timeout: Duration, + proxy: Option, +} + +pub fn parse_proxy_opts( + proxy: Option, + no_proxy: bool, +) -> std::result::Result, BoxedError> { + if no_proxy { + return Ok(None); + } + proxy + .map(|proxy| { + reqwest::Proxy::all(proxy) + .context(ParseProxyOptsSnafu) + .map_err(BoxedError::new) + }) + .transpose() } impl DatabaseClient { @@ -40,6 +60,7 @@ impl DatabaseClient { catalog: String, auth_basic: Option, timeout: Duration, + proxy: Option, ) -> Self { let auth_header = if let Some(basic) = auth_basic { let encoded = general_purpose::STANDARD.encode(basic); @@ -48,11 +69,18 @@ impl DatabaseClient { None }; + if let Some(ref proxy) = proxy { + common_telemetry::info!("Using proxy: {:?}", proxy); + } else { + common_telemetry::info!("Using system proxy(if any)"); + } + Self { addr, catalog, auth_header, timeout, + proxy, } } @@ -67,7 +95,13 @@ impl DatabaseClient { ("db", format!("{}-{}", self.catalog, schema)), ("sql", sql.to_string()), ]; - let mut request = reqwest::Client::new() + let client = self + .proxy + .clone() + .map(|proxy| reqwest::Client::builder().proxy(proxy).build()) + .unwrap_or_else(|| Ok(reqwest::Client::new())) + .context(BuildClientSnafu)?; + let mut request = client .post(&url) .form(¶ms) .header("Content-Type", "application/x-www-form-urlencoded"); diff --git a/src/cli/src/error.rs b/src/cli/src/error.rs index bf0b6342c1f9..1b79ee759be1 100644 --- a/src/cli/src/error.rs +++ b/src/cli/src/error.rs @@ -86,6 +86,22 @@ pub enum Error { location: Location, }, + #[snafu(display("Failed to parse proxy options: {}", error))] + ParseProxyOpts { + #[snafu(source)] + error: reqwest::Error, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Failed to build reqwest client: {}", error))] + BuildClient { + #[snafu(implicit)] + location: Location, + #[snafu(source)] + error: reqwest::Error, + }, + #[snafu(display("Invalid REPL command: {reason}"))] InvalidReplCommand { reason: String }, @@ -278,7 +294,8 @@ impl ErrorExt for Error { | Error::InitTimezone { .. } | Error::ConnectEtcd { .. } | Error::CreateDir { .. } - | Error::EmptyResult { .. } => StatusCode::InvalidArguments, + | Error::EmptyResult { .. } + | Error::ParseProxyOpts { .. } => StatusCode::InvalidArguments, Error::StartProcedureManager { source, .. } | Error::StopProcedureManager { source, .. } => source.status_code(), @@ -298,7 +315,8 @@ impl ErrorExt for Error { Error::SerdeJson { .. } | Error::FileIo { .. } | Error::SpawnThread { .. } - | Error::InitTlsProvider { .. } => StatusCode::Unexpected, + | Error::InitTlsProvider { .. } + | Error::BuildClient { .. } => StatusCode::Unexpected, Error::Other { source, .. } => source.status_code(), diff --git a/src/cli/src/export.rs b/src/cli/src/export.rs index 91e4be22bb93..846e2a49adc6 100644 --- a/src/cli/src/export.rs +++ b/src/cli/src/export.rs @@ -28,7 +28,7 @@ use tokio::io::{AsyncWriteExt, BufWriter}; use tokio::sync::Semaphore; use tokio::time::Instant; -use crate::database::DatabaseClient; +use crate::database::{parse_proxy_opts, DatabaseClient}; use crate::error::{EmptyResultSnafu, Error, FileIoSnafu, Result, SchemaNotFoundSnafu}; use crate::{database, Tool}; @@ -91,19 +91,30 @@ pub struct ExportCommand { /// The default behavior will disable server-side default timeout(i.e. `0s`). #[clap(long, value_parser = humantime::parse_duration)] timeout: Option, + + /// The proxy server address to connect, if set, will override the system proxy. + /// + /// The default behavior will use the system proxy if neither `proxy` nor `no_proxy` is set. + #[clap(long)] + proxy: Option, + + /// Disable proxy server, if set, will not use any proxy. + #[clap(long)] + no_proxy: bool, } impl ExportCommand { pub async fn build(&self) -> std::result::Result, BoxedError> { let (catalog, schema) = database::split_database(&self.database).map_err(BoxedError::new)?; - + let proxy = parse_proxy_opts(self.proxy.clone(), self.no_proxy)?; let database_client = DatabaseClient::new( self.addr.clone(), catalog.clone(), self.auth_basic.clone(), // Treats `None` as `0s` to disable server-side default timeout. self.timeout.unwrap_or_default(), + proxy, ); Ok(Box::new(Export { diff --git a/src/cli/src/import.rs b/src/cli/src/import.rs index f76560fbcd55..7cff2fd37f24 100644 --- a/src/cli/src/import.rs +++ b/src/cli/src/import.rs @@ -25,7 +25,7 @@ use snafu::{OptionExt, ResultExt}; use tokio::sync::Semaphore; use tokio::time::Instant; -use crate::database::DatabaseClient; +use crate::database::{parse_proxy_opts, DatabaseClient}; use crate::error::{Error, FileIoSnafu, Result, SchemaNotFoundSnafu}; use crate::{database, Tool}; @@ -76,18 +76,30 @@ pub struct ImportCommand { /// The default behavior will disable server-side default timeout(i.e. `0s`). #[clap(long, value_parser = humantime::parse_duration)] timeout: Option, + + /// The proxy server address to connect, if set, will override the system proxy. + /// + /// The default behavior will use the system proxy if neither `proxy` nor `no_proxy` is set. + #[clap(long)] + proxy: Option, + + /// Disable proxy server, if set, will not use any proxy. + #[clap(long, default_value = "false")] + no_proxy: bool, } impl ImportCommand { pub async fn build(&self) -> std::result::Result, BoxedError> { let (catalog, schema) = database::split_database(&self.database).map_err(BoxedError::new)?; + let proxy = parse_proxy_opts(self.proxy.clone(), self.no_proxy)?; let database_client = DatabaseClient::new( self.addr.clone(), catalog.clone(), self.auth_basic.clone(), // Treats `None` as `0s` to disable server-side default timeout. self.timeout.unwrap_or_default(), + proxy, ); Ok(Box::new(Import {