Skip to content

Commit

Permalink
add special single-pattern match!() handling
Browse files Browse the repository at this point in the history
For improved error messages, fixes pbackus#81
  • Loading branch information
WebFreak001 committed Mar 7, 2022
1 parent bb7e2ae commit e9d7fc6
Showing 1 changed file with 119 additions and 1 deletion.
120 changes: 119 additions & 1 deletion src/sumtype.d
Original file line number Diff line number Diff line change
Expand Up @@ -1544,7 +1544,15 @@ template match(handlers...)
auto ref match(SumTypes...)(auto ref SumTypes args)
if (allSatisfy!(isSumType, SumTypes) && args.length > 0)
{
return matchImpl!(Yes.exhaustive, handlers)(args);
static if (handlers.length == 1)
{
// try to call the handler with any type
return matchSingleImpl!(handlers[0])(args);
}
else
{
return matchImpl!(Yes.exhaustive, handlers)(args);
}
}
}

Expand Down Expand Up @@ -1964,6 +1972,94 @@ private template matchImpl(Flag!"exhaustive" exhaustive, handlers...)
}
}

// copy of matchImpl, but only using a single handler without any canMatch checks.
// this is useful to improve error messages on single-handler match invocations.
private template matchSingleImpl(alias handler)
{
auto ref matchSingleImpl(SumTypes...)(auto ref SumTypes args)
if (allSatisfy!(isSumType, SumTypes) && args.length > 0)
{
enum typeCount(SumType) = SumType.Types.length;
alias stride(size_t i) = .stride!(i, Map!(typeCount, SumTypes));

static struct TagTuple
{
size_t[SumTypes.length] tags;
alias tags this;

invariant {
static foreach (i; 0 .. tags.length) {
assert(tags[i] < SumTypes[i].Types.length);
}
}

this(ref const(SumTypes) args)
{
static foreach (i; 0 .. tags.length) {
tags[i] = args[i].tag;
}
}

static TagTuple fromCaseId(size_t caseId)
{
TagTuple result;

// Most-significant to least-significant
static foreach_reverse (i; 0 .. result.length) {
result[i] = caseId / stride!i;
caseId %= stride!i;
}

return result;
}

size_t toCaseId()
{
size_t result;

static foreach (i; 0 .. tags.length) {
result += tags[i] * stride!i;
}

return result;
}
}

/*
* A list of arguments to be passed to a handler needed for the case
* labeled with `caseId`.
*/
template handlerArgs(size_t caseId)
{
enum tags = TagTuple.fromCaseId(caseId);
enum argsFrom(size_t i: tags.length) = "";
enum argsFrom(size_t i) = "args[" ~ toCtString!i ~ "].get!(SumTypes[" ~ toCtString!i ~ "]" ~
".Types[" ~ toCtString!(tags[i]) ~ "])(), " ~ argsFrom!(i + 1);
enum handlerArgs = argsFrom!0;
}

/* The total number of cases is
*
* Π SumTypes[i].Types.length for 0 ≤ i < SumTypes.length
*
* Conveniently, this is equal to stride!(SumTypes.length), so we can
* use that function to compute it.
*/
enum numCases = stride!(SumTypes.length);

immutable argsId = TagTuple(args).toCaseId;

final switch (argsId) {
static foreach (caseId; 0 .. numCases) {
case caseId:
return mixin("handler(", handlerArgs!caseId, ")");
}
}

assert(false, "unreachable");
}
}

// Matching
@safe unittest {
alias MySum = SumType!(int, float);
Expand Down Expand Up @@ -2419,3 +2515,25 @@ private void destroyIfOwner(T)(ref T value)
destroy(value);
}
}

@safe unittest {
int someFun(T)(T v)
{
static assert(!is(T == int));
return 3;
}

SumType!(int, string, This[]) v;
static assert(!__traits(compiles, v.match!(all => someFun(all))));
// v.match!(all => someFun(all)); // uncomment to test error message

/* should say something like:
src/sumtype.d(2564,3): Error: static assert: `!is(int == int)` is false
src/sumtype.d(2570,25): instantiated from here: `someFun!int`
src/sumtype.d-mixin-2097(2097,8): instantiated from here: `__lambda4!int`
src/sumtype.d(1550,40): instantiated from here: `matchSingleImpl!(SumType!(int, string, This[]))`
src/sumtype.d(2570,3): instantiated from here: `match!(SumType!(int, string, This[]))`
and not say handlers[0] doesn't match anything
*/
}

0 comments on commit e9d7fc6

Please sign in to comment.