Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support casting to and from spark-like structs #1991

Merged
merged 11 commits into from
Feb 16, 2025
65 changes: 44 additions & 21 deletions narwhals/_spark_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ def native_to_narwhals_dtype(
if isinstance(dtype, spark_types.ByteType):
return dtypes.Int8()
if isinstance(
dtype,
(spark_types.StringType, spark_types.VarcharType, spark_types.CharType),
dtype, (spark_types.StringType, spark_types.VarcharType, spark_types.CharType)
):
return dtypes.String()
if isinstance(dtype, spark_types.BooleanType):
Expand All @@ -70,15 +69,26 @@ def native_to_narwhals_dtype(
return dtypes.Datetime()
if isinstance(dtype, spark_types.TimestampType):
return dtypes.Datetime(time_zone="UTC")
if isinstance(dtype, spark_types.DecimalType): # pragma: no cover
# TODO(unassigned): cover this in dtypes_test.py
if isinstance(dtype, spark_types.DecimalType):
return dtypes.Decimal()
if isinstance(dtype, spark_types.ArrayType): # pragma: no cover
if isinstance(dtype, spark_types.ArrayType):
return dtypes.List(
inner=native_to_narwhals_dtype(
dtype.elementType, version=version, spark_types=spark_types
)
)
if isinstance(dtype, spark_types.StructType):
return dtypes.Struct(
fields=[
dtypes.Field(
name=name,
dtype=native_to_narwhals_dtype(
dtype[name], version=version, spark_types=spark_types
),
)
for name in dtype.fieldNames()
]
)
return dtypes.Unknown()


Expand Down Expand Up @@ -113,28 +123,41 @@ def narwhals_to_native_dtype(
msg = f"Only UTC time zone is supported for PySpark, got: {dt_time_zone}"
raise ValueError(msg)
return spark_types.TimestampType()
if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover
inner = narwhals_to_native_dtype(
dtype.inner, # type: ignore[union-attr]
version=version,
spark_types=spark_types,
if isinstance_or_issubclass(dtype, (dtypes.List, dtypes.Array)):
return spark_types.ArrayType(
elementType=narwhals_to_native_dtype(
dtype.inner, # type: ignore[union-attr]
version=version,
spark_types=spark_types,
)
Comment on lines +126 to +132
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The # type: ignore here is an example of this issue (#1807 (comment))

Off-topic-ish, but should I spin that out into a new issue?

I think it might get lost in that PR

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @dangotbanned - I'd say let's keep track in a dedicated issue, as that's not even introduced in this specific PR

)
return spark_types.ArrayType(elementType=inner)
if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover
msg = "Converting to Struct dtype is not supported yet"
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
inner = narwhals_to_native_dtype(
dtype.inner, # type: ignore[union-attr]
version=version,
spark_types=spark_types,
return spark_types.StructType(
fields=[
spark_types.StructField(
name=field.name,
dataType=narwhals_to_native_dtype(
field.dtype,
version=version,
spark_types=spark_types,
),
)
for field in dtype.fields # type: ignore[union-attr]
]
)
return spark_types.ArrayType(elementType=inner)

if isinstance_or_issubclass(
dtype, (dtypes.UInt64, dtypes.UInt32, dtypes.UInt16, dtypes.UInt8)
dtype,
(
dtypes.UInt64,
dtypes.UInt32,
dtypes.UInt16,
dtypes.UInt8,
dtypes.Enum,
dtypes.Categorical,
),
): # pragma: no cover
msg = "Unsigned integer types are not supported by PySpark"
msg = "Unsigned integer, Enum and Categorical types are not supported by spark-like backend"
raise UnsupportedDTypeError(msg)

msg = f"Unknown dtype: {dtype}" # pragma: no cover
Expand Down
24 changes: 18 additions & 6 deletions tests/expr_and_series/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,7 @@ def test_cast_datetime_tz_aware(


def test_cast_struct(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if any(
backend in str(constructor) for backend in ("dask", "modin", "cudf", "pyspark")
):
if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")):
request.applymarker(pytest.mark.xfail)

if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2):
Expand All @@ -251,10 +249,24 @@ def test_cast_struct(request: pytest.FixtureRequest, constructor: Constructor) -
]
}

native_df = constructor(data)

if "spark" in str(constructor): # pragma: no cover
# Special handling for pyspark as it natively maps the input to
# a column of type MAP<STRING, STRING>
import pyspark.sql.functions as F # noqa: N812
import pyspark.sql.types as T # noqa: N812

native_df = native_df.withColumn( # type: ignore[union-attr]
"a",
F.struct(
F.col("a.movie ").alias("movie ").cast(T.StringType()),
F.col("a.rating").alias("rating").cast(T.DoubleType()),
),
)

dtype = nw.Struct([nw.Field("movie ", nw.String()), nw.Field("rating", nw.Float64())])
result = (
nw.from_native(constructor(data)).select(nw.col("a").cast(dtype)).lazy().collect()
)
result = nw.from_native(native_df).select(nw.col("a").cast(dtype)).lazy().collect()

assert result.schema == {"a": dtype}

Expand Down
Loading