diff --git a/python/psql_rust_engine/__init__.py b/python/psql_rust_engine/__init__.py index 118574c2..8fb49001 100644 --- a/python/psql_rust_engine/__init__.py +++ b/python/psql_rust_engine/__init__.py @@ -1,6 +1,5 @@ -from ._internal import RustEngine, PyRustEngine +from ._internal import PyRustEngine __all__ = [ - "RustEngine", "PyRustEngine", ] \ No newline at end of file diff --git a/python/psql_rust_engine/_internal/__init__.pyi b/python/psql_rust_engine/_internal/__init__.pyi index 4cb78042..e4ff8531 100644 --- a/python/psql_rust_engine/_internal/__init__.pyi +++ b/python/psql_rust_engine/_internal/__init__.pyi @@ -13,30 +13,6 @@ class RustEnginePyQueryResult: """""" -class RustEngine: - """Rust engine.""" - - def __init__( - self, - username: Optional[str], - password: Optional[str], - host: Optional[str], - port: Optional[int], - db_name: Optional[str], - ) -> None: - """Test ebana.""" - - async def startup(self) -> None: - ... - - async def execute( - self, - querystring: str, - parameters: List[Any], - ) -> RustEnginePyQueryResult: - ... - - class PyRustEngine: """Aboba""" diff --git a/src/engine.rs b/src/engine.rs index eb1ea7e8..3b6672e7 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -1,5 +1,5 @@ use deadpool_postgres::{Manager, ManagerConfig, Object, Pool, RecyclingMethod}; -use pyo3::{pyclass, pymethods, IntoPy, PyAny, PyObject, Python}; +use pyo3::{pyclass, pymethods, IntoPy, Py, PyAny, PyObject, PyRef, PyRefMut, Python}; use std::{future::Future, sync::Arc, vec}; use tokio_postgres::{types::ToSql, NoTls}; @@ -19,57 +19,220 @@ where Ok(res) } -#[pyclass()] pub struct RustEngineTransaction { db_client: Arc>, is_started: Arc>, + is_done: Arc>, } -#[pymethods] impl RustEngineTransaction { + pub async fn inner_execute<'a>( + &'a self, + querystring: String, + parameters: Vec, + ) -> RustEnginePyResult { + let db_client_arc = self.db_client.clone(); + let is_started_arc = self.is_started.clone(); + let is_done_arc = self.is_done.clone(); + + let db_client_guard = db_client_arc.read().await; + let is_started_guard = is_started_arc.read().await; + let is_done_guard = is_done_arc.read().await; + + if !*is_started_guard { + return Err(RustEngineError::DBTransactionError( + "Transaction is not started, please call begin() on transaction".into(), + )); + } + if *is_done_guard { + return Err(RustEngineError::DBTransactionError( + "Transaction is already committed or rolled back".into(), + )); + } + + let mut vec_parameters: Vec<&(dyn ToSql + Sync)> = Vec::with_capacity(parameters.len()); + for param in parameters.iter() { + vec_parameters.push(param); + } + + let statement: tokio_postgres::Statement = + db_client_guard.prepare_cached(&querystring).await.unwrap(); + + let result = db_client_guard + .query(&statement, &vec_parameters.into_boxed_slice()) + .await?; + + Ok(RustEnginePyQueryResult::new(result)) + } + + pub async fn inner_begin<'a>(&'a self) -> RustEnginePyResult<()> { + let db_client_arc = self.db_client.clone(); + let is_started_arc = self.is_started.clone(); + let is_done_arc = self.is_done.clone(); + + let started = { + let is_started_guard = is_started_arc.read().await; + is_started_guard.clone() + }; + if started { + return Err(RustEngineError::DBTransactionError( + "Transaction is already started".into(), + )); + } + + let done = { + let is_done_guard = is_done_arc.read().await; + is_done_guard.clone() + }; + if done { + return Err(RustEngineError::DBTransactionError( + "Transaction is already committed or rolled back".into(), + )); + } + + let db_client_guard = db_client_arc.read().await; + db_client_guard.batch_execute("BEGIN").await?; + let mut is_started_write_guard = is_started_arc.write().await; + *is_started_write_guard = true; + + Ok(()) + } + + pub async fn inner_commit<'a>(&'a self) -> RustEnginePyResult<()> { + let db_client_arc = self.db_client.clone(); + let is_started_arc = self.is_started.clone(); + let is_done_arc = self.is_done.clone(); + + let started = { + let is_started_guard = is_started_arc.read().await; + is_started_guard.clone() + }; + if !started { + return Err(RustEngineError::DBTransactionError( + "Can not commit not started transaction".into(), + )); + } + + let done = { + let is_done_guard = is_done_arc.read().await; + is_done_guard.clone() + }; + if done { + return Err(RustEngineError::DBTransactionError( + "Transaction is already committed or rolled back".into(), + )); + } + + let db_client_guard = db_client_arc.read().await; + db_client_guard.batch_execute("COMMIT").await?; + let mut is_done_write_guard = is_done_arc.write().await; + *is_done_write_guard = true; + + Ok(()) + } +} + +#[pyclass()] +pub struct PyRustEngineTransaction { + transaction: Arc>, +} + +#[pymethods] +impl PyRustEngineTransaction { + #[must_use] + pub fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + pub fn __anext__(&self, py: Python<'_>) -> RustEnginePyResult> { + let transaction_clone = self.transaction.clone(); + let future = rustengine_future(py, async move { + Ok(PyRustEngineTransaction { + transaction: transaction_clone, + }) + }); + Ok(Some(future?.into())) + } + + pub fn __await__<'a>( + slf: PyRefMut<'a, Self>, + _py: Python, + ) -> RustEnginePyResult> { + println!("__await__"); + Ok(slf) + } + + fn __aenter__<'a>(slf: PyRefMut<'a, Self>, py: Python<'a>) -> RustEnginePyResult<&'a PyAny> { + let transaction_arc = slf.transaction.clone(); + let transaction_arc2 = slf.transaction.clone(); + rustengine_future(py, async move { + let transaction_guard = transaction_arc.read().await; + transaction_guard.inner_begin().await?; + Ok(PyRustEngineTransaction { + transaction: transaction_arc2, + }) + }) + } + + fn __aexit__<'a>( + slf: PyRefMut<'a, Self>, + py: Python<'a>, + _exception_type: Py, + _exception: Py, + _traceback: Py, + ) -> RustEnginePyResult<&'a PyAny> { + let transaction_arc = slf.transaction.clone(); + let transaction_arc2 = slf.transaction.clone(); + rustengine_future(py, async move { + let transaction_guard = transaction_arc.read().await; + transaction_guard.inner_commit().await?; + Ok(PyRustEngineTransaction { + transaction: transaction_arc2, + }) + }) + } + pub fn execute<'a>( &'a self, py: Python<'a>, querystring: String, parameters: Option<&'a PyAny>, ) -> RustEnginePyResult<&PyAny> { - let db_client_arc = self.db_client.clone(); - let is_started_arc = self.is_started.clone(); - + let transaction_arc = self.transaction.clone(); let mut params: Vec = vec![]; if let Some(parameters) = parameters { params = convert_parameters(parameters)? } rustengine_future(py, async move { - let db_client_guard = db_client_arc.read().await; - let started = { - let is_started_guard = is_started_arc.read().await; - is_started_guard.clone() - }; - - if !started { - let mut is_started_write_guard = is_started_arc.write().await; - println!("Called BEGIN!"); - db_client_guard.batch_execute("BEGIN").await?; - *is_started_write_guard = true; - }; - - let mut vec_parameters: Vec<&(dyn ToSql + Sync)> = Vec::with_capacity(params.len()); - for param in params.iter() { - vec_parameters.push(param); - } - - let result = db_client_guard - .query(&querystring, &vec_parameters.into_boxed_slice()) - .await?; - - Ok(RustEnginePyQueryResult::new(result)) + let transaction_guard = transaction_arc.read().await; + Ok(transaction_guard.inner_execute(querystring, params).await?) + }) + } + + pub fn begin<'a>(&'a self, py: Python<'a>) -> RustEnginePyResult<&PyAny> { + let transaction_arc = self.transaction.clone(); + + rustengine_future(py, async move { + let transaction_guard = transaction_arc.read().await; + transaction_guard.inner_begin().await?; + + Ok(()) + }) + } + + pub fn commit<'a>(&'a self, py: Python<'a>) -> RustEnginePyResult<&PyAny> { + let transaction_arc = self.transaction.clone(); + + rustengine_future(py, async move { + let transaction_guard = transaction_arc.read().await; + transaction_guard.inner_commit().await?; + + Ok(()) }) } } -#[pyclass()] pub struct RustEngine { username: Option, password: Option, @@ -79,6 +242,25 @@ pub struct RustEngine { db_pool: Arc>>, } +impl RustEngine { + pub fn new( + username: Option, + password: Option, + host: Option, + port: Option, + db_name: Option, + ) -> Self { + RustEngine { + username, + password, + host, + port, + db_name, + db_pool: Arc::new(tokio::sync::RwLock::new(None)), + } + } +} + impl RustEngine { pub async fn inner_execute<'a>( &'a self, @@ -111,7 +293,7 @@ impl RustEngine { Ok(RustEnginePyQueryResult::new(result)) } - pub async fn inner_transaction<'a>(&'a self) -> RustEnginePyResult { + pub async fn inner_transaction<'a>(&'a self) -> RustEnginePyResult { let db_pool_arc = self.db_pool.clone(); let db_pool_guard = db_pool_arc.read().await; @@ -123,9 +305,14 @@ impl RustEngine { .get() .await?; - Ok(RustEngineTransaction { + let inner_transaction = RustEngineTransaction { db_client: Arc::new(tokio::sync::RwLock::new(db_pool_manager)), is_started: Arc::new(tokio::sync::RwLock::new(false)), + is_done: Arc::new(tokio::sync::RwLock::new(false)), + }; + + Ok(PyRustEngineTransaction { + transaction: Arc::new(tokio::sync::RwLock::new(inner_transaction)), }) } @@ -172,25 +359,6 @@ impl RustEngine { } } -impl RustEngine { - pub fn new( - username: Option, - password: Option, - host: Option, - port: Option, - db_name: Option, - ) -> Self { - RustEngine { - username, - password, - host, - port, - db_name, - db_pool: Arc::new(tokio::sync::RwLock::new(None)), - } - } -} - #[pyclass()] pub struct PyRustEngine { engine: Arc>>, diff --git a/src/exceptions/rust_errors.rs b/src/exceptions/rust_errors.rs index fda9444b..9afef3ee 100644 --- a/src/exceptions/rust_errors.rs +++ b/src/exceptions/rust_errors.rs @@ -12,6 +12,8 @@ pub enum RustEngineError { RustToPyValueConversionError(String), #[error("Can't convert value from python to rust type: {0}")] PyToRustValueConversionError(String), + #[error("Transaction exception: {0}")] + DBTransactionError(String), #[error("Python exception: {0}.")] PyError(#[from] pyo3::PyErr), @@ -38,6 +40,7 @@ impl From for pyo3::PyErr { RustEngineError::DatabasePoolError(_) => RustEnginePyBaseError::new_err((error_desc,)), RustEngineError::DBEnginePoolError(_) => RustEnginePyBaseError::new_err((error_desc,)), RustEngineError::DBEngineBuildError(_) => RustEnginePyBaseError::new_err((error_desc,)), + RustEngineError::DBTransactionError(_) => RustEnginePyBaseError::new_err((error_desc,)), } } } diff --git a/src/lib.rs b/src/lib.rs index f5520937..818d5a58 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,8 +8,8 @@ use pyo3::{pymodule, types::PyModule, PyResult, Python}; #[pymodule] #[pyo3(name = "_internal")] fn psql_rust_engine(_py: Python<'_>, pymod: &PyModule) -> PyResult<()> { - pymod.add_class::()?; pymod.add_class::()?; + pymod.add_class::()?; pymod.add_class::()?; Ok(()) }