Skip to content

Commit

Permalink
Adjust decimal tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mroeschke committed Feb 20, 2025
1 parent 84d88bf commit b6ea9c7
Showing 1 changed file with 14 additions and 16 deletions.
30 changes: 14 additions & 16 deletions python/cudf/cudf/tests/test_decimal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2024, NVIDIA CORPORATION.
# Copyright (c) 2021-2025, NVIDIA CORPORATION.

import decimal
from decimal import Decimal
Expand Down Expand Up @@ -31,25 +31,25 @@
[None, None, None],
[],
]
typ_ = [
pa.decimal128(precision=4, scale=2),
pa.decimal128(precision=5, scale=3),
pa.decimal128(precision=6, scale=4),
params = [
dict(precision=4, scale=2),
dict(precision=5, scale=3),
dict(precision=6, scale=4),
]


@pytest.mark.parametrize("data_", data_)
@pytest.mark.parametrize("typ_", typ_)
def test_round_trip_decimal64_column(data_, typ_):
pa_arr = pa.array(data_, type=typ_)
@pytest.mark.parametrize("dec_params", params)
def test_round_trip_decimal64_column(data_, dec_params):
pa_arr = pa.array(data_, type=pa.decimal64(**dec_params))
col_64 = Decimal64Column.from_arrow(pa_arr)
assert pa_arr.equals(col_64.to_arrow())


@pytest.mark.parametrize("data_", data_)
@pytest.mark.parametrize("typ_", typ_)
def test_round_trip_decimal32_column(data_, typ_):
pa_arr = pa.array(data_, type=typ_)
@pytest.mark.parametrize("dec_params", params)
def test_round_trip_decimal32_column(data_, dec_params):
pa_arr = pa.array(data_, type=pa.decimal32(**dec_params))
col_32 = Decimal32Column.from_arrow(pa_arr)
assert pa_arr.equals(col_32.to_arrow())

Expand Down Expand Up @@ -104,7 +104,7 @@ def test_typecast_from_float_to_decimal(request, data, from_dtype, to_dtype):
got = data.astype(from_dtype)

pa_arr = got.to_arrow().cast(
pa.decimal128(to_dtype.precision, to_dtype.scale)
pa.decimal64(to_dtype.precision, to_dtype.scale)
)
expected = cudf.Series._from_column(Decimal64Column.from_arrow(pa_arr))

Expand Down Expand Up @@ -144,7 +144,7 @@ def test_typecast_from_int_to_decimal(data, from_dtype, to_dtype):
pa_arr = (
got.to_arrow()
.cast("float64")
.cast(pa.decimal128(to_dtype.precision, to_dtype.scale))
.cast(pa.decimal64(to_dtype.precision, to_dtype.scale))
)
expected = cudf.Series._from_column(Decimal64Column.from_arrow(pa_arr))

Expand Down Expand Up @@ -202,9 +202,7 @@ def test_typecast_to_from_decimal(data, from_dtype, to_dtype):
)
s = data.astype(from_dtype)

pa_arr = s.to_arrow().cast(
pa.decimal128(to_dtype.precision, to_dtype.scale), safe=False
)
pa_arr = s.to_arrow().cast(to_dtype.to_arrow(), safe=False)
if isinstance(to_dtype, Decimal32Dtype):
expected = cudf.Series._from_column(Decimal32Column.from_arrow(pa_arr))
elif isinstance(to_dtype, Decimal64Dtype):
Expand Down

0 comments on commit b6ea9c7

Please sign in to comment.