Skip to content

Commit

Permalink
[http-client-csharp] fix: use correct custom ctor in model factory (#…
Browse files Browse the repository at this point in the history
…4921)

This PR fixes an issue where the model factory method for a model was
using the incorrect full constructor when the full constructor was
suppressed and customized.

fixes: #4830
  • Loading branch information
jorgerangel-msft authored Oct 30, 2024
1 parent 6d168d3 commit 80f5c4c
Showing 8 changed files with 121 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -44,7 +44,11 @@ public CanonicalTypeProvider(TypeProvider generatedTypeProvider, InputType? inpu

private protected override CanonicalTypeProvider GetCanonicalView() => this;

// TODO - Implement BuildMethods, BuildConstructors, etc as needed
// TODO - Implement BuildMethods, etc as needed
protected override ConstructorProvider[] BuildConstructors()
{
return [.. _generatedTypeProvider.Constructors, .. _generatedTypeProvider.CustomCodeView?.Constructors ?? []];
}

protected override PropertyProvider[] BuildProperties()
{
Original file line number Diff line number Diff line change
@@ -79,14 +79,39 @@ protected override MethodProvider[] BuildMethods()
if (typeToInstantiate is null)
continue;

var modelCtor = modelProvider.FullConstructor;
var fullConstructor = modelProvider.FullConstructor;
var binaryDataParam = fullConstructor.Signature.Parameters.FirstOrDefault(p => p.Name.Equals(AdditionalBinaryDataParameterName));

// Use a custom constructor if the generated full constructor was suppressed or customized
if (!modelProvider.Constructors.Contains(fullConstructor))
{
foreach (var constructor in modelProvider.CanonicalView.Constructors)
{
var customCtorParamCount = constructor.Signature.Parameters.Count;
var fullCtorParamCount = fullConstructor.Signature.Parameters.Count;

if (constructor.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Internal)
&& customCtorParamCount >= fullCtorParamCount)
{
binaryDataParam = constructor.Signature.Parameters
.FirstOrDefault(p => p?.Type.Equals(typeof(IDictionary<string, BinaryData>)) == true, binaryDataParam);

if (customCtorParamCount > fullCtorParamCount)
{
fullConstructor = constructor;
break;
}
}
}
}

var signature = new MethodSignature(
modelProvider.Name,
null,
MethodSignatureModifiers.Static | MethodSignatureModifiers.Public,
modelProvider.Type,
$"A new {modelProvider.Type:C} instance for mocking.",
GetParameters(modelProvider));
GetParameters(modelProvider, fullConstructor));

var docs = new XmlDocProvider();
docs.Summary = modelProvider.XmlDocs?.Summary;
@@ -100,7 +125,7 @@ protected override MethodProvider[] BuildMethods()
[
.. GetCollectionInitialization(signature),
MethodBodyStatement.EmptyLine,
Return(New.Instance(typeToInstantiate.Type, [.. GetCtorArgs(modelProvider, signature)]))
Return(New.Instance(typeToInstantiate.Type, [.. GetCtorArgs(modelProvider, signature, fullConstructor, binaryDataParam)]))
]);

methods.Add(new MethodProvider(signature, statements, this, docs));
@@ -110,9 +135,11 @@ .. GetCollectionInitialization(signature),

private static IReadOnlyList<ValueExpression> GetCtorArgs(
ModelProvider modelProvider,
MethodSignature factoryMethodSignature)
MethodSignature factoryMethodSignature,
ConstructorProvider fullConstructor,
ParameterProvider? binaryDataParameter)
{
var modelCtorFullSignature = modelProvider.FullConstructor.Signature;
var modelCtorFullSignature = fullConstructor.Signature;
var expressions = new List<ValueExpression>(modelCtorFullSignature.Parameters.Count);

for (int i = 0; i < modelCtorFullSignature.Parameters.Count; i++)
@@ -153,10 +180,9 @@ private static IReadOnlyList<ValueExpression> GetCtorArgs(
}
}

if (modelCtorFullSignature.Parameters.Any(p => p.Name.Equals(AdditionalBinaryDataParameterName)) &&
!modelProvider.SupportsBinaryDataAdditionalProperties)
if (binaryDataParameter != null && !modelProvider.SupportsBinaryDataAdditionalProperties)
{
expressions.Add(Null);
expressions.Add(binaryDataParameter.PositionalReference(Null));
}

return [.. expressions];
@@ -175,14 +201,22 @@ private IReadOnlyList<MethodBodyStatement> GetCollectionInitialization(MethodSig
return [.. statements];
}

private static IReadOnlyList<ParameterProvider> GetParameters(ModelProvider modelProvider)
private static IReadOnlyList<ParameterProvider> GetParameters(
ModelProvider modelProvider,
ConstructorProvider fullConstructor)
{
var modelCtorParams = modelProvider.FullConstructor.Signature.Parameters;
var modelCtorParams = fullConstructor.Signature.Parameters;
var parameters = new List<ParameterProvider>(modelCtorParams.Count);
bool isCustomConstructor = fullConstructor != modelProvider.FullConstructor;

foreach (var param in modelCtorParams)
{
if (param.Name.Equals(AdditionalBinaryDataParameterName) && !modelProvider.SupportsBinaryDataAdditionalProperties)
bool isBinaryDataParam = param.Name.Equals(AdditionalBinaryDataParameterName)
|| (isCustomConstructor && param.Type.Equals(typeof(IDictionary<string, BinaryData>)));

if (isBinaryDataParam && !modelProvider.SupportsBinaryDataAdditionalProperties)
continue;

// skip discriminator parameters if the model has a discriminator value as those shouldn't be exposed in the factory methods
if (param.Property?.IsDiscriminator == true && modelProvider.DiscriminatorValue != null)
continue;
Original file line number Diff line number Diff line change
@@ -21,6 +21,8 @@ public static partial class Snippet
public static ValueExpression NullConditional(this ParameterProvider parameter) => new NullConditionalExpression(parameter);

public static ValueExpression NullCoalesce(this ParameterProvider parameter, ValueExpression value) => parameter.AsExpression.NullCoalesce(value);
public static ValueExpression PositionalReference(this ParameterProvider parameter, ValueExpression value)
=> new PositionalParameterReferenceExpression(parameter.Name, value);

public static DictionaryExpression AsDictionary(this FieldProvider field, CSharpType keyType, CSharpType valueType) => new(new KeyValuePairType(keyType, valueType), field);
public static DictionaryExpression AsDictionary(this ParameterProvider parameter, CSharpType keyType, CSharpType valueType) => new(new KeyValuePairType(keyType, valueType), parameter);
Original file line number Diff line number Diff line change
@@ -77,7 +77,7 @@ public void DerivedShouldUseItsDiscriminatorValueInModelFactory()
// ensure the signature is correct and includes the base discriminator value
// and the cat model's discriminator with literal value
Assert.IsTrue(birdModelMethod!.BodyStatements!.ToDisplayString()
.Contains("return new global::Sample.Models.Bird(\"red\", \"bird\", name, null);"));
.Contains("return new global::Sample.Models.Bird(\"red\", \"bird\", name, additionalBinaryDataProperties: null);"));
}

private static ModelFactoryProvider SetupModelFactory()
Original file line number Diff line number Diff line change
@@ -122,5 +122,45 @@ public async Task CanChangeAccessibilityOfModelFactory()
Assert.IsFalse(modelFactory!.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public));
ValidateModelFactoryCommon(modelFactory);
}

[Test]
public async Task CanCustomizeModelFullConstructor()
{
var plugin = await MockHelpers.LoadMockPluginAsync(
inputModelTypes: [
InputFactory.Model(
"mockInputModel",
properties:
[
InputFactory.Property("Prop1", InputPrimitiveType.String, isRequired: true),
])
],
compilation: async () => await Helpers.GetCompilationFromDirectoryAsync());
var csharpGen = new CSharpGen();

await csharpGen.ExecuteAsync();

// Find the model factory provider
var modelFactory = plugin.Object.OutputLibrary.TypeProviders.SingleOrDefault(t => t is ModelFactoryProvider);
Assert.IsNotNull(modelFactory);

// The model factory should be public
Assert.IsTrue(modelFactory!.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public));
ValidateModelFactoryCommon(modelFactory);

// The model factory method should be replaced
var modelFactoryMethods = modelFactory!.Methods;
Assert.AreEqual(1, modelFactoryMethods.Count);

var modelFactoryMethod = modelFactoryMethods[0];
Assert.AreEqual("MockInputModel", modelFactoryMethod.Signature.Name);

Assert.AreEqual(2, modelFactoryMethod.Signature.Parameters.Count);
Assert.AreEqual("data", modelFactoryMethod.Signature.Parameters[0].Name);
Assert.AreEqual("prop1", modelFactoryMethod.Signature.Parameters[1].Name);

Assert.IsTrue(modelFactoryMethod.BodyStatements!.ToDisplayString()
.Contains("return new global::Sample.Models.MockInputModel(data?.ToList(), prop1, additionalBinaryDataProperties: null);"));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using Microsoft.Generator.CSharp.Customization;

#nullable disable

namespace Sample.Models;

[CodeGenSuppress("MockInputModel", typeof(string), typeof(IDictionary<string, BinaryData>))]
public partial class MockInputModel
{
private readonly IReadOnlyList<MockInputModel> _data;

internal MockInputModel(IReadOnlyList<MockInputModel> data, string prop1, IDictionary<string, BinaryData> additionalBinaryDataProperties)
{
Prop1 = prop1;
_data = data;
_additionalBinaryDataProperties = additionalBinaryDataProperties;
}
}
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@ public static partial class SampleNamespaceModelFactory
public static global::Sample.Models.MockInputModel MockInputModel(global::System.ReadOnlyMemory<byte> prop1 = default)
{

return new global::Sample.Models.MockInputModel(prop1, null);
return new global::Sample.Models.MockInputModel(prop1, additionalBinaryDataProperties: null);
}
}
}
Original file line number Diff line number Diff line change
@@ -46,7 +46,7 @@ public static Thing Thing(string name = default, BinaryData requiredUnion = defa
requiredBadDescription,
optionalNullableList?.ToList(),
requiredNullableList?.ToList(),
null);
additionalBinaryDataProperties: null);
}

/// <summary> this is a roundtrip model. </summary>
@@ -113,7 +113,7 @@ public static RoundTripModel RoundTripModel(string requiredString = default, int
readOnlyOptionalRecordUnknown,
modelWithRequiredNullable,
requiredBytes,
null);
additionalBinaryDataProperties: null);
}

/// <summary> A model with a few required nullable properties. </summary>
@@ -124,7 +124,7 @@ public static RoundTripModel RoundTripModel(string requiredString = default, int
public static ModelWithRequiredNullableProperties ModelWithRequiredNullableProperties(int? requiredNullablePrimitive = default, StringExtensibleEnum? requiredExtensibleEnum = default, StringFixedEnum? requiredFixedEnum = default)
{

return new ModelWithRequiredNullableProperties(requiredNullablePrimitive, requiredExtensibleEnum, requiredFixedEnum, null);
return new ModelWithRequiredNullableProperties(requiredNullablePrimitive, requiredExtensibleEnum, requiredFixedEnum, additionalBinaryDataProperties: null);
}

/// <summary> this is not a friendly model but with a friendly name. </summary>
@@ -133,7 +133,7 @@ public static ModelWithRequiredNullableProperties ModelWithRequiredNullablePrope
public static Friend Friend(string name = default)
{

return new Friend(name, null);
return new Friend(name, additionalBinaryDataProperties: null);
}

/// <summary> this is a model with a projected name. </summary>
@@ -142,15 +142,15 @@ public static Friend Friend(string name = default)
public static ProjectedModel ProjectedModel(string name = default)
{

return new ProjectedModel(name, null);
return new ProjectedModel(name, additionalBinaryDataProperties: null);
}

/// <summary> The ReturnsAnonymousModelResponse. </summary>
/// <returns> A new <see cref="Models.ReturnsAnonymousModelResponse"/> instance for mocking. </returns>
public static ReturnsAnonymousModelResponse ReturnsAnonymousModelResponse()
{

return new ReturnsAnonymousModelResponse(null);
return new ReturnsAnonymousModelResponse(additionalBinaryDataProperties: null);
}
}
}

0 comments on commit 80f5c4c

Please sign in to comment.