Skip to content

Commit

Permalink
Added async manager protocol for transaction
Browse files Browse the repository at this point in the history
  • Loading branch information
chandr-andr committed Feb 5, 2024
1 parent e2ab967 commit 51e1c14
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 78 deletions.
3 changes: 1 addition & 2 deletions python/psql_rust_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from ._internal import RustEngine, PyRustEngine
from ._internal import PyRustEngine

__all__ = [
"RustEngine",
"PyRustEngine",
]
24 changes: 0 additions & 24 deletions python/psql_rust_engine/_internal/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,6 @@ class RustEnginePyQueryResult:
""""""


class RustEngine:
"""Rust engine."""

def __init__(
self,
username: Optional[str],
password: Optional[str],
host: Optional[str],
port: Optional[int],
db_name: Optional[str],
) -> None:
"""Test ebana."""

async def startup(self) -> None:
...

async def execute(
self,
querystring: str,
parameters: List[Any],
) -> RustEnginePyQueryResult:
...


class PyRustEngine:
"""Aboba"""

Expand Down
270 changes: 219 additions & 51 deletions src/engine.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use deadpool_postgres::{Manager, ManagerConfig, Object, Pool, RecyclingMethod};
use pyo3::{pyclass, pymethods, IntoPy, PyAny, PyObject, Python};
use pyo3::{pyclass, pymethods, IntoPy, Py, PyAny, PyObject, PyRef, PyRefMut, Python};
use std::{future::Future, sync::Arc, vec};
use tokio_postgres::{types::ToSql, NoTls};

Expand All @@ -19,57 +19,220 @@ where
Ok(res)
}

#[pyclass()]
pub struct RustEngineTransaction {
db_client: Arc<tokio::sync::RwLock<Object>>,
is_started: Arc<tokio::sync::RwLock<bool>>,
is_done: Arc<tokio::sync::RwLock<bool>>,
}

#[pymethods]
impl RustEngineTransaction {
pub async fn inner_execute<'a>(
&'a self,
querystring: String,
parameters: Vec<PythonType>,
) -> RustEnginePyResult<RustEnginePyQueryResult> {
let db_client_arc = self.db_client.clone();
let is_started_arc = self.is_started.clone();
let is_done_arc = self.is_done.clone();

let db_client_guard = db_client_arc.read().await;
let is_started_guard = is_started_arc.read().await;
let is_done_guard = is_done_arc.read().await;

if !*is_started_guard {
return Err(RustEngineError::DBTransactionError(
"Transaction is not started, please call begin() on transaction".into(),
));
}
if *is_done_guard {
return Err(RustEngineError::DBTransactionError(
"Transaction is already committed or rolled back".into(),
));
}

let mut vec_parameters: Vec<&(dyn ToSql + Sync)> = Vec::with_capacity(parameters.len());
for param in parameters.iter() {
vec_parameters.push(param);
}

let statement: tokio_postgres::Statement =
db_client_guard.prepare_cached(&querystring).await.unwrap();

let result = db_client_guard
.query(&statement, &vec_parameters.into_boxed_slice())
.await?;

Ok(RustEnginePyQueryResult::new(result))
}

pub async fn inner_begin<'a>(&'a self) -> RustEnginePyResult<()> {
let db_client_arc = self.db_client.clone();
let is_started_arc = self.is_started.clone();
let is_done_arc = self.is_done.clone();

let started = {
let is_started_guard = is_started_arc.read().await;
is_started_guard.clone()
};
if started {
return Err(RustEngineError::DBTransactionError(
"Transaction is already started".into(),
));
}

let done = {
let is_done_guard = is_done_arc.read().await;
is_done_guard.clone()
};
if done {
return Err(RustEngineError::DBTransactionError(
"Transaction is already committed or rolled back".into(),
));
}

let db_client_guard = db_client_arc.read().await;
db_client_guard.batch_execute("BEGIN").await?;
let mut is_started_write_guard = is_started_arc.write().await;
*is_started_write_guard = true;

Ok(())
}

pub async fn inner_commit<'a>(&'a self) -> RustEnginePyResult<()> {
let db_client_arc = self.db_client.clone();
let is_started_arc = self.is_started.clone();
let is_done_arc = self.is_done.clone();

let started = {
let is_started_guard = is_started_arc.read().await;
is_started_guard.clone()
};
if !started {
return Err(RustEngineError::DBTransactionError(
"Can not commit not started transaction".into(),
));
}

let done = {
let is_done_guard = is_done_arc.read().await;
is_done_guard.clone()
};
if done {
return Err(RustEngineError::DBTransactionError(
"Transaction is already committed or rolled back".into(),
));
}

let db_client_guard = db_client_arc.read().await;
db_client_guard.batch_execute("COMMIT").await?;
let mut is_done_write_guard = is_done_arc.write().await;
*is_done_write_guard = true;

Ok(())
}
}

#[pyclass()]
pub struct PyRustEngineTransaction {
transaction: Arc<tokio::sync::RwLock<RustEngineTransaction>>,
}

#[pymethods]
impl PyRustEngineTransaction {
#[must_use]
pub fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}

pub fn __anext__(&self, py: Python<'_>) -> RustEnginePyResult<Option<PyObject>> {
let transaction_clone = self.transaction.clone();
let future = rustengine_future(py, async move {
Ok(PyRustEngineTransaction {
transaction: transaction_clone,
})
});
Ok(Some(future?.into()))
}

pub fn __await__<'a>(
slf: PyRefMut<'a, Self>,
_py: Python,
) -> RustEnginePyResult<PyRefMut<'a, Self>> {
println!("__await__");
Ok(slf)
}

fn __aenter__<'a>(slf: PyRefMut<'a, Self>, py: Python<'a>) -> RustEnginePyResult<&'a PyAny> {
let transaction_arc = slf.transaction.clone();
let transaction_arc2 = slf.transaction.clone();
rustengine_future(py, async move {
let transaction_guard = transaction_arc.read().await;
transaction_guard.inner_begin().await?;
Ok(PyRustEngineTransaction {
transaction: transaction_arc2,
})
})
}

fn __aexit__<'a>(
slf: PyRefMut<'a, Self>,
py: Python<'a>,
_exception_type: Py<PyAny>,
_exception: Py<PyAny>,
_traceback: Py<PyAny>,
) -> RustEnginePyResult<&'a PyAny> {
let transaction_arc = slf.transaction.clone();
let transaction_arc2 = slf.transaction.clone();
rustengine_future(py, async move {
let transaction_guard = transaction_arc.read().await;
transaction_guard.inner_commit().await?;
Ok(PyRustEngineTransaction {
transaction: transaction_arc2,
})
})
}

pub fn execute<'a>(
&'a self,
py: Python<'a>,
querystring: String,
parameters: Option<&'a PyAny>,
) -> RustEnginePyResult<&PyAny> {
let db_client_arc = self.db_client.clone();
let is_started_arc = self.is_started.clone();

let transaction_arc = self.transaction.clone();
let mut params: Vec<PythonType> = vec![];
if let Some(parameters) = parameters {
params = convert_parameters(parameters)?
}

rustengine_future(py, async move {
let db_client_guard = db_client_arc.read().await;
let started = {
let is_started_guard = is_started_arc.read().await;
is_started_guard.clone()
};

if !started {
let mut is_started_write_guard = is_started_arc.write().await;
println!("Called BEGIN!");
db_client_guard.batch_execute("BEGIN").await?;
*is_started_write_guard = true;
};

let mut vec_parameters: Vec<&(dyn ToSql + Sync)> = Vec::with_capacity(params.len());
for param in params.iter() {
vec_parameters.push(param);
}

let result = db_client_guard
.query(&querystring, &vec_parameters.into_boxed_slice())
.await?;

Ok(RustEnginePyQueryResult::new(result))
let transaction_guard = transaction_arc.read().await;
Ok(transaction_guard.inner_execute(querystring, params).await?)
})
}

pub fn begin<'a>(&'a self, py: Python<'a>) -> RustEnginePyResult<&PyAny> {
let transaction_arc = self.transaction.clone();

rustengine_future(py, async move {
let transaction_guard = transaction_arc.read().await;
transaction_guard.inner_begin().await?;

Ok(())
})
}

pub fn commit<'a>(&'a self, py: Python<'a>) -> RustEnginePyResult<&PyAny> {
let transaction_arc = self.transaction.clone();

rustengine_future(py, async move {
let transaction_guard = transaction_arc.read().await;
transaction_guard.inner_commit().await?;

Ok(())
})
}
}

#[pyclass()]
pub struct RustEngine {
username: Option<String>,
password: Option<String>,
Expand All @@ -79,6 +242,25 @@ pub struct RustEngine {
db_pool: Arc<tokio::sync::RwLock<Option<Pool>>>,
}

impl RustEngine {
pub fn new(
username: Option<String>,
password: Option<String>,
host: Option<String>,
port: Option<u16>,
db_name: Option<String>,
) -> Self {
RustEngine {
username,
password,
host,
port,
db_name,
db_pool: Arc::new(tokio::sync::RwLock::new(None)),
}
}
}

impl RustEngine {
pub async fn inner_execute<'a>(
&'a self,
Expand Down Expand Up @@ -111,7 +293,7 @@ impl RustEngine {
Ok(RustEnginePyQueryResult::new(result))
}

pub async fn inner_transaction<'a>(&'a self) -> RustEnginePyResult<RustEngineTransaction> {
pub async fn inner_transaction<'a>(&'a self) -> RustEnginePyResult<PyRustEngineTransaction> {
let db_pool_arc = self.db_pool.clone();
let db_pool_guard = db_pool_arc.read().await;

Expand All @@ -123,9 +305,14 @@ impl RustEngine {
.get()
.await?;

Ok(RustEngineTransaction {
let inner_transaction = RustEngineTransaction {
db_client: Arc::new(tokio::sync::RwLock::new(db_pool_manager)),
is_started: Arc::new(tokio::sync::RwLock::new(false)),
is_done: Arc::new(tokio::sync::RwLock::new(false)),
};

Ok(PyRustEngineTransaction {
transaction: Arc::new(tokio::sync::RwLock::new(inner_transaction)),
})
}

Expand Down Expand Up @@ -172,25 +359,6 @@ impl RustEngine {
}
}

impl RustEngine {
pub fn new(
username: Option<String>,
password: Option<String>,
host: Option<String>,
port: Option<u16>,
db_name: Option<String>,
) -> Self {
RustEngine {
username,
password,
host,
port,
db_name,
db_pool: Arc::new(tokio::sync::RwLock::new(None)),
}
}
}

#[pyclass()]
pub struct PyRustEngine {
engine: Arc<tokio::sync::RwLock<Option<RustEngine>>>,
Expand Down
3 changes: 3 additions & 0 deletions src/exceptions/rust_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ pub enum RustEngineError {
RustToPyValueConversionError(String),
#[error("Can't convert value from python to rust type: {0}")]
PyToRustValueConversionError(String),
#[error("Transaction exception: {0}")]
DBTransactionError(String),

#[error("Python exception: {0}.")]
PyError(#[from] pyo3::PyErr),
Expand All @@ -38,6 +40,7 @@ impl From<RustEngineError> for pyo3::PyErr {
RustEngineError::DatabasePoolError(_) => RustEnginePyBaseError::new_err((error_desc,)),
RustEngineError::DBEnginePoolError(_) => RustEnginePyBaseError::new_err((error_desc,)),
RustEngineError::DBEngineBuildError(_) => RustEnginePyBaseError::new_err((error_desc,)),
RustEngineError::DBTransactionError(_) => RustEnginePyBaseError::new_err((error_desc,)),
}
}
}
Loading

0 comments on commit 51e1c14

Please sign in to comment.