Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature-2924: Add an option to suppress server identification headers #3770

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/cli/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ pub struct Config {
pub key: Option<PathBuf>,
pub tick_interval: Duration,
pub engine: Option<EngineOptions>,
pub no_identification_headers: bool,
}
7 changes: 6 additions & 1 deletion src/cli/start.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ pub struct StartCommandArguments {
#[arg(env = "SURREAL_BIND", short = 'b', long = "bind")]
#[arg(default_value = "127.0.0.1:8000")]
listen_addresses: Vec<SocketAddr>,

#[arg(help = "Whether to suppress the server name and version headers")]
#[arg(env = "SURREAL_NO_IDENTIFICATION_HEADERS", long)]
#[arg(default_value_t = false)]
no_identification_headers: bool,
//
// Database options
//
Expand Down Expand Up @@ -142,6 +145,7 @@ pub async fn init(
log,
tick_interval,
no_banner,
no_identification_headers,
..
}: StartCommandArguments,
) -> Result<(), Error> {
Expand Down Expand Up @@ -171,6 +175,7 @@ pub async fn init(
user,
pass,
tick_interval,
no_identification_headers,
crt: web.as_ref().and_then(|x| x.web_crt.clone()),
key: web.as_ref().and_then(|x| x.web_key.clone()),
engine: None,
Expand Down
22 changes: 17 additions & 5 deletions src/net/headers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,25 @@ pub use db::SurrealDatabase;
pub use id::SurrealId;
pub use ns::SurrealNamespace;

pub fn add_version_header() -> SetResponseHeaderLayer<HeaderValue> {
let val = format!("{PKG_NAME}-{}", *PKG_VERSION);
SetResponseHeaderLayer::if_not_present(VERSION.to_owned(), HeaderValue::try_from(val).unwrap())
pub fn add_version_header(enabled: bool) -> SetResponseHeaderLayer<Option<HeaderValue>> {
let header_value = if enabled {
let val = format!("{PKG_NAME}-{}", *PKG_VERSION);
Some(HeaderValue::try_from(val).unwrap())
} else {
None
};

SetResponseHeaderLayer::if_not_present(VERSION.to_owned(), header_value)
}

pub fn add_server_header() -> SetResponseHeaderLayer<HeaderValue> {
SetResponseHeaderLayer::if_not_present(SERVER, HeaderValue::try_from(SERVER_NAME).unwrap())
pub fn add_server_header(enabled: bool) -> SetResponseHeaderLayer<Option<HeaderValue>> {
let header_value = if enabled {
Some(HeaderValue::try_from(SERVER_NAME).unwrap())
} else {
None
};

SetResponseHeaderLayer::if_not_present(SERVER, header_value)
}

// Parse a TypedHeader, returning None if the header is missing and an error if the header is invalid.
Expand Down
4 changes: 2 additions & 2 deletions src/net/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ pub async fn init(ct: CancellationToken) -> Result<(), Error> {
.layer(HttpMetricsLayer)
.layer(SetSensitiveResponseHeadersLayer::from_shared(headers))
.layer(AsyncRequireAuthorizationLayer::new(auth::SurrealAuth))
.layer(headers::add_server_header())
.layer(headers::add_version_header())
.layer(headers::add_server_header(!opt.no_identification_headers))
.layer(headers::add_version_header(!opt.no_identification_headers))
.layer(
CorsLayer::new()
.allow_methods([
Expand Down
29 changes: 28 additions & 1 deletion tests/http_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ mod http_integration {
use test_log::test;
use ulid::Ulid;

use super::common::{self, PASS, USER};
use super::common::{self, StartServerArguments, PASS, USER};

#[test(tokio::test)]
async fn basic_auth() -> Result<(), Box<dyn std::error::Error>> {
Expand Down Expand Up @@ -352,6 +352,33 @@ mod http_integration {
Ok(())
}

#[test(tokio::test)]
async fn no_server_id_headers() -> Result<(), Box<dyn std::error::Error>> {
// default server has the id headers
{
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
let url = &format!("http://{addr}/health");

let res = Client::default().get(url).send().await?;
assert!(res.headers().contains_key("server"));
assert!(res.headers().contains_key("surreal-version"));
}

// turn on the no-identification-headers option to suppress headers
{
let mut start_server_arguments = StartServerArguments::default();
start_server_arguments.args.push_str(" --no-identification-headers");
let (addr, _server) = common::start_server(start_server_arguments).await.unwrap();
let url = &format!("http://{addr}/health");

let res = Client::default().get(url).send().await?;
assert!(!res.headers().contains_key("server"));
assert!(!res.headers().contains_key("surreal-version"));
}

Ok(())
}

byarr marked this conversation as resolved.
Show resolved Hide resolved
#[test(tokio::test)]
async fn import_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
Expand Down
Loading