diff --git a/RELEASES.md b/RELEASES.md index 378c113a649..6b026c360ef 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -7,6 +7,7 @@ None ### Internal Improvements - Improve live engines error logging (will now log all exceptions rather than just `RuntimeError`) +- Refined `HttpClient` for use directly from Rust - Upgraded `datafusion` crate to v43.0.0 (#2056), thanks @twitu ### Breaking Changes diff --git a/nautilus_core/network/src/http.rs b/nautilus_core/network/src/http.rs index 1f110273e62..0a2bc503e48 100644 --- a/nautilus_core/network/src/http.rs +++ b/nautilus_core/network/src/http.rs @@ -18,7 +18,6 @@ use std::{collections::HashMap, hash::Hash, sync::Arc, time::Duration}; use bytes::Bytes; -use futures_util::{stream, StreamExt}; use reqwest::{ header::{HeaderMap, HeaderName}, Method, Response, Url, @@ -111,10 +110,10 @@ impl From for HttpClientError { pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network") )] pub struct HttpClient { + /// The underlying HTTP client used to make requests. + pub(crate) client: InnerHttpClient, /// The rate limiter to control the request rate. pub(crate) rate_limiter: Arc>, - /// The underlying HTTP client used to make requests. - pub(crate) client: Arc, } impl HttpClient { @@ -125,15 +124,15 @@ impl HttpClient { keyed_quotas: Vec<(String, Quota)>, default_quota: Option, ) -> Self { - let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas)); let client = InnerHttpClient { client: reqwest::Client::new(), - header_keys, + header_keys: Arc::new(header_keys), }; + let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas)); Self { + client, rate_limiter, - client: Arc::new(client), } } @@ -156,24 +155,14 @@ impl HttpClient { method: Method, url: String, headers: Option>, - body: Option, + body: Option>, keys: Option>, timeout_secs: Option, ) -> Result { - let client = self.client.clone(); let rate_limiter = self.rate_limiter.clone(); - // Check keys for rate limiting quota - let keys = keys.unwrap_or_default(); - let tasks = keys.iter().map(|key| rate_limiter.until_key_ready(key)); - - stream::iter(tasks) - .for_each(|key| async move { - key.await; - }) - .await; - - client + rate_limiter.await_keys_ready(keys).await; + self.client .send_request(method, url, headers, body, timeout_secs) .await } @@ -190,7 +179,7 @@ impl HttpClient { #[derive(Clone)] pub struct InnerHttpClient { pub(crate) client: reqwest::Client, - pub(crate) header_keys: Vec, + pub(crate) header_keys: Arc>, } impl InnerHttpClient { @@ -206,7 +195,7 @@ impl InnerHttpClient { method: Method, url: String, headers: Option>, - body: Option, + body: Option>, timeout_secs: Option, ) -> Result { let headers = headers.unwrap_or_default(); @@ -233,7 +222,7 @@ impl InnerHttpClient { let request = match body { Some(b) => request_builder - .body(b.to_vec()) + .body(b) .build() .map_err(HttpClientError::from)?, None => request_builder.build().map_err(HttpClientError::from)?, @@ -387,7 +376,7 @@ mod tests { ); let body_string = serde_json::to_string(&body).unwrap(); - let body_bytes = Bytes::from(body_string.into_bytes()); + let body_bytes = body_string.into_bytes(); let response = client .send_request( diff --git a/nautilus_core/network/src/python/http.rs b/nautilus_core/network/src/python/http.rs index 7c37865dd72..9101daf1964 100644 --- a/nautilus_core/network/src/python/http.rs +++ b/nautilus_core/network/src/python/http.rs @@ -19,8 +19,7 @@ use std::{ }; use bytes::Bytes; -use futures_util::{stream, StreamExt}; -use pyo3::{create_exception, exceptions::PyException, prelude::*, types::PyBytes}; +use pyo3::{create_exception, exceptions::PyException, prelude::*}; use crate::{ http::{HttpClient, HttpClientError, HttpMethod, HttpResponse}, @@ -138,24 +137,16 @@ impl HttpClient { method: HttpMethod, url: String, headers: Option>, - body: Option>, + body: Option>, keys: Option>, timeout_secs: Option, py: Python<'py>, ) -> PyResult> { let client = self.client.clone(); let rate_limiter = self.rate_limiter.clone(); - let keys = keys.unwrap_or_default(); - let body = body.map(|py_bytes| Bytes::from(py_bytes.as_bytes().to_vec())); pyo3_async_runtimes::tokio::future_into_py(py, async move { - // TODO: Consolidate rate limiting - let tasks = keys.iter().map(|key| rate_limiter.until_key_ready(key)); - stream::iter(tasks) - .for_each(|key| async move { - key.await; - }) - .await; + rate_limiter.await_keys_ready(keys).await; client .send_request(method.into(), url, headers, body, timeout_secs) .await diff --git a/nautilus_core/network/src/ratelimiter/mod.rs b/nautilus_core/network/src/ratelimiter/mod.rs index 724087ee8b8..d2687e751ab 100644 --- a/nautilus_core/network/src/ratelimiter/mod.rs +++ b/nautilus_core/network/src/ratelimiter/mod.rs @@ -29,6 +29,7 @@ use std::{ }; use dashmap::DashMap; +use futures_util::StreamExt; use tokio::time::sleep; use self::{ @@ -191,6 +192,17 @@ where } } } + + pub async fn await_keys_ready(&self, keys: Option>) { + let keys = keys.unwrap_or_default(); + let tasks = keys.iter().map(|key| self.until_key_ready(key)); + + futures::stream::iter(tasks) + .for_each_concurrent(None, |key_future| async move { + key_future.await; + }) + .await; + } } ////////////////////////////////////////////////////////////////////////////////