Skip to content

Commit

Permalink
Merge pull request #9087 from pbackus/sumtype-template-overhead
Browse files Browse the repository at this point in the history
sumtype: reduce template overhead of match
  • Loading branch information
pbackus authored Nov 21, 2024
2 parents edf6fb9 + f3d92d9 commit 08638dd
Showing 1 changed file with 106 additions and 76 deletions.
182 changes: 106 additions & 76 deletions std/sumtype.d
Original file line number Diff line number Diff line change
Expand Up @@ -1860,88 +1860,65 @@ private template Iota(size_t n)
assert(Iota!3 == AliasSeq!(0, 1, 2));
}

/* The number that the dim-th argument's tag is multiplied by when
* converting TagTuples to and from case indices ("caseIds").
*
* Named by analogy to the stride that the dim-th index into a
* multidimensional static array is multiplied by to calculate the
* offset of a specific element.
*/
private size_t stride(size_t dim, lengths...)()
{
import core.checkedint : mulu;

size_t result = 1;
bool overflow = false;

static foreach (i; 0 .. dim)
{
result = mulu(result, lengths[i], overflow);
}

/* The largest number matchImpl uses, numCases, is calculated with
* stride!(SumTypes.length), so as long as this overflow check
* passes, we don't need to check for overflow anywhere else.
*/
assert(!overflow, "Integer overflow");
return result;
}

private template matchImpl(Flag!"exhaustive" exhaustive, handlers...)
{
auto ref matchImpl(SumTypes...)(auto ref SumTypes args)
if (allSatisfy!(isSumType, SumTypes) && args.length > 0)
{
alias stride(size_t i) = .stride!(i, Map!(typeCount, SumTypes));
alias TagTuple = .TagTuple!(SumTypes);

/*
* A list of arguments to be passed to a handler needed for the case
* labeled with `caseId`.
*/
template handlerArgs(size_t caseId)
// Single dispatch (fast path)
static if (args.length == 1)
{
enum tags = TagTuple.fromCaseId(caseId);
enum argsFrom(size_t i : tags.length) = "";
enum argsFrom(size_t i) = "args[" ~ toCtString!i ~ "].get!(SumTypes[" ~ toCtString!i ~ "]" ~
".Types[" ~ toCtString!(tags[i]) ~ "])(), " ~ argsFrom!(i + 1);
enum handlerArgs = argsFrom!0;
}
/* When there's only one argument, the caseId is just that
* argument's tag, so there's no need for TagTuple.
*/
enum handlerArgs(size_t caseId) =
"args[0].get!(SumTypes[0].Types[" ~ toCtString!caseId ~ "])()";

/* An AliasSeq of the types of the member values in the argument list
* returned by `handlerArgs!caseId`.
*
* Note that these are the actual (that is, qualified) types of the
* member values, which may not be the same as the types listed in
* the arguments' `.Types` properties.
*/
template valueTypes(size_t caseId)
alias valueTypes(size_t caseId) =
typeof(args[0].get!(SumTypes[0].Types[caseId])());

enum numCases = SumTypes[0].Types.length;
}
// Multiple dispatch (slow path)
else
{
enum tags = TagTuple.fromCaseId(caseId);
alias typeCounts = Map!(typeCount, SumTypes);
alias stride(size_t i) = .stride!(i, typeCounts);
alias TagTuple = .TagTuple!typeCounts;

alias handlerArgs(size_t caseId) = .handlerArgs!(caseId, typeCounts);

template getType(size_t i)
/* An AliasSeq of the types of the member values in the argument list
* returned by `handlerArgs!caseId`.
*
* Note that these are the actual (that is, qualified) types of the
* member values, which may not be the same as the types listed in
* the arguments' `.Types` properties.
*/
template valueTypes(size_t caseId)
{
enum tid = tags[i];
alias T = SumTypes[i].Types[tid];
alias getType = typeof(args[i].get!T());
enum tags = TagTuple.fromCaseId(caseId);

template getType(size_t i)
{
enum tid = tags[i];
alias T = SumTypes[i].Types[tid];
alias getType = typeof(args[i].get!T());
}

alias valueTypes = Map!(getType, Iota!(tags.length));
}

alias valueTypes = Map!(getType, Iota!(tags.length));
/* The total number of cases is
*
* Π SumTypes[i].Types.length for 0 ≤ i < SumTypes.length
*
* Conveniently, this is equal to stride!(SumTypes.length), so we can
* use that function to compute it.
*/
enum numCases = stride!(SumTypes.length);
}

/* The total number of cases is
*
* Π SumTypes[i].Types.length for 0 ≤ i < SumTypes.length
*
* Or, equivalently,
*
* ubyte[SumTypes[0].Types.length]...[SumTypes[$-1].Types.length].sizeof
*
* Conveniently, this is equal to stride!(SumTypes.length), so we can
* use that function to compute it.
*/
enum numCases = stride!(SumTypes.length);

/* Guaranteed to never be a valid handler index, since
* handlers.length <= size_t.max.
*/
Expand Down Expand Up @@ -1998,7 +1975,12 @@ private template matchImpl(Flag!"exhaustive" exhaustive, handlers...)
mixin("alias ", handlerName!hid, " = handler;");
}

immutable argsId = TagTuple(args).toCaseId;
// Single dispatch (fast path)
static if (args.length == 1)
immutable argsId = args[0].tag;
// Multiple dispatch (slow path)
else
immutable argsId = TagTuple(args).toCaseId;

final switch (argsId)
{
Expand Down Expand Up @@ -2029,10 +2011,11 @@ private template matchImpl(Flag!"exhaustive" exhaustive, handlers...)
}
}

// Predicate for staticMap
private enum typeCount(SumType) = SumType.Types.length;

/* A TagTuple represents a single possible set of tags that `args`
* could have at runtime.
/* A TagTuple represents a single possible set of tags that the arguments to
* `matchImpl` could have at runtime.
*
* Because D does not allow a struct to be the controlling expression
* of a switch statement, we cannot dispatch on the TagTuple directly.
Expand All @@ -2054,22 +2037,23 @@ private enum typeCount(SumType) = SumType.Types.length;
* When there is only one argument, the caseId is equal to that
* argument's tag.
*/
private struct TagTuple(SumTypes...)
private struct TagTuple(typeCounts...)
{
size_t[SumTypes.length] tags;
size_t[typeCounts.length] tags;
alias tags this;

alias stride(size_t i) = .stride!(i, Map!(typeCount, SumTypes));
alias stride(size_t i) = .stride!(i, typeCounts);

invariant
{
static foreach (i; 0 .. tags.length)
{
assert(tags[i] < SumTypes[i].Types.length, "Invalid tag");
assert(tags[i] < typeCounts[i], "Invalid tag");
}
}

this(ref const(SumTypes) args)
this(SumTypes...)(ref const SumTypes args)
if (allSatisfy!(isSumType, SumTypes) && args.length == typeCounts.length)
{
static foreach (i; 0 .. tags.length)
{
Expand Down Expand Up @@ -2104,6 +2088,52 @@ private struct TagTuple(SumTypes...)
}
}

/* The number that the dim-th argument's tag is multiplied by when
* converting TagTuples to and from case indices ("caseIds").
*
* Named by analogy to the stride that the dim-th index into a
* multidimensional static array is multiplied by to calculate the
* offset of a specific element.
*/
private size_t stride(size_t dim, lengths...)()
{
import core.checkedint : mulu;

size_t result = 1;
bool overflow = false;

static foreach (i; 0 .. dim)
{
result = mulu(result, lengths[i], overflow);
}

/* The largest number matchImpl uses, numCases, is calculated with
* stride!(SumTypes.length), so as long as this overflow check
* passes, we don't need to check for overflow anywhere else.
*/
assert(!overflow, "Integer overflow");
return result;
}

/* A list of arguments to be passed to a handler needed for the case
* labeled with `caseId`.
*/
private template handlerArgs(size_t caseId, typeCounts...)
{
enum tags = TagTuple!typeCounts.fromCaseId(caseId);

alias handlerArgs = AliasSeq!();

static foreach (i; 0 .. tags.length)
{
handlerArgs = AliasSeq!(
handlerArgs,
"args[" ~ toCtString!i ~ "].get!(SumTypes[" ~ toCtString!i ~ "]" ~
".Types[" ~ toCtString!(tags[i]) ~ "])(), "
);
}
}

// Matching
@safe unittest
{
Expand Down

0 comments on commit 08638dd

Please sign in to comment.