Skip to content

Commit 56393b9

Browse files
committed
[SPARK-45091][PYTHON][CONNECT][SQL] Function floor/round/bround accept Column type scale
### What changes were proposed in this pull request? 1, `floor`: add missing parameter `scale` in Python, which already existed in Scala for a long time; 2, `round/bround`: parameter `scale` support Column type, to be consistent with `floor/ceil/ceiling` ### Why are the changes needed? to make related functions consistent ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? added doctest ### Was this patch authored or co-authored using generative AI tooling? NO Closes apache#42833 from zhengruifeng/py_func_floor. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 66f89f3 commit 56393b9

File tree

4 files changed

+131
-22
lines changed

4 files changed

+131
-22
lines changed

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2845,6 +2845,15 @@ object functions {
28452845
*/
28462846
def round(e: Column, scale: Int): Column = Column.fn("round", e, lit(scale))
28472847

2848+
/**
2849+
* Round the value of `e` to `scale` decimal places with HALF_UP round mode if `scale` is
2850+
* greater than or equal to 0 or at integral part when `scale` is less than 0.
2851+
*
2852+
* @group math_funcs
2853+
* @since 4.0.0
2854+
*/
2855+
def round(e: Column, scale: Column): Column = Column.fn("round", e, scale)
2856+
28482857
/**
28492858
* Returns the value of the column `e` rounded to 0 decimal places with HALF_EVEN round mode.
28502859
*
@@ -2862,6 +2871,15 @@ object functions {
28622871
*/
28632872
def bround(e: Column, scale: Int): Column = Column.fn("bround", e, lit(scale))
28642873

2874+
/**
2875+
* Round the value of `e` to `scale` decimal places with HALF_EVEN round mode if `scale` is
2876+
* greater than or equal to 0 or at integral part when `scale` is less than 0.
2877+
*
2878+
* @group math_funcs
2879+
* @since 4.0.0
2880+
*/
2881+
def bround(e: Column, scale: Column): Column = Column.fn("bround", e, scale)
2882+
28652883
/**
28662884
* @param e
28672885
* angle in radians

python/pyspark/sql/connect/functions.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -538,8 +538,12 @@ def bin(col: "ColumnOrName") -> Column:
538538
bin.__doc__ = pysparkfuncs.bin.__doc__
539539

540540

541-
def bround(col: "ColumnOrName", scale: int = 0) -> Column:
542-
return _invoke_function("bround", _to_col(col), lit(scale))
541+
def bround(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) -> Column:
542+
if scale is None:
543+
return _invoke_function_over_columns("bround", col)
544+
else:
545+
scale = lit(scale) if isinstance(scale, int) else scale
546+
return _invoke_function_over_columns("bround", col, scale)
543547

544548

545549
bround.__doc__ = pysparkfuncs.bround.__doc__
@@ -644,8 +648,12 @@ def factorial(col: "ColumnOrName") -> Column:
644648
factorial.__doc__ = pysparkfuncs.factorial.__doc__
645649

646650

647-
def floor(col: "ColumnOrName") -> Column:
648-
return _invoke_function_over_columns("floor", col)
651+
def floor(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) -> Column:
652+
if scale is None:
653+
return _invoke_function_over_columns("floor", col)
654+
else:
655+
scale = lit(scale) if isinstance(scale, int) else scale
656+
return _invoke_function_over_columns("floor", col, scale)
649657

650658

651659
floor.__doc__ = pysparkfuncs.floor.__doc__
@@ -773,8 +781,12 @@ def rint(col: "ColumnOrName") -> Column:
773781
rint.__doc__ = pysparkfuncs.rint.__doc__
774782

775783

776-
def round(col: "ColumnOrName", scale: int = 0) -> Column:
777-
return _invoke_function("round", _to_col(col), lit(scale))
784+
def round(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) -> Column:
785+
if scale is None:
786+
return _invoke_function_over_columns("round", col)
787+
else:
788+
scale = lit(scale) if isinstance(scale, int) else scale
789+
return _invoke_function_over_columns("round", col, scale)
778790

779791

780792
round.__doc__ = pysparkfuncs.round.__doc__

python/pyspark/sql/functions.py

Lines changed: 73 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2346,7 +2346,7 @@ def expm1(col: "ColumnOrName") -> Column:
23462346

23472347

23482348
@try_remote_functions
2349-
def floor(col: "ColumnOrName") -> Column:
2349+
def floor(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) -> Column:
23502350
"""
23512351
Computes the floor of the given value.
23522352
@@ -2359,6 +2359,11 @@ def floor(col: "ColumnOrName") -> Column:
23592359
----------
23602360
col : :class:`~pyspark.sql.Column` or str
23612361
column to find floor for.
2362+
scale : :class:`~pyspark.sql.Column` or int
2363+
an optional parameter to control the rounding behavior.
2364+
2365+
.. versionadded:: 4.0.0
2366+
23622367
23632368
Returns
23642369
-------
@@ -2367,15 +2372,27 @@ def floor(col: "ColumnOrName") -> Column:
23672372
23682373
Examples
23692374
--------
2370-
>>> df = spark.range(1)
2371-
>>> df.select(floor(lit(2.5))).show()
2375+
>>> import pyspark.sql.functions as sf
2376+
>>> spark.range(1).select(sf.floor(sf.lit(2.5))).show()
23722377
+----------+
23732378
|FLOOR(2.5)|
23742379
+----------+
23752380
| 2|
23762381
+----------+
2382+
2383+
>>> import pyspark.sql.functions as sf
2384+
>>> spark.range(1).select(sf.floor(sf.lit(2.1267), sf.lit(2))).show()
2385+
+----------------+
2386+
|floor(2.1267, 2)|
2387+
+----------------+
2388+
| 2.12|
2389+
+----------------+
23772390
"""
2378-
return _invoke_function_over_columns("floor", col)
2391+
if scale is None:
2392+
return _invoke_function_over_columns("floor", col)
2393+
else:
2394+
scale = lit(scale) if isinstance(scale, int) else scale
2395+
return _invoke_function_over_columns("floor", col, scale)
23792396

23802397

23812398
@try_remote_functions
@@ -5631,7 +5648,7 @@ def randn(seed: Optional[int] = None) -> Column:
56315648

56325649

56335650
@try_remote_functions
5634-
def round(col: "ColumnOrName", scale: int = 0) -> Column:
5651+
def round(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) -> Column:
56355652
"""
56365653
Round the given value to `scale` decimal places using HALF_UP rounding mode if `scale` >= 0
56375654
or at integral part when `scale` < 0.
@@ -5645,8 +5662,11 @@ def round(col: "ColumnOrName", scale: int = 0) -> Column:
56455662
----------
56465663
col : :class:`~pyspark.sql.Column` or str
56475664
input column to round.
5648-
scale : int optional default 0
5649-
scale value.
5665+
scale : :class:`~pyspark.sql.Column` or int
5666+
an optional parameter to control the rounding behavior.
5667+
5668+
.. versionchanged:: 4.0.0
5669+
Support Column type.
56505670
56515671
Returns
56525672
-------
@@ -5655,14 +5675,31 @@ def round(col: "ColumnOrName", scale: int = 0) -> Column:
56555675
56565676
Examples
56575677
--------
5658-
>>> spark.createDataFrame([(2.5,)], ['a']).select(round('a', 0).alias('r')).collect()
5659-
[Row(r=3.0)]
5678+
>>> import pyspark.sql.functions as sf
5679+
>>> spark.range(1).select(sf.round(sf.lit(2.5))).show()
5680+
+-------------+
5681+
|round(2.5, 0)|
5682+
+-------------+
5683+
| 3.0|
5684+
+-------------+
5685+
5686+
>>> import pyspark.sql.functions as sf
5687+
>>> spark.range(1).select(sf.round(sf.lit(2.1267), sf.lit(2))).show()
5688+
+----------------+
5689+
|round(2.1267, 2)|
5690+
+----------------+
5691+
| 2.13|
5692+
+----------------+
56605693
"""
5661-
return _invoke_function("round", _to_java_column(col), scale)
5694+
if scale is None:
5695+
return _invoke_function_over_columns("round", col)
5696+
else:
5697+
scale = lit(scale) if isinstance(scale, int) else scale
5698+
return _invoke_function_over_columns("round", col, scale)
56625699

56635700

56645701
@try_remote_functions
5665-
def bround(col: "ColumnOrName", scale: int = 0) -> Column:
5702+
def bround(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) -> Column:
56665703
"""
56675704
Round the given value to `scale` decimal places using HALF_EVEN rounding mode if `scale` >= 0
56685705
or at integral part when `scale` < 0.
@@ -5676,8 +5713,11 @@ def bround(col: "ColumnOrName", scale: int = 0) -> Column:
56765713
----------
56775714
col : :class:`~pyspark.sql.Column` or str
56785715
input column to round.
5679-
scale : int optional default 0
5680-
scale value.
5716+
scale : :class:`~pyspark.sql.Column` or int
5717+
an optional parameter to control the rounding behavior.
5718+
5719+
.. versionchanged:: 4.0.0
5720+
Support Column type.
56815721
56825722
Returns
56835723
-------
@@ -5686,10 +5726,27 @@ def bround(col: "ColumnOrName", scale: int = 0) -> Column:
56865726
56875727
Examples
56885728
--------
5689-
>>> spark.createDataFrame([(2.5,)], ['a']).select(bround('a', 0).alias('r')).collect()
5690-
[Row(r=2.0)]
5729+
>>> import pyspark.sql.functions as sf
5730+
>>> spark.range(1).select(sf.bround(sf.lit(2.5))).show()
5731+
+--------------+
5732+
|bround(2.5, 0)|
5733+
+--------------+
5734+
| 2.0|
5735+
+--------------+
5736+
5737+
>>> import pyspark.sql.functions as sf
5738+
>>> spark.range(1).select(sf.bround(sf.lit(2.1267), sf.lit(2))).show()
5739+
+-----------------+
5740+
|bround(2.1267, 2)|
5741+
+-----------------+
5742+
| 2.13|
5743+
+-----------------+
56915744
"""
5692-
return _invoke_function("bround", _to_java_column(col), scale)
5745+
if scale is None:
5746+
return _invoke_function_over_columns("bround", col)
5747+
else:
5748+
scale = lit(scale) if isinstance(scale, int) else scale
5749+
return _invoke_function_over_columns("bround", col, scale)
56935750

56945751

56955752
@try_remote_functions

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2881,6 +2881,17 @@ object functions {
28812881
*/
28822882
def round(e: Column, scale: Int): Column = withExpr { Round(e.expr, Literal(scale)) }
28832883

2884+
/**
2885+
* Round the value of `e` to `scale` decimal places with HALF_UP round mode
2886+
* if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0.
2887+
*
2888+
* @group math_funcs
2889+
* @since 4.0.0
2890+
*/
2891+
def round(e: Column, scale: Column): Column = withExpr {
2892+
Round(e.expr, scale.expr)
2893+
}
2894+
28842895
/**
28852896
* Returns the value of the column `e` rounded to 0 decimal places with HALF_EVEN round mode.
28862897
*
@@ -2898,6 +2909,17 @@ object functions {
28982909
*/
28992910
def bround(e: Column, scale: Int): Column = withExpr { BRound(e.expr, Literal(scale)) }
29002911

2912+
/**
2913+
* Round the value of `e` to `scale` decimal places with HALF_EVEN round mode
2914+
* if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0.
2915+
*
2916+
* @group math_funcs
2917+
* @since 4.0.0
2918+
*/
2919+
def bround(e: Column, scale: Column): Column = withExpr {
2920+
BRound(e.expr, scale.expr)
2921+
}
2922+
29012923
/**
29022924
* @param e angle in radians
29032925
* @return secant of the angle

0 commit comments

Comments
 (0)