diff --git a/src/Kit/Ast/ExprType.hs b/src/Kit/Ast/ExprType.hs index f9bd7df..ada8fe6 100644 --- a/src/Kit/Ast/ExprType.hs +++ b/src/Kit/Ast/ExprType.hs @@ -100,6 +100,7 @@ data ExprType a b | Yield a | Tokens Str | Defined (Identifier b) + | TagExpr a deriving (Eq, Generic, Show) instance (Hashable a, Hashable b) => Hashable (ExprType a b) @@ -162,6 +163,7 @@ exprDiscriminant et = case et of Defined _ -> 53 ArrayWrite _ _ _ -> 54 FieldWrite _ _ _ -> 55 + TagExpr _ -> 56 x -> throwk $ InternalError ("Expression has no discriminant: " ++ show x) Nothing @@ -203,6 +205,7 @@ exprChildren et = case et of Yield x -> [x] ArrayWrite x y z -> [x, y, z] FieldWrite x s y -> [x, y] + TagExpr x -> [x] _ -> [] exprMapReduce :: (a -> c) -> (c -> d -> d) -> (a -> ExprType a b) -> d -> a -> d diff --git a/src/Kit/Compiler/Ir/ExprToIr.hs b/src/Kit/Compiler/Ir/ExprToIr.hs index ba1732f..f5de286 100644 --- a/src/Kit/Compiler/Ir/ExprToIr.hs +++ b/src/Kit/Compiler/Ir/ExprToIr.hs @@ -507,6 +507,9 @@ typedToIr ctx ictx mod e@(TypedExpr { tExpr = et, tPos = pos, inferredType = t } [IrIdentifier ([], x), IrType f] ( VarArgListCopy x) -> return $ IrIdentifier ([], x) InlineCExpr s t -> return $ IrInlineC s + (TagExpr x) -> do + x <- r x + return $ IrField x discriminantFieldName t -> do throwk $ InternalError ("Unexpected expression in typed AST:\n\n" ++ show t) diff --git a/src/Kit/Compiler/Typers/ConvertExpr.hs b/src/Kit/Compiler/Typers/ConvertExpr.hs index 9c257bd..f72ca88 100644 --- a/src/Kit/Compiler/Typers/ConvertExpr.hs +++ b/src/Kit/Compiler/Typers/ConvertExpr.hs @@ -211,6 +211,9 @@ convertExpr ctx tctx mod params e = do return $ m (Defined id) TypeBool VarArgListCopy s -> do return $ m (VarArgListCopy s) TypeVaList + TagExpr x -> do + x <- r x + return $ m (TagExpr x) (TypeInt 0) _ -> throwk $ InternalError ("Can't convert expression: " ++ show (expr e)) (Just pos') diff --git a/src/Kit/Compiler/Typers/TypeExpression.hs b/src/Kit/Compiler/Typers/TypeExpression.hs index 3c36134..5853335 100644 --- a/src/Kit/Compiler/Typers/TypeExpression.hs +++ b/src/Kit/Compiler/Typers/TypeExpression.hs @@ -330,6 +330,10 @@ typeExpr ctx tctx mod ex@(TypedExpr { tExpr = et, tPos = pos }) = do (tPos ex) return ex + (TagExpr x) -> do + x <- r x + return $ makeExprTyped (TagExpr x) (inferredType ex) pos + _ -> return $ ex t' <- follow ctx tctx $ inferredType result diff --git a/src/Kit/Compiler/Typers/TypeExpression/TypeField.hs b/src/Kit/Compiler/Typers/TypeExpression/TypeField.hs index 59c972f..8236c57 100644 --- a/src/Kit/Compiler/Typers/TypeExpression/TypeField.hs +++ b/src/Kit/Compiler/Typers/TypeExpression/TypeField.hs @@ -229,33 +229,36 @@ typeField (TyperUtils { _r = r, _tryRewrite = tryRewrite, _resolve = resolve, _t pos Enum { enumVariants = variants } -> do - case - find (((==) fieldName) . tpName . variantName) - variants - of - Just v -> do - resolve $ TypeEq - (TypeEnumVariant tp - (tpName $ variantName v) - params - ) - (inferredType ex) - "Struct field access must match the field's type" - (tPos r1) - return $ r1 - { inferredType = (TypeEnumVariant - tp + if fieldName == "tag" + then return $ makeExprTyped (TagExpr r1) (TypeInt 0) pos + else + case + find (((==) fieldName) . tpName . variantName) + variants + of + Just v -> do + resolve $ TypeEq + (TypeEnumVariant tp (tpName $ variantName v) params - ) - } - Nothing -> throwk $ TypingError - ( "Enum " - ++ (s_unpack $ showTypePath tp) - ++ " doesn't have a variant called " - ++ s_unpack fieldName - ) - pos + ) + (inferredType ex) + "Enum field access must match the field's type" + (tPos r1) + return $ r1 + { inferredType = (TypeEnumVariant + tp + (tpName $ variantName v) + params + ) + } + Nothing -> throwk $ TypingError + ( "Enum " + ++ (s_unpack $ showTypePath tp) + ++ " doesn't have a variant called " + ++ s_unpack fieldName + ) + pos Abstract { abstractUnderlyingType = u } -> -- forward to parent diff --git a/tests/functional/enums.kit b/tests/functional/enums.kit index 5fb51b0..9a10af3 100644 --- a/tests/functional/enums.kit +++ b/tests/functional/enums.kit @@ -60,10 +60,15 @@ function main() { // simple enum equality if (a != b) { - printf("hello!\n"); + printf("simple enums: equal\n"); } - // TODO: complex enum equality (not yet implemented) + if (c.tag == Apple2(1).tag) { + printf("complex enums: tags should be equal\n"); + } + if (c.tag != Banana2(1).tag) { + printf("complex enums: tags shouldn't be equal\n"); + } // simple enum methods + match a.print(); diff --git a/tests/functional/enums.stdout b/tests/functional/enums.stdout index 7efa9e8..a7a26d1 100644 --- a/tests/functional/enums.stdout +++ b/tests/functional/enums.stdout @@ -1,5 +1,7 @@ hello -hello! +simple enums: equal +complex enums: tags should be equal +complex enums: tags shouldn't be equal apple banana Apple2: 1 diff --git a/tests/functional/issues/issue134.kit b/tests/functional/issues/issue134.kit new file mode 100644 index 0000000..cadf5d6 --- /dev/null +++ b/tests/functional/issues/issue134.kit @@ -0,0 +1,13 @@ +struct MyStruct[A] { + var value: A; +} + +enum MyEnum[A] { + MyVariant(a: MyStruct[A]); +} + +function main() { + var e: MyEnum[Int]; + // var s: MyStruct[Int]; + puts("hi"); +} diff --git a/tests/functional/issues/issue134.stdout b/tests/functional/issues/issue134.stdout new file mode 100644 index 0000000..45b983b --- /dev/null +++ b/tests/functional/issues/issue134.stdout @@ -0,0 +1 @@ +hi