Skip to content

Commit 739688e

Browse files
GoEddieimback82
authored andcommitted
Support for Bucketizer (#378)
1 parent 3b72c48 commit 739688e

File tree

7 files changed

+406
-0
lines changed

7 files changed

+406
-0
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Security.Cryptography;
8+
using Microsoft.Spark.ML.Feature;
9+
using Microsoft.Spark.Sql;
10+
using Xunit;
11+
12+
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
13+
{
14+
[Collection("Spark E2E Tests")]
15+
public class BucketizerTests
16+
{
17+
private readonly SparkSession _spark;
18+
19+
public BucketizerTests(SparkFixture fixture)
20+
{
21+
_spark = fixture.Spark;
22+
}
23+
24+
[Fact]
25+
public void TestBucketizer()
26+
{
27+
var expectedSplits = new double[] {Double.MinValue, 0.0, 10.0, 50.0, Double.MaxValue};
28+
29+
var expectedHandle = "skip";
30+
var expectedUid = "uid";
31+
var expectedInputCol = "input_col";
32+
var expectedOutputCol = "output_col";
33+
34+
var bucketizer = new Bucketizer(expectedUid)
35+
.SetInputCol(expectedInputCol)
36+
.SetOutputCol(expectedOutputCol)
37+
.SetHandleInvalid(expectedHandle)
38+
.SetSplits(expectedSplits);
39+
40+
Assert.Equal(expectedHandle, bucketizer.GetHandleInvalid());
41+
42+
Assert.Equal(expectedUid, bucketizer.Uid());
43+
44+
DataFrame input = _spark.Sql("SELECT ID as input_col from range(100)");
45+
46+
DataFrame output = bucketizer.Transform(input);
47+
Assert.Contains(output.Schema().Fields, (f => f.Name == expectedOutputCol));
48+
49+
Assert.Equal(expectedInputCol, bucketizer.GetInputCol());
50+
Assert.Equal(expectedOutputCol, bucketizer.GetOutputCol());
51+
Assert.Equal(expectedSplits, bucketizer.GetSplits());
52+
}
53+
54+
[Fact]
55+
public void TestBucketizer_MultipleColumns()
56+
{
57+
double[][] expectedSplitsArray = new[]
58+
{
59+
new[] {Double.MinValue, 0.0, 10.0, 50.0, Double.MaxValue},
60+
new[] {Double.MinValue, 0.0, 10000.0, Double.MaxValue}
61+
};
62+
63+
var expectedHandle = "keep";
64+
65+
var expectedInputCols = new List<string>() {"input_col_a", "input_col_b"};
66+
var expectedOutputCols = new List<string>() {"output_col_a", "output_col_b"};
67+
68+
var bucketizer = new Bucketizer()
69+
.SetInputCols(expectedInputCols)
70+
.SetOutputCols(expectedOutputCols)
71+
.SetHandleInvalid(expectedHandle)
72+
.SetSplitsArray(expectedSplitsArray);
73+
74+
Assert.Equal(expectedHandle, bucketizer.GetHandleInvalid());
75+
76+
DataFrame input =
77+
_spark.Sql("SELECT ID as input_col_a, ID as input_col_b from range(100)");
78+
79+
DataFrame output = bucketizer.Transform(input);
80+
Assert.Contains(output.Schema().Fields, (f => f.Name == "output_col_a"));
81+
Assert.Contains(output.Schema().Fields, (f => f.Name == "output_col_b"));
82+
83+
Assert.Equal(expectedInputCols, bucketizer.GetInputCols());
84+
Assert.Equal(expectedOutputCols, bucketizer.GetOutputCols());
85+
Assert.Equal(expectedSplitsArray, bucketizer.GetSplitsArray());
86+
}
87+
}
88+
}

src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,14 @@ private object ReadCollection(Stream s)
365365
}
366366
returnValue = doubleArray;
367367
break;
368+
case 'A':
369+
var doubleArrayArray = new double[numOfItemsInList][];
370+
for (int itemIndex = 0; itemIndex < numOfItemsInList; ++itemIndex)
371+
{
372+
doubleArrayArray[itemIndex] = ReadCollection(s) as double[];
373+
}
374+
returnValue = doubleArrayArray;
375+
break;
368376
case 'b':
369377
var boolArray = new bool[numOfItemsInList];
370378
for (int itemIndex = 0; itemIndex < numOfItemsInList; ++itemIndex)

src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ internal class PayloadHelper
2424
private static readonly byte[] s_doubleTypeId = new[] { (byte)'d' };
2525
private static readonly byte[] s_jvmObjectTypeId = new[] { (byte)'j' };
2626
private static readonly byte[] s_byteArrayTypeId = new[] { (byte)'r' };
27+
private static readonly byte[] s_doubleArrayArrayTypeId = new[] { ( byte)'A' };
2728
private static readonly byte[] s_arrayTypeId = new[] { (byte)'l' };
2829
private static readonly byte[] s_dictionaryTypeId = new[] { (byte)'e' };
2930
private static readonly byte[] s_rowArrTypeId = new[] { (byte)'R' };
@@ -135,6 +136,19 @@ internal static void ConvertArgsToBytes(
135136
SerDe.Write(destination, d);
136137
}
137138
break;
139+
140+
case double[][] argDoubleArrayArray:
141+
SerDe.Write(destination, s_doubleArrayArrayTypeId);
142+
SerDe.Write(destination, argDoubleArrayArray.Length);
143+
foreach (double[] doubleArray in argDoubleArrayArray)
144+
{
145+
SerDe.Write(destination, doubleArray.Length);
146+
foreach (double d in doubleArray)
147+
{
148+
SerDe.Write(destination, d);
149+
}
150+
}
151+
break;
138152

139153
case IEnumerable<byte[]> argByteArrayEnumerable:
140154
SerDe.Write(destination, s_byteArrayTypeId);
@@ -286,6 +300,7 @@ internal static byte[] GetTypeId(Type type)
286300
if (type == typeof(int[]) ||
287301
type == typeof(long[]) ||
288302
type == typeof(double[]) ||
303+
type == typeof(double[][]) ||
289304
typeof(IEnumerable<byte[]>).IsAssignableFrom(type) ||
290305
typeof(IEnumerable<string>).IsAssignableFrom(type))
291306
{
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Linq;
8+
using Microsoft.Spark.Interop;
9+
using Microsoft.Spark.Interop.Ipc;
10+
using Microsoft.Spark.Sql;
11+
using Microsoft.Spark.Sql.Types;
12+
13+
namespace Microsoft.Spark.ML.Feature
14+
{
15+
/// <summary>
16+
/// <see cref="Bucketizer"/> maps a column of continuous features to a column of feature
17+
/// buckets.
18+
///
19+
/// <see cref="Bucketizer"/> can map multiple columns at once by setting the inputCols
20+
/// parameter. Note that when both the inputCol and inputCols parameters are set, an Exception
21+
/// will be thrown. The splits parameter is only used for single column usage, and splitsArray
22+
/// is for multiple columns.
23+
/// </summary>
24+
public class Bucketizer : IJvmObjectReferenceProvider
25+
{
26+
internal Bucketizer(JvmObjectReference jvmObject)
27+
{
28+
_jvmObject = jvmObject;
29+
}
30+
31+
/// <summary>
32+
/// Create a <see cref="Bucketizer"/> without any parameters
33+
/// </summary>
34+
public Bucketizer()
35+
{
36+
_jvmObject = SparkEnvironment.JvmBridge.CallConstructor(
37+
"org.apache.spark.ml.feature.Bucketizer");
38+
}
39+
40+
/// <summary>
41+
/// Create a <see cref="Bucketizer"/> with a UID that is used to give the
42+
/// <see cref="Bucketizer"/> a unique ID
43+
/// </summary>
44+
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param>
45+
public Bucketizer(string uid)
46+
{
47+
_jvmObject = SparkEnvironment.JvmBridge.CallConstructor(
48+
"org.apache.spark.ml.feature.Bucketizer", uid);
49+
}
50+
51+
private readonly JvmObjectReference _jvmObject;
52+
JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject;
53+
54+
/// <summary>
55+
/// Gets the splits that were set using SetSplits
56+
/// </summary>
57+
/// <returns>double[], the splits to be used to bucket the input column</returns>
58+
public double[] GetSplits()
59+
{
60+
return (double[])_jvmObject.Invoke("getSplits");
61+
}
62+
63+
/// <summary>
64+
/// Split points for splitting a single column into buckets. To split multiple columns use
65+
/// SetSplitsArray. You cannot use both SetSplits and SetSplitsArray at the same time
66+
/// </summary>
67+
/// <param name="value">
68+
/// Split points for mapping continuous features into buckets. With n+1 splits, there are n
69+
/// buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last
70+
/// bucket, which also includes y. The splits should be of length &gt;= 3 and strictly
71+
/// increasing. Values outside the splits specified will be treated as errors.
72+
/// </param>
73+
/// <returns><see cref="Bucketizer"/></returns>
74+
public Bucketizer SetSplits(double[] value)
75+
{
76+
return WrapAsBucketizer(_jvmObject.Invoke("setSplits", value));
77+
}
78+
79+
/// <summary>
80+
/// Gets the splits that were set by SetSplitsArray
81+
/// </summary>
82+
/// <returns>double[][], the splits to be used to bucket the input columns</returns>
83+
public double[][] GetSplitsArray()
84+
{
85+
return (double[][])_jvmObject.Invoke("getSplitsArray");
86+
}
87+
88+
/// <summary>
89+
/// Split points fot splitting multiple columns into buckets. To split a single column use
90+
/// SetSplits. You cannot use both SetSplits and SetSplitsArray at the same time.
91+
/// </summary>
92+
/// <param name="value">
93+
/// The array of split points for mapping continuous features into buckets for multiple
94+
/// columns. For each input column, with n+1 splits, there are n buckets. A bucket defined
95+
/// by splits x,y holds values in the range [x,y) except the last bucket, which also
96+
/// includes y. The splits should be of length &gt;= 3 and strictly increasing.
97+
/// Values outside the splits specified will be treated as errors.</param>
98+
/// <returns><see cref="Bucketizer"/></returns>
99+
public Bucketizer SetSplitsArray(double[][] value)
100+
{
101+
return WrapAsBucketizer(_jvmObject.Invoke("setSplitsArray", (object)value));
102+
}
103+
104+
/// <summary>
105+
/// Gets the column that the <see cref="Bucketizer"/> should read from and convert into
106+
/// buckets. This would have been set by SetInputCol
107+
/// </summary>
108+
/// <returns>string, the input column</returns>
109+
public string GetInputCol()
110+
{
111+
return (string)_jvmObject.Invoke("getInputCol");
112+
}
113+
114+
/// <summary>
115+
/// Sets the column that the <see cref="Bucketizer"/> should read from and convert into
116+
/// buckets
117+
/// </summary>
118+
/// <param name="value">The name of the column to as the source of the buckets</param>
119+
/// <returns><see cref="Bucketizer"/></returns>
120+
public Bucketizer SetInputCol(string value)
121+
{
122+
return WrapAsBucketizer(_jvmObject.Invoke("setInputCol", value));
123+
}
124+
125+
/// <summary>
126+
/// Gets the columns that <see cref="Bucketizer"/> should read from and convert into
127+
/// buckets. This is set by SetInputCol
128+
/// </summary>
129+
/// <returns>IEnumerable&lt;string&gt;, list of input columns</returns>
130+
public IEnumerable<string> GetInputCols()
131+
{
132+
return ((string[])(_jvmObject.Invoke("getInputCols"))).ToList();
133+
}
134+
135+
/// <summary>
136+
/// Sets the columns that <see cref="Bucketizer"/> should read from and convert into
137+
/// buckets.
138+
///
139+
/// Each column is one set of buckets so if you have two input columns you can have two
140+
/// sets of buckets and two output columns.
141+
/// </summary>
142+
/// <param name="value">List of input columns to use as sources for buckets</param>
143+
/// <returns><see cref="Bucketizer"/></returns>
144+
public Bucketizer SetInputCols(IEnumerable<string> value)
145+
{
146+
return WrapAsBucketizer(_jvmObject.Invoke("setInputCols", value));
147+
}
148+
149+
/// <summary>
150+
/// Gets the name of the column the output data will be written to. This is set by
151+
/// SetInputCol
152+
/// </summary>
153+
// <returns>string, the output column</returns>
154+
public string GetOutputCol()
155+
{
156+
return (string)_jvmObject.Invoke("getOutputCol");
157+
}
158+
159+
/// <summary>
160+
/// The <see cref="Bucketizer"/> will create a new column in the DataFrame, this is the
161+
/// name of the new column.
162+
/// </summary>
163+
/// <param name="value">The name of the new column which contains the bucket ID</param>
164+
/// <returns><see cref="Bucketizer"/></returns>
165+
public Bucketizer SetOutputCol(string value)
166+
{
167+
return WrapAsBucketizer(_jvmObject.Invoke("setOutputCol", value));
168+
}
169+
170+
/// <summary>
171+
/// The list of columns that the <see cref="Bucketizer"/> will create in the DataFrame.
172+
/// This is set by SetOutputCols
173+
/// </summary>
174+
/// <returns>IEnumerable&lt;string&gt;, list of output columns</returns>
175+
public IEnumerable<string> GetOutputCols()
176+
{
177+
return ((string[])_jvmObject.Invoke("getOutputCols")).ToList();
178+
}
179+
180+
/// <summary>
181+
/// The list of columns that the <see cref="Bucketizer"/> will create in the DataFrame.
182+
/// </summary>
183+
/// <param name="value">List of column names which will contain the bucket ID</param>
184+
/// <returns><see cref="Bucketizer"/></returns>
185+
public Bucketizer SetOutputCols(List<string> value)
186+
{
187+
return WrapAsBucketizer(_jvmObject.Invoke("setOutputCols", value));
188+
}
189+
190+
/// <summary>
191+
/// Executes the <see cref="Bucketizer"/> and transforms the DataFrame to include the new
192+
/// column or columns with the bucketed data.
193+
/// </summary>
194+
/// <param name="source">The DataFrame to add the bucketed data to</param>
195+
/// <returns><see cref="DataFrame"/> containing the original data and the new bucketed
196+
/// columns</returns>
197+
public DataFrame Transform(DataFrame source)
198+
{
199+
return new DataFrame((JvmObjectReference)_jvmObject.Invoke("transform", source));
200+
}
201+
202+
/// <summary>
203+
/// The reference we get back from each call isn't usable unless we wrap it in a new dotnet
204+
/// <see cref="Bucketizer"/>
205+
/// </summary>
206+
/// <param name="obj">The <see cref="JvmObjectReference"/> to convert into a dotnet
207+
/// <see cref="Bucketizer"/></param>
208+
/// <returns><see cref="Bucketizer"/></returns>
209+
private static Bucketizer WrapAsBucketizer(object obj)
210+
{
211+
return new Bucketizer((JvmObjectReference)obj);
212+
}
213+
214+
/// <summary>
215+
/// The uid that was used to create the <see cref="Bucketizer"/>. If no UID is passed in
216+
/// when creating the <see cref="Bucketizer"/> then a random UID is created when the
217+
/// <see cref="Bucketizer"/> is created.
218+
/// </summary>
219+
/// <returns>string UID identifying the <see cref="Bucketizer"/></returns>
220+
public string Uid()
221+
{
222+
return (string)_jvmObject.Invoke("uid");
223+
}
224+
225+
/// <summary>
226+
/// How should the <see cref="Bucketizer"/> handle invalid data, choices are "skip",
227+
/// "error" or "keep"
228+
/// </summary>
229+
/// <returns>string showing the way Spark will handle invalid data</returns>
230+
public string GetHandleInvalid()
231+
{
232+
return (string)_jvmObject.Invoke("getHandleInvalid");
233+
}
234+
235+
/// <summary>
236+
/// Tells the <see cref="Bucketizer"/> what to do with invalid data.
237+
///
238+
/// Choices are "skip", "error" or "keep". Default is "error"
239+
/// </summary>
240+
/// <param name="value">"skip", "error" or "keep"</param>
241+
/// <returns><see cref="Bucketizer"/></returns>
242+
public Bucketizer SetHandleInvalid(string value)
243+
{
244+
return WrapAsBucketizer(_jvmObject.Invoke("setHandleInvalid", value.ToString()));
245+
}
246+
}
247+
}

0 commit comments

Comments
 (0)