Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add-support-datetime-zone-info #107

Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from typing import AsyncGenerator

import pytest
from pydantic import BaseModel

from psqlpy import ConnectionPool, Cursor
from psqlpy._internal import SslMode
from pydantic import BaseModel


class DefaultPydanticModel(BaseModel):
Expand Down
3 changes: 1 addition & 2 deletions python/tests/test_binary_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@

import pytest
from pgpq import ArrowToPostgresBinaryEncoder
from pyarrow import parquet

from psqlpy import ConnectionPool
from pyarrow import parquet

pytestmark = pytest.mark.anyio

Expand Down
3 changes: 1 addition & 2 deletions python/tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
import typing

import pytest
from tests.helpers import count_rows_in_test_table

from psqlpy import ConnectionPool, Cursor, QueryResult, Transaction
from psqlpy.exceptions import (
ConnectionClosedError,
ConnectionExecuteError,
TransactionExecuteError,
)
from tests.helpers import count_rows_in_test_table

pytestmark = pytest.mark.anyio

Expand Down
1 change: 0 additions & 1 deletion python/tests/test_connection_pool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest

from psqlpy import (
Connection,
ConnectionPool,
Expand Down
1 change: 0 additions & 1 deletion python/tests/test_connection_pool_builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest

from psqlpy import (
ConnectionPoolBuilder,
ConnRecyclingMethod,
Expand Down
1 change: 0 additions & 1 deletion python/tests/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import math

import pytest

from psqlpy import ConnectionPool, Cursor, QueryResult, Transaction

pytestmark = pytest.mark.anyio
Expand Down
1 change: 0 additions & 1 deletion python/tests/test_row_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, Callable, Dict, Type

import pytest

from psqlpy import ConnectionPool
from psqlpy.row_factories import class_row, tuple_row

Expand Down
1 change: 0 additions & 1 deletion python/tests/test_ssl_mode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest

from psqlpy import ConnectionPool, SslMode
from psqlpy._internal import ConnectionPoolBuilder

Expand Down
3 changes: 1 addition & 2 deletions python/tests/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import typing

import pytest
from tests.helpers import count_rows_in_test_table

from psqlpy import (
ConnectionPool,
Cursor,
Expand All @@ -18,6 +16,7 @@
TransactionExecuteError,
TransactionSavepointError,
)
from tests.helpers import count_rows_in_test_table

pytestmark = pytest.mark.anyio

Expand Down
24 changes: 20 additions & 4 deletions python/tests/test_value_converter.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import datetime
import uuid
import zoneinfo
from decimal import Decimal
from enum import Enum
from ipaddress import IPv4Address
from typing import Any, Dict, List, Tuple, Union

import pytest
from pydantic import BaseModel
from tests.conftest import DefaultPydanticModel, DefaultPythonModelClass
from typing_extensions import Annotated

from psqlpy import ConnectionPool
from psqlpy.exceptions import PyToRustValueMappingError
from psqlpy.extra_types import (
Expand Down Expand Up @@ -56,6 +53,9 @@
UUIDArray,
VarCharArray,
)
from pydantic import BaseModel
from tests.conftest import DefaultPydanticModel, DefaultPythonModelClass
from typing_extensions import Annotated

pytestmark = pytest.mark.anyio
now_datetime = datetime.datetime.now()
Expand All @@ -69,6 +69,16 @@
142574,
tzinfo=datetime.timezone.utc,
)
now_datetime_with_tz_in_asia_jakarta = datetime.datetime(
2024,
4,
13,
17,
3,
46,
142574,
tzinfo=zoneinfo.ZoneInfo(key="Asia/Jakarta"),
)
uuid_ = uuid.uuid4()


Expand Down Expand Up @@ -125,6 +135,7 @@ async def test_as_class(
("TIME", now_datetime.time(), now_datetime.time()),
("TIMESTAMP", now_datetime, now_datetime),
("TIMESTAMPTZ", now_datetime_with_tz, now_datetime_with_tz),
("TIMESTAMPTZ", now_datetime_with_tz_in_asia_jakarta, now_datetime_with_tz_in_asia_jakarta),
("UUID", uuid_, str(uuid_)),
("INET", IPv4Address("192.0.0.1"), IPv4Address("192.0.0.1")),
(
Expand Down Expand Up @@ -287,6 +298,11 @@ async def test_as_class(
[now_datetime_with_tz, now_datetime_with_tz],
[now_datetime_with_tz, now_datetime_with_tz],
),
(
"TIMESTAMPTZ ARRAY",
[now_datetime_with_tz, now_datetime_with_tz_in_asia_jakarta],
[now_datetime_with_tz, now_datetime_with_tz_in_asia_jakarta],
),
(
"TIMESTAMPTZ ARRAY",
[[now_datetime_with_tz], [now_datetime_with_tz]],
Expand Down
9 changes: 0 additions & 9 deletions src/driver/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,6 @@ impl Connection {

#[pymethods]
impl Connection {
#[must_use]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels unnecessary.
aenter and aexit here the main methods, as I see

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, It needs to be checked)

pub fn __aiter__(self_: Py<Self>) -> Py<Self> {
self_
}

fn __await__(self_: Py<Self>) -> Py<Self> {
self_
}

async fn __aenter__<'a>(self_: Py<Self>) -> RustPSQLDriverPyResult<Py<Self>> {
let (db_client, db_pool) = pyo3::Python::with_gil(|gil| {
let self_ = self_.borrow(gil);
Expand Down
94 changes: 89 additions & 5 deletions src/value_converter.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use chrono::{self, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime};
use chrono::{self, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, TimeZone};
use chrono_tz::Tz;
use geo_types::{coord, Coord, Line as LineSegment, LineString, Point, Rect};
use itertools::Itertools;
use macaddr::{MacAddr6, MacAddr8};
Expand Down Expand Up @@ -626,8 +627,7 @@ impl ToSql for PythonDTO {
#[allow(clippy::needless_pass_by_value)]
pub fn convert_parameters(parameters: Py<PyAny>) -> RustPSQLDriverPyResult<Vec<PythonDTO>> {
let mut result_vec: Vec<PythonDTO> = vec![];

result_vec = Python::with_gil(|gil| {
Python::with_gil(|gil| {
let params = parameters.extract::<Vec<Py<PyAny>>>(gil).map_err(|_| {
RustPSQLDriverError::PyToRustValueConversionError(
"Cannot convert you parameters argument into Rust type, please use List/Tuple"
Expand All @@ -637,8 +637,9 @@ pub fn convert_parameters(parameters: Py<PyAny>) -> RustPSQLDriverPyResult<Vec<P
for parameter in params {
result_vec.push(py_to_rust(parameter.bind(gil))?);
}
Ok::<Vec<PythonDTO>, RustPSQLDriverError>(result_vec)
Ok::<(), RustPSQLDriverError>(())
})?;

Ok(result_vec)
}

Expand Down Expand Up @@ -744,6 +745,84 @@ pub fn py_sequence_into_postgres_array(
}
}

/// Extract a value from a Python object, raising an error if missing or invalid
///
/// # Type Parameters
/// - `T`: The type to which the attribute's value will be converted. This type must implement the `FromPyObject` trait
///
/// # Errors
/// This function will return `Err` in the following cases:
/// - The Python object does not have the specified attribute
/// - The attribute exists but cannot be extracted into the specified Rust type
fn extract_value_from_python_object_or_raise<'py, T>(
parameter: &'py pyo3::Bound<'_, PyAny>,
attr_name: &str,
) -> Result<T, RustPSQLDriverError>
where
T: FromPyObject<'py>,
{
parameter
.getattr(attr_name)
.ok()
.and_then(|attr| attr.extract::<T>().ok())
.ok_or_else(|| {
RustPSQLDriverError::PyToRustValueConversionError("Invalid attribute".into())
})
}

/// Extract a timezone-aware datetime from a Python object.
/// This function retrieves various datetime components (`year`, `month`, `day`, etc.)
/// from a Python object and constructs a `DateTime<FixedOffset>`
///
/// # Errors
/// This function will return `Err` in the following cases:
/// - The Python object does not contain or support one or more required datetime attributes
/// - The retrieved values are invalid for constructing a date, time, or datetime (e.g., invalid month or day)
/// - The timezone information (`tzinfo`) is not available or cannot be parsed
/// - The resulting datetime is ambiguous or invalid (e.g., due to DST transitions)
fn extract_datetime_from_python_object_attrs(
parameter: &pyo3::Bound<'_, PyAny>,
) -> Result<DateTime<FixedOffset>, RustPSQLDriverError> {
let year = extract_value_from_python_object_or_raise::<i32>(parameter, "year")?;
let month = extract_value_from_python_object_or_raise::<u32>(parameter, "month")?;
let day = extract_value_from_python_object_or_raise::<u32>(parameter, "day")?;
let hour = extract_value_from_python_object_or_raise::<u32>(parameter, "hour")?;
let minute = extract_value_from_python_object_or_raise::<u32>(parameter, "minute")?;
let second = extract_value_from_python_object_or_raise::<u32>(parameter, "second")?;
let microsecond = extract_value_from_python_object_or_raise::<u32>(parameter, "microsecond")?;

let date = NaiveDate::from_ymd_opt(year, month, day)
.ok_or_else(|| RustPSQLDriverError::PyToRustValueConversionError("Invalid date".into()))?;
let time = NaiveTime::from_hms_micro_opt(hour, minute, second, microsecond)
.ok_or_else(|| RustPSQLDriverError::PyToRustValueConversionError("Invalid time".into()))?;
let naive_datetime = NaiveDateTime::new(date, time);

let raw_timestamp_tz = parameter
.getattr("tzinfo")
.ok()
.and_then(|tzinfo| tzinfo.getattr("key").ok())
.and_then(|key| key.extract::<String>().ok())
.ok_or_else(|| {
RustPSQLDriverError::PyToRustValueConversionError("Invalid timezone info".into())
})?;

let fixed_offset_datetime = raw_timestamp_tz
.parse::<Tz>()
.map_err(|_| {
RustPSQLDriverError::PyToRustValueConversionError("Failed to parse TZ".into())
})?
.from_local_datetime(&naive_datetime)
.single()
.ok_or_else(|| {
RustPSQLDriverError::PyToRustValueConversionError(
"Ambiguous or invalid datetime".into(),
)
})?
.fixed_offset();

Ok(fixed_offset_datetime)
}

/// Convert single python parameter to `PythonDTO` enum.
///
/// # Errors
Expand Down Expand Up @@ -849,6 +928,11 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult<
return Ok(PythonDTO::PyDateTime(pydatetime_no_tz));
}

let timestamp_tz = extract_datetime_from_python_object_attrs(parameter);
if let Ok(pydatetime_tz) = timestamp_tz {
return Ok(PythonDTO::PyDateTimeTz(pydatetime_tz));
}

return Err(RustPSQLDriverError::PyToRustValueConversionError(
"Can not convert you datetime to rust type".into(),
));
Expand Down Expand Up @@ -1663,7 +1747,7 @@ pub fn other_postgres_bytes_to_py(
}

Err(RustPSQLDriverError::RustToPyValueConversionError(
format!("Cannot convert {type_} into Python type, please look at the custom_decoders functionality.")
format!("Cannot convert {type_} into Python type, please look at the custom_decoders functionality.")
))
}

Expand Down
Loading