Skip to content

Commit

Permalink
Refactor substituteType
Browse files Browse the repository at this point in the history
  • Loading branch information
gardspirito committed Aug 18, 2023
1 parent cd70563 commit 47ff526
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 151 deletions.
36 changes: 18 additions & 18 deletions src/Grace/Infer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -326,28 +326,28 @@ subtype _A0 _B0 = do
-- <:∃R
(_, Type.Exists{ domain = Domain.Type, .. }) -> do
scopedUnsolvedType nameLocation \a ->
subtype _A0 (Type.substituteType name 0 a type_)
subtype _A0 (Type.substituteType name a type_)

(_, Type.Exists{ domain = Domain.Fields, .. }) -> do
scopedUnsolvedFields \a -> do
subtype _A0 (Type.substituteFields name 0 a type_)
subtype _A0 (Type.substituteFields name a type_)

(_, Type.Exists{ domain = Domain.Alternatives, .. }) -> do
scopedUnsolvedAlternatives \a -> do
subtype _A0 (Type.substituteAlternatives name 0 a type_)
subtype _A0 (Type.substituteAlternatives name a type_)

-- <:∀L
(Type.Forall{ domain = Domain.Type, .. }, _) -> do
scopedUnsolvedType nameLocation \a -> do
subtype (Type.substituteType name 0 a type_) _B0
subtype (Type.substituteType name a type_) _B0

(Type.Forall{ domain = Domain.Fields, .. }, _) -> do
scopedUnsolvedFields \a -> do
subtype (Type.substituteFields name 0 a type_) _B0
subtype (Type.substituteFields name a type_) _B0

(Type.Forall{ domain = Domain.Alternatives, .. }, _) -> do
scopedUnsolvedAlternatives \a -> do
subtype (Type.substituteAlternatives name 0 a type_) _B0
subtype (Type.substituteAlternatives name a type_) _B0

(Type.Scalar{ scalar = s0 }, Type.Scalar{ scalar = s1 })
| s0 == s1 -> do
Expand Down Expand Up @@ -788,13 +788,13 @@ instantiateTypeL a _A0 = do
-- InstLExt
Type.Exists{ domain = Domain.Type, .. } -> do
scopedUnsolvedType nameLocation \b -> do
instantiateTypeR (Type.substituteType name 0 b type_) a
instantiateTypeR (Type.substituteType name b type_) a
Type.Exists{ domain = Domain.Fields, .. } -> do
scopedUnsolvedFields \b -> do
instantiateTypeR (Type.substituteFields name 0 b type_) a
instantiateTypeR (Type.substituteFields name b type_) a
Type.Exists{ domain = Domain.Alternatives, .. } -> do
scopedUnsolvedAlternatives \b -> do
instantiateTypeR (Type.substituteAlternatives name 0 b type_) a
instantiateTypeR (Type.substituteAlternatives name b type_) a

-- InstLArr
Type.Function{..} -> do
Expand Down Expand Up @@ -963,13 +963,13 @@ instantiateTypeR _A0 a = do
-- InstRAllL
Type.Forall{ domain = Domain.Type, .. } -> do
scopedUnsolvedType nameLocation \b -> do
instantiateTypeR (Type.substituteType name 0 b type_) a
instantiateTypeR (Type.substituteType name b type_) a
Type.Forall{ domain = Domain.Fields, .. } -> do
scopedUnsolvedFields \b -> do
instantiateTypeR (Type.substituteFields name 0 b type_) a
instantiateTypeR (Type.substituteFields name b type_) a
Type.Forall{ domain = Domain.Alternatives, .. } -> do
scopedUnsolvedAlternatives \b -> do
instantiateTypeR (Type.substituteAlternatives name 0 b type_) a
instantiateTypeR (Type.substituteAlternatives name b type_) a

Type.Optional{..} -> do
let _ΓL =
Expand Down Expand Up @@ -1887,13 +1887,13 @@ check Syntax.Lambda{ location = _, ..} Type.Function{..} = do
-- ∃I
check e Type.Exists{ domain = Domain.Type, .. } = do
scopedUnsolvedType nameLocation \a -> do
check e (Type.substituteType name 0 a type_)
check e (Type.substituteType name a type_)
check e Type.Exists{ domain = Domain.Fields, .. } = do
scopedUnsolvedFields \a -> do
check e (Type.substituteFields name 0 a type_)
check e (Type.substituteFields name a type_)
check e Type.Exists{ domain = Domain.Alternatives, .. } = do
scopedUnsolvedAlternatives \a -> do
check e (Type.substituteAlternatives name 0 a type_)
check e (Type.substituteAlternatives name a type_)

-- ∀I
check e Type.Forall{..} = do
Expand Down Expand Up @@ -2009,23 +2009,23 @@ inferApplication Type.Forall{ domain = Domain.Type, .. } e = do

let a' = Type.UnsolvedType{ location = nameLocation, existential = a}

inferApplication (Type.substituteType name 0 a' type_) e
inferApplication (Type.substituteType name a' type_) e
inferApplication Type.Forall{ domain = Domain.Fields, .. } e = do
a <- fresh

push (Context.UnsolvedFields a)

let a' = Type.Fields [] (Monotype.UnsolvedFields a)

inferApplication (Type.substituteFields name 0 a' type_) e
inferApplication (Type.substituteFields name a' type_) e
inferApplication Type.Forall{ domain = Domain.Alternatives, .. } e = do
a <- fresh

push (Context.UnsolvedAlternatives a)

let a' = Type.Alternatives [] (Monotype.UnsolvedAlternatives a)

inferApplication (Type.substituteAlternatives name 0 a' type_) e
inferApplication (Type.substituteAlternatives name a' type_) e

-- ∃App
inferApplication Type.Exists{..} e = do
Expand Down
182 changes: 49 additions & 133 deletions src/Grace/Type.hs
Original file line number Diff line number Diff line change
Expand Up @@ -262,166 +262,82 @@ solveAlternatives unsolved (Monotype.Alternatives alternativeMonotypes alternati
transformType type_ =
type_

{-| Replace all occurrences of a variable within one `Type` with another `Type`,
given the variable's label and index
{-| Helper function for traversing the tree during `Type` substitutions.
-}
substituteType :: Text -> Int -> Type s -> Type s -> Type s
substituteType a n _A type_ =
case type_ of
VariableType{..}
| a == name && n == 0 -> _A
| otherwise -> VariableType{..}

UnsolvedType{..} ->
UnsolvedType{..}
substitute :: Text -> Domain -> ((Type s -> Type s) -> Type s -> Type s) -> Type s -> Type s
substitute a aDomain substituter = sub
where
sub = substituter def

Exists{ type_ = oldType, .. } -> Exists{ type_ = newType, .. }
where
newType = substituteType a n' _A oldType
def VariableType{..} =
VariableType{..}

n' | a == name && domain == Domain.Type = n + 1
| otherwise = n
def UnsolvedType{..} =
UnsolvedType{..}

Forall{ type_ = oldType, .. } -> Forall{ type_ = newType, .. }
def Exists{ type_ = oldType, .. } = Exists{ type_ = newType, .. }
where
newType = substituteType a n' _A oldType
newType | a == name && domain == aDomain = oldType
| otherwise = sub oldType

n' | a == name && domain == Domain.Type = n + 1
| otherwise = n

Function{ input = oldInput, output = oldOutput, .. } ->
Function{ input = newInput, output = newOutput, .. }
def Forall{ type_ = oldType, .. } = Forall{ type_ = newType, .. }
where
newInput = substituteType a n _A oldInput
newType | a == name && domain == aDomain = oldType
| otherwise = sub oldType

newOutput = substituteType a n _A oldOutput
def Function{ input = oldInput, output = oldOutput, .. } =
Function{ input = sub oldInput, output = sub oldOutput, .. }

Optional{ type_ = oldType, .. } -> Optional{ type_ = newType, .. }
where
newType = substituteType a n _A oldType
def Optional{ type_ = oldType, .. } = Optional{ type_ = sub oldType, .. }

List{ type_ = oldType, .. } -> List{ type_ = newType, .. }
where
newType = substituteType a n _A oldType
def List{ type_ = oldType, .. } = List{ type_ = sub oldType, .. }

Record{ fields = Fields kAs ρ, .. } ->
Record{ fields = Fields (map (second (substituteType a n _A)) kAs) ρ, .. }
def Record{ fields = Fields kAs ρ, .. } =
Record{ fields = Fields (map (second sub) kAs) ρ, .. }

Union{ alternatives = Alternatives kAs ρ, .. } ->
Union{ alternatives = Alternatives (map (second (substituteType a n _A)) kAs) ρ, .. }
def Union{ alternatives = Alternatives kAs ρ, .. } =
Union{ alternatives = Alternatives (map (second sub) kAs) ρ, .. }

Scalar{..} ->
def Scalar{..} =
Scalar{..}

{-| Replace all occurrences of a variable within one `Type` with another `Type`,
given the variable's label and index
given the variable's label
-}
substituteFields :: Text -> Int -> Record s -> Type s -> Type s
substituteFields ρ0 n r@(Fields kτs ρ1) type_ =
case type_ of
VariableType{..} ->
VariableType{..}

UnsolvedType{..} ->
UnsolvedType{..}

Exists{ type_ = oldType, .. } -> Exists{ type_ = newType, .. }
where
newType = substituteFields ρ0 n' r oldType

n' | ρ0 == name && domain == Domain.Fields = n + 1
| otherwise = n

Forall{ type_ = oldType, .. } -> Forall{ type_ = newType, .. }
where
newType = substituteFields ρ0 n' r oldType

n' | ρ0 == name && domain == Domain.Fields = n + 1
| otherwise = n

Function{ input = oldInput, output = oldOutput, .. } ->
Function{ input = newInput, output = newOutput, .. }
where
newInput = substituteFields ρ0 n r oldInput

newOutput = substituteFields ρ0 n r oldOutput

Optional{ type_ = oldType, .. } -> Optional{ type_ = newType, .. }
where
newType = substituteFields ρ0 n r oldType
substituteType :: Text -> Type s -> Type s -> Type s
substituteType a _A = substitute a Domain.Type substituteType'
where
substituteType' _ VariableType{..}
| a == name = _A
substituteType' skipLevel type_ = skipLevel type_

List{ type_ = oldType, .. } -> List{ type_ = newType, .. }
where
newType = substituteFields ρ0 n r oldType

Record{ fields = Fields kAs0 ρ, .. }
| VariableFields ρ0 == ρ && n == 0 ->
Record{ fields = Fields (map (second (substituteFields ρ0 n r)) kAs1) ρ1, .. }
| otherwise ->
Record{ fields = Fields (map (second (substituteFields ρ0 n r)) kAs0) ρ, .. }
{-| Replace all occurrences of a variable within one `Type` with another `Type`,
given the variable's label
-}
substituteFields :: Text -> Record s -> Type s -> Type s
substituteFields ρ0 (Fields kτs ρ1) = substitute ρ0 Domain.Fields substituteFields'
where
substituteFields' skipLevel Record{ fields = Fields kAs0 ρ, .. }
| VariableFields ρ0 == ρ =
Record{ fields = Fields (map (second (substituteFields' skipLevel)) kAs1) ρ1, .. }
where
kAs1 = kAs0 <> map (second (fmap (\_ -> location))) kτs
substituteFields' skipLevel type_ = skipLevel type_

Union{ alternatives = Alternatives kAs ρ, .. } ->
Union{ alternatives = Alternatives (map (second (substituteFields ρ0 n r)) kAs) ρ, .. }

Scalar{..} ->
Scalar{..}

{-| Replace all occurrences of a variable within one `Type` with another `Type`,
given the variable's label and index
-}
substituteAlternatives :: Text -> Int -> Union s -> Type s -> Type s
substituteAlternatives ρ0 n r@(Alternatives kτs ρ1) type_ =
case type_ of
VariableType{..} ->
VariableType{..}

UnsolvedType{..} ->
UnsolvedType{..}

Exists{ type_ = oldType, .. } -> Exists{ type_ = newType, .. }
where
newType = substituteAlternatives ρ0 n' r oldType

n' | ρ0 == name && domain == Domain.Alternatives = n + 1
| otherwise = n

Forall{ type_ = oldType, .. } -> Forall{ type_ = newType, .. }
where
newType = substituteAlternatives ρ0 n' r oldType

n' | ρ0 == name && domain == Domain.Alternatives = n + 1
| otherwise = n

Function{ input = oldInput, output = oldOutput, .. } ->
Function{ input = newInput, output = newOutput, .. }
where
newInput = substituteAlternatives ρ0 n r oldInput

newOutput = substituteAlternatives ρ0 n r oldOutput

Optional{ type_ = oldType, .. } -> Optional{ type_ = newType, .. }
where
newType = substituteAlternatives ρ0 n r oldType

List{ type_ = oldType, .. } -> List{ type_ = newType, .. }
where
newType = substituteAlternatives ρ0 n r oldType

Record{ fields = Fields kAs ρ, .. } ->
Record{ fields = Fields (map (second (substituteAlternatives ρ0 n r)) kAs) ρ, .. }

Union{ alternatives = Alternatives kAs0 ρ, .. }
| Monotype.VariableAlternatives ρ0 == ρ && n == 0 ->
Union{ alternatives = Alternatives (map (second (substituteAlternatives ρ0 n r)) kAs1) ρ1, .. }
| otherwise ->
Union{ alternatives = Alternatives (map (second (substituteAlternatives ρ0 n r)) kAs0) ρ, .. }
substituteAlternatives :: Text -> Union s -> Type s -> Type s
substituteAlternatives ρ0 (Alternatives kτs ρ1) = substitute ρ0 Domain.Alternatives substituteAlternatives'
where
substituteAlternatives' skipLevel Union{ alternatives = Alternatives kAs0 ρ, .. }
| Monotype.VariableAlternatives ρ0 == ρ =
Union{ alternatives = Alternatives (map (second (substituteAlternatives' skipLevel)) kAs1) ρ1, .. }
where
kAs1 = kAs0 <> map (second (fmap (\_ -> location))) kτs

Scalar{..} ->
Scalar{..}
substituteAlternatives' skipLevel type_ = skipLevel type_

{-| Count how many times the given `Existential` `Type` variable appears within
a `Type`
Expand Down

0 comments on commit 47ff526

Please sign in to comment.