Skip to content

Commit

Permalink
feat: implement the built-in trait entities
Browse files Browse the repository at this point in the history
  • Loading branch information
mtshiba committed Mar 20, 2024
1 parent 05fedc3 commit 5345b07
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 9 deletions.
13 changes: 13 additions & 0 deletions crates/erg_compiler/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ pub struct PyCodeGenerator {
module_type_loaded: bool,
control_loaded: bool,
convertors_loaded: bool,
traits_loaded: bool,
operators_loaded: bool,
union_loaded: bool,
fake_generic_loaded: bool,
Expand Down Expand Up @@ -248,6 +249,7 @@ impl PyCodeGenerator {
module_type_loaded: false,
control_loaded: false,
convertors_loaded: false,
traits_loaded: false,
operators_loaded: false,
union_loaded: false,
fake_generic_loaded: false,
Expand All @@ -271,6 +273,7 @@ impl PyCodeGenerator {
module_type_loaded: false,
control_loaded: false,
convertors_loaded: false,
traits_loaded: false,
operators_loaded: false,
union_loaded: false,
fake_generic_loaded: false,
Expand Down Expand Up @@ -298,6 +301,7 @@ impl PyCodeGenerator {
self.module_type_loaded = false;
self.control_loaded = false;
self.convertors_loaded = false;
self.traits_loaded = false;
self.operators_loaded = false;
self.union_loaded = false;
self.fake_generic_loaded = false;
Expand Down Expand Up @@ -840,6 +844,9 @@ impl PyCodeGenerator {
| "invert" | "is_" | "is_not" | "call" => {
self.load_operators();
}
"Eq" | "Ord" | "Hash" | "Add" | "Sub" | "Mul" | "Div" | "Pos" | "Neg" => {
self.load_traits();
}
"CodeType" => {
self.emit_global_import_items(
Identifier::static_public("types"),
Expand Down Expand Up @@ -3860,6 +3867,12 @@ impl PyCodeGenerator {
self.convertors_loaded = true;
}

fn load_traits(&mut self) {
let mod_name = Identifier::static_public("_erg_traits");
self.emit_import_all_instr(mod_name);
self.traits_loaded = true;
}

fn load_operators(&mut self) {
let mod_name = Identifier::static_public("operator");
self.emit_import_all_instr(mod_name);
Expand Down
42 changes: 33 additions & 9 deletions crates/erg_compiler/context/initialize/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -539,12 +539,12 @@ impl Context {
Const,
None,
);
self.register_builtin_type(mono(EQ), eq, vis.clone(), Const, None);
self.register_builtin_type(mono(EQ), eq, vis.clone(), Const, Some(EQ));
self.register_builtin_type(mono(IRREGULAR_EQ), irregular_eq, vis.clone(), Const, None);
self.register_builtin_type(mono(HASH), hash, vis.clone(), Const, None);
self.register_builtin_type(mono(HASH), hash, vis.clone(), Const, Some(HASH));
self.register_builtin_type(mono(EQ_HASH), eq_hash, vis.clone(), Const, None);
self.register_builtin_type(mono(PARTIAL_ORD), partial_ord, vis.clone(), Const, None);
self.register_builtin_type(mono(ORD), ord, vis.clone(), Const, None);
self.register_builtin_type(mono(ORD), ord, vis.clone(), Const, Some(ORD));
self.register_builtin_type(mono(NUM), num, vis.clone(), Const, None);
self.register_builtin_type(mono(TO_BOOL), to_bool, vis.clone(), Const, None);
self.register_builtin_type(mono(TO_INT), to_int, vis.clone(), Const, None);
Expand Down Expand Up @@ -641,19 +641,43 @@ impl Context {
Const,
None,
);
self.register_builtin_type(poly(ADD, ty_params.clone()), add, vis.clone(), Const, None);
self.register_builtin_type(poly(SUB, ty_params.clone()), sub, vis.clone(), Const, None);
self.register_builtin_type(poly(MUL, ty_params.clone()), mul, vis.clone(), Const, None);
self.register_builtin_type(poly(DIV, ty_params.clone()), div, vis.clone(), Const, None);
self.register_builtin_type(
poly(ADD, ty_params.clone()),
add,
vis.clone(),
Const,
Some(ADD),
);
self.register_builtin_type(
poly(SUB, ty_params.clone()),
sub,
vis.clone(),
Const,
Some(SUB),
);
self.register_builtin_type(
poly(MUL, ty_params.clone()),
mul,
vis.clone(),
Const,
Some(MUL),
);
self.register_builtin_type(
poly(DIV, ty_params.clone()),
div,
vis.clone(),
Const,
Some(DIV),
);
self.register_builtin_type(
poly(FLOOR_DIV, ty_params),
floor_div,
vis.clone(),
Const,
None,
);
self.register_builtin_type(mono(POS), pos, vis.clone(), Const, None);
self.register_builtin_type(mono(NEG), neg, vis, Const, None);
self.register_builtin_type(mono(POS), pos, vis.clone(), Const, Some(POS));
self.register_builtin_type(mono(NEG), neg, vis, Const, Some(NEG));
self.register_const_param_defaults(
ADD,
vec![ConstTemplate::Obj(ValueObj::builtin_type(Slf.clone()))],
Expand Down
131 changes: 131 additions & 0 deletions crates/erg_compiler/lib/core/_erg_traits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from abc import ABC, abstractmethod

from _erg_float import Float

class Eq(ABC):
@abstractmethod
def __eq__(self, other): pass

@classmethod
def __subclasshook__(cls, C):
if cls is Eq:
if any("__eq__" in B.__dict__ for B in C.__mro__):
return True
return NotImplemented

class Ord(ABC):
@abstractmethod
def __lt__(self, other): pass
def __gt__(self, other): pass
def __le__(self, other): pass
def __ge__(self, other): pass

@classmethod
def __subclasshook__(cls, C):
if cls is Ord:
# TODO: adhoc
if C == float or C == Float:
return False
if any("__lt__" in B.__dict__ for B in C.__mro__):
return True
return NotImplemented

class Hash(ABC):
@abstractmethod
def __hash__(self): pass

@classmethod
def __subclasshook__(cls, C):
if cls is Hash:
if any("__hash__" in B.__dict__ for B in C.__mro__):
return True
return NotImplemented

class Sized(ABC):
@abstractmethod
def __len__(self): pass

@classmethod
def __subclasshook__(cls, C):
if cls is Sized:
if any("__len__" in B.__dict__ for B in C.__mro__):
return True
return NotImplemented

class Add(ABC):
Output: type

@abstractmethod
def __add__(self, other): pass

@classmethod
def __subclasshook__(cls, C):
if cls is Add:
if any("__add__" in B.__dict__ for B in C.__mro__):
return True
return NotImplemented

class Sub(ABC):
Output: type

@abstractmethod
def __sub__(self, other): pass

@classmethod
def __subclasshook__(cls, C):
if cls is Sub:
if any("__sub__" in B.__dict__ for B in C.__mro__):
return True
return NotImplemented

class Mul(ABC):
Output: type

@abstractmethod
def __mul__(self, other): pass

@classmethod
def __subclasshook__(cls, C):
if cls is Mul:
if any("__mul__" in B.__dict__ for B in C.__mro__):
return True
return NotImplemented

class Div(ABC):
Output: type

@abstractmethod
def __truediv__(self, other): pass

@classmethod
def __subclasshook__(cls, C):
if cls is Div:
if any("__truediv__" in B.__dict__ for B in C.__mro__):
return True
return NotImplemented

class Pos(ABC):
Output: type

@abstractmethod
def __pos__(self): pass

@classmethod
def __subclasshook__(cls, C):
if cls is Pos:
if any("__pos__" in B.__dict__ for B in C.__mro__):
return True
return NotImplemented

class Neg(ABC):
Output: type

@abstractmethod
def __neg__(self): pass

@classmethod
def __subclasshook__(cls, C):
if cls is Neg:
if any("__neg__" in B.__dict__ for B in C.__mro__):
return True
return NotImplemented

0 comments on commit 5345b07

Please sign in to comment.