Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
conf.get(PYTHON_DAEMON_KILL_WORKER_ON_FLUSH_FAILURE)
protected val hideTraceback: Boolean = false
protected val simplifiedTraceback: Boolean = false
protected val sessionLocalTimeZone = conf.getOption("spark.sql.session.timeZone")

// All the Python functions should have the same exec, version and envvars.
protected val envVars: java.util.Map[String, String] = funcs.head.funcs.head.envVars
Expand Down Expand Up @@ -282,6 +283,9 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
if (simplifiedTraceback) {
envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
}
if (sessionLocalTimeZone.isDefined) {
envVars.put("SPARK_SESSION_LOCAL_TIMEZONE", sessionLocalTimeZone.get)
}
// SPARK-30299 this could be wrong with standalone mode when executor
// cores might not be correct because it defaults to all cores on the box.
val execCores = execCoresProp.map(_.toInt).getOrElse(conf.get(EXECUTOR_CORES))
Expand Down
5 changes: 2 additions & 3 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,9 +452,8 @@ def needConversion(self) -> bool:

def toInternal(self, dt: datetime.datetime) -> int:
if dt is not None:
seconds = (
calendar.timegm(dt.utctimetuple()) if dt.tzinfo else time.mktime(dt.timetuple())
)
tzinfo = dt.tzinfo if dt.tzinfo else self.tz_info
seconds = calendar.timegm(dt.utctimetuple()) if tzinfo else time.mktime(dt.timetuple())
return int(seconds) * 1000000 + dt.microsecond

def fromInternal(self, ts: int) -> datetime.datetime:
Expand Down
9 changes: 7 additions & 2 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import inspect
import itertools
import json
import zoneinfo
from typing import Any, Callable, Iterable, Iterator, Optional, Tuple

from pyspark.accumulators import (
Expand Down Expand Up @@ -3304,8 +3305,12 @@ def main(infile, outfile):
sys.exit(-1)
start_faulthandler_periodic_traceback()

# Use the local timezone to convert the timestamp
tz = datetime.datetime.now().astimezone().tzinfo
tzname = os.environ.get("SPARK_SESSION_LOCAL_TIMEZONE", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To confirm, we will hit this branch for every udf execution, not just once per python worker initialization, right?

if tzname:
tz = zoneinfo.ZoneInfo(tzname)
else:
# Use the local timezone to convert the timestamp
tz = datetime.datetime.now().astimezone().tzinfo
TimestampType.tz_info = tz

check_python_version(infile)
Expand Down