diff --git a/std/sumtype.d b/std/sumtype.d index 69c2a49dd7f..ad2942873a6 100644 --- a/std/sumtype.d +++ b/std/sumtype.d @@ -1860,88 +1860,65 @@ private template Iota(size_t n) assert(Iota!3 == AliasSeq!(0, 1, 2)); } -/* The number that the dim-th argument's tag is multiplied by when - * converting TagTuples to and from case indices ("caseIds"). - * - * Named by analogy to the stride that the dim-th index into a - * multidimensional static array is multiplied by to calculate the - * offset of a specific element. - */ -private size_t stride(size_t dim, lengths...)() -{ - import core.checkedint : mulu; - - size_t result = 1; - bool overflow = false; - - static foreach (i; 0 .. dim) - { - result = mulu(result, lengths[i], overflow); - } - - /* The largest number matchImpl uses, numCases, is calculated with - * stride!(SumTypes.length), so as long as this overflow check - * passes, we don't need to check for overflow anywhere else. - */ - assert(!overflow, "Integer overflow"); - return result; -} - private template matchImpl(Flag!"exhaustive" exhaustive, handlers...) { auto ref matchImpl(SumTypes...)(auto ref SumTypes args) if (allSatisfy!(isSumType, SumTypes) && args.length > 0) { - alias stride(size_t i) = .stride!(i, Map!(typeCount, SumTypes)); - alias TagTuple = .TagTuple!(SumTypes); - - /* - * A list of arguments to be passed to a handler needed for the case - * labeled with `caseId`. - */ - template handlerArgs(size_t caseId) + // Single dispatch (fast path) + static if (args.length == 1) { - enum tags = TagTuple.fromCaseId(caseId); - enum argsFrom(size_t i : tags.length) = ""; - enum argsFrom(size_t i) = "args[" ~ toCtString!i ~ "].get!(SumTypes[" ~ toCtString!i ~ "]" ~ - ".Types[" ~ toCtString!(tags[i]) ~ "])(), " ~ argsFrom!(i + 1); - enum handlerArgs = argsFrom!0; - } + /* When there's only one argument, the caseId is just that + * argument's tag, so there's no need for TagTuple. + */ + enum handlerArgs(size_t caseId) = + "args[0].get!(SumTypes[0].Types[" ~ toCtString!caseId ~ "])()"; - /* An AliasSeq of the types of the member values in the argument list - * returned by `handlerArgs!caseId`. - * - * Note that these are the actual (that is, qualified) types of the - * member values, which may not be the same as the types listed in - * the arguments' `.Types` properties. - */ - template valueTypes(size_t caseId) + alias valueTypes(size_t caseId) = + typeof(args[0].get!(SumTypes[0].Types[caseId])()); + + enum numCases = SumTypes[0].Types.length; + } + // Multiple dispatch (slow path) + else { - enum tags = TagTuple.fromCaseId(caseId); + alias typeCounts = Map!(typeCount, SumTypes); + alias stride(size_t i) = .stride!(i, typeCounts); + alias TagTuple = .TagTuple!typeCounts; + + alias handlerArgs(size_t caseId) = .handlerArgs!(caseId, typeCounts); - template getType(size_t i) + /* An AliasSeq of the types of the member values in the argument list + * returned by `handlerArgs!caseId`. + * + * Note that these are the actual (that is, qualified) types of the + * member values, which may not be the same as the types listed in + * the arguments' `.Types` properties. + */ + template valueTypes(size_t caseId) { - enum tid = tags[i]; - alias T = SumTypes[i].Types[tid]; - alias getType = typeof(args[i].get!T()); + enum tags = TagTuple.fromCaseId(caseId); + + template getType(size_t i) + { + enum tid = tags[i]; + alias T = SumTypes[i].Types[tid]; + alias getType = typeof(args[i].get!T()); + } + + alias valueTypes = Map!(getType, Iota!(tags.length)); } - alias valueTypes = Map!(getType, Iota!(tags.length)); + /* The total number of cases is + * + * Π SumTypes[i].Types.length for 0 ≤ i < SumTypes.length + * + * Conveniently, this is equal to stride!(SumTypes.length), so we can + * use that function to compute it. + */ + enum numCases = stride!(SumTypes.length); } - /* The total number of cases is - * - * Π SumTypes[i].Types.length for 0 ≤ i < SumTypes.length - * - * Or, equivalently, - * - * ubyte[SumTypes[0].Types.length]...[SumTypes[$-1].Types.length].sizeof - * - * Conveniently, this is equal to stride!(SumTypes.length), so we can - * use that function to compute it. - */ - enum numCases = stride!(SumTypes.length); - /* Guaranteed to never be a valid handler index, since * handlers.length <= size_t.max. */ @@ -1998,7 +1975,12 @@ private template matchImpl(Flag!"exhaustive" exhaustive, handlers...) mixin("alias ", handlerName!hid, " = handler;"); } - immutable argsId = TagTuple(args).toCaseId; + // Single dispatch (fast path) + static if (args.length == 1) + immutable argsId = args[0].tag; + // Multiple dispatch (slow path) + else + immutable argsId = TagTuple(args).toCaseId; final switch (argsId) { @@ -2029,10 +2011,11 @@ private template matchImpl(Flag!"exhaustive" exhaustive, handlers...) } } +// Predicate for staticMap private enum typeCount(SumType) = SumType.Types.length; -/* A TagTuple represents a single possible set of tags that `args` - * could have at runtime. +/* A TagTuple represents a single possible set of tags that the arguments to + * `matchImpl` could have at runtime. * * Because D does not allow a struct to be the controlling expression * of a switch statement, we cannot dispatch on the TagTuple directly. @@ -2054,22 +2037,23 @@ private enum typeCount(SumType) = SumType.Types.length; * When there is only one argument, the caseId is equal to that * argument's tag. */ -private struct TagTuple(SumTypes...) +private struct TagTuple(typeCounts...) { - size_t[SumTypes.length] tags; + size_t[typeCounts.length] tags; alias tags this; - alias stride(size_t i) = .stride!(i, Map!(typeCount, SumTypes)); + alias stride(size_t i) = .stride!(i, typeCounts); invariant { static foreach (i; 0 .. tags.length) { - assert(tags[i] < SumTypes[i].Types.length, "Invalid tag"); + assert(tags[i] < typeCounts[i], "Invalid tag"); } } - this(ref const(SumTypes) args) + this(SumTypes...)(ref const SumTypes args) + if (allSatisfy!(isSumType, SumTypes) && args.length == typeCounts.length) { static foreach (i; 0 .. tags.length) { @@ -2104,6 +2088,52 @@ private struct TagTuple(SumTypes...) } } +/* The number that the dim-th argument's tag is multiplied by when + * converting TagTuples to and from case indices ("caseIds"). + * + * Named by analogy to the stride that the dim-th index into a + * multidimensional static array is multiplied by to calculate the + * offset of a specific element. + */ +private size_t stride(size_t dim, lengths...)() +{ + import core.checkedint : mulu; + + size_t result = 1; + bool overflow = false; + + static foreach (i; 0 .. dim) + { + result = mulu(result, lengths[i], overflow); + } + + /* The largest number matchImpl uses, numCases, is calculated with + * stride!(SumTypes.length), so as long as this overflow check + * passes, we don't need to check for overflow anywhere else. + */ + assert(!overflow, "Integer overflow"); + return result; +} + +/* A list of arguments to be passed to a handler needed for the case + * labeled with `caseId`. + */ +private template handlerArgs(size_t caseId, typeCounts...) +{ + enum tags = TagTuple!typeCounts.fromCaseId(caseId); + + alias handlerArgs = AliasSeq!(); + + static foreach (i; 0 .. tags.length) + { + handlerArgs = AliasSeq!( + handlerArgs, + "args[" ~ toCtString!i ~ "].get!(SumTypes[" ~ toCtString!i ~ "]" ~ + ".Types[" ~ toCtString!(tags[i]) ~ "])(), " + ); + } +} + // Matching @safe unittest {