diff --git a/docs/components/connection.md b/docs/components/connection.md index db66d32..959978b 100644 --- a/docs/components/connection.md +++ b/docs/components/connection.md @@ -61,6 +61,28 @@ async def main() -> None: dict_results: list[dict[str, Any]] = results.result() ``` +### Execute Batch + +#### Parameters: + +- `querystring`: querystrings separated by semicolons. + +Executes a sequence of SQL statements using the simple query protocol. + +Statements should be separated by semicolons. +If an error occurs, execution of the sequence will stop at that point. +This is intended for use when, for example, +initializing a database schema. + +```python +async def main() -> None: + ... + connection = await db_pool.connection() + await connection.execute_batch( + "CREATE TABLE psqlpy (name VARCHAR); CREATE TABLE psqlpy2 (name VARCHAR);", + ) +``` + ### Fetch #### Parameters: diff --git a/docs/components/transaction.md b/docs/components/transaction.md index 6b15edb..1bed17e 100644 --- a/docs/components/transaction.md +++ b/docs/components/transaction.md @@ -144,6 +144,29 @@ async def main() -> None: dict_results: list[dict[str, Any]] = results.result() ``` +### Execute Batch + +#### Parameters: + +- `querystring`: querystrings separated by semicolons. + +Executes a sequence of SQL statements using the simple query protocol. + +Statements should be separated by semicolons. +If an error occurs, execution of the sequence will stop at that point. +This is intended for use when, for example, +initializing a database schema. + +```python +async def main() -> None: + ... + connection = await db_pool.connection() + async with connection.transaction() as transaction: + await transaction.execute_batch( + "CREATE TABLE psqlpy (name VARCHAR); CREATE TABLE psqlpy2 (name VARCHAR);", + ) +``` + ### Fetch #### Parameters: diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index 854469a..353d936 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -455,6 +455,21 @@ class Transaction: await transaction.commit() ``` """ + async def execute_batch( + self: Self, + querystring: str, + ) -> None: + """ + Executes a sequence of SQL statements using the simple query protocol. + + Statements should be separated by semicolons. + If an error occurs, execution of the sequence will stop at that point. + This is intended for use when, for example, + initializing a database schema. + + ### Parameters: + - `querystring`: querystrings separated by semicolons. + """ async def execute_many( self: Self, querystring: str, @@ -885,6 +900,21 @@ class Connection: dict_result: List[Dict[Any, Any]] = query_result.result() ``` """ + async def execute_batch( + self: Self, + querystring: str, + ) -> None: + """ + Executes a sequence of SQL statements using the simple query protocol. + + Statements should be separated by semicolons. + If an error occurs, execution of the sequence will stop at that point. + This is intended for use when, for example, + initializing a database schema. + + ### Parameters: + - `querystring`: querystrings separated by semicolons. + """ async def execute_many( self: Self, querystring: str, diff --git a/python/tests/test_connection.py b/python/tests/test_connection.py index 8cf7274..9efcb41 100644 --- a/python/tests/test_connection.py +++ b/python/tests/test_connection.py @@ -236,3 +236,14 @@ async def test_binary_copy_to_table( f"SELECT COUNT(*) AS rows_count FROM {table_name}", ) assert real_table_rows.result()[0]["rows_count"] == expected_inserted_row + + +async def test_execute_batch_method(psql_pool: ConnectionPool) -> None: + """Test `execute_batch` method.""" + await psql_pool.execute(querystring="DROP TABLE IF EXISTS execute_batch") + await psql_pool.execute(querystring="DROP TABLE IF EXISTS execute_batch2") + query = "CREATE TABLE execute_batch (name VARCHAR);CREATE TABLE execute_batch2 (name VARCHAR);" + async with psql_pool.acquire() as conn: + await conn.execute_batch(querystring=query) + await conn.execute(querystring="SELECT * FROM execute_batch") + await conn.execute(querystring="SELECT * FROM execute_batch2") diff --git a/python/tests/test_transaction.py b/python/tests/test_transaction.py index c7bcbc3..a9df2d2 100644 --- a/python/tests/test_transaction.py +++ b/python/tests/test_transaction.py @@ -390,3 +390,14 @@ async def test_binary_copy_to_table( f"SELECT COUNT(*) AS rows_count FROM {table_name}", ) assert real_table_rows.result()[0]["rows_count"] == expected_inserted_row + + +async def test_execute_batch_method(psql_pool: ConnectionPool) -> None: + """Test `execute_batch` method.""" + await psql_pool.execute(querystring="DROP TABLE IF EXISTS execute_batch") + await psql_pool.execute(querystring="DROP TABLE IF EXISTS execute_batch2") + query = "CREATE TABLE execute_batch (name VARCHAR);CREATE TABLE execute_batch2 (name VARCHAR);" + async with psql_pool.acquire() as conn, conn.transaction() as transaction: + await transaction.execute_batch(querystring=query) + await transaction.execute(querystring="SELECT * FROM execute_batch") + await transaction.execute(querystring="SELECT * FROM execute_batch2") diff --git a/src/driver/connection.rs b/src/driver/connection.rs index eef3ac2..d145a59 100644 --- a/src/driver/connection.rs +++ b/src/driver/connection.rs @@ -260,6 +260,31 @@ impl Connection { Err(RustPSQLDriverError::ConnectionClosedError) } + /// Executes a sequence of SQL statements using the simple query protocol. + /// + /// Statements should be separated by semicolons. + /// If an error occurs, execution of the sequence will stop at that point. + /// This is intended for use when, for example, + /// initializing a database schema. + /// + /// # Errors + /// + /// May return Err Result if: + /// 1) Connection is closed. + /// 2) Cannot execute querystring. + pub async fn execute_batch( + self_: pyo3::Py, + querystring: String, + ) -> RustPSQLDriverPyResult<()> { + let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); + + if let Some(db_client) = db_client { + return Ok(db_client.batch_execute(&querystring).await?); + } + + Err(RustPSQLDriverError::ConnectionClosedError) + } + /// Execute querystring with parameters. /// /// It converts incoming parameters to rust readable diff --git a/src/driver/transaction.rs b/src/driver/transaction.rs index a18880b..535edc1 100644 --- a/src/driver/transaction.rs +++ b/src/driver/transaction.rs @@ -301,6 +301,31 @@ impl Transaction { Err(RustPSQLDriverError::TransactionClosedError) } + /// Executes a sequence of SQL statements using the simple query protocol. + /// + /// Statements should be separated by semicolons. + /// If an error occurs, execution of the sequence will stop at that point. + /// This is intended for use when, for example, + /// initializing a database schema. + /// + /// # Errors + /// + /// May return Err Result if: + /// 1) Transaction is closed. + /// 2) Cannot execute querystring. + pub async fn execute_batch(self_: Py, querystring: String) -> RustPSQLDriverPyResult<()> { + let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| { + let self_ = self_.borrow(gil); + (self_.check_is_transaction_ready(), self_.db_client.clone()) + }); + is_transaction_ready?; + if let Some(db_client) = db_client { + return Ok(db_client.batch_execute(&querystring).await?); + } + + Err(RustPSQLDriverError::TransactionClosedError) + } + /// Fetch result from the database. /// /// It converts incoming parameters to rust readable