diff --git a/python/psqlpy/__init__.py b/python/psqlpy/__init__.py index 98a7c154..ea4c0e2d 100644 --- a/python/psqlpy/__init__.py +++ b/python/psqlpy/__init__.py @@ -9,6 +9,7 @@ SingleQueryResult, Transaction, connect, + create_pool, ) __all__ = [ @@ -21,5 +22,6 @@ "ConnRecyclingMethod", "IsolationLevel", "ReadVariant", + "create_pool", "connect", ] diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index df1b05d5..3e9b86f8 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -940,7 +940,7 @@ class ConnectionPool: def close(self: Self) -> None: """Close the connection pool.""" -def connect( +def create_pool( dsn: Optional[str] = None, username: Optional[str] = None, password: Optional[str] = None, @@ -963,3 +963,23 @@ def connect( - `max_db_pool_size`: maximum size of the connection pool - `conn_recycling_method`: how a connection is recycled. """ + +async def connect( + dsn: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + db_name: Optional[str] = None, +) -> Connection: + """Create new single connection. + + ### Parameters: + - `dsn`: full dsn connection string. + `postgres://postgres:postgres@localhost:5432/postgres?target_session_attrs=read-write` + - `username`: username of the user in the PostgreSQL + - `password`: password of the user in PostgreSQL + - `host`: host of the PostgreSQL + - `port`: port of the PostgreSQL + - `db_name`: name of the database in PostgreSQL + """ diff --git a/python/tests/test_connection.py b/python/tests/test_connection.py index 5dad00ec..62feded5 100644 --- a/python/tests/test_connection.py +++ b/python/tests/test_connection.py @@ -5,7 +5,7 @@ import pytest from tests.helpers import count_rows_in_test_table -from psqlpy import ConnectionPool, QueryResult, Transaction +from psqlpy import ConnectionPool, QueryResult, Transaction, connect from psqlpy.exceptions import RustPSQLDriverPyBaseError, TransactionError pytestmark = pytest.mark.anyio @@ -113,3 +113,12 @@ async def test_connection_fetch_val_more_than_one_row( f"SELECT * FROM {table_name}", [], ) + + +async def test_connect_method() -> None: + connection = await connect( + dsn="postgres://postgres:postgres@localhost:5432/psqlpy_test", + ) + + res = await connection.execute("SELECT 1") + assert res.result() diff --git a/python/tests/test_connection_pool.py b/python/tests/test_connection_pool.py index 467b2899..5632c7c0 100644 --- a/python/tests/test_connection_pool.py +++ b/python/tests/test_connection_pool.py @@ -1,6 +1,12 @@ import pytest -from psqlpy import Connection, ConnectionPool, ConnRecyclingMethod, QueryResult, connect +from psqlpy import ( + Connection, + ConnectionPool, + ConnRecyclingMethod, + QueryResult, + create_pool, +) from psqlpy.exceptions import RustPSQLDriverPyBaseError pytestmark = pytest.mark.anyio @@ -8,7 +14,7 @@ async def test_connect_func() -> None: """Test that connect function makes new connection pool.""" - pg_pool = connect( + pg_pool = create_pool( dsn="postgres://postgres:postgres@localhost:5432/psqlpy_test", ) diff --git a/src/driver/connection.rs b/src/driver/connection.rs index 159e5da6..27a417cd 100644 --- a/src/driver/connection.rs +++ b/src/driver/connection.rs @@ -1,28 +1,301 @@ use deadpool_postgres::Object; -use pyo3::{pyclass, pymethods, Py, PyAny, Python}; +use postgres_types::ToSql; +use pyo3::{pyclass, pyfunction, pymethods, Py, PyAny, Python}; use std::{collections::HashSet, sync::Arc, vec}; +use tokio_postgres::{Client, NoTls}; use crate::{ exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, + runtime::tokio_runtime, value_converter::{convert_parameters, postgres_to_py, PythonDTO, QueryParameter}, }; use super::{ transaction::Transaction, transaction_options::{IsolationLevel, ReadVariant}, + utils::build_connection_config, }; +/// Connect to the PostgreSQL, creating single connection. +/// +/// # Errors +/// May return Err Result if +/// 1) Connect parameters are incorrect +/// 2) Cannot connect to the PostgreSQL +/// 3) Error on the PostgreSQL side. +#[pyfunction] +#[allow(clippy::too_many_arguments)] +pub async fn connect( + dsn: Option, + username: Option, + password: Option, + host: Option, + port: Option, + db_name: Option, +) -> RustPSQLDriverPyResult { + let conn_config = build_connection_config(dsn, username, password, host, port, db_name)?; + + let (client, connection) = tokio_runtime() + .spawn(async move { conn_config.connect(NoTls).await }) + .await??; + + tokio_runtime().spawn(async move { + if let Err(connection_error) = connection.await { + eprintln!("connection error: {connection_error}"); + } + }); + + Ok(Connection::new(ConnectionVar::SingleConn(client))) +} + +#[allow(clippy::module_name_repetitions)] +pub enum ConnectionVar { + Pool(Object), + SingleConn(Client), +} + +impl ConnectionVar { + /// Make prepared statement. + /// + /// # Errors + /// May return Err Result if `query` returns Err. + pub async fn prepare_stmt_cached( + &self, + query: &str, + ) -> Result { + match self { + ConnectionVar::Pool(pool_conn) => pool_conn.prepare(query).await, + ConnectionVar::SingleConn(single_conn) => single_conn.prepare(query).await, + } + } + + /// Execute `query()` method. + /// + /// # Errors + /// May return Err Result if `query` returns Err. + pub async fn query_qs( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result, tokio_postgres::Error> + where + T: ?Sized + tokio_postgres::ToStatement, + { + match self { + ConnectionVar::Pool(pool_conn) => pool_conn.query(statement, params).await, + ConnectionVar::SingleConn(single_conn) => single_conn.query(statement, params).await, + } + } + + /// Execute `query_one()` method. + /// + /// # Errors + /// May return Err Result if `query_one` returns Err. + pub async fn query_qs_one( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result + where + T: ?Sized + tokio_postgres::ToStatement, + { + match self { + ConnectionVar::Pool(pool_conn) => pool_conn.query_one(statement, params).await, + ConnectionVar::SingleConn(single_conn) => { + single_conn.query_one(statement, params).await + } + } + } + + /// Execute `batch_execute()` method. + /// + /// # Errors + /// May return Err Result if `batch_execute` returns Err. + pub async fn batch_execute_qs(&self, query: &str) -> Result<(), tokio_postgres::Error> { + match self { + ConnectionVar::Pool(pool_conn) => pool_conn.batch_execute(query).await, + ConnectionVar::SingleConn(single_conn) => single_conn.batch_execute(query).await, + } + } + + /// Execute querystring with parameters. + /// + /// # Errors + /// May return Err Result if + /// 1) Cannot convert parameters + /// 2) Cannot prepare querystring + /// 3) Cannot execute statement + pub async fn psqlpy_query( + &self, + querystring: String, + parameters: Option>, + prepared: Option, + ) -> RustPSQLDriverPyResult { + let mut params: Vec = vec![]; + if let Some(parameters) = parameters { + params = convert_parameters(parameters)?; + } + let prepared = prepared.unwrap_or(true); + + let result = if prepared { + self.query_qs( + &self.prepare_stmt_cached(&querystring).await?, + ¶ms + .iter() + .map(|param| param as &QueryParameter) + .collect::>() + .into_boxed_slice(), + ) + .await? + } else { + self.query_qs( + &querystring, + ¶ms + .iter() + .map(|param| param as &QueryParameter) + .collect::>() + .into_boxed_slice(), + ) + .await? + }; + + Ok(PSQLDriverPyQueryResult::new(result)) + } + + /// Start the transaction. + /// + /// # Errors + /// May return Err Result if cannot execute statement. + pub async fn start_transaction( + &self, + isolation_level: Option, + read_variant: Option, + deferrable: Option, + ) -> RustPSQLDriverPyResult<()> { + let mut querystring = "START TRANSACTION".to_string(); + + if let Some(level) = isolation_level { + let level = &level.to_str_level(); + querystring.push_str(format!(" ISOLATION LEVEL {level}").as_str()); + }; + + querystring.push_str(match read_variant { + Some(ReadVariant::ReadOnly) => " READ ONLY", + Some(ReadVariant::ReadWrite) => " READ WRITE", + None => "", + }); + + querystring.push_str(match deferrable { + Some(true) => " DEFERRABLE", + Some(false) => " NOT DEFERRABLE", + None => "", + }); + match self { + ConnectionVar::Pool(pool_conn) => pool_conn.batch_execute(&querystring).await?, + ConnectionVar::SingleConn(single_conn) => { + single_conn.batch_execute(&querystring).await?; + } + } + + Ok(()) + } + + /// Commit the transaction. + /// + /// # Errors + /// May return Err Result if cannot execute statement. + pub async fn commit(&self) -> RustPSQLDriverPyResult<()> { + match self { + ConnectionVar::Pool(pool_conn) => pool_conn.batch_execute("COMMIT;").await?, + ConnectionVar::SingleConn(single_conn) => single_conn.batch_execute("COMMIT;").await?, + }; + Ok(()) + } + + /// Rollback the transaction. + /// + /// # Errors + /// May return Err Result if cannot execute statement. + pub async fn rollback(&self) -> RustPSQLDriverPyResult<()> { + match self { + ConnectionVar::Pool(pool_conn) => pool_conn.batch_execute("ROLLBACK;").await?, + ConnectionVar::SingleConn(single_conn) => { + single_conn.batch_execute("ROLLBACK;").await?; + } + }; + Ok(()) + } + + /// Start the cursor. + /// + /// Execute `DECLARE` command with parameters. + /// + /// # Errors + /// May return Err Result if cannot execute querystring. + pub async fn cursor_start( + &self, + cursor_name: &str, + scroll: &Option, + querystring: &str, + prepared: &Option, + parameters: &Option>, + ) -> RustPSQLDriverPyResult<()> { + let mut cursor_init_query = format!("DECLARE {cursor_name}"); + if let Some(scroll) = scroll { + if *scroll { + cursor_init_query.push_str(" SCROLL"); + } else { + cursor_init_query.push_str(" NO SCROLL"); + } + } + + cursor_init_query.push_str(format!(" CURSOR FOR {querystring}").as_str()); + + self.psqlpy_query(cursor_init_query, parameters.clone(), *prepared) + .await?; + + Ok(()) + } + + /// Close the cursor. + /// + /// Execute `CLOSE` command. + /// + /// # Errors + /// May return Err Result if cannot execute querystring. + pub async fn cursor_close( + &self, + closed: &bool, + cursor_name: &str, + ) -> RustPSQLDriverPyResult<()> { + if *closed { + return Err(RustPSQLDriverError::DataBaseCursorError( + "Cursor is already closed".into(), + )); + } + + self.psqlpy_query( + format!("CLOSE {cursor_name}"), + Option::default(), + Some(false), + ) + .await?; + + Ok(()) + } +} + #[pyclass] pub struct Connection { - db_client: Arc, + connection: Arc, } impl Connection { #[must_use] - pub fn new(db_client: Object) -> Self { + pub fn new(connection: ConnectionVar) -> Self { Connection { - db_client: Arc::new(db_client), + connection: Arc::new(connection), } } } @@ -43,7 +316,7 @@ impl Connection { parameters: Option>, prepared: Option, ) -> RustPSQLDriverPyResult { - let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); + let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).connection.clone()); let mut params: Vec = vec![]; if let Some(parameters) = parameters { @@ -53,8 +326,8 @@ impl Connection { let result = if prepared { db_client - .query( - &db_client.prepare_cached(&querystring).await?, + .query_qs( + &db_client.prepare_stmt_cached(&querystring).await?, ¶ms .iter() .map(|param| param as &QueryParameter) @@ -64,7 +337,7 @@ impl Connection { .await? } else { db_client - .query( + .query_qs( &querystring, ¶ms .iter() @@ -94,7 +367,7 @@ impl Connection { parameters: Option>>, prepared: Option, ) -> RustPSQLDriverPyResult<()> { - let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); + let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).connection.clone()); let mut params: Vec> = vec![]; if let Some(parameters) = parameters { for vec_of_py_any in parameters { @@ -103,22 +376,22 @@ impl Connection { } let prepared = prepared.unwrap_or(true); - db_client.batch_execute("BEGIN;").await.map_err(|err| { + db_client.batch_execute_qs("BEGIN;").await.map_err(|err| { RustPSQLDriverError::DataBaseTransactionError(format!( "Cannot start transaction to run execute_many: {err}" )) })?; for param in params { let querystring_result = if prepared { - let prepared_stmt = &db_client.prepare_cached(&querystring).await; + let prepared_stmt = &db_client.prepare_stmt_cached(&querystring).await; if let Err(error) = prepared_stmt { return Err(RustPSQLDriverError::DataBaseTransactionError(format!( "Cannot prepare statement in execute_many, operation rolled back {error}", ))); } db_client - .query( - &db_client.prepare_cached(&querystring).await?, + .query_qs( + &db_client.prepare_stmt_cached(&querystring).await?, ¶m .iter() .map(|param| param as &QueryParameter) @@ -128,7 +401,7 @@ impl Connection { .await } else { db_client - .query( + .query_qs( &querystring, ¶m .iter() @@ -140,14 +413,14 @@ impl Connection { }; if let Err(error) = querystring_result { - db_client.batch_execute("ROLLBACK;").await?; + db_client.batch_execute_qs("ROLLBACK;").await?; return Err(RustPSQLDriverError::DataBaseTransactionError(format!( "Error occured in `execute_many` statement, transaction is rolled back: {error}" ))); } } - db_client.batch_execute("COMMIT;").await?; + db_client.batch_execute_qs("COMMIT;").await?; Ok(()) } @@ -172,7 +445,7 @@ impl Connection { parameters: Option>, prepared: Option, ) -> RustPSQLDriverPyResult { - let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); + let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).connection.clone()); let mut params: Vec = vec![]; if let Some(parameters) = parameters { @@ -182,8 +455,8 @@ impl Connection { let result = if prepared { db_client - .query_one( - &db_client.prepare_cached(&querystring).await?, + .query_qs_one( + &db_client.prepare_stmt_cached(&querystring).await?, ¶ms .iter() .map(|param| param as &QueryParameter) @@ -193,7 +466,7 @@ impl Connection { .await? } else { db_client - .query_one( + .query_qs_one( &querystring, ¶ms .iter() @@ -224,7 +497,7 @@ impl Connection { parameters: Option>, prepared: Option, ) -> RustPSQLDriverPyResult> { - let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); + let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).connection.clone()); let mut params: Vec = vec![]; if let Some(parameters) = parameters { @@ -234,8 +507,8 @@ impl Connection { let result = if prepared { db_client - .query_one( - &db_client.prepare_cached(&querystring).await?, + .query_qs_one( + &db_client.prepare_stmt_cached(&querystring).await?, ¶ms .iter() .map(|param| param as &QueryParameter) @@ -245,7 +518,7 @@ impl Connection { .await? } else { db_client - .query_one( + .query_qs_one( &querystring, ¶ms .iter() @@ -270,7 +543,7 @@ impl Connection { deferrable: Option, ) -> Transaction { Transaction::new( - self.db_client.clone(), + self.connection.clone(), false, false, isolation_level, diff --git a/src/driver/connection_pool.rs b/src/driver/connection_pool.rs index fad7cb9e..7d8f90cc 100644 --- a/src/driver/connection_pool.rs +++ b/src/driver/connection_pool.rs @@ -1,7 +1,7 @@ use crate::runtime::tokio_runtime; use deadpool_postgres::{Manager, ManagerConfig, Object, Pool, RecyclingMethod}; use pyo3::{pyclass, pyfunction, pymethods, PyAny}; -use std::{str::FromStr, vec}; +use std::vec; use tokio_postgres::{NoTls, Row}; use crate::{ @@ -10,7 +10,11 @@ use crate::{ value_converter::{convert_parameters, PythonDTO, QueryParameter}, }; -use super::{common_options::ConnRecyclingMethod, connection::Connection}; +use super::{ + common_options::ConnRecyclingMethod, + connection::{Connection, ConnectionVar}, + utils::build_connection_config, +}; /// Make new connection pool. /// @@ -18,7 +22,7 @@ use super::{common_options::ConnRecyclingMethod, connection::Connection}; /// May return error if cannot build new connection pool. #[pyfunction] #[allow(clippy::too_many_arguments)] -pub fn connect( +pub fn create_pool( dsn: Option, username: Option, password: Option, @@ -36,27 +40,7 @@ pub fn connect( } } - let mut pg_config: tokio_postgres::Config; - if let Some(dsn_string) = dsn { - pg_config = tokio_postgres::Config::from_str(&dsn_string)?; - } else { - pg_config = tokio_postgres::Config::new(); - if let (Some(password), Some(username)) = (password, username) { - pg_config.password(&password); - pg_config.user(&username); - } - if let Some(host) = host { - pg_config.host(&host); - } - - if let Some(port) = port { - pg_config.port(port); - } - - if let Some(db_name) = db_name { - pg_config.dbname(&db_name); - } - } + let conn_config = build_connection_config(dsn, username, password, host, port, db_name)?; let mgr_config: ManagerConfig; if let Some(conn_recycling_method) = conn_recycling_method { @@ -68,7 +52,7 @@ pub fn connect( recycling_method: RecyclingMethod::Fast, }; } - let mgr = Manager::from_config(pg_config, NoTls, mgr_config); + let mgr = Manager::from_config(conn_config, NoTls, mgr_config); let mut db_pool_builder = Pool::builder(mgr); if let Some(max_db_pool_size) = max_db_pool_size { @@ -101,7 +85,7 @@ impl ConnectionPool { max_db_pool_size: Option, conn_recycling_method: Option, ) -> RustPSQLDriverPyResult { - connect( + create_pool( dsn, username, password, @@ -186,7 +170,7 @@ impl ConnectionPool { }) .await??; - Ok(Connection::new(db_connection)) + Ok(Connection::new(ConnectionVar::Pool(db_connection))) } /// Return new single connection. diff --git a/src/driver/cursor.rs b/src/driver/cursor.rs index 29265b40..aeb02ac1 100644 --- a/src/driver/cursor.rs +++ b/src/driver/cursor.rs @@ -12,6 +12,8 @@ use crate::{ runtime::rustdriver_future, }; +use super::connection::ConnectionVar; + /// Additional implementation for the `Object` type. trait CursorObjectTrait { async fn cursor_start( @@ -84,7 +86,7 @@ impl CursorObjectTrait for Object { #[pyclass] pub struct Cursor { - db_transaction: Arc, + db_transaction: Arc, querystring: String, parameters: Option>, cursor_name: String, @@ -98,7 +100,7 @@ pub struct Cursor { impl Cursor { #[must_use] pub fn new( - db_transaction: Arc, + db_transaction: Arc, querystring: String, parameters: Option>, cursor_name: String, diff --git a/src/driver/mod.rs b/src/driver/mod.rs index aec33d5b..1ba0c203 100644 --- a/src/driver/mod.rs +++ b/src/driver/mod.rs @@ -4,3 +4,4 @@ pub mod connection_pool; pub mod cursor; pub mod transaction; pub mod transaction_options; +pub mod utils; diff --git a/src/driver/transaction.rs b/src/driver/transaction.rs index 8ed00cd7..2982e3e4 100644 --- a/src/driver/transaction.rs +++ b/src/driver/transaction.rs @@ -1,4 +1,3 @@ -use deadpool_postgres::Object; use futures_util::future; use pyo3::{ prelude::*, @@ -13,66 +12,15 @@ use crate::{ }; use super::{ + connection::ConnectionVar, cursor::Cursor, transaction_options::{IsolationLevel, ReadVariant}, }; -use crate::common::ObjectQueryTrait; use std::{collections::HashSet, sync::Arc}; -#[allow(clippy::module_name_repetitions)] -pub trait TransactionObjectTrait { - fn start_transaction( - &self, - isolation_level: Option, - read_variant: Option, - defferable: Option, - ) -> impl std::future::Future> + Send; - fn commit(&self) -> impl std::future::Future> + Send; - fn rollback(&self) -> impl std::future::Future> + Send; -} - -impl TransactionObjectTrait for Object { - async fn start_transaction( - &self, - isolation_level: Option, - read_variant: Option, - deferrable: Option, - ) -> RustPSQLDriverPyResult<()> { - let mut querystring = "START TRANSACTION".to_string(); - - if let Some(level) = isolation_level { - let level = &level.to_str_level(); - querystring.push_str(format!(" ISOLATION LEVEL {level}").as_str()); - }; - - querystring.push_str(match read_variant { - Some(ReadVariant::ReadOnly) => " READ ONLY", - Some(ReadVariant::ReadWrite) => " READ WRITE", - None => "", - }); - - querystring.push_str(match deferrable { - Some(true) => " DEFERRABLE", - Some(false) => " NOT DEFERRABLE", - None => "", - }); - self.batch_execute(&querystring).await?; - - Ok(()) - } - async fn commit(&self) -> RustPSQLDriverPyResult<()> { - self.batch_execute("COMMIT;").await?; - Ok(()) - } - async fn rollback(&self) -> RustPSQLDriverPyResult<()> { - self.batch_execute("ROLLBACK;").await?; - Ok(()) - } -} - #[pyclass] pub struct Transaction { - pub db_client: Arc, + pub db_client: Arc, is_started: bool, is_done: bool, @@ -191,7 +139,7 @@ impl Transaction { /// 3) Can not execute ROLLBACK command pub async fn rollback(&mut self) -> RustPSQLDriverPyResult<()> { self.check_is_transaction_ready()?; - self.db_client.batch_execute("ROLLBACK").await?; + self.db_client.batch_execute_qs("ROLLBACK;").await?; self.is_done = true; Ok(()) } @@ -254,8 +202,8 @@ impl Transaction { let result = if prepared.unwrap_or(true) { db_client - .query_one( - &db_client.prepare_cached(&querystring).await?, + .query_qs_one( + &db_client.prepare_stmt_cached(&querystring).await?, ¶ms .iter() .map(|param| param as &QueryParameter) @@ -265,7 +213,7 @@ impl Transaction { .await? } else { db_client - .query_one( + .query_qs_one( &querystring, ¶ms .iter() @@ -308,8 +256,8 @@ impl Transaction { let result = if prepared.unwrap_or(true) { db_client - .query_one( - &db_client.prepare_cached(&querystring).await?, + .query_qs_one( + &db_client.prepare_stmt_cached(&querystring).await?, ¶ms .iter() .map(|param| param as &QueryParameter) @@ -319,7 +267,7 @@ impl Transaction { .await? } else { db_client - .query_one( + .query_qs_one( &querystring, ¶ms .iter() @@ -367,15 +315,15 @@ impl Transaction { for param in params { let is_query_result_ok = if prepared { - let prepared_stmt = &db_client.prepare_cached(&querystring).await; + let prepared_stmt = &db_client.prepare_stmt_cached(&querystring).await; if let Err(error) = prepared_stmt { return Err(RustPSQLDriverError::DataBaseTransactionError(format!( "Cannot prepare statement in execute_many, operation rolled back {error}", ))); } db_client - .query( - &db_client.prepare_cached(&querystring).await?, + .query_qs( + &db_client.prepare_stmt_cached(&querystring).await?, ¶m .iter() .map(|param| param as &QueryParameter) @@ -385,7 +333,7 @@ impl Transaction { .await } else { db_client - .query( + .query_qs( &querystring, ¶m .iter() @@ -478,7 +426,7 @@ impl Transaction { } let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); db_client - .batch_execute(format!("SAVEPOINT {savepoint_name}").as_str()) + .batch_execute_qs(format!("SAVEPOINT {savepoint_name}").as_str()) .await?; pyo3::Python::with_gil(|gil| { @@ -516,7 +464,7 @@ impl Transaction { } let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); db_client - .batch_execute(format!("RELEASE SAVEPOINT {savepoint_name}").as_str()) + .batch_execute_qs(format!("RELEASE SAVEPOINT {savepoint_name}").as_str()) .await?; pyo3::Python::with_gil(|gil| { @@ -554,7 +502,7 @@ impl Transaction { } let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); db_client - .batch_execute(format!("ROLLBACK TO SAVEPOINT {savepoint_name}").as_str()) + .batch_execute_qs(format!("ROLLBACK TO SAVEPOINT {savepoint_name}").as_str()) .await?; pyo3::Python::with_gil(|gil| { @@ -639,7 +587,7 @@ impl Transaction { #[allow(clippy::too_many_arguments)] #[must_use] pub fn new( - db_client: Arc, + db_client: Arc, is_started: bool, is_done: bool, isolation_level: Option, diff --git a/src/driver/utils.rs b/src/driver/utils.rs new file mode 100644 index 00000000..3e69bd1f --- /dev/null +++ b/src/driver/utils.rs @@ -0,0 +1,40 @@ +use std::str::FromStr; + +use crate::exceptions::rust_errors::RustPSQLDriverPyResult; + +/// Build new config for making connection pool or single connection +/// +/// # Errors +/// May return Err Result if cannot build config from dsn. +pub fn build_connection_config( + dsn: Option, + username: Option, + password: Option, + host: Option, + port: Option, + db_name: Option, +) -> RustPSQLDriverPyResult { + let mut pg_config: tokio_postgres::Config; + if let Some(dsn_string) = dsn { + pg_config = tokio_postgres::Config::from_str(&dsn_string)?; + } else { + pg_config = tokio_postgres::Config::new(); + if let (Some(password), Some(username)) = (password, username) { + pg_config.password(&password); + pg_config.user(&username); + } + if let Some(host) = host { + pg_config.host(&host); + } + + if let Some(port) = port { + pg_config.port(port); + } + + if let Some(db_name) = db_name { + pg_config.dbname(&db_name); + } + } + + Ok(pg_config) +} diff --git a/src/exceptions/python_errors.rs b/src/exceptions/python_errors.rs index c8ceafe8..f28867dc 100644 --- a/src/exceptions/python_errors.rs +++ b/src/exceptions/python_errors.rs @@ -50,6 +50,11 @@ create_exception!( ); create_exception!(psqlpy.exceptions, CursorError, RustPSQLDriverPyBaseError); +create_exception!( + psqlpy.exceptions, + ConnectionError, + RustPSQLDriverPyBaseError +); #[allow(clippy::missing_errors_doc)] pub fn python_exceptions_module(py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyResult<()> { @@ -84,5 +89,6 @@ pub fn python_exceptions_module(py: Python<'_>, pymod: &Bound<'_, PyModule>) -> "RustRuntimeJoinError", py.get_type_bound::(), )?; + pymod.add("ConnectionError", py.get_type_bound::())?; Ok(()) } diff --git a/src/exceptions/rust_errors.rs b/src/exceptions/rust_errors.rs index 2cd01d67..90d2c257 100644 --- a/src/exceptions/rust_errors.rs +++ b/src/exceptions/rust_errors.rs @@ -6,7 +6,7 @@ use crate::exceptions::python_errors::{ RustToPyValueMappingError, TransactionError, }; -use super::python_errors::{CursorError, UUIDValueConvertError}; +use super::python_errors::{ConnectionError, CursorError, UUIDValueConvertError}; pub type RustPSQLDriverPyResult = Result; @@ -24,6 +24,8 @@ pub enum RustPSQLDriverError { DataBasePoolConfigurationError(String), #[error("Cursor error: {0}")] DataBaseCursorError(String), + #[error("Connection problem: {0}")] + ConnectionError(String), #[error("Python exception: {0}.")] PyError(#[from] pyo3::PyErr), @@ -68,6 +70,7 @@ impl From for pyo3::PyErr { } RustPSQLDriverError::UUIDConvertError(_) => UUIDValueConvertError::new_err(error_desc), RustPSQLDriverError::DataBaseCursorError(_) => CursorError::new_err(error_desc), + RustPSQLDriverError::ConnectionError(_) => ConnectionError::new_err(error_desc), } } } diff --git a/src/lib.rs b/src/lib.rs index 2392f098..7b077e42 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,8 +16,12 @@ use pyo3::{pymodule, types::PyModule, wrap_pyfunction, Bound, PyResult, Python}; #[pyo3(name = "_internal")] fn psqlpy(py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyResult<()> { pymod.add_class::()?; - pymod.add_function(wrap_pyfunction!(driver::connection_pool::connect, pymod)?)?; + pymod.add_function(wrap_pyfunction!( + driver::connection_pool::create_pool, + pymod + )?)?; pymod.add_class::()?; + pymod.add_function(wrap_pyfunction!(driver::connection::connect, pymod)?)?; pymod.add_class::()?; pymod.add_class::()?; pymod.add_class::()?;