Skip to content

Commit

Permalink
[SPARK-49530][PYTHON][CONNECT] Support kde/density plots
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Support kde/density plots with plotly backend on both Spark Connect and Spark classic.

### Why are the changes needed?
While Pandas on Spark supports plotting, PySpark currently lacks this feature. The proposed API will enable users to generate visualizations. This will provide users with an intuitive, interactive way to explore and understand large datasets directly from PySpark DataFrames, streamlining the data analysis workflow in distributed environments.

See more at [PySpark Plotting API Specification](https://docs.google.com/document/d/1IjOEzC8zcetG86WDvqkereQPj_NGLNW7Bdu910g30Dg/edit?usp=sharing) in progress.

Part of https://issues.apache.org/jira/browse/SPARK-49530.

### Does this PR introduce _any_ user-facing change?
Yes. kde/density plots are supported as shown below.

```py
>>> data = [
...     (1.0, 4.0),
...     (2.0, 4.0),
...     (2.5, 4.5),
...     (3.0, 5.0),
...     (3.5, 5.5),
...     (4.0, 6.0),
...     (5.0, 6.0)
... ]
>>> columns = ["x", "y"]
>>> df = spark.createDataFrame(data, columns)
>>> fig1 = df.plot.kde(column=["x", "y"], bw_method=0.3, ind=100)
>>> fig1.show()  # see below
>>> fig2 = df.plot(kind="kde", column="x", bw_method=0.3, ind=20)
>>> fig2.show()  # see below
```

fig1:
![newplot (23)](https://github.com/user-attachments/assets/2cb84a78-7d92-43b5-afec-df83e5c55f5c)

fig2:
![newplot (22)](https://github.com/user-attachments/assets/90a70770-8d05-4c81-8f02-4f954c2c689e)

### How was this patch tested?
Unit tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #48492 from xinrong-meng/kde.

Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
xinrong-meng authored and zhengruifeng committed Oct 20, 2024
1 parent 25b03f9 commit d5550f6
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 3 deletions.
30 changes: 30 additions & 0 deletions python/pyspark/sql/pandas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,33 @@ def require_minimum_pyarrow_version() -> None:
errorClass="ARROW_LEGACY_IPC_FORMAT",
messageParameters={},
)


def require_minimum_numpy_version() -> None:
"""Raise ImportError if minimum version of NumPy is not installed"""
minimum_numpy_version = "1.21"

try:
import numpy

have_numpy = True
except ImportError as error:
have_numpy = False
raised_error = error
if not have_numpy:
raise PySparkImportError(
errorClass="PACKAGE_NOT_INSTALLED",
messageParameters={
"package_name": "NumPy",
"minimum_version": str(minimum_numpy_version),
},
) from raised_error
if LooseVersion(numpy.__version__) < LooseVersion(minimum_numpy_version):
raise PySparkImportError(
errorClass="UNSUPPORTED_PACKAGE_VERSION",
messageParameters={
"package_name": "NumPy",
"minimum_version": str(minimum_numpy_version),
"current_version": str(numpy.__version__),
},
)
131 changes: 130 additions & 1 deletion python/pyspark/sql/plot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,26 @@
# limitations under the License.
#

import math

from typing import Any, TYPE_CHECKING, List, Optional, Union
from types import ModuleType
from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError
from pyspark.errors import (
PySparkRuntimeError,
PySparkTypeError,
PySparkValueError,
)
from pyspark.sql import Column, functions as F
from pyspark.sql.types import NumericType
from pyspark.sql.utils import is_remote, require_minimum_plotly_version
from pandas.core.dtypes.inference import is_integer


if TYPE_CHECKING:
from pyspark.sql import DataFrame, Row
from pyspark.sql._typing import ColumnOrName
import pandas as pd
import numpy as np
from plotly.graph_objs import Figure


Expand Down Expand Up @@ -398,6 +406,127 @@ def box(
"""
return self(kind="box", column=column, precision=precision, **kwargs)

def kde(
self,
column: Union[str, List[str]],
bw_method: Union[int, float],
ind: Union["np.ndarray", int, None] = None,
**kwargs: Any,
) -> "Figure":
"""
Generate Kernel Density Estimate plot using Gaussian kernels.
In statistics, kernel density estimation (KDE) is a non-parametric way to
estimate the probability density function (PDF) of a random variable. This
function uses Gaussian kernels and includes automatic bandwidth determination.
Parameters
----------
column: str or list of str
Column name or list of names to be used for creating the kde plot.
bw_method : int or float
The method used to calculate the estimator bandwidth.
See KernelDensity in PySpark for more information.
ind : NumPy array or integer, optional
Evaluation points for the estimated PDF. If None (default),
1000 equally spaced points are used. If `ind` is a NumPy array, the
KDE is evaluated at the points passed. If `ind` is an integer,
`ind` number of equally spaced points are used.
**kwargs : optional
Additional keyword arguments.
Returns
-------
:class:`plotly.graph_objs.Figure`
Examples
--------
>>> data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)]
>>> columns = ["length", "width", "species"]
>>> df = spark.createDataFrame(data, columns)
>>> df.plot.kde(column=["length", "width"], bw_method=0.3) # doctest: +SKIP
>>> df.plot.kde(column="length", bw_method=0.3) # doctest: +SKIP
"""
return self(kind="kde", column=column, bw_method=bw_method, ind=ind, **kwargs)


class PySparkKdePlotBase:
@staticmethod
def get_ind(sdf: "DataFrame", ind: Union["np.ndarray", int, None]) -> "np.ndarray":
from pyspark.sql.pandas.utils import require_minimum_numpy_version

require_minimum_numpy_version()
import numpy as np

def calc_min_max() -> "Row":
if len(sdf.columns) > 1:
min_col = F.least(*map(F.min, sdf)) # type: ignore
max_col = F.greatest(*map(F.max, sdf)) # type: ignore
else:
min_col = F.min(sdf.columns[-1])
max_col = F.max(sdf.columns[-1])
return sdf.select(min_col, max_col).first() # type: ignore

if ind is None:
min_val, max_val = calc_min_max()
sample_range = max_val - min_val
ind = np.linspace(
min_val - 0.5 * sample_range,
max_val + 0.5 * sample_range,
1000,
)
elif is_integer(ind):
min_val, max_val = calc_min_max()
sample_range = max_val - min_val
ind = np.linspace(
min_val - 0.5 * sample_range,
max_val + 0.5 * sample_range,
ind,
)
return ind # type: ignore

@staticmethod
def compute_kde_col(
input_col: Column,
bw_method: Union[int, float],
ind: "np.ndarray",
) -> Column:
# refers to org.apache.spark.mllib.stat.KernelDensity
assert bw_method is not None and isinstance(
bw_method, (int, float)
), "'bw_method' must be set as a scalar number."

assert ind is not None, "'ind' must be a scalar array."

bandwidth = float(bw_method)
points = [float(i) for i in ind]
log_std_plus_half_log2_pi = math.log(bandwidth) + 0.5 * math.log(2 * math.pi)

def norm_pdf(
mean: Column,
std: Column,
log_std_plus_half_log2_pi: Column,
x: Column,
) -> Column:
x0 = x - mean
x1 = x0 / std
log_density = -0.5 * x1 * x1 - log_std_plus_half_log2_pi
return F.exp(log_density)

return F.array(
[
F.avg(
norm_pdf(
input_col.cast("double"),
F.lit(bandwidth),
F.lit(log_std_plus_half_log2_pi),
F.lit(point),
)
)
for point in points
]
)


class PySparkBoxPlotBase:
@staticmethod
Expand Down
47 changes: 46 additions & 1 deletion python/pyspark/sql/plot/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import TYPE_CHECKING, Any

from pyspark.errors import PySparkValueError
from pyspark.sql.plot import PySparkPlotAccessor, PySparkBoxPlotBase
from pyspark.sql.plot import PySparkPlotAccessor, PySparkBoxPlotBase, PySparkKdePlotBase

if TYPE_CHECKING:
from pyspark.sql import DataFrame
Expand All @@ -32,6 +32,8 @@ def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure":
return plot_pie(data, **kwargs)
if kind == "box":
return plot_box(data, **kwargs)
if kind == "kde" or kind == "density":
return plot_kde(data, **kwargs)

return plotly.plot(PySparkPlotAccessor.plot_data_map[kind](data), kind, **kwargs)

Expand Down Expand Up @@ -118,3 +120,46 @@ def plot_box(data: "DataFrame", **kwargs: Any) -> "Figure":

fig["layout"]["yaxis"]["title"] = "value"
return fig


def plot_kde(data: "DataFrame", **kwargs: Any) -> "Figure":
from pyspark.sql.pandas.utils import require_minimum_pandas_version

require_minimum_pandas_version()

import pandas as pd
from plotly import express

if "color" not in kwargs:
kwargs["color"] = "names"

bw_method = kwargs.pop("bw_method", None)
colnames = kwargs.pop("column", None)
if isinstance(colnames, str):
colnames = [colnames]
ind = PySparkKdePlotBase.get_ind(data.select(*colnames), kwargs.pop("ind", None))

kde_cols = [
PySparkKdePlotBase.compute_kde_col(
input_col=data[col_name],
ind=ind,
bw_method=bw_method,
).alias(f"kde_{i}")
for i, col_name in enumerate(colnames)
]
kde_results = data.select(*kde_cols).first()
pdf = pd.concat(
[
pd.DataFrame( # type: ignore
{
"Density": kde_result,
"names": col_name,
"index": ind,
}
)
for col_name, kde_result in zip(colnames, list(kde_results)) # type: ignore[arg-type]
]
)
fig = express.line(pdf, x="index", y="Density", **kwargs)
fig["layout"]["xaxis"]["title"] = None
return fig
34 changes: 33 additions & 1 deletion python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@

import pyspark.sql.plot # noqa: F401
from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_plotly,
have_numpy,
plotly_requirement_message,
numpy_requirement_message,
)


@unittest.skipIf(not have_plotly, plotly_requirement_message)
Expand Down Expand Up @@ -375,6 +381,32 @@ def test_box_plot(self):
},
)

@unittest.skipIf(not have_numpy, numpy_requirement_message)
def test_kde_plot(self):
fig = self.sdf4.plot.kde(column="math_score", bw_method=0.3, ind=5)
expected_fig_data = {
"mode": "lines",
"name": "math_score",
"orientation": "v",
"xaxis": "x",
"yaxis": "y",
"type": "scatter",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)

fig = self.sdf4.plot.kde(column=["math_score", "english_score"], bw_method=0.3, ind=5)
self._check_fig_data(fig["data"][0], **expected_fig_data)
expected_fig_data = {
"mode": "lines",
"name": "english_score",
"orientation": "v",
"xaxis": "x",
"yaxis": "y",
"type": "scatter",
}
self._check_fig_data(fig["data"][1], **expected_fig_data)
self.assertEqual(list(fig["data"][0]["x"]), list(fig["data"][1]["x"]))


class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase):
pass
Expand Down
8 changes: 8 additions & 0 deletions python/pyspark/testing/sqlutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@
plotly_requirement_message = str(e)
have_plotly = plotly_requirement_message is None

numpy_requirement_message = None
try:
import numpy
except ImportError as e:
numpy_requirement_message = str(e)
have_numpy = numpy_requirement_message is None

from pyspark.sql import SparkSession
from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row
from pyspark.testing.utils import ReusedPySparkTestCase, PySparkErrorTestUtils
Expand All @@ -63,6 +70,7 @@
have_pandas = pandas_requirement_message is None
have_pyarrow = pyarrow_requirement_message is None
test_compiled = test_not_compiled_message is None
have_numpy = numpy_requirement_message is None


class UTCOffsetTimezone(datetime.tzinfo):
Expand Down

0 comments on commit d5550f6

Please sign in to comment.