Skip to content

Commit

Permalink
Fix: Models with nested Discriminators (#4596)
Browse files Browse the repository at this point in the history
fixes: #4597
  • Loading branch information
jorgerangel-msft authored Oct 3, 2024
1 parent fa10eac commit 56c3d62
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -735,9 +735,22 @@ private List<MethodBodyStatement> BuildDeserializePropertiesStatements(ScopedApi
CreateDeserializeAdditionalPropsValueKindCheck(jsonProperty, additionalPropsValueKindBodyStatements));
}

// deserialize the raw binary data for the model
var rawBinaryData = _rawDataField
?? _model.BaseModelProvider?.Fields.FirstOrDefault(f => f.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName);
// deserialize the raw binary data for the model by searching for the raw binary data field in the model and any base models.
var rawBinaryData = _rawDataField;
if (rawBinaryData == null)
{
var baseModelProvider = _model.BaseModelProvider;
while (baseModelProvider != null)
{
var field = baseModelProvider.Fields.FirstOrDefault(f => f.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName);
if (field != null)
{
rawBinaryData = field;
break;
}
baseModelProvider = baseModelProvider.BaseModelProvider;
}
}

if (_additionalBinaryDataProperty != null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,40 @@ public void TestBuildDeserializationMethod()
Assert.IsNotNull(methodBody);
}

[Test]
public void TestBuildDeserializationMethodNestedSARD()
{
var baseModel = InputFactory.Model("BaseModel");
var nestedModel = InputFactory.Model("NestedModel", baseModel: baseModel);
var inputModel = InputFactory.Model("mockInputModel", baseModel: nestedModel);
var (baseModelProvider, baseSerialization) = CreateModelAndSerialization(baseModel);
var (nestedModelProvider, nestedSerialization) = CreateModelAndSerialization(nestedModel);
var (model, serialization) = CreateModelAndSerialization(inputModel);

Assert.AreEqual(0, model.Fields.Count);
Assert.AreEqual(0, nestedModelProvider.Fields.Count);
Assert.AreEqual(1, baseModelProvider.Fields.Count);

var deserializationMethod = serialization.BuildDeserializationMethod();
Assert.IsNotNull(deserializationMethod);

var signature = deserializationMethod?.Signature;
Assert.IsNotNull(signature);
Assert.AreEqual($"Deserialize{model.Name}", signature?.Name);
Assert.AreEqual(2, signature?.Parameters.Count);
Assert.AreEqual(new CSharpType(typeof(JsonElement)), signature?.Parameters[0].Type);
Assert.AreEqual(new CSharpType(typeof(ModelReaderWriterOptions)), signature?.Parameters[1].Type);
Assert.AreEqual(model.Type, signature?.ReturnType);
Assert.AreEqual(MethodSignatureModifiers.Internal | MethodSignatureModifiers.Static, signature?.Modifiers);

var methodBody = deserializationMethod?.BodyStatements;
Assert.IsNotNull(methodBody);
// validate that only one SARD variable is created.
var methodBodyString = methodBody!.ToDisplayString();
var sardDeclaration = "global::System.Collections.Generic.IDictionary<string, global::System.BinaryData> additionalBinaryDataProperties";
Assert.AreEqual(1, methodBodyString.Split(sardDeclaration).Length - 1);
}

[Test]
public void TestBuildImplicitToBinaryContent()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
using Microsoft.Generator.CSharp.Input;
using Microsoft.Generator.CSharp.Primitives;
using Microsoft.Generator.CSharp.Snippets;
using Microsoft.Generator.CSharp.SourceInput;
using Microsoft.Generator.CSharp.Statements;
using static Microsoft.Generator.CSharp.Snippets.Snippet;

Expand Down Expand Up @@ -331,12 +330,12 @@ protected override PropertyProvider[] BuildProperties()
var properties = new List<PropertyProvider>(propertiesCount + 1);

Dictionary<string, InputModelProperty> baseProperties = _inputModel.BaseModel?.Properties.ToDictionary(p => p.Name) ?? [];

var baseModelDiscriminator = _inputModel.BaseModel?.DiscriminatorProperty;
for (int i = 0; i < propertiesCount; i++)
{
var property = _inputModel.Properties[i];

if (property.IsDiscriminator && Type.BaseType is not null)
if (property.IsDiscriminator && property.Name == baseModelDiscriminator?.Name)
continue;

var outputProperty = CodeModelPlugin.Instance.TypeFactory.CreateProperty(property, this);
Expand Down Expand Up @@ -458,9 +457,8 @@ private ConstructorProvider BuildFullConstructor()

if (isPrimaryConstructor)
{
baseProperties = _inputModel.GetAllBaseModels()
.Reverse()
.SelectMany(model => CodeModelPlugin.Instance.TypeFactory.CreateModel(model)?.Properties ?? []);
// the primary ctor should only include the properties of the direct base model
baseProperties = BaseModelProvider?.Properties ?? [];
}
else if (BaseModelProvider?.FullConstructor.Signature != null)
{
Expand Down Expand Up @@ -515,24 +513,7 @@ p.Property is null
var type = discriminator.Type;
if (IsUnknownDiscriminatorModel)
{
var discriminatorExpression = discriminator.AsParameter.AsExpression;
if (!type.IsFrameworkType && type.IsEnum)
{
if (type.IsStruct)
{
/* kind != default ? kind : "unknown" */
return new TernaryConditionalExpression(discriminatorExpression.NotEqual(Default), discriminatorExpression, Literal(_inputModel.DiscriminatorValue));
}
else
{
return discriminatorExpression;
}
}
else
{
/* kind ?? "unknown" */
return discriminatorExpression.NullCoalesce(Literal(_inputModel.DiscriminatorValue));
}
return GetUnknownDiscriminatorExpression(discriminator);
}
else
{
Expand All @@ -558,17 +539,52 @@ p.Property is null

private ValueExpression GetExpressionForCtor(ParameterProvider parameter, HashSet<PropertyProvider> overriddenProperties, bool isPrimaryConstructor)
{
if (parameter.Property is not null && parameter.Property.IsDiscriminator && _inputModel.DiscriminatorValue != null &&
(isPrimaryConstructor || !isPrimaryConstructor && IsUnknownDiscriminatorModel))
if (parameter.Property is not null && parameter.Property.IsDiscriminator && _inputModel.DiscriminatorValue != null)
{
return DiscriminatorValueExpression ?? throw new InvalidOperationException($"invalid discriminator {_inputModel.DiscriminatorValue} for property {parameter.Property.Name}");
if (isPrimaryConstructor)
{
return DiscriminatorValueExpression ?? throw new InvalidOperationException($"invalid discriminator {_inputModel.DiscriminatorValue} for property {parameter.Property.Name}");
}
else if (IsUnknownDiscriminatorModel)
{
return GetUnknownDiscriminatorExpression(parameter.Property) ?? throw new InvalidOperationException($"invalid discriminator {_inputModel.DiscriminatorValue} for property {parameter.Property.Name}");
}
}

var paramToUse = parameter.Property is not null && overriddenProperties.Contains(parameter.Property) ? Properties.First(p => p.Name == parameter.Property.Name).AsParameter : parameter;

return paramToUse.Property is not null ? GetConversion(paramToUse.Property) : paramToUse;
}

private ValueExpression? GetUnknownDiscriminatorExpression(PropertyProvider property)
{
if (!property.IsDiscriminator || _inputModel.DiscriminatorValue == null)
{
return null;
}

var discriminatorExpression = property.AsParameter.AsExpression;
var type = property.Type;

if (!type.IsFrameworkType && type.IsEnum)
{
if (type.IsStruct)
{
/* kind != default ? kind : "unknown" */
return new TernaryConditionalExpression(discriminatorExpression.NotEqual(Default), discriminatorExpression, Literal(_inputModel.DiscriminatorValue));
}
else
{
return discriminatorExpression;
}
}
else
{
/* kind ?? "unknown" */
return discriminatorExpression.NullCoalesce(Literal(_inputModel.DiscriminatorValue));
}
}

private static void AddInitializationParameterForCtor(
List<ParameterProvider> parameters,
PropertyProvider property,
Expand Down Expand Up @@ -697,10 +713,15 @@ private ValueExpression GetConversion(PropertyProvider property)
/// <returns>The constructed <see cref="FieldProvider"/> if the model should generate the field.</returns>
private FieldProvider? BuildRawDataField()
{
// check if there is a raw data field on my base, if so, we do not have to have one here
if (BaseModelProvider?.RawDataField != null)
// check if there is a raw data field on any of the base models, if so, we do not have to have one here.
var baseModelProvider = BaseModelProvider;
while (baseModelProvider != null)
{
return null;
if (baseModelProvider.RawDataField != null)
{
return null;
}
baseModelProvider = baseModelProvider.BaseModelProvider;
}

var modifiers = FieldModifiers.Private;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,9 @@ public void DiscriminatorEnumParamShape()
Assert.IsNotNull(method);
foreach (var property in model!.Properties.Where(p => p.Type.IsEnum))
{
// enum discriminator properties are not included in the factory method
var parameter = method!.Signature.Parameters.FirstOrDefault(p => p.Name == property.Name.ToVariableName());
Assert.IsNotNull(parameter);
Assert.IsTrue(parameter!.Type.IsFrameworkType);
Assert.AreEqual(typeof(int), parameter!.Type.FrameworkType);
Assert.IsNull(parameter);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,21 @@ public class DiscriminatorTests
InputFactory.Property("kind", InputPrimitiveType.String, isRequired: true, isDiscriminator: true),
InputFactory.Property("likesBones", InputPrimitiveType.Boolean, isRequired: true)
]);

private static readonly InputModelType _anotherAnimal = InputFactory.Model("anotherAnimal", discriminatedKind: "dog", properties:
[
InputFactory.Property("kind", InputPrimitiveType.String, isRequired: true, isDiscriminator: true),
InputFactory.Property("other", InputPrimitiveType.String, isRequired: true, isDiscriminator: true)
]);
private static readonly InputModelType _baseModel = InputFactory.Model(
"pet",
properties: [InputFactory.Property("kind", InputPrimitiveType.String, isRequired: true, isDiscriminator: true)],
discriminatedModels: new Dictionary<string, InputModelType>() { { "cat", _catModel }, { "dog", _dogModel } });
discriminatedModels: new Dictionary<string, InputModelType>()
{
{ "cat", _catModel },
{ "dog", _dogModel },
{ "otherAnimal", _anotherAnimal }
});

private static readonly InputEnumType _petEnum = InputFactory.Enum("pet", InputPrimitiveType.String, isExtensible: true, values:
[
Expand Down Expand Up @@ -206,5 +217,30 @@ public void DerivedHasNoKindProperty()
var kindProperty = catModel!.Properties.FirstOrDefault(p => p.Name == "Kind");
Assert.IsNull(kindProperty);
}

[Test]
public void ModelWithNestedDiscriminators()
{
MockHelpers.LoadMockPlugin(inputModelTypes: [_baseEnumModel, _dogEnumModel, _anotherAnimal]);
var outputLibrary = CodeModelPlugin.Instance.OutputLibrary;
var anotherDogModel = outputLibrary.TypeProviders.OfType<ModelProvider>().FirstOrDefault(t => t.Name == "AnotherAnimal");
Assert.IsNotNull(anotherDogModel);

var serializationCtor = anotherDogModel!.Constructors.FirstOrDefault(c => c.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Internal));
Assert.IsNotNull(serializationCtor);
Assert.AreEqual(3, serializationCtor!.Signature.Parameters.Count);

// ensure both discriminators are present
var kindParam = serializationCtor!.Signature.Parameters.FirstOrDefault(p => p.Name == "kind");
Assert.IsNotNull(kindParam);
var otherParam = serializationCtor!.Signature.Parameters.FirstOrDefault(p => p.Name == "other");
Assert.IsNotNull(otherParam);

// the primary ctor should only have the model's own discriminator
var publicCtor = anotherDogModel.Constructors.FirstOrDefault(c => c.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Public));
Assert.IsNotNull(publicCtor);
Assert.AreEqual(1, publicCtor!.Signature.Parameters.Count);
Assert.AreEqual("other", publicCtor.Signature.Parameters[0].Name);
}
}
}

0 comments on commit 56c3d62

Please sign in to comment.