Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typehandler support #117

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Index:
- [SQL Syntax](/sqlsyntax)
- [Generated Code](/generatedcode)
- [Bulk Copy](/bulkcopy)
- [Type Handlers](/typehandlers)
- [Frequently Asked Questions](/faq)

Packages:
Expand Down
20 changes: 20 additions & 0 deletions docs/rules/DAP048.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# DAP048

Duplicate classes have been registered as type handlers for the same type,
meaning it's not possible to determine which to use when handling the type.
Note type handlers can be registered at the assembly and module level, so
ensure the type used for the `TValue` parameter in the attribute is only
specified once.

Error:

``` c#
[module: TypeHandler<MyClass, MyHandler1>]
[module: TypeHandler<MyClass, MyHandler2>]
```

Good:

``` c#
[module: TypeHandler<MyClass, MyHandler>]
```
21 changes: 21 additions & 0 deletions docs/typehandlers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Type Handlers

At times you might want to customise how a type is read from a query or how it
is saved in a parameter. In Dapper you might use a `SqlMapper.TypeHandler` for
this, which has a slightly altered interface in the AOT version and a different
way of registering them.

To register your own type handler, use either an assembly or module level
attribute to specify the mapping (you can replace `module` with `assembly`
below, it has the same effect):

``` csharp
using Dapper;

[module: TypeHandler<MyClass, MyClassTypeHandler>]
```

Your type handler must inherit from `Dapper.TypeHandler<T>` and be default
constructable. The methods are virtual, so you can override only which ones you
need (e.g. if you're just interested in reading your values and not using them
as parameters, you only need to override the `Read` method).
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ internal static readonly DiagnosticDescriptor
LanguageVersionTooLow = LibraryWarning("DAP004", "Language version too low", "Interceptors require at least C# version 11"),

CommandPropertyNotFound = LibraryWarning("DAP033", "Command property not found", "Command property {0}.{1} was not found or was not valid; attribute will be ignored"),
CommandPropertyReserved = LibraryWarning("DAP034", "Command property reserved", "Command property {1} is reserved for internal usage; attribute will be ignored");
CommandPropertyReserved = LibraryWarning("DAP034", "Command property reserved", "Command property {1} is reserved for internal usage; attribute will be ignored"),

DuplicateTypeHandlers = LibraryError("DAP048", "Duplicate type handlers", "Type {0} has multiple type handlers registered");
}
}
163 changes: 118 additions & 45 deletions src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ private void Generate(SourceProductionContext ctx, (Compilation Compilation, Imm
{
try
{
Generate(new(ctx, state));
var typeHandlers = IdentifyTypeHandlers(ctx, state.Compilation);
Generate(new(ctx, state.Compilation, state.Nodes, typeHandlers));
}
catch (Exception ex)
{
Expand Down Expand Up @@ -490,11 +491,11 @@ private static void WriteCommandFactory(in GenerateState ctx, string baseFactory
else
{
sb.Append("public override void AddParameters(in global::Dapper.UnifiedCommand cmd, ").Append(declaredType).Append(" args)").Indent().NewLine();
WriteArgs(type, sb, WriteArgsMode.Add, map, ref flags);
WriteArgs(ctx, type, sb, WriteArgsMode.Add, map, ref flags);
sb.Outdent().NewLine();

sb.Append("public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, ").Append(declaredType).Append(" args)").Indent().NewLine();
WriteArgs(type, sb, WriteArgsMode.Update, map, ref flags);
WriteArgs(ctx, type, sb, WriteArgsMode.Update, map, ref flags);
sb.Outdent().NewLine();

if ((flags & (WriteArgsFlags.NeedsRowCount | WriteArgsFlags.NeedsPostProcess)) != 0)
Expand All @@ -507,11 +508,11 @@ private static void WriteCommandFactory(in GenerateState ctx, string baseFactory
sb.Append("public override void PostProcess(in global::Dapper.UnifiedCommand cmd, ").Append(declaredType).Append(" args, int rowCount)").Indent().NewLine();
if ((flags & WriteArgsFlags.NeedsPostProcess) != 0)
{
WriteArgs(type, sb, WriteArgsMode.PostProcess, map, ref flags);
WriteArgs(ctx, type, sb, WriteArgsMode.PostProcess, map, ref flags);
}
if ((flags & WriteArgsFlags.NeedsRowCount) != 0)
{
WriteArgs(type, sb, WriteArgsMode.SetRowCount, map, ref flags);
WriteArgs(ctx, type, sb, WriteArgsMode.SetRowCount, map, ref flags);
}
if (baseFactory != DapperBaseCommandFactory)
{
Expand All @@ -524,7 +525,7 @@ private static void WriteCommandFactory(in GenerateState ctx, string baseFactory
{
sb.Append("public override global::System.Threading.CancellationToken GetCancellationToken(").Append(declaredType).Append(" args)")
.Indent().NewLine();
WriteArgs(type, sb, WriteArgsMode.GetCancellationToken, map, ref flags);
WriteArgs(ctx, type, sb, WriteArgsMode.GetCancellationToken, map, ref flags);
sb.Outdent().NewLine();
}
}
Expand Down Expand Up @@ -702,7 +703,7 @@ static bool IsReserved(string name)
}
}

private static void WriteRowFactory(in GenerateState context, CodeWriter sb, ITypeSymbol type, int index)
private static void WriteRowFactory(in GenerateState ctx, CodeWriter sb, ITypeSymbol type, int index)
{
var map = MemberMap.CreateForResults(type);
if (map is null) return;
Expand All @@ -723,6 +724,7 @@ private static void WriteRowFactory(in GenerateState context, CodeWriter sb, ITy
var hasGetOnlyMembers = members.Any(member => member is { IsGettable: true, IsSettable: false, IsInitOnly: false });
var useConstructorDeferred = map.Constructor is not null;
var useFactoryMethodDeferred = map.FactoryMethod is not null;
var typeHandlers = ctx.TypeHandlers; // Prevent ctx getting captured

// Implementation detail:
// constructor takes advantage over factory method.
Expand Down Expand Up @@ -756,18 +758,29 @@ void WriteTokenizeMethod()
.Append("var type = reader.GetFieldType(columnOffset);").NewLine()
.Append("switch (NormalizedHash(name))").Indent().NewLine();

int token = 0;
int firstToken = 0;
int secondToken = map.Members.Length;
foreach (var member in members)
{
var dbName = member.DbName;
sb.Append("case ").Append(StringHashing.NormalizedHash(dbName))
.Append(" when NormalizedEquals(name, ")
.AppendVerbatimLiteral(StringHashing.Normalize(dbName)).Append("):").Indent(false).NewLine()
.Append("token = type == typeof(").Append(Inspection.MakeNonNullable(member.CodeType)).Append(") ? ").Append(token)
.Append(" : ").Append(token + map.Members.Length).Append(";")
.Append(token == 0 ? " // two tokens for right-typed and type-flexible" : "").NewLine()
.AppendVerbatimLiteral(StringHashing.Normalize(dbName)).Append("):").Indent(false).NewLine();

if (typeHandlers.TryGetValue(member.CodeType, out var typeHandler))
{
sb.Append("token = ").Append(firstToken).Append(";");
}
else
{
sb.Append("token = type == typeof(").Append(Inspection.MakeNonNullable(member.CodeType)).Append(") ? ").Append(firstToken)
.Append(" : ").Append(secondToken).Append(";");
secondToken++;
}

sb.Append(firstToken == 0 ? " // two tokens for right-typed and type-flexible" : "").NewLine()
.Append("break;").Outdent(false).NewLine();
token++;
firstToken++;
}
sb.Outdent().NewLine()
.Append("tokens[i] = token;").NewLine()
Expand Down Expand Up @@ -825,45 +838,55 @@ void WriteReadMethod()
sb.Append("foreach (var token in tokens)").Indent().NewLine()
.Append("switch (token)").Indent().NewLine();

token = 0;
int firstToken = 0;
int secondToken = members.Length;
foreach (var member in members)
{
var memberType = member.CodeType;

member.GetDbType(out var readerMethod);
var nullCheck = Inspection.CouldBeNullable(memberType) ? $"reader.IsDBNull(columnOffset) ? ({CodeWriter.GetTypeName(memberType.WithNullableAnnotation(NullableAnnotation.Annotated))})null : " : "";
sb.Append("case ").Append(token).Append(":").NewLine().Indent(false);
sb.Append("case ").Append(firstToken).Append(":").NewLine().Indent(false);

// write `result.X = ` or `member0 = `
if (useDeferredConstruction) sb.Append(DeferredConstructionVariableName).Append(token);
if (useDeferredConstruction) sb.Append(DeferredConstructionVariableName).Append(firstToken);
else sb.Append("result.").Append(member.CodeName);
sb.Append(" = ");

sb.Append(nullCheck);
if (readerMethod is null)
if (typeHandlers.TryGetValue(memberType, out var handler))
{
sb.Append("reader.GetFieldValue<").Append(memberType).Append(">(columnOffset);");
sb.Append("new ").Append(handler).Append("().Read(reader, columnOffset);").NewLine()
.Append("break;").NewLine().Outdent(false);
}
else
{
sb.Append("reader.").Append(readerMethod).Append("(columnOffset);");
}
if (readerMethod is null)
{
sb.Append("reader.GetFieldValue<").Append(memberType).Append(">(columnOffset);");
}
else
{
sb.Append("reader.").Append(readerMethod).Append("(columnOffset);");
}

sb.NewLine().Append("break;").NewLine().Outdent(false)
.Append("case ").Append(secondToken).Append(":").NewLine().Indent(false);

sb.NewLine().Append("break;").NewLine().Outdent(false)
.Append("case ").Append(token + map.Members.Length).Append(":").NewLine().Indent(false);
// write `result.X = ` or `member0 = `
if (useDeferredConstruction) sb.Append(DeferredConstructionVariableName).Append(firstToken);
else sb.Append("result.").Append(member.CodeName);

// write `result.X = ` or `member0 = `
if (useDeferredConstruction) sb.Append(DeferredConstructionVariableName).Append(token);
else sb.Append("result.").Append(member.CodeName);
sb.Append(" = ")
.Append(nullCheck)
.Append("GetValue<")
.Append(Inspection.MakeNonNullable(memberType)).Append(">(reader, columnOffset);").NewLine()
.Append("break;").NewLine().Outdent(false);

sb.Append(" = ")
.Append(nullCheck)
.Append("GetValue<")
.Append(Inspection.MakeNonNullable(memberType)).Append(">(reader, columnOffset);").NewLine()
.Append("break;").NewLine().Outdent(false);
secondToken++;
}

token++;
firstToken++;
}

sb.Outdent().NewLine().Append("columnOffset++;").NewLine().Outdent().NewLine();
Expand Down Expand Up @@ -966,7 +989,7 @@ enum WriteArgsMode
GetCancellationToken
}

private static void WriteArgs(ITypeSymbol? parameterType, CodeWriter sb, WriteArgsMode mode, string map, ref WriteArgsFlags flags)
private static void WriteArgs(in GenerateState ctx, ITypeSymbol? parameterType, CodeWriter sb, WriteArgsMode mode, string map, ref WriteArgsFlags flags)
{
if (parameterType is null)
{
Expand Down Expand Up @@ -1130,7 +1153,7 @@ private static void WriteArgs(ITypeSymbol? parameterType, CodeWriter sb, WriteAr
}
else
{
sb.Append("p.Value = ").Append("AsValue(").Append(source).Append(".").Append(member.CodeName).Append(");").NewLine();
AppendSetValue(ctx, sb, "p", source, member);
}
break;
default:
Expand All @@ -1149,30 +1172,32 @@ private static void WriteArgs(ITypeSymbol? parameterType, CodeWriter sb, WriteAr
}
break;
case WriteArgsMode.Update:
sb.Append("ps[");
if ((flags & WriteArgsFlags.NeedsTest) != 0) sb.AppendVerbatimLiteral(member.DbName);
else sb.Append(parameterIndex);
sb.Append("].Value = ");
var parameter = GetParameterIndex(flags, member.DbName, parameterIndex);
switch (direction)
{
case ParameterDirection.Input:
case ParameterDirection.InputOutput:
sb.Append("AsValue(").Append(source).Append(".").Append(member.CodeName).Append(");").NewLine();
AppendSetValue(ctx, sb, parameter, source, member);
break;
default:
sb.Append("global::System.DBNull.Value;").NewLine();
sb.Append(parameter).Append(".Value = global::System.DBNull.Value;").NewLine();
break;

}
break;
case WriteArgsMode.PostProcess:
// we already eliminated args that we don't need to look at
sb.Append(source).Append(".").Append(member.CodeName).Append(" = Parse<")
.Append(member.CodeType).Append(">(ps[");
if ((flags & WriteArgsFlags.NeedsTest) != 0) sb.AppendVerbatimLiteral(member.DbName);
else sb.Append(parameterIndex);
sb.Append("].Value);").NewLine();

parameter = GetParameterIndex(flags, member.DbName, parameterIndex);
sb.Append(source).Append(".").Append(member.CodeName).Append(" = ");
if (ctx.TypeHandlers.TryGetValue(member.CodeType, out var handler))
{
sb.Append("new ").Append(handler).Append("().Parse(").Append(parameter).Append(");").NewLine();
}
else
{
sb.Append(source).Append(".").Append(member.CodeName).Append("Parse<")
.Append(member.CodeType).Append(">(").Append(parameter).Append(".Value);").NewLine();
}
break;
}
if (test)
Expand All @@ -1198,6 +1223,20 @@ static void AppendDbParameterSetting(CodeWriter sb, string memberName, byte? val
}
}

private static void AppendSetValue(in GenerateState ctx, CodeWriter sb, string parameter, string? source, in Inspection.ElementMember member)
{
if (ctx.TypeHandlers.TryGetValue(member.CodeType, out var handler))
{
sb.Append("new ").Append(handler).Append("().SetValue(")
.Append(parameter).Append(", ").Append(source).Append(".").Append(member.CodeName)
.Append(");").NewLine();
}
else
{
sb.Append(parameter).Append(".Value = AsValue(").Append(source).Append(".").Append(member.CodeName).Append(");").NewLine();
}
}

private static void AppendShapeLambda(CodeWriter sb, ITypeSymbol parameterType)
{
var members = parameterType.GetMembers();
Expand Down Expand Up @@ -1227,6 +1266,15 @@ private static void AppendShapeLambda(CodeWriter sb, ITypeSymbol parameterType)
}
}

private static string GetParameterIndex(WriteArgsFlags flags, string dbName, int parameterIndex)
{
string index = ((flags & WriteArgsFlags.NeedsTest) != 0)
? CodeWriter.CreateVerbatimLiteral(dbName)
: parameterIndex.ToString(CultureInfo.InvariantCulture);

return "ps[" + index + "]";
}

private static SpecialCommandFlags GetSpecialCommandFlags(ITypeSymbol type)
{
// check whether these command-types need special handling
Expand Down Expand Up @@ -1336,6 +1384,31 @@ static bool IsDerived(ITypeSymbol? type, ITypeSymbol baseType)
}
}

private static IImmutableDictionary<ITypeSymbol, ITypeSymbol> IdentifyTypeHandlers(in SourceProductionContext ctx, Compilation compilation)
{
var assembly = compilation.Assembly;
var attributes = assembly.GetAttributes()
.Concat(assembly.Modules.SelectMany(x => x.GetAttributes()))
.Where(x => Inspection.IsDapperAttribute(x) && x.AttributeClass!.Name == "TypeHandlerAttribute");

var dictionary = ImmutableDictionary.CreateBuilder<ITypeSymbol, ITypeSymbol>(SymbolEqualityComparer.Default);
foreach (var attribute in attributes)
{
var valueType = attribute.AttributeClass!.TypeArguments[0];
var typeHandler = attribute.AttributeClass!.TypeArguments[1];
if (dictionary.ContainsKey(valueType))
{
ctx.ReportDiagnostic(Diagnostic.Create(Diagnostics.DuplicateTypeHandlers, null, valueType.Name));
}
else
{
dictionary.Add(valueType, typeHandler);
}
}

return dictionary.ToImmutable();
}

internal abstract class SourceState
{
public Location? Location { get; }
Expand Down
9 changes: 6 additions & 3 deletions src/Dapper.AOT.Analyzers/CodeAnalysis/ParseState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,21 @@ public GenerateState(GenerateContextProxy proxy)
Nodes = proxy.Nodes;
ctx = default;
this.proxy = proxy;
TypeHandlers = ImmutableDictionary<ITypeSymbol, ITypeSymbol>.Empty;
}
public GenerateState(SourceProductionContext ctx, in (Compilation Compilation, ImmutableArray<SourceState> Nodes) state)
public GenerateState(SourceProductionContext ctx, Compilation compilation, ImmutableArray<SourceState> nodes, IImmutableDictionary<ITypeSymbol, ITypeSymbol> typeHandlers)
{
Compilation = state.Compilation;
Nodes = state.Nodes;
Compilation = compilation;
Nodes = nodes;
TypeHandlers = typeHandlers;
this.ctx = ctx;
proxy = null;
}
private readonly SourceProductionContext ctx;
private readonly GenerateContextProxy? proxy;
public readonly ImmutableArray<SourceState> Nodes;
public readonly Compilation Compilation;
public readonly IImmutableDictionary<ITypeSymbol, ITypeSymbol> TypeHandlers;

internal void ReportDiagnostic(Diagnostic diagnostic)
{
Expand Down
Loading