Skip to content

Commit

Permalink
[SPARK-50030][PYTHON][CONNECT] API compatibility check for Window
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR proposes to add API compatibility check for Spark SQL Window functions

### Why are the changes needed?

To guarantee of the same behavior between Spark Classic and Spark Connect

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Added UTs

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #48541 from itholic/SPARK-50030.

Authored-by: Haejoon Lee <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
itholic authored and HyukjinKwon committed Oct 19, 2024
1 parent f8d9222 commit 14ed86e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 10 deletions.
18 changes: 8 additions & 10 deletions python/pyspark/sql/connect/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,23 +84,21 @@ def __init__(
self._orderSpec = orderSpec
self._frame = frame

def partitionBy(
self, *cols: Union["ColumnOrName", Sequence["ColumnOrName"]]
) -> ParentWindowSpec:
def partitionBy(self, *cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> "WindowSpec":
return WindowSpec(
partitionSpec=[c._expr for c in _to_cols(cols)], # type: ignore[misc]
orderSpec=self._orderSpec,
frame=self._frame,
)

def orderBy(self, *cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> ParentWindowSpec:
def orderBy(self, *cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> "WindowSpec":
return WindowSpec(
partitionSpec=self._partitionSpec,
orderSpec=[cast(SortOrder, F._sort_col(c)._expr) for c in _to_cols(cols)],
frame=self._frame,
)

def rowsBetween(self, start: int, end: int) -> ParentWindowSpec:
def rowsBetween(self, start: int, end: int) -> "WindowSpec":
if start <= Window._PRECEDING_THRESHOLD:
start = Window.unboundedPreceding
if end >= Window._FOLLOWING_THRESHOLD:
Expand All @@ -112,7 +110,7 @@ def rowsBetween(self, start: int, end: int) -> ParentWindowSpec:
frame=WindowFrame(isRowFrame=True, start=start, end=end),
)

def rangeBetween(self, start: int, end: int) -> ParentWindowSpec:
def rangeBetween(self, start: int, end: int) -> "WindowSpec":
if start <= Window._PRECEDING_THRESHOLD:
start = Window.unboundedPreceding
if end >= Window._FOLLOWING_THRESHOLD:
Expand Down Expand Up @@ -141,19 +139,19 @@ class Window(ParentWindow):
_spec = WindowSpec(partitionSpec=[], orderSpec=[], frame=None)

@staticmethod
def partitionBy(*cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> ParentWindowSpec:
def partitionBy(*cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> "WindowSpec":
return Window._spec.partitionBy(*cols)

@staticmethod
def orderBy(*cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> ParentWindowSpec:
def orderBy(*cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> "WindowSpec":
return Window._spec.orderBy(*cols)

@staticmethod
def rowsBetween(start: int, end: int) -> ParentWindowSpec:
def rowsBetween(start: int, end: int) -> "WindowSpec":
return Window._spec.rowsBetween(start, end)

@staticmethod
def rangeBetween(start: int, end: int) -> ParentWindowSpec:
def rangeBetween(start: int, end: int) -> "WindowSpec":
return Window._spec.rangeBetween(start, end)


Expand Down
36 changes: 36 additions & 0 deletions python/pyspark/sql/tests/test_connect_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from pyspark.sql.readwriter import DataFrameReader as ClassicDataFrameReader
from pyspark.sql.readwriter import DataFrameWriter as ClassicDataFrameWriter
from pyspark.sql.readwriter import DataFrameWriterV2 as ClassicDataFrameWriterV2
from pyspark.sql.window import Window as ClassicWindow
from pyspark.sql.window import WindowSpec as ClassicWindowSpec

if should_test_connect:
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
Expand All @@ -37,6 +39,8 @@
from pyspark.sql.connect.readwriter import DataFrameReader as ConnectDataFrameReader
from pyspark.sql.connect.readwriter import DataFrameWriter as ConnectDataFrameWriter
from pyspark.sql.connect.readwriter import DataFrameWriterV2 as ConnectDataFrameWriterV2
from pyspark.sql.connect.window import Window as ConnectWindow
from pyspark.sql.connect.window import WindowSpec as ConnectWindowSpec


class ConnectCompatibilityTestsMixin:
Expand Down Expand Up @@ -303,6 +307,38 @@ def test_dataframe_writer_v2_compatibility(self):
expected_missing_classic_methods,
)

def test_window_compatibility(self):
"""Test Window compatibility between classic and connect."""
expected_missing_connect_properties = set()
expected_missing_classic_properties = set()
expected_missing_connect_methods = set()
expected_missing_classic_methods = set()
self.check_compatibility(
ClassicWindow,
ConnectWindow,
"Window",
expected_missing_connect_properties,
expected_missing_classic_properties,
expected_missing_connect_methods,
expected_missing_classic_methods,
)

def test_window_spec_compatibility(self):
"""Test WindowSpec compatibility between classic and connect."""
expected_missing_connect_properties = set()
expected_missing_classic_properties = set()
expected_missing_connect_methods = set()
expected_missing_classic_methods = set()
self.check_compatibility(
ClassicWindowSpec,
ConnectWindowSpec,
"WindowSpec",
expected_missing_connect_properties,
expected_missing_classic_properties,
expected_missing_connect_methods,
expected_missing_classic_methods,
)


@unittest.skipIf(not should_test_connect, connect_requirement_message)
class ConnectCompatibilityTests(ConnectCompatibilityTestsMixin, ReusedSQLTestCase):
Expand Down

0 comments on commit 14ed86e

Please sign in to comment.