From c998823cd89ef73193e3c74ea8f527259af4e784 Mon Sep 17 00:00:00 2001 From: Vladislav Yashkov Date: Sun, 1 Dec 2024 23:42:40 +0300 Subject: [PATCH 1/5] WIP# 1 --- python/tests/conftest.py | 3 +- python/tests/test_binary_copy.py | 3 +- python/tests/test_connection.py | 3 +- python/tests/test_connection_pool.py | 1 - python/tests/test_connection_pool_builder.py | 1 - python/tests/test_cursor.py | 1 - python/tests/test_row_factories.py | 1 - python/tests/test_ssl_mode.py | 1 - python/tests/test_transaction.py | 3 +- python/tests/test_value_converter.py | 24 ++++- src/driver/connection.rs | 9 -- src/value_converter.rs | 94 ++++++++++++++++++-- 12 files changed, 113 insertions(+), 31 deletions(-) diff --git a/python/tests/conftest.py b/python/tests/conftest.py index a798958e..c62737bf 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -3,10 +3,9 @@ from typing import AsyncGenerator import pytest -from pydantic import BaseModel - from psqlpy import ConnectionPool, Cursor from psqlpy._internal import SslMode +from pydantic import BaseModel class DefaultPydanticModel(BaseModel): diff --git a/python/tests/test_binary_copy.py b/python/tests/test_binary_copy.py index e1c3d473..93cc1335 100644 --- a/python/tests/test_binary_copy.py +++ b/python/tests/test_binary_copy.py @@ -4,9 +4,8 @@ import pytest from pgpq import ArrowToPostgresBinaryEncoder -from pyarrow import parquet - from psqlpy import ConnectionPool +from pyarrow import parquet pytestmark = pytest.mark.anyio diff --git a/python/tests/test_connection.py b/python/tests/test_connection.py index f9f72d9a..77fa0cd7 100644 --- a/python/tests/test_connection.py +++ b/python/tests/test_connection.py @@ -3,14 +3,13 @@ import typing import pytest -from tests.helpers import count_rows_in_test_table - from psqlpy import ConnectionPool, Cursor, QueryResult, Transaction from psqlpy.exceptions import ( ConnectionClosedError, ConnectionExecuteError, TransactionExecuteError, ) +from tests.helpers import count_rows_in_test_table pytestmark = pytest.mark.anyio diff --git a/python/tests/test_connection_pool.py b/python/tests/test_connection_pool.py index 6bf9a936..cdf2fa48 100644 --- a/python/tests/test_connection_pool.py +++ b/python/tests/test_connection_pool.py @@ -1,5 +1,4 @@ import pytest - from psqlpy import ( Connection, ConnectionPool, diff --git a/python/tests/test_connection_pool_builder.py b/python/tests/test_connection_pool_builder.py index b14ed3ad..f937bec3 100644 --- a/python/tests/test_connection_pool_builder.py +++ b/python/tests/test_connection_pool_builder.py @@ -1,5 +1,4 @@ import pytest - from psqlpy import ( ConnectionPoolBuilder, ConnRecyclingMethod, diff --git a/python/tests/test_cursor.py b/python/tests/test_cursor.py index 08644c59..ce1a9393 100644 --- a/python/tests/test_cursor.py +++ b/python/tests/test_cursor.py @@ -3,7 +3,6 @@ import math import pytest - from psqlpy import ConnectionPool, Cursor, QueryResult, Transaction pytestmark = pytest.mark.anyio diff --git a/python/tests/test_row_factories.py b/python/tests/test_row_factories.py index cd0220e1..75d03e5a 100644 --- a/python/tests/test_row_factories.py +++ b/python/tests/test_row_factories.py @@ -2,7 +2,6 @@ from typing import Any, Callable, Dict, Type import pytest - from psqlpy import ConnectionPool from psqlpy.row_factories import class_row, tuple_row diff --git a/python/tests/test_ssl_mode.py b/python/tests/test_ssl_mode.py index 99f6b9b6..72efde5b 100644 --- a/python/tests/test_ssl_mode.py +++ b/python/tests/test_ssl_mode.py @@ -1,5 +1,4 @@ import pytest - from psqlpy import ConnectionPool, SslMode from psqlpy._internal import ConnectionPoolBuilder diff --git a/python/tests/test_transaction.py b/python/tests/test_transaction.py index 6e34a3d6..2a07227b 100644 --- a/python/tests/test_transaction.py +++ b/python/tests/test_transaction.py @@ -3,8 +3,6 @@ import typing import pytest -from tests.helpers import count_rows_in_test_table - from psqlpy import ( ConnectionPool, Cursor, @@ -18,6 +16,7 @@ TransactionExecuteError, TransactionSavepointError, ) +from tests.helpers import count_rows_in_test_table pytestmark = pytest.mark.anyio diff --git a/python/tests/test_value_converter.py b/python/tests/test_value_converter.py index 585d30b6..3dfff770 100644 --- a/python/tests/test_value_converter.py +++ b/python/tests/test_value_converter.py @@ -1,15 +1,12 @@ import datetime import uuid +import zoneinfo from decimal import Decimal from enum import Enum from ipaddress import IPv4Address from typing import Any, Dict, List, Tuple, Union import pytest -from pydantic import BaseModel -from tests.conftest import DefaultPydanticModel, DefaultPythonModelClass -from typing_extensions import Annotated - from psqlpy import ConnectionPool from psqlpy.exceptions import PyToRustValueMappingError from psqlpy.extra_types import ( @@ -56,6 +53,9 @@ UUIDArray, VarCharArray, ) +from pydantic import BaseModel +from tests.conftest import DefaultPydanticModel, DefaultPythonModelClass +from typing_extensions import Annotated pytestmark = pytest.mark.anyio now_datetime = datetime.datetime.now() @@ -69,6 +69,16 @@ 142574, tzinfo=datetime.timezone.utc, ) +now_datetime_with_tz_in_asia_jakarta = datetime.datetime( + 2024, + 4, + 13, + 17, + 3, + 46, + 142574, + tzinfo=zoneinfo.ZoneInfo(key="Asia/Jakarta"), +) uuid_ = uuid.uuid4() @@ -125,6 +135,7 @@ async def test_as_class( ("TIME", now_datetime.time(), now_datetime.time()), ("TIMESTAMP", now_datetime, now_datetime), ("TIMESTAMPTZ", now_datetime_with_tz, now_datetime_with_tz), + ("TIMESTAMPTZ", now_datetime_with_tz_in_asia_jakarta, now_datetime_with_tz_in_asia_jakarta), ("UUID", uuid_, str(uuid_)), ("INET", IPv4Address("192.0.0.1"), IPv4Address("192.0.0.1")), ( @@ -287,6 +298,11 @@ async def test_as_class( [now_datetime_with_tz, now_datetime_with_tz], [now_datetime_with_tz, now_datetime_with_tz], ), + ( + "TIMESTAMPTZ ARRAY", + [now_datetime_with_tz, now_datetime_with_tz_in_asia_jakarta], + [now_datetime_with_tz, now_datetime_with_tz_in_asia_jakarta], + ), ( "TIMESTAMPTZ ARRAY", [[now_datetime_with_tz], [now_datetime_with_tz]], diff --git a/src/driver/connection.rs b/src/driver/connection.rs index 23e86a44..97dc66a1 100644 --- a/src/driver/connection.rs +++ b/src/driver/connection.rs @@ -127,15 +127,6 @@ impl Connection { #[pymethods] impl Connection { - #[must_use] - pub fn __aiter__(self_: Py) -> Py { - self_ - } - - fn __await__(self_: Py) -> Py { - self_ - } - async fn __aenter__<'a>(self_: Py) -> RustPSQLDriverPyResult> { let (db_client, db_pool) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); diff --git a/src/value_converter.rs b/src/value_converter.rs index 3efeaf64..6569f9ec 100644 --- a/src/value_converter.rs +++ b/src/value_converter.rs @@ -1,4 +1,5 @@ -use chrono::{self, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime}; +use chrono::{self, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, TimeZone}; +use chrono_tz::Tz; use geo_types::{coord, Coord, Line as LineSegment, LineString, Point, Rect}; use itertools::Itertools; use macaddr::{MacAddr6, MacAddr8}; @@ -626,8 +627,7 @@ impl ToSql for PythonDTO { #[allow(clippy::needless_pass_by_value)] pub fn convert_parameters(parameters: Py) -> RustPSQLDriverPyResult> { let mut result_vec: Vec = vec![]; - - result_vec = Python::with_gil(|gil| { + Python::with_gil(|gil| { let params = parameters.extract::>>(gil).map_err(|_| { RustPSQLDriverError::PyToRustValueConversionError( "Cannot convert you parameters argument into Rust type, please use List/Tuple" @@ -637,8 +637,9 @@ pub fn convert_parameters(parameters: Py) -> RustPSQLDriverPyResult, RustPSQLDriverError>(result_vec) + Ok::<(), RustPSQLDriverError>(()) })?; + Ok(result_vec) } @@ -744,6 +745,84 @@ pub fn py_sequence_into_postgres_array( } } +/// Extract a value from a Python object, raising an error if missing or invalid +/// +/// # Type Parameters +/// - `T`: The type to which the attribute's value will be converted. This type must implement the `FromPyObject` trait +/// +/// # Errors +/// This function will return `Err` in the following cases: +/// - The Python object does not have the specified attribute +/// - The attribute exists but cannot be extracted into the specified Rust type +fn extract_value_from_python_object_or_raise<'py, T>( + parameter: &'py pyo3::Bound<'_, PyAny>, + attr_name: &str, +) -> Result +where + T: FromPyObject<'py>, +{ + parameter + .getattr(attr_name) + .ok() + .and_then(|attr| attr.extract::().ok()) + .ok_or_else(|| { + RustPSQLDriverError::PyToRustValueConversionError("Invalid attribute".into()) + }) +} + +/// Extract a timezone-aware datetime from a Python object. +/// This function retrieves various datetime components (`year`, `month`, `day`, etc.) +/// from a Python object and constructs a `DateTime` +/// +/// # Errors +/// This function will return `Err` in the following cases: +/// - The Python object does not contain or support one or more required datetime attributes +/// - The retrieved values are invalid for constructing a date, time, or datetime (e.g., invalid month or day) +/// - The timezone information (`tzinfo`) is not available or cannot be parsed +/// - The resulting datetime is ambiguous or invalid (e.g., due to DST transitions) +fn extract_datetime_from_python_object_attrs( + parameter: &pyo3::Bound<'_, PyAny>, +) -> Result, RustPSQLDriverError> { + let year = extract_value_from_python_object_or_raise::(parameter, "year")?; + let month = extract_value_from_python_object_or_raise::(parameter, "month")?; + let day = extract_value_from_python_object_or_raise::(parameter, "day")?; + let hour = extract_value_from_python_object_or_raise::(parameter, "hour")?; + let minute = extract_value_from_python_object_or_raise::(parameter, "minute")?; + let second = extract_value_from_python_object_or_raise::(parameter, "second")?; + let microsecond = extract_value_from_python_object_or_raise::(parameter, "microsecond")?; + + let date = NaiveDate::from_ymd_opt(year, month, day) + .ok_or_else(|| RustPSQLDriverError::PyToRustValueConversionError("Invalid date".into()))?; + let time = NaiveTime::from_hms_micro_opt(hour, minute, second, microsecond) + .ok_or_else(|| RustPSQLDriverError::PyToRustValueConversionError("Invalid time".into()))?; + let naive_datetime = NaiveDateTime::new(date, time); + + let raw_timestamp_tz = parameter + .getattr("tzinfo") + .ok() + .and_then(|tzinfo| tzinfo.getattr("key").ok()) + .and_then(|key| key.extract::().ok()) + .ok_or_else(|| { + RustPSQLDriverError::PyToRustValueConversionError("Invalid timezone info".into()) + })?; + + let fixed_offset_datetime = raw_timestamp_tz + .parse::() + .map_err(|_| { + RustPSQLDriverError::PyToRustValueConversionError("Failed to parse TZ".into()) + })? + .from_local_datetime(&naive_datetime) + .single() + .ok_or_else(|| { + RustPSQLDriverError::PyToRustValueConversionError( + "Ambiguous or invalid datetime".into(), + ) + })? + .fixed_offset(); + + Ok(fixed_offset_datetime) +} + /// Convert single python parameter to `PythonDTO` enum. /// /// # Errors @@ -849,6 +928,11 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult< return Ok(PythonDTO::PyDateTime(pydatetime_no_tz)); } + let timestamp_tz = extract_datetime_from_python_object_attrs(parameter); + if let Ok(pydatetime_tz) = timestamp_tz { + return Ok(PythonDTO::PyDateTimeTz(pydatetime_tz)); + } + return Err(RustPSQLDriverError::PyToRustValueConversionError( "Can not convert you datetime to rust type".into(), )); @@ -1663,7 +1747,7 @@ pub fn other_postgres_bytes_to_py( } Err(RustPSQLDriverError::RustToPyValueConversionError( - format!("Cannot convert {type_} into Python type, please look at the custom_decoders functionality.") + format!("Cannot convert {type_} into Python type, please look at the custom_decoders functionality.") )) } From d2bd3864ff84171fb1c892fa8d616a7449c136d7 Mon Sep 17 00:00:00 2001 From: Vladislav Yashkov Date: Mon, 2 Dec 2024 16:40:00 +0300 Subject: [PATCH 2/5] WIP# 2 --- python/tests/test_value_converter.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/python/tests/test_value_converter.py b/python/tests/test_value_converter.py index 3dfff770..91d5a557 100644 --- a/python/tests/test_value_converter.py +++ b/python/tests/test_value_converter.py @@ -1,10 +1,10 @@ import datetime import uuid -import zoneinfo from decimal import Decimal from enum import Enum from ipaddress import IPv4Address from typing import Any, Dict, List, Tuple, Union +import sys import pytest from psqlpy import ConnectionPool @@ -57,6 +57,7 @@ from tests.conftest import DefaultPydanticModel, DefaultPythonModelClass from typing_extensions import Annotated +uuid_ = uuid.uuid4() pytestmark = pytest.mark.anyio now_datetime = datetime.datetime.now() now_datetime_with_tz = datetime.datetime( @@ -69,6 +70,7 @@ 142574, tzinfo=datetime.timezone.utc, ) + now_datetime_with_tz_in_asia_jakarta = datetime.datetime( 2024, 4, @@ -77,9 +79,21 @@ 3, 46, 142574, - tzinfo=zoneinfo.ZoneInfo(key="Asia/Jakarta"), + tzinfo=datetime.timezone.utc, ) -uuid_ = uuid.uuid4() +if sys.version_info >= (3, 9): + import zoneinfo + + now_datetime_with_tz_in_asia_jakarta = datetime.datetime( + 2024, + 4, + 13, + 17, + 3, + 46, + 142574, + tzinfo=zoneinfo.ZoneInfo(key="Asia/Jakarta"), + ) async def test_as_class( From 6ae6688ca714160991600967b1939fb9e8dd16db Mon Sep 17 00:00:00 2001 From: Vladislav Yashkov Date: Mon, 2 Dec 2024 16:45:52 +0300 Subject: [PATCH 3/5] WIP# 3 --- python/tests/test_value_converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/test_value_converter.py b/python/tests/test_value_converter.py index 91d5a557..1e0e4673 100644 --- a/python/tests/test_value_converter.py +++ b/python/tests/test_value_converter.py @@ -1,10 +1,10 @@ import datetime +import sys import uuid from decimal import Decimal from enum import Enum from ipaddress import IPv4Address from typing import Any, Dict, List, Tuple, Union -import sys import pytest from psqlpy import ConnectionPool From 93069eb1acfd7cb49a9ea820d6f7b669e2a1b1df Mon Sep 17 00:00:00 2001 From: Vladislav Yashkov Date: Tue, 3 Dec 2024 11:42:52 +0300 Subject: [PATCH 4/5] WIP# 4 --- python/tests/test_connection.py | 1 - python/tests/test_cursor.py | 3 +-- python/tests/test_transaction.py | 1 - src/driver/cursor.rs | 4 ++-- src/query_result.rs | 2 +- 5 files changed, 4 insertions(+), 7 deletions(-) diff --git a/python/tests/test_connection.py b/python/tests/test_connection.py index 4469bdb6..3c15991a 100644 --- a/python/tests/test_connection.py +++ b/python/tests/test_connection.py @@ -9,7 +9,6 @@ ConnectionExecuteError, TransactionExecuteError, ) -from tests.helpers import count_rows_in_test_table from tests.helpers import count_rows_in_test_table diff --git a/python/tests/test_cursor.py b/python/tests/test_cursor.py index fb9fe433..b9546f22 100644 --- a/python/tests/test_cursor.py +++ b/python/tests/test_cursor.py @@ -168,8 +168,7 @@ async def test_cursor_as_async_manager( querystring=f"SELECT * FROM {table_name}", fetch_number=fetch_number, ) as cursor: - async for result in cursor: - all_results.append(result) + all_results.extend([result async for result in cursor]) assert len(all_results) == expected_num_results diff --git a/python/tests/test_transaction.py b/python/tests/test_transaction.py index 52fe3636..7704393b 100644 --- a/python/tests/test_transaction.py +++ b/python/tests/test_transaction.py @@ -16,7 +16,6 @@ TransactionExecuteError, TransactionSavepointError, ) -from tests.helpers import count_rows_in_test_table from tests.helpers import count_rows_in_test_table diff --git a/src/driver/cursor.rs b/src/driver/cursor.rs index 3f8008be..b04ed8e6 100644 --- a/src/driver/cursor.rs +++ b/src/driver/cursor.rs @@ -13,7 +13,7 @@ use crate::{ }; /// Additional implementation for the `Object` type. -#[allow(clippy::ref_option)] +#[allow(clippy::ref_option_ref)] trait CursorObjectTrait { async fn cursor_start( &self, @@ -34,7 +34,7 @@ impl CursorObjectTrait for Object { /// /// # Errors /// May return Err Result if cannot execute querystring. - #[allow(clippy::ref_option)] + #[allow(clippy::ref_option_ref)] async fn cursor_start( &self, cursor_name: &str, diff --git a/src/query_result.rs b/src/query_result.rs index 06299b86..68e8eb3f 100644 --- a/src/query_result.rs +++ b/src/query_result.rs @@ -10,7 +10,7 @@ use crate::{exceptions::rust_errors::RustPSQLDriverPyResult, value_converter::po /// May return Err Result if can not convert /// postgres type to python or set new key-value pair /// in python dict. -#[allow(clippy::ref_option)] +#[allow(clippy::ref_option_ref)] fn row_to_dict<'a>( py: Python<'a>, postgres_row: &'a Row, From cdfa4099924500f4ad3eeab2db3b0432de35478e Mon Sep 17 00:00:00 2001 From: Vladislav Yashkov Date: Tue, 3 Dec 2024 11:51:54 +0300 Subject: [PATCH 5/5] WIP# 5 --- src/driver/cursor.rs | 12 ++++++------ src/query_result.rs | 2 +- src/value_converter.rs | 6 +++--- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/driver/cursor.rs b/src/driver/cursor.rs index b04ed8e6..74e353b2 100644 --- a/src/driver/cursor.rs +++ b/src/driver/cursor.rs @@ -18,10 +18,10 @@ trait CursorObjectTrait { async fn cursor_start( &self, cursor_name: &str, - scroll: &Option, + scroll: Option<&bool>, querystring: &str, - prepared: &Option, - parameters: &Option>, + prepared: Option<&bool>, + parameters: Option<&Py>, ) -> RustPSQLDriverPyResult<()>; async fn cursor_close(&self, closed: &bool, cursor_name: &str) -> RustPSQLDriverPyResult<()>; @@ -38,10 +38,10 @@ impl CursorObjectTrait for Object { async fn cursor_start( &self, cursor_name: &str, - scroll: &Option, + scroll: Option<&bool>, querystring: &str, - prepared: &Option, - parameters: &Option>, + prepared: Option<&bool>, + parameters: Option<&Py>, ) -> RustPSQLDriverPyResult<()> { let mut cursor_init_query = format!("DECLARE {cursor_name}"); if let Some(scroll) = scroll { diff --git a/src/query_result.rs b/src/query_result.rs index 68e8eb3f..c4025ee3 100644 --- a/src/query_result.rs +++ b/src/query_result.rs @@ -14,7 +14,7 @@ use crate::{exceptions::rust_errors::RustPSQLDriverPyResult, value_converter::po fn row_to_dict<'a>( py: Python<'a>, postgres_row: &'a Row, - custom_decoders: &Option>, + custom_decoders: Option<&Py>, ) -> RustPSQLDriverPyResult> { let python_dict = PyDict::new_bound(py); for (column_idx, column) in postgres_row.columns().iter().enumerate() { diff --git a/src/value_converter.rs b/src/value_converter.rs index 418e7998..b3b252fe 100644 --- a/src/value_converter.rs +++ b/src/value_converter.rs @@ -1752,7 +1752,7 @@ pub fn composite_postgres_to_py( py: Python<'_>, fields: &Vec, buf: &mut &[u8], - custom_decoders: &Option>, + custom_decoders: Option<&Py>, ) -> RustPSQLDriverPyResult> { let result_py_dict: Bound<'_, PyDict> = PyDict::new_bound(py); @@ -1821,7 +1821,7 @@ pub fn raw_bytes_data_process( raw_bytes_data: &mut &[u8], column_name: &str, column_type: &Type, - custom_decoders: &Option>, + custom_decoders: Option<&Py>, ) -> RustPSQLDriverPyResult> { if let Some(custom_decoders) = custom_decoders { let py_encoder_func = custom_decoders @@ -1860,7 +1860,7 @@ pub fn postgres_to_py( row: &Row, column: &Column, column_i: usize, - custom_decoders: &Option>, + custom_decoders: Option<&Py>, ) -> RustPSQLDriverPyResult> { let raw_bytes_data = row.col_buffer(column_i); if let Some(mut raw_bytes_data) = raw_bytes_data {