Skip to content

Commit

Permalink
Add support for spread of aliases & models (#4199)
Browse files Browse the repository at this point in the history
This PR adds support for spread of aliases & models within a client
operation's parameters.

fixes: #3831

---------

Co-authored-by: m-nash <[email protected]>
  • Loading branch information
jorgerangel-msft and m-nash authored Aug 21, 2024
1 parent e8c493e commit 371d5c1
Show file tree
Hide file tree
Showing 12 changed files with 660 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public static ParameterProvider ClientOptions(CSharpType clientOptionsType)
public static readonly ParameterProvider TokenAuth = new("tokenCredential", $"The token credential to copy", ClientModelPlugin.Instance.TypeFactory.TokenCredentialType());
public static readonly ParameterProvider MatchConditionsParameter = new("matchConditions", $"The content to send as the request conditions of the request.", ClientModelPlugin.Instance.TypeFactory.MatchConditionsType(), DefaultOf(ClientModelPlugin.Instance.TypeFactory.MatchConditionsType()));
public static readonly ParameterProvider RequestOptions = new("options", $"The request options, which can override default behaviors of the client pipeline on a per-call basis.", typeof(RequestOptions));
public static readonly ParameterProvider BinaryContent = new("content", $"The content to send as the body of the request.", typeof(BinaryContent)) { Validation = ParameterValidationType.AssertNotNull };
public static readonly ParameterProvider BinaryContent = new("content", $"The content to send as the body of the request.", typeof(BinaryContent), location: ParameterLocation.Body) { Validation = ParameterValidationType.AssertNotNull };

// Known header parameters
public static readonly ParameterProvider RepeatabilityRequestId = new("repeatabilityRequestId", FormattableStringHelpers.Empty, typeof(Guid))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,43 @@ private static void GetParamInfo(Dictionary<string, ParameterProvider> paramMap,
}
}

private static IReadOnlyList<ParameterProvider> BuildSpreadParametersForModel(InputModelType inputModel)
{
var builtParameters = new ParameterProvider[inputModel.Properties.Count];

int index = 0;
foreach (var property in inputModel.Properties)
{
// convert the property to a parameter
var inputParameter = new InputParameter(
property.Name,
property.SerializedName,
property.Description,
property.Type,
RequestLocation.Body,
null,
InputOperationParameterKind.Method,
property.IsRequired,
false,
false,
false,
false,
false,
false,
null,
null);

var paramProvider = ClientModelPlugin.Instance.TypeFactory.CreateParameter(inputParameter);
paramProvider.DefaultValue = !inputParameter.IsRequired ? Default : null;
paramProvider.SpreadSource = ClientModelPlugin.Instance.TypeFactory.CreateModel(inputModel);
paramProvider.Type = paramProvider.Type.InputType;

builtParameters[index++] = paramProvider;
}

return builtParameters;
}

private static bool TryGetSpecialHeaderParam(InputParameter inputParameter, [NotNullWhen(true)] out ParameterProvider? parameterProvider)
{
if (inputParameter.Location == RequestLocation.Header)
Expand All @@ -375,12 +412,22 @@ internal MethodProvider GetCreateRequestMethod(InputOperation operation)

internal static List<ParameterProvider> GetMethodParameters(InputOperation operation, bool isProtocol = false)
{
List<ParameterProvider> methodParameters = new();
SortedList<int, ParameterProvider> sortedParams = [];
int path = 0;
int required = 100;
int bodyRequired = 200;
int bodyOptional = 300;
int contentType = 400;
int optional = 500;

foreach (InputParameter inputParam in operation.Parameters)
{
if (inputParam.Kind != InputOperationParameterKind.Method || TryGetSpecialHeaderParam(inputParam, out var _))
if ((inputParam.Kind != InputOperationParameterKind.Method && inputParam.Kind != InputOperationParameterKind.Spread)
|| TryGetSpecialHeaderParam(inputParam, out var _))
continue;

var spreadInputModel = inputParam.Kind == InputOperationParameterKind.Spread ? GetSpreadParameterModel(inputParam) : null;

ParameterProvider? parameter = ClientModelPlugin.Instance.TypeFactory.CreateParameter(inputParam);

if (isProtocol)
Expand All @@ -394,11 +441,66 @@ internal static List<ParameterProvider> GetMethodParameters(InputOperation opera
parameter.Type = parameter.Type.IsEnum ? parameter.Type.UnderlyingEnumType : parameter.Type;
}
}
else if (spreadInputModel != null)
{
foreach (var bodyParam in BuildSpreadParametersForModel(spreadInputModel))
{
if (bodyParam.DefaultValue is null)
{
sortedParams.Add(bodyRequired++, bodyParam);
}
else
{
sortedParams.Add(bodyOptional++, bodyParam);
}
}
continue;
}

if (parameter is null)
continue;

if (parameter is not null)
methodParameters.Add(parameter);
switch (parameter.Location)
{
case ParameterLocation.Path:
case ParameterLocation.Uri:
sortedParams.Add(path++, parameter);
break;
case ParameterLocation.Query:
case ParameterLocation.Header:
if (inputParam.IsContentType)
{
sortedParams.Add(contentType++, parameter);
}
else if (parameter.Validation != ParameterValidationType.None)
{
sortedParams.Add(required++, parameter);
}
else
{
sortedParams.Add(optional++, parameter);
}
break;
case ParameterLocation.Body:
sortedParams.Add(bodyRequired++, parameter);
break;
default:
sortedParams.Add(optional++, parameter);
break;
}
}

return [.. sortedParams.Values];
}

internal static InputModelType GetSpreadParameterModel(InputParameter inputParam)
{
if (inputParam.Kind.HasFlag(InputOperationParameterKind.Spread) && inputParam.Type is InputModelType model)
{
return model;
}
return methodParameters;

throw new InvalidOperationException($"inputParam `{inputParam.Name}` is `Spread` but not a model type");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ private MethodProvider BuildConvenienceMethod(MethodProvider protocolMethod, boo
{
methodModifier |= MethodSignatureModifiers.Async;
}

var methodSignature = new MethodSignature(
isAsync ? _cleanOperationName + "Async" : _cleanOperationName,
FormattableStringHelpers.FromString(Operation.Description),
Expand All @@ -74,12 +75,13 @@ private MethodProvider BuildConvenienceMethod(MethodProvider protocolMethod, boo
var processMessageName = isAsync ? "ProcessMessageAsync" : "ProcessMessage";

MethodBodyStatement[] methodBody;

if (responseBodyType is null)
{
methodBody =
[
.. GetStackVariablesForProtocolParamConversion(ConvenienceMethodParameters, out var paramDeclarations),
Return(This.Invoke(protocolMethod.Signature, [.. GetParamConversions(ConvenienceMethodParameters, paramDeclarations), Null], isAsync))
.. GetStackVariablesForProtocolParamConversion(ConvenienceMethodParameters, out var declarations),
Return(This.Invoke(protocolMethod.Signature, [.. GetParamConversions(ConvenienceMethodParameters, declarations), Null], isAsync))
];
}
else
Expand All @@ -88,11 +90,11 @@ .. GetStackVariablesForProtocolParamConversion(ConvenienceMethodParameters, out
[
.. GetStackVariablesForProtocolParamConversion(ConvenienceMethodParameters, out var paramDeclarations),
Declare("result", This.Invoke(protocolMethod.Signature, [.. GetParamConversions(ConvenienceMethodParameters, paramDeclarations), Null], isAsync).As<ClientResult>(), out ScopedApi<ClientResult> result),
.. GetStackVariablesForReturnValueConversion(result, responseBodyType, isAsync, out var declarations),
.. GetStackVariablesForReturnValueConversion(result, responseBodyType, isAsync, out var resultDeclarations),
Return(Static<ClientResult>().Invoke(
nameof(ClientResult.FromValue),
[
GetResultConversion(result, responseBodyType, declarations),
GetResultConversion(result, responseBodyType, resultDeclarations),
result.Invoke("GetRawResponse")
])),
];
Expand All @@ -109,6 +111,9 @@ private IEnumerable<MethodBodyStatement> GetStackVariablesForProtocolParamConver
declarations = new Dictionary<string, ValueExpression>();
foreach (var parameter in convenienceMethodParameters)
{
if (parameter.SpreadSource is not null)
continue;

if (parameter.Location == ParameterLocation.Body)
{
if (parameter.Type.IsReadOnlyMemory)
Expand Down Expand Up @@ -136,9 +141,53 @@ private IEnumerable<MethodBodyStatement> GetStackVariablesForProtocolParamConver
}
}
}

// add spread parameter model variable declaration
var spreadSource = convenienceMethodParameters.FirstOrDefault(p => p.SpreadSource is not null)?.SpreadSource;
if (spreadSource is not null)
{
statements.Add(Declare("spreadModel", New.Instance(spreadSource.Type, [.. GetSpreadConversion(spreadSource)]).As(spreadSource.Type), out var spread));
declarations["spread"] = spread;
}

return statements;
}

private List<ValueExpression> GetSpreadConversion(TypeProvider spreadSource)
{
var convenienceMethodParams = ConvenienceMethodParameters.ToDictionary(p => p.Name);
List<ValueExpression> expressions = new(spreadSource.Properties.Count);
// we should make this find more deterministic
var ctor = spreadSource.Constructors.First(c => c.Signature.Parameters.Count == spreadSource.Properties.Count + 1 &&
c.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Internal));

foreach (var param in ctor.Signature.Parameters)
{
if (convenienceMethodParams.TryGetValue(param.Name, out var convenienceParam))
{
if (convenienceParam.Type.IsList)
{
var interfaceType = param.Property!.WireInfo?.IsReadOnly == true
? new CSharpType(typeof(IReadOnlyList<>), convenienceParam.Type.Arguments)
: new CSharpType(typeof(IList<>), convenienceParam.Type.Arguments);
expressions.Add(NullCoalescing(
new AsExpression(convenienceParam.NullConditional().ToList(), interfaceType),
New.Instance(convenienceParam.Type.PropertyInitializationType, [])));
}
else
{
expressions.Add(convenienceParam);
}
}
else
{
expressions.Add(Null);
}
}

return expressions;
}

private IEnumerable<MethodBodyStatement> GetStackVariablesForReturnValueConversion(ScopedApi<ClientResult> result, CSharpType responseBodyType, bool isAsync, out Dictionary<string, ValueExpression> declarations)
{
if (responseBodyType.IsList)
Expand Down Expand Up @@ -216,9 +265,18 @@ private ValueExpression GetResultConversion(ScopedApi<ClientResult> result, CSha
private IReadOnlyList<ValueExpression> GetParamConversions(IReadOnlyList<ParameterProvider> convenienceMethodParameters, Dictionary<string, ValueExpression> declarations)
{
List<ValueExpression> conversions = new List<ValueExpression>();
bool addedSpreadSource = false;
foreach (var param in convenienceMethodParameters)
{
if (param.Location == ParameterLocation.Body)
if (param.SpreadSource is not null)
{
if (!addedSpreadSource)
{
conversions.Add(declarations["spread"]);
addedSpreadSource = true;
}
}
else if (param.Location == ParameterLocation.Body)
{
if (param.Type.IsReadOnlyMemory || param.Type.IsList)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System;
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Linq;
Expand All @@ -25,6 +26,13 @@ public class ClientProviderTests
private static readonly InputClient _animalClient = new("animal", "AnimalClient description", [], [], TestClientName);
private static readonly InputClient _dogClient = new("dog", "DogClient description", [], [], _animalClient.Name);
private static readonly InputClient _huskyClient = new("husky", "HuskyClient description", [], [], _dogClient.Name);
private static readonly InputModelType _spreadModel = InputFactory.Model(
"spreadModel",
usage: InputModelTypeUsage.Spread,
properties:
[
InputFactory.Property("p1", InputPrimitiveType.String, isRequired: true),
]);

[SetUp]
public void SetUp()
Expand Down Expand Up @@ -365,6 +373,33 @@ public void ValidateQueryParamWriterDiff()
Assert.AreEqual(Helpers.GetExpectedFromFile(), codeFile.Content);
}

[TestCaseSource(nameof(ValidateClientWithSpreadTestCases))]
public void ValidateClientWithSpread(InputClient inputClient)
{
var clientProvider = new ClientProvider(inputClient);
var methods = clientProvider.Methods;

Assert.AreEqual(4, methods.Count);

var protocolMethods = methods.Where(m => m.Signature.Parameters.Any(p => p.Type.Equals(typeof(BinaryContent)))).ToList();
Assert.AreEqual(2, protocolMethods.Count);
Assert.AreEqual(2, protocolMethods[0].Signature.Parameters.Count);
Assert.AreEqual(2, protocolMethods[1].Signature.Parameters.Count);

Assert.AreEqual(new CSharpType(typeof(BinaryContent)), protocolMethods[0].Signature.Parameters[0].Type);
Assert.AreEqual(new CSharpType(typeof(RequestOptions)), protocolMethods[0].Signature.Parameters[1].Type);
Assert.AreEqual(new CSharpType(typeof(BinaryContent)), protocolMethods[1].Signature.Parameters[0].Type);
Assert.AreEqual(new CSharpType(typeof(RequestOptions)), protocolMethods[1].Signature.Parameters[1].Type);

var convenienceMethods = methods.Where(m => m.Signature.Parameters.Any(p => p.Type.Equals(typeof(string)))).ToList();
Assert.AreEqual(2, convenienceMethods.Count);
Assert.AreEqual(1, convenienceMethods[0].Signature.Parameters.Count);

Assert.AreEqual(new CSharpType(typeof(string)), convenienceMethods[0].Signature.Parameters[0].Type);
Assert.AreEqual("p1", convenienceMethods[0].Signature.Parameters[0].Name);

}

private static InputClient GetEnumQueryParamClient()
=> InputFactory.Client(
TestClientName,
Expand All @@ -386,6 +421,7 @@ private static InputClient GetEnumQueryParamClient()
InputFactory.EnumMember.String("value1", "value1"),
InputFactory.EnumMember.String("value2", "value2")
]),
isRequired: true,
location: RequestLocation.Query)
])
]);
Expand Down Expand Up @@ -469,6 +505,29 @@ public static IEnumerable<TestCaseData> SubClientTestCases
}
}

public static IEnumerable<TestCaseData> ValidateClientWithSpreadTestCases
{
get
{
yield return new TestCaseData(InputFactory.Client(
TestClientName,
operations:
[
InputFactory.Operation(
"CreateMessage",
parameters:
[
InputFactory.Parameter(
"spread",
_spreadModel,
location: RequestLocation.Body,
isRequired: true,
kind: InputOperationParameterKind.Spread),
])
]));
}
}

public static IEnumerable<TestCaseData> BuildConstructorsTestCases
{
get
Expand Down
Loading

0 comments on commit 371d5c1

Please sign in to comment.