From 5345b07791aca535ab70f456d9213168307c0229 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 20 Mar 2024 16:53:03 +0900 Subject: [PATCH] feat: implement the built-in trait entities --- crates/erg_compiler/codegen.rs | 13 ++ .../erg_compiler/context/initialize/traits.rs | 42 ++++-- crates/erg_compiler/lib/core/_erg_traits.py | 131 ++++++++++++++++++ 3 files changed, 177 insertions(+), 9 deletions(-) create mode 100644 crates/erg_compiler/lib/core/_erg_traits.py diff --git a/crates/erg_compiler/codegen.rs b/crates/erg_compiler/codegen.rs index 9b461ad47..85babbe3b 100644 --- a/crates/erg_compiler/codegen.rs +++ b/crates/erg_compiler/codegen.rs @@ -220,6 +220,7 @@ pub struct PyCodeGenerator { module_type_loaded: bool, control_loaded: bool, convertors_loaded: bool, + traits_loaded: bool, operators_loaded: bool, union_loaded: bool, fake_generic_loaded: bool, @@ -248,6 +249,7 @@ impl PyCodeGenerator { module_type_loaded: false, control_loaded: false, convertors_loaded: false, + traits_loaded: false, operators_loaded: false, union_loaded: false, fake_generic_loaded: false, @@ -271,6 +273,7 @@ impl PyCodeGenerator { module_type_loaded: false, control_loaded: false, convertors_loaded: false, + traits_loaded: false, operators_loaded: false, union_loaded: false, fake_generic_loaded: false, @@ -298,6 +301,7 @@ impl PyCodeGenerator { self.module_type_loaded = false; self.control_loaded = false; self.convertors_loaded = false; + self.traits_loaded = false; self.operators_loaded = false; self.union_loaded = false; self.fake_generic_loaded = false; @@ -840,6 +844,9 @@ impl PyCodeGenerator { | "invert" | "is_" | "is_not" | "call" => { self.load_operators(); } + "Eq" | "Ord" | "Hash" | "Add" | "Sub" | "Mul" | "Div" | "Pos" | "Neg" => { + self.load_traits(); + } "CodeType" => { self.emit_global_import_items( Identifier::static_public("types"), @@ -3860,6 +3867,12 @@ impl PyCodeGenerator { self.convertors_loaded = true; } + fn load_traits(&mut self) { + let mod_name = Identifier::static_public("_erg_traits"); + self.emit_import_all_instr(mod_name); + self.traits_loaded = true; + } + fn load_operators(&mut self) { let mod_name = Identifier::static_public("operator"); self.emit_import_all_instr(mod_name); diff --git a/crates/erg_compiler/context/initialize/traits.rs b/crates/erg_compiler/context/initialize/traits.rs index 13dfa0a53..ace98bceb 100644 --- a/crates/erg_compiler/context/initialize/traits.rs +++ b/crates/erg_compiler/context/initialize/traits.rs @@ -539,12 +539,12 @@ impl Context { Const, None, ); - self.register_builtin_type(mono(EQ), eq, vis.clone(), Const, None); + self.register_builtin_type(mono(EQ), eq, vis.clone(), Const, Some(EQ)); self.register_builtin_type(mono(IRREGULAR_EQ), irregular_eq, vis.clone(), Const, None); - self.register_builtin_type(mono(HASH), hash, vis.clone(), Const, None); + self.register_builtin_type(mono(HASH), hash, vis.clone(), Const, Some(HASH)); self.register_builtin_type(mono(EQ_HASH), eq_hash, vis.clone(), Const, None); self.register_builtin_type(mono(PARTIAL_ORD), partial_ord, vis.clone(), Const, None); - self.register_builtin_type(mono(ORD), ord, vis.clone(), Const, None); + self.register_builtin_type(mono(ORD), ord, vis.clone(), Const, Some(ORD)); self.register_builtin_type(mono(NUM), num, vis.clone(), Const, None); self.register_builtin_type(mono(TO_BOOL), to_bool, vis.clone(), Const, None); self.register_builtin_type(mono(TO_INT), to_int, vis.clone(), Const, None); @@ -641,10 +641,34 @@ impl Context { Const, None, ); - self.register_builtin_type(poly(ADD, ty_params.clone()), add, vis.clone(), Const, None); - self.register_builtin_type(poly(SUB, ty_params.clone()), sub, vis.clone(), Const, None); - self.register_builtin_type(poly(MUL, ty_params.clone()), mul, vis.clone(), Const, None); - self.register_builtin_type(poly(DIV, ty_params.clone()), div, vis.clone(), Const, None); + self.register_builtin_type( + poly(ADD, ty_params.clone()), + add, + vis.clone(), + Const, + Some(ADD), + ); + self.register_builtin_type( + poly(SUB, ty_params.clone()), + sub, + vis.clone(), + Const, + Some(SUB), + ); + self.register_builtin_type( + poly(MUL, ty_params.clone()), + mul, + vis.clone(), + Const, + Some(MUL), + ); + self.register_builtin_type( + poly(DIV, ty_params.clone()), + div, + vis.clone(), + Const, + Some(DIV), + ); self.register_builtin_type( poly(FLOOR_DIV, ty_params), floor_div, @@ -652,8 +676,8 @@ impl Context { Const, None, ); - self.register_builtin_type(mono(POS), pos, vis.clone(), Const, None); - self.register_builtin_type(mono(NEG), neg, vis, Const, None); + self.register_builtin_type(mono(POS), pos, vis.clone(), Const, Some(POS)); + self.register_builtin_type(mono(NEG), neg, vis, Const, Some(NEG)); self.register_const_param_defaults( ADD, vec![ConstTemplate::Obj(ValueObj::builtin_type(Slf.clone()))], diff --git a/crates/erg_compiler/lib/core/_erg_traits.py b/crates/erg_compiler/lib/core/_erg_traits.py new file mode 100644 index 000000000..1ac76702c --- /dev/null +++ b/crates/erg_compiler/lib/core/_erg_traits.py @@ -0,0 +1,131 @@ +from abc import ABC, abstractmethod + +from _erg_float import Float + +class Eq(ABC): + @abstractmethod + def __eq__(self, other): pass + + @classmethod + def __subclasshook__(cls, C): + if cls is Eq: + if any("__eq__" in B.__dict__ for B in C.__mro__): + return True + return NotImplemented + +class Ord(ABC): + @abstractmethod + def __lt__(self, other): pass + def __gt__(self, other): pass + def __le__(self, other): pass + def __ge__(self, other): pass + + @classmethod + def __subclasshook__(cls, C): + if cls is Ord: + # TODO: adhoc + if C == float or C == Float: + return False + if any("__lt__" in B.__dict__ for B in C.__mro__): + return True + return NotImplemented + +class Hash(ABC): + @abstractmethod + def __hash__(self): pass + + @classmethod + def __subclasshook__(cls, C): + if cls is Hash: + if any("__hash__" in B.__dict__ for B in C.__mro__): + return True + return NotImplemented + +class Sized(ABC): + @abstractmethod + def __len__(self): pass + + @classmethod + def __subclasshook__(cls, C): + if cls is Sized: + if any("__len__" in B.__dict__ for B in C.__mro__): + return True + return NotImplemented + +class Add(ABC): + Output: type + + @abstractmethod + def __add__(self, other): pass + + @classmethod + def __subclasshook__(cls, C): + if cls is Add: + if any("__add__" in B.__dict__ for B in C.__mro__): + return True + return NotImplemented + +class Sub(ABC): + Output: type + + @abstractmethod + def __sub__(self, other): pass + + @classmethod + def __subclasshook__(cls, C): + if cls is Sub: + if any("__sub__" in B.__dict__ for B in C.__mro__): + return True + return NotImplemented + +class Mul(ABC): + Output: type + + @abstractmethod + def __mul__(self, other): pass + + @classmethod + def __subclasshook__(cls, C): + if cls is Mul: + if any("__mul__" in B.__dict__ for B in C.__mro__): + return True + return NotImplemented + +class Div(ABC): + Output: type + + @abstractmethod + def __truediv__(self, other): pass + + @classmethod + def __subclasshook__(cls, C): + if cls is Div: + if any("__truediv__" in B.__dict__ for B in C.__mro__): + return True + return NotImplemented + +class Pos(ABC): + Output: type + + @abstractmethod + def __pos__(self): pass + + @classmethod + def __subclasshook__(cls, C): + if cls is Pos: + if any("__pos__" in B.__dict__ for B in C.__mro__): + return True + return NotImplemented + +class Neg(ABC): + Output: type + + @abstractmethod + def __neg__(self): pass + + @classmethod + def __subclasshook__(cls, C): + if cls is Neg: + if any("__neg__" in B.__dict__ for B in C.__mro__): + return True + return NotImplemented