Skip to content

Commit

Permalink
make unit test working
Browse files Browse the repository at this point in the history
  • Loading branch information
jadewang-db committed Mar 14, 2024
1 parent d0e8158 commit 4313cc2
Showing 1 changed file with 171 additions and 10 deletions.
181 changes: 171 additions & 10 deletions csharp/src/Drivers/Apache/Spark/SparkConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations;
using System.Linq;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Text;
Expand All @@ -36,7 +37,20 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
{
public class SparkConnection : HiveServer2Connection
{
const string userAgent = "AdbcExperimental/0.0";
const string userAgent = "MicrosoftSparkODBCDriver/2.7.6.1014";

readonly IReadOnlyList<AdbcInfoCode> infoSupportedCodes = new List<AdbcInfoCode> {
AdbcInfoCode.DriverName,
AdbcInfoCode.DriverVersion,
AdbcInfoCode.DriverArrowVersion,
AdbcInfoCode.VendorName
};

const string infoDriverName = "ADBC Spark Driver";
const string infoDriverVersion = "1.0.0";
const string infoVendorName = "Spark";
const string infoDriverArrowVersion = "1.0.0";

internal static TSparkGetDirectResults sparkGetDirectResults = new TSparkGetDirectResults(1000);

internal static readonly Dictionary<string, string> timestampConfig = new Dictionary<string, string>
Expand Down Expand Up @@ -72,7 +86,7 @@ protected override TProtocol CreateProtocol()

TConfiguration config = new TConfiguration();

THttpTransport transport = new THttpTransport(httpClient, config);
ThriftHttpTransport transport = new ThriftHttpTransport(httpClient, config);
// can switch to the one below if want to use the experimental one with IPeekableTransport
// ThriftHttpTransport transport = new ThriftHttpTransport(httpClient, config);
transport.OpenAsync(CancellationToken.None).Wait();
Expand Down Expand Up @@ -102,19 +116,150 @@ public override void Dispose()

this.transport.Close();
this.client.Dispose();

this.transport = null;
this.client = null;
}
}

public override IArrowArrayStream GetInfo(List<AdbcInfoCode> codes)
{
const int strValTypeID = 0;

UnionType infoUnionType = new UnionType(
new List<Field>()
{
new Field("string_value", StringType.Default, true),
new Field("bool_value", BooleanType.Default, true),
new Field("int64_value", Int64Type.Default, true),
new Field("int32_bitmask", Int32Type.Default, true),
new Field(
"string_list",
new ListType(
new Field("item", StringType.Default, true)
),
false
),
new Field(
"int32_to_int32_list_map",
new ListType(
new Field("entries", new StructType(
new List<Field>()
{
new Field("key", Int32Type.Default, false),
new Field("value", Int32Type.Default, true),
}
), false)
),
true
)
},
new int[] { 0, 1, 2, 3, 4, 5 },
UnionMode.Dense);

if (codes.Count == 0)
{
codes = new List<AdbcInfoCode>(infoSupportedCodes);
}

UInt32Array.Builder infoNameBuilder = new UInt32Array.Builder();
ArrowBuffer.Builder<byte> typeBuilder = new ArrowBuffer.Builder<byte>();
ArrowBuffer.Builder<int> offsetBuilder = new ArrowBuffer.Builder<int>();
StringArray.Builder stringInfoBuilder = new StringArray.Builder();
int nullCount = 0;
int arrayLength = codes.Count;

foreach (AdbcInfoCode code in codes)
{
switch (code)
{
case AdbcInfoCode.DriverName:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
stringInfoBuilder.Append(infoDriverName);
break;
case AdbcInfoCode.DriverVersion:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
stringInfoBuilder.Append(infoDriverVersion);
break;
case AdbcInfoCode.DriverArrowVersion:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
stringInfoBuilder.Append(infoDriverArrowVersion);
break;
case AdbcInfoCode.VendorName:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
stringInfoBuilder.Append(infoVendorName);
break;
default:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
stringInfoBuilder.AppendNull();
nullCount++;
break;
}
}

StructType entryType = new StructType(
new List<Field>(){
new Field("key", Int32Type.Default, false),
new Field("value", Int32Type.Default, true)});

StructArray entriesDataArray = new StructArray(entryType, 0,
new[] { new Int32Array.Builder().Build(), new Int32Array.Builder().Build() },
new ArrowBuffer.BitmapBuilder().Build());

List<IArrowArray> childrenArrays = new List<IArrowArray>()
{
stringInfoBuilder.Build(),
new BooleanArray.Builder().Build(),
new Int64Array.Builder().Build(),
new Int32Array.Builder().Build(),
new ListArray.Builder(StringType.Default).Build(),
CreateNestedListArray(new List<IArrowArray?>(){ entriesDataArray }, entryType)
};

DenseUnionArray infoValue = new DenseUnionArray(infoUnionType, arrayLength, childrenArrays, typeBuilder.Build(), offsetBuilder.Build(), nullCount);

List<IArrowArray> dataArrays = new List<IArrowArray>
{
infoNameBuilder.Build(),
infoValue
};

return new SparkInfoArrowStream(StandardSchemas.GetInfoSchema, dataArrays);

}

public override IArrowArrayStream GetInfo(List<int> codes) => base.GetInfo(codes);

public override IArrowArrayStream GetTableTypes()
{
StringArray.Builder tableTypesBuilder = new StringArray.Builder();
tableTypesBuilder.AppendRange(new string[] { "BASE TABLE", "VIEW" });

List<IArrowArray> dataArrays = new List<IArrowArray>
{
tableTypesBuilder.Build()
};

return new SparkInfoArrowStream(StandardSchemas.TableTypesSchema, dataArrays);
}

public override Schema GetTableSchema(string catalog, string dbSchema, string tableName)
{
TGetColumnsReq getColumnsReq = new TGetColumnsReq(this.sessionHandle);
getColumnsReq.CatalogName = catalog;
getColumnsReq.SchemaName = dbSchema;
getColumnsReq.TableName = tableName;
getColumnsReq.GetDirectResults = new TSparkGetDirectResults();
getColumnsReq.GetDirectResults = sparkGetDirectResults;

var columnsResponse = this.client.GetColumns(getColumnsReq).Result;
if (columnsResponse.Status.StatusCode == TStatusCode.ERROR_STATUS)
Expand All @@ -125,17 +270,32 @@ public override Schema GetTableSchema(string catalog, string dbSchema, string ta
var result = columnsResponse.DirectResults;
var resultSchema = result.ResultSetMetadata.ArrowSchema;
var columns = result.ResultSet.Results.Columns;
var rowCount = columns[4].StringVal.Values.Length;
var rowCount = columns[3].StringVal.Values.Length;

Field[] fields = new Field[rowCount];
for (int i = 0; i < rowCount; i++)
{
fields[i] = new Field(columns[4].StringVal.Values.GetString(i),
SchemaParser.GetArrowType((TTypeId)columns[5].I32Val.Values.GetValue(i)),
fields[i] = new Field(columns[3].StringVal.Values.GetString(i),
SchemaParser.GetArrowType((TTypeId)columns[4].I32Val.Values.GetValue(i)),
nullable: true /* ??? */);
}
return new Schema(fields, null);
}
private static IReadOnlyList<int> ConvertSpanToReadOnlyList(Int32Array span)
{
// Initialize a list with the capacity equal to the length of the span
// to avoid resizing during the addition of elements
List<int> list = new List<int>(span.Length);

// Copy elements from the span to the list
foreach (int item in span)
{
list.Add(item);
}

// Return the list as IReadOnlyList<int>
return list;
}

public override IArrowArrayStream GetObjects(GetObjectsDepth depth, string catalogPattern, string dbSchemaPattern, string tableNamePattern, List<string> tableTypes, string columnNamePattern)
{
Expand Down Expand Up @@ -180,8 +340,8 @@ public override IArrowArrayStream GetObjects(GetObjectsDepth depth, string catal
}
TRowSet resp = getSchemasResp.DirectResults.ResultSet.Results;

IReadOnlyList<string> catalogList = resp.Columns[0].StringVal.Values;
IReadOnlyList<string> schemaList = resp.Columns[1].StringVal.Values;
IReadOnlyList<string> catalogList = resp.Columns[1].StringVal.Values;
IReadOnlyList<string> schemaList = resp.Columns[0].StringVal.Values;

for (int i = 0; i < catalogList.Count; i++)
{
Expand Down Expand Up @@ -220,6 +380,7 @@ public override IArrowArrayStream GetObjects(GetObjectsDepth depth, string catal
TableInfoPair tableInfo = new TableInfoPair();
tableInfo.Type = tableType;
tableInfo.Columns = new List<string>();
tableInfo.ColType = new List<int>();
catalogMap.GetValueOrDefault(catalog).GetValueOrDefault(schemaDb).Add(tableName, tableInfo);
}
}
Expand Down Expand Up @@ -247,7 +408,7 @@ public override IArrowArrayStream GetObjects(GetObjectsDepth depth, string catal
IReadOnlyList<string> schemaList = resp.Columns[1].StringVal.Values;
IReadOnlyList<string> tableList = resp.Columns[2].StringVal.Values;
IReadOnlyList<string> columnList = resp.Columns[3].StringVal.Values;
IReadOnlyList<int> columnTypeList = (IReadOnlyList<int>)resp.Columns[4].I32Val.Values;
IReadOnlyList<int> columnTypeList = ConvertSpanToReadOnlyList(resp.Columns[4].I32Val.Values);

for (int i = 0; i < catalogList.Count; i++)
{
Expand Down

0 comments on commit 4313cc2

Please sign in to comment.