Skip to content

Commit

Permalink
stricter schema validation on deserialization (#573)
Browse files Browse the repository at this point in the history
  • Loading branch information
aloneguid committed Dec 4, 2024
1 parent 2a1bead commit 876df34
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 13 deletions.
1 change: 1 addition & 0 deletions docs/rn/5.0.3.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

- Class-reflected schemas for `map` complex types will generate key/value properties with "key" and "value" names, respectively. This is required in order to deserialise externally generated parquet files with dictionaries.
- Updated dependent packages.
- Class deserializer will check for type compatibility when deserialising incompatible types, which will prevent accidental data loss or narrowing. Thanks to @dkotov in #573.

# Floor

Expand Down
43 changes: 30 additions & 13 deletions src/Parquet.Test/Serialisation/ParquetSerializerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1147,23 +1147,40 @@ class EdgeCaseInt32 {
public int Id { get; set; }
}

class EdgeCaseInt32Optional {
public int? Id { get; set; }
}

class EdgeCaseInt64Optional {
public long? Id { get; set; }
}

/// <summary>
/// This shoudl throw InvalidCastException as the raw Int64 is being cast to Int32
/// </summary>
/// <returns></returns>
[Fact]
public async Task EdgeCase_rawint64_to_classInt32() {
var schema = new ParquetSchema(new DataField<long>("Id"));
using var ms = new MemoryStream();
using(ParquetWriter writer = await ParquetWriter.CreateAsync(schema, ms)) {
using(ParquetRowGroupWriter rg = writer.CreateRowGroup()) {
await rg.WriteColumnAsync(new DataColumn(schema.DataFields[0], new long[] { 1, 2, 3 }));
}
}
ms.Position = 0;
public async Task EdgeCase_RawInt64_to_Int32() {

using Stream testFile = OpenTestFile("special/no-logical-type.parquet");

await Assert.ThrowsAsync<InvalidCastException>(async () => {
await ParquetSerializer.DeserializeAsync<EdgeCaseInt32Optional>(testFile);
});
}

IList<EdgeCaseInt32> data = await ParquetSerializer.DeserializeAsync<EdgeCaseInt32>(ms);
[Fact]
public async Task EdgeCase_Int64() {

Assert.Equal(1, data[0].Id);
Assert.Equal(2, data[1].Id);
Assert.Equal(3, data[2].Id);
IList<EdgeCaseInt64Optional> r = await ParquetSerializer.DeserializeAsync<EdgeCaseInt64Optional>(
OpenTestFile("special/no-logical-type.parquet"));

Assert.NotNull(r);
Assert.Equal(3, r.Count);
Assert.Equal(1, r[0].Id);
Assert.Equal(2, r[1].Id);
Assert.Equal(3, r[2].Id);
}

}
}
Binary file not shown.
17 changes: 17 additions & 0 deletions src/Parquet/Serialization/ParquetSerializer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ public static class ParquetSerializer {
private static readonly ConcurrentDictionary<ParquetSchema, object> _schemaToStriper = new();
private static readonly ConcurrentDictionary<Type, object> _typeToAssembler = new();
private static readonly ConcurrentDictionary<ParquetSchema, object> _schemaToAssembler = new();
private static readonly Dictionary<Type, HashSet<Type>> AllowedDeserializerConversions = new() {
#if NET8_0_OR_GREATER
{ typeof(DateOnly), new HashSet<Type>{ typeof(DateTime) } },
{ typeof(TimeOnly), new HashSet<Type>{ typeof(TimeSpan) } },
#endif
};

private static async Task SerializeRowGroupAsync<T>(ParquetWriter writer, Striper<T> striper,
IEnumerable<T> objectInstances,
Expand Down Expand Up @@ -513,6 +519,17 @@ private static async Task DeserializeRowGroupAsync(ParquetReader reader, int rgi
if(fileField.MaxRepetitionLevel != assemblerField.MaxRepetitionLevel)
throw new InvalidDataException($"class repetition level ({assemblerField.MaxRepetitionLevel}) does not match file's repetition level ({fileField.MaxRepetitionLevel}) in field '{assemblerField.Path}'. This usually means collection in class definition is incompatible.");

if(fileField.ClrType != assemblerField.ClrType) {

// check if this is one of the allowed conversions
bool isStillAllowed =
AllowedDeserializerConversions.TryGetValue(assemblerField.ClrType, out HashSet<Type>? allowedConversions) &&
allowedConversions.Contains(fileField.ClrType);

if(!isStillAllowed)
throw new InvalidCastException($"class type ({assemblerField.ClrType}) does not match file's type ({fileField.ClrType}) in field '{assemblerField.Path}'");
}


// make final result
DataField r = (DataField)assemblerField.Clone();
Expand Down

0 comments on commit 876df34

Please sign in to comment.