Skip to content

Commit d35f779

Browse files
Lordwormsfindepi
authored andcommitted
Minor: Add more support for ScalarValue::Float16 (apache#11156)
1 parent 9fbbc36 commit d35f779

File tree

1 file changed

+35
-1
lines changed
  • datafusion/common/src/scalar

1 file changed

+35
-1
lines changed

datafusion/common/src/scalar/mod.rs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -982,6 +982,7 @@ impl ScalarValue {
982982
DataType::UInt16 => ScalarValue::UInt16(Some(0)),
983983
DataType::UInt32 => ScalarValue::UInt32(Some(0)),
984984
DataType::UInt64 => ScalarValue::UInt64(Some(0)),
985+
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(0.0))),
985986
DataType::Float32 => ScalarValue::Float32(Some(0.0)),
986987
DataType::Float64 => ScalarValue::Float64(Some(0.0)),
987988
DataType::Timestamp(TimeUnit::Second, tz) => {
@@ -1035,6 +1036,7 @@ impl ScalarValue {
10351036
DataType::UInt16 => ScalarValue::UInt16(Some(1)),
10361037
DataType::UInt32 => ScalarValue::UInt32(Some(1)),
10371038
DataType::UInt64 => ScalarValue::UInt64(Some(1)),
1039+
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(1.0))),
10381040
DataType::Float32 => ScalarValue::Float32(Some(1.0)),
10391041
DataType::Float64 => ScalarValue::Float64(Some(1.0)),
10401042
_ => {
@@ -1053,6 +1055,7 @@ impl ScalarValue {
10531055
DataType::Int16 | DataType::UInt16 => ScalarValue::Int16(Some(-1)),
10541056
DataType::Int32 | DataType::UInt32 => ScalarValue::Int32(Some(-1)),
10551057
DataType::Int64 | DataType::UInt64 => ScalarValue::Int64(Some(-1)),
1058+
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(-1.0))),
10561059
DataType::Float32 => ScalarValue::Float32(Some(-1.0)),
10571060
DataType::Float64 => ScalarValue::Float64(Some(-1.0)),
10581061
_ => {
@@ -1074,6 +1077,7 @@ impl ScalarValue {
10741077
DataType::UInt16 => ScalarValue::UInt16(Some(10)),
10751078
DataType::UInt32 => ScalarValue::UInt32(Some(10)),
10761079
DataType::UInt64 => ScalarValue::UInt64(Some(10)),
1080+
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(10.0))),
10771081
DataType::Float32 => ScalarValue::Float32(Some(10.0)),
10781082
DataType::Float64 => ScalarValue::Float64(Some(10.0)),
10791083
_ => {
@@ -1181,8 +1185,12 @@ impl ScalarValue {
11811185
| ScalarValue::Int16(None)
11821186
| ScalarValue::Int32(None)
11831187
| ScalarValue::Int64(None)
1188+
| ScalarValue::Float16(None)
11841189
| ScalarValue::Float32(None)
11851190
| ScalarValue::Float64(None) => Ok(self.clone()),
1191+
ScalarValue::Float16(Some(v)) => {
1192+
Ok(ScalarValue::Float16(Some(f16::from_f32(-v.to_f32()))))
1193+
}
11861194
ScalarValue::Float64(Some(v)) => Ok(ScalarValue::Float64(Some(-v))),
11871195
ScalarValue::Float32(Some(v)) => Ok(ScalarValue::Float32(Some(-v))),
11881196
ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(v.neg_checked()?))),
@@ -1435,6 +1443,9 @@ impl ScalarValue {
14351443
(Self::UInt32(Some(l)), Self::UInt32(Some(r))) => Some(l.abs_diff(*r) as _),
14361444
(Self::UInt64(Some(l)), Self::UInt64(Some(r))) => Some(l.abs_diff(*r) as _),
14371445
// TODO: we might want to look into supporting ceil/floor here for floats.
1446+
(Self::Float16(Some(l)), Self::Float16(Some(r))) => {
1447+
Some((f16::to_f32(*l) - f16::to_f32(*r)).abs().round() as _)
1448+
}
14381449
(Self::Float32(Some(l)), Self::Float32(Some(r))) => {
14391450
Some((l - r).abs().round() as _)
14401451
}
@@ -2452,6 +2463,7 @@ impl ScalarValue {
24522463
DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean)?,
24532464
DataType::Float64 => typed_cast!(array, index, Float64Array, Float64)?,
24542465
DataType::Float32 => typed_cast!(array, index, Float32Array, Float32)?,
2466+
DataType::Float16 => typed_cast!(array, index, Float16Array, Float16)?,
24552467
DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64)?,
24562468
DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32)?,
24572469
DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16)?,
@@ -5635,7 +5647,6 @@ mod tests {
56355647
}
56365648

56375649
#[test]
5638-
#[should_panic(expected = "Can not run arithmetic negative on scalar value Float16")]
56395650
fn f16_test_overflow() {
56405651
// TODO: if negate supports f16, add these cases to `test_scalar_negative_overflows` test case
56415652
let cases = [
@@ -5805,6 +5816,21 @@ mod tests {
58055816
ScalarValue::UInt64(Some(10)),
58065817
5,
58075818
),
5819+
(
5820+
ScalarValue::Float16(Some(f16::from_f32(1.1))),
5821+
ScalarValue::Float16(Some(f16::from_f32(1.9))),
5822+
1,
5823+
),
5824+
(
5825+
ScalarValue::Float16(Some(f16::from_f32(-5.3))),
5826+
ScalarValue::Float16(Some(f16::from_f32(-9.2))),
5827+
4,
5828+
),
5829+
(
5830+
ScalarValue::Float16(Some(f16::from_f32(-5.3))),
5831+
ScalarValue::Float16(Some(f16::from_f32(-9.7))),
5832+
4,
5833+
),
58085834
(
58095835
ScalarValue::Float32(Some(1.0)),
58105836
ScalarValue::Float32(Some(2.0)),
@@ -5877,6 +5903,14 @@ mod tests {
58775903
// Different type
58785904
(ScalarValue::Int8(Some(1)), ScalarValue::Int16(Some(1))),
58795905
(ScalarValue::Int8(Some(1)), ScalarValue::Float32(Some(1.0))),
5906+
(
5907+
ScalarValue::Float16(Some(f16::from_f32(1.0))),
5908+
ScalarValue::Float32(Some(1.0)),
5909+
),
5910+
(
5911+
ScalarValue::Float16(Some(f16::from_f32(1.0))),
5912+
ScalarValue::Int32(Some(1)),
5913+
),
58805914
(
58815915
ScalarValue::Float64(Some(1.1)),
58825916
ScalarValue::Float32(Some(2.2)),

0 commit comments

Comments
 (0)