Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add-support-datetime-zone-info #111

Merged
merged 2 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion python/tests/test_value_converter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import sys
import uuid
from decimal import Decimal
from enum import Enum
Expand Down Expand Up @@ -57,6 +58,7 @@

from tests.conftest import DefaultPydanticModel, DefaultPythonModelClass

uuid_ = uuid.uuid4()
pytestmark = pytest.mark.anyio
now_datetime = datetime.datetime.now() # noqa: DTZ005
now_datetime_with_tz = datetime.datetime(
Expand All @@ -69,7 +71,30 @@
142574,
tzinfo=datetime.timezone.utc,
)
uuid_ = uuid.uuid4()

now_datetime_with_tz_in_asia_jakarta = datetime.datetime(
2024,
4,
13,
17,
3,
46,
142574,
tzinfo=datetime.timezone.utc,
)
if sys.version_info >= (3, 9):
import zoneinfo

now_datetime_with_tz_in_asia_jakarta = datetime.datetime(
2024,
4,
13,
17,
3,
46,
142574,
tzinfo=zoneinfo.ZoneInfo(key="Asia/Jakarta"),
)


async def test_as_class(
Expand Down Expand Up @@ -125,6 +150,7 @@ async def test_as_class(
("TIME", now_datetime.time(), now_datetime.time()),
("TIMESTAMP", now_datetime, now_datetime),
("TIMESTAMPTZ", now_datetime_with_tz, now_datetime_with_tz),
("TIMESTAMPTZ", now_datetime_with_tz_in_asia_jakarta, now_datetime_with_tz_in_asia_jakarta),
("UUID", uuid_, str(uuid_)),
("INET", IPv4Address("192.0.0.1"), IPv4Address("192.0.0.1")),
(
Expand Down Expand Up @@ -287,6 +313,11 @@ async def test_as_class(
[now_datetime_with_tz, now_datetime_with_tz],
[now_datetime_with_tz, now_datetime_with_tz],
),
(
"TIMESTAMPTZ ARRAY",
[now_datetime_with_tz, now_datetime_with_tz_in_asia_jakarta],
[now_datetime_with_tz, now_datetime_with_tz_in_asia_jakarta],
),
(
"TIMESTAMPTZ ARRAY",
[[now_datetime_with_tz], [now_datetime_with_tz]],
Expand Down
9 changes: 0 additions & 9 deletions src/driver/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,6 @@ impl Connection {

#[pymethods]
impl Connection {
#[must_use]
pub fn __aiter__(self_: Py<Self>) -> Py<Self> {
self_
}

fn __await__(self_: Py<Self>) -> Py<Self> {
self_
}

async fn __aenter__<'a>(self_: Py<Self>) -> RustPSQLDriverPyResult<Py<Self>> {
let (db_client, db_pool) = pyo3::Python::with_gil(|gil| {
let self_ = self_.borrow(gil);
Expand Down
89 changes: 85 additions & 4 deletions src/value_converter.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use chrono::{self, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime};
use chrono::{self, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, TimeZone};
use chrono_tz::Tz;
use geo_types::{coord, Coord, Line as LineSegment, LineString, Point, Rect};
use itertools::Itertools;
use macaddr::{MacAddr6, MacAddr8};
Expand Down Expand Up @@ -626,8 +627,7 @@ impl ToSql for PythonDTO {
#[allow(clippy::needless_pass_by_value)]
pub fn convert_parameters(parameters: Py<PyAny>) -> RustPSQLDriverPyResult<Vec<PythonDTO>> {
let mut result_vec: Vec<PythonDTO> = vec![];

result_vec = Python::with_gil(|gil| {
Python::with_gil(|gil| {
let params = parameters.extract::<Vec<Py<PyAny>>>(gil).map_err(|_| {
RustPSQLDriverError::PyToRustValueConversionError(
"Cannot convert you parameters argument into Rust type, please use List/Tuple"
Expand All @@ -637,8 +637,9 @@ pub fn convert_parameters(parameters: Py<PyAny>) -> RustPSQLDriverPyResult<Vec<P
for parameter in params {
result_vec.push(py_to_rust(parameter.bind(gil))?);
}
Ok::<Vec<PythonDTO>, RustPSQLDriverError>(result_vec)
Ok::<(), RustPSQLDriverError>(())
})?;

Ok(result_vec)
}

Expand Down Expand Up @@ -744,6 +745,81 @@ pub fn py_sequence_into_postgres_array(
}
}

/// Extract a value from a Python object, raising an error if missing or invalid
///
/// # Errors
/// This function will return `Err` in the following cases:
/// - The Python object does not have the specified attribute
/// - The attribute exists but cannot be extracted into the specified Rust type
fn extract_value_from_python_object_or_raise<'py, T>(
parameter: &'py pyo3::Bound<'_, PyAny>,
attr_name: &str,
) -> Result<T, RustPSQLDriverError>
where
T: FromPyObject<'py>,
{
parameter
.getattr(attr_name)
.ok()
.and_then(|attr| attr.extract::<T>().ok())
.ok_or_else(|| {
RustPSQLDriverError::PyToRustValueConversionError("Invalid attribute".into())
})
}

/// Extract a timezone-aware datetime from a Python object.
/// This function retrieves various datetime components (`year`, `month`, `day`, etc.)
/// from a Python object and constructs a `DateTime<FixedOffset>`
///
/// # Errors
/// This function will return `Err` in the following cases:
/// - The Python object does not contain or support one or more required datetime attributes
/// - The retrieved values are invalid for constructing a date, time, or datetime (e.g., invalid month or day)
/// - The timezone information (`tzinfo`) is not available or cannot be parsed
/// - The resulting datetime is ambiguous or invalid (e.g., due to DST transitions)
fn extract_datetime_from_python_object_attrs(
parameter: &pyo3::Bound<'_, PyAny>,
) -> Result<DateTime<FixedOffset>, RustPSQLDriverError> {
let year = extract_value_from_python_object_or_raise::<i32>(parameter, "year")?;
let month = extract_value_from_python_object_or_raise::<u32>(parameter, "month")?;
let day = extract_value_from_python_object_or_raise::<u32>(parameter, "day")?;
let hour = extract_value_from_python_object_or_raise::<u32>(parameter, "hour")?;
let minute = extract_value_from_python_object_or_raise::<u32>(parameter, "minute")?;
let second = extract_value_from_python_object_or_raise::<u32>(parameter, "second")?;
let microsecond = extract_value_from_python_object_or_raise::<u32>(parameter, "microsecond")?;

let date = NaiveDate::from_ymd_opt(year, month, day)
.ok_or_else(|| RustPSQLDriverError::PyToRustValueConversionError("Invalid date".into()))?;
let time = NaiveTime::from_hms_micro_opt(hour, minute, second, microsecond)
.ok_or_else(|| RustPSQLDriverError::PyToRustValueConversionError("Invalid time".into()))?;
let naive_datetime = NaiveDateTime::new(date, time);

let raw_timestamp_tz = parameter
.getattr("tzinfo")
.ok()
.and_then(|tzinfo| tzinfo.getattr("key").ok())
.and_then(|key| key.extract::<String>().ok())
.ok_or_else(|| {
RustPSQLDriverError::PyToRustValueConversionError("Invalid timezone info".into())
})?;

let fixed_offset_datetime = raw_timestamp_tz
.parse::<Tz>()
.map_err(|_| {
RustPSQLDriverError::PyToRustValueConversionError("Failed to parse TZ".into())
})?
.from_local_datetime(&naive_datetime)
.single()
.ok_or_else(|| {
RustPSQLDriverError::PyToRustValueConversionError(
"Ambiguous or invalid datetime".into(),
)
})?
.fixed_offset();

Ok(fixed_offset_datetime)
}

/// Convert single python parameter to `PythonDTO` enum.
///
/// # Errors
Expand Down Expand Up @@ -849,6 +925,11 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult<
return Ok(PythonDTO::PyDateTime(pydatetime_no_tz));
}

let timestamp_tz = extract_datetime_from_python_object_attrs(parameter);
if let Ok(pydatetime_tz) = timestamp_tz {
chandr-andr marked this conversation as resolved.
Show resolved Hide resolved
return Ok(PythonDTO::PyDateTimeTz(pydatetime_tz));
}

return Err(RustPSQLDriverError::PyToRustValueConversionError(
"Can not convert you datetime to rust type".into(),
));
Expand Down
Loading