From 0152e368ab259dd908cb3165c6237dc28bfb1277 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sat, 15 Jul 2023 12:29:07 +0900 Subject: [PATCH] feat: support `Array`, `Str`, `Bytes` slice --- crates/erg_compiler/codegen.rs | 4 +++ crates/erg_compiler/context/eval.rs | 24 +++++++------ .../context/initialize/classes.rs | 34 +++++++++++++++++-- .../context/initialize/const_func.rs | 2 +- crates/erg_compiler/context/inquire.rs | 19 +++++++++-- crates/erg_compiler/context/instantiate.rs | 20 ++++++++++- crates/erg_compiler/lib/std/_erg_array.py | 9 +++++ crates/erg_compiler/lib/std/_erg_bytes.py | 9 +++++ crates/erg_compiler/lib/std/_erg_range.py | 21 ++++++++++++ crates/erg_compiler/lib/std/_erg_str.py | 10 +++++- crates/erg_compiler/ty/const_subr.rs | 13 +++++++ crates/erg_compiler/ty/mod.rs | 8 +++++ tests/should_ok/index.er | 11 ++++++ tests/test.rs | 5 +++ 14 files changed, 170 insertions(+), 19 deletions(-) create mode 100644 tests/should_ok/index.er diff --git a/crates/erg_compiler/codegen.rs b/crates/erg_compiler/codegen.rs index f5e754715..0a0f22093 100644 --- a/crates/erg_compiler/codegen.rs +++ b/crates/erg_compiler/codegen.rs @@ -2717,6 +2717,10 @@ impl PyCodeGenerator { self.emit_load_name_instr(Identifier::public("Str")); } other => match &other.qual_name()[..] { + "Bytes" => { + self.emit_push_null(); + self.emit_load_name_instr(Identifier::public("Bytes")); + } "Array" => { self.emit_push_null(); self.emit_load_name_instr(Identifier::public("Array")); diff --git a/crates/erg_compiler/context/eval.rs b/crates/erg_compiler/context/eval.rs index 7fddf93a0..070434707 100644 --- a/crates/erg_compiler/context/eval.rs +++ b/crates/erg_compiler/context/eval.rs @@ -1508,24 +1508,29 @@ impl Context { } } - fn convert_type_to_array(&self, ty: Type) -> Result, ()> { + fn convert_type_to_array(&self, ty: Type) -> Result, Type> { match ty { Type::Poly { name, params } if &name[..] == "Array" || &name[..] == "Array!" => { - let t = self - .convert_tp_into_type(params[0].clone()) - .map_err(|_| ())?; + let Ok(t) = self.convert_tp_into_type(params[0].clone()) else { + return Err(poly(name, params)); + }; let TyParam::Value(ValueObj::Nat(len)) = params[1] else { unreachable!() }; Ok(vec![ValueObj::builtin_type(t); len as usize]) } - _ => Err(()), + _ => Err(ty), } } - pub(crate) fn convert_value_into_array(&self, val: ValueObj) -> Result, ()> { + pub(crate) fn convert_value_into_array( + &self, + val: ValueObj, + ) -> Result, ValueObj> { match val { ValueObj::Array(arr) => Ok(arr.to_vec()), - ValueObj::Type(t) => self.convert_type_to_array(t.into_typ()), - _ => Err(()), + ValueObj::Type(t) => self + .convert_type_to_array(t.into_typ()) + .map_err(ValueObj::builtin_type), + _ => Err(val), } } @@ -1814,9 +1819,8 @@ impl Context { })? { if let Ok(obj) = ty_ctx.get_const_local(&Token::symbol(&attr_name), &self.name) { if let ValueObj::Subr(subr) = obj { - let is_method = subr.sig_t().self_t().is_some(); let mut pos_args = vec![]; - if is_method { + if subr.sig_t().is_method() { match ValueObj::try_from(lhs) { Ok(value) => { pos_args.push(value); diff --git a/crates/erg_compiler/context/initialize/classes.rs b/crates/erg_compiler/context/initialize/classes.rs index 3f14bade4..727adecbd 100644 --- a/crates/erg_compiler/context/initialize/classes.rs +++ b/crates/erg_compiler/context/initialize/classes.rs @@ -897,7 +897,7 @@ impl Context { Immutable, Visibility::BUILTIN_PUBLIC, ); - let str_getitem_t = fn1_kw_met(Str, kw(KW_IDX, Nat), Str); + let str_getitem_t = fn1_kw_met(Str, kw(KW_IDX, Nat | poly(RANGE, vec![ty_tp(Int)])), Str); str_.register_builtin_erg_impl( FUNDAMENTAL_GETITEM, str_getitem_t, @@ -1240,11 +1240,16 @@ impl Context { Predicate::le(var, N.clone() - value(1usize)), ); // __getitem__: |T, N|(self: [T; N], _: {I: Nat | I <= N}) -> T - let array_getitem_t = fn1_kw_met( + // and (self: [T; N], _: Range(Int)) -> [T; _] + let array_getitem_t = (fn1_kw_met( array_t(T.clone(), N.clone()), anon(input.clone()), T.clone(), - ) + ) & fn1_kw_met( + array_t(T.clone(), N.clone()), + anon(poly(RANGE, vec![ty_tp(Int)])), + unknown_len_array_t(T.clone()), + )) .quantify(); let get_item = ValueObj::Subr(ConstSubr::Builtin(BuiltinConstSubr::new( FUNDAMENTAL_GETITEM, @@ -1540,6 +1545,29 @@ impl Context { Str, ); bytes.register_py_builtin(FUNC_DECODE, decode_t, Some(FUNC_DECODE), 6); + let bytes_getitem_t = fn1_kw_met(mono(BYTES), kw(KW_IDX, Nat), Int) + & fn1_kw_met( + mono(BYTES), + kw(KW_IDX, poly(RANGE, vec![ty_tp(Int)])), + mono(BYTES), + ); + bytes.register_builtin_erg_impl( + FUNDAMENTAL_GETITEM, + bytes_getitem_t, + Immutable, + Visibility::BUILTIN_PUBLIC, + ); + bytes + .register_marker_trait(self, poly(INDEXABLE, vec![ty_tp(Nat), ty_tp(Int)])) + .unwrap(); + let mut bytes_eq = Self::builtin_methods(Some(mono(EQ)), 2); + bytes_eq.register_builtin_erg_impl( + OP_EQ, + fn1_met(mono(BYTES), mono(BYTES), Bool), + Const, + Visibility::BUILTIN_PUBLIC, + ); + bytes.register_trait(mono(BYTES), bytes_eq); /* GenericTuple */ let mut generic_tuple = Self::builtin_mono_class(GENERIC_TUPLE, 1); generic_tuple.register_superclass(Obj, &obj); diff --git a/crates/erg_compiler/context/initialize/const_func.rs b/crates/erg_compiler/context/initialize/const_func.rs index 09ef3716c..a10c708b4 100644 --- a/crates/erg_compiler/context/initialize/const_func.rs +++ b/crates/erg_compiler/context/initialize/const_func.rs @@ -242,7 +242,7 @@ pub(crate) fn structural_func(mut args: ValueArgs, ctx: &Context) -> EvalValueRe pub(crate) fn __array_getitem__(mut args: ValueArgs, ctx: &Context) -> EvalValueResult { let slf = ctx .convert_value_into_array(args.remove_left_or_key("Self").unwrap()) - .unwrap(); + .unwrap_or_else(|err| panic!("{err}, {args}")); let index = enum_unwrap!(args.remove_left_or_key("Index").unwrap(), ValueObj::Nat); if let Some(v) = slf.get(index as usize) { Ok(v.clone()) diff --git a/crates/erg_compiler/context/inquire.rs b/crates/erg_compiler/context/inquire.rs index f46046629..2c3fb9dd1 100644 --- a/crates/erg_compiler/context/inquire.rs +++ b/crates/erg_compiler/context/inquire.rs @@ -831,10 +831,10 @@ impl Context { if let Some(attr_name) = attr_name.as_ref() { let mut vi = self.search_method_info(obj, attr_name, pos_args, kw_args, input, namespace)?; - vi.t = self.resolve_overload(vi.t, pos_args, kw_args, attr_name)?; + vi.t = self.resolve_overload(obj, vi.t, pos_args, kw_args, attr_name)?; Ok(vi) } else { - let t = self.resolve_overload(obj.t(), pos_args, kw_args, obj)?; + let t = self.resolve_overload(obj, obj.t(), pos_args, kw_args, obj)?; Ok(VarInfo { t, ..VarInfo::default() @@ -844,6 +844,7 @@ impl Context { fn resolve_overload( &self, + obj: &hir::Expr, instance: Type, pos_args: &[hir::PosArg], kw_args: &[hir::KwArg], @@ -853,7 +854,7 @@ impl Context { if intersecs.len() == 1 { Ok(instance) } else { - let input_t = subr_t( + let mut input_t = subr_t( SubrKind::Proc, pos_args .iter() @@ -867,6 +868,18 @@ impl Context { Obj, ); for ty in intersecs.iter() { + match (ty.is_method(), input_t.is_method()) { + (true, false) => { + let Type::Subr(sub) = &mut input_t else { unreachable!() }; + sub.non_default_params + .insert(0, ParamTy::kw(Str::ever("self"), obj.t())); + } + (false, true) => { + let Type::Subr(sub) = &mut input_t else { unreachable!() }; + sub.non_default_params.remove(0); + } + _ => {} + } if self.subtype_of(ty, &input_t) { return Ok(ty.clone()); } diff --git a/crates/erg_compiler/context/instantiate.rs b/crates/erg_compiler/context/instantiate.rs index c3fdee650..9f7a5b8e5 100644 --- a/crates/erg_compiler/context/instantiate.rs +++ b/crates/erg_compiler/context/instantiate.rs @@ -602,7 +602,25 @@ impl Context { )?; } } - _ => unreachable!(), + Type::And(l, r) => { + if let Some(self_t) = l.self_t() { + self.sub_unify( + callee.ref_t(), + self_t, + callee, + Some(&Str::ever("self")), + )?; + } + if let Some(self_t) = r.self_t() { + self.sub_unify( + callee.ref_t(), + self_t, + callee, + Some(&Str::ever("self")), + )?; + } + } + other => unreachable!("{other}"), } Ok(t) } diff --git a/crates/erg_compiler/lib/std/_erg_array.py b/crates/erg_compiler/lib/std/_erg_array.py index cde763c39..6ccaa0afc 100644 --- a/crates/erg_compiler/lib/std/_erg_array.py +++ b/crates/erg_compiler/lib/std/_erg_array.py @@ -1,4 +1,5 @@ from _erg_control import then__ +from _erg_range import Range class Array(list): def dedup(self, same_bucket=None): @@ -24,3 +25,11 @@ def partition(self, f): def __mul__(self, n): return then__(list.__mul__(self, n), Array) + + def __getitem__(self, index_or_slice): + if isinstance(index_or_slice, slice): + return Array(list.__getitem__(self, index_or_slice)) + elif isinstance(index_or_slice, Range): + return Array(list.__getitem__(self, index_or_slice.into_slice())) + else: + return list.__getitem__(self, index_or_slice) diff --git a/crates/erg_compiler/lib/std/_erg_bytes.py b/crates/erg_compiler/lib/std/_erg_bytes.py index e6ec1b2c1..76df15ea7 100644 --- a/crates/erg_compiler/lib/std/_erg_bytes.py +++ b/crates/erg_compiler/lib/std/_erg_bytes.py @@ -1,3 +1,12 @@ class Bytes(bytes): def try_new(*b): # -> Result[Nat] return Bytes(bytes(*b)) + + def __getitem__(self, index_or_slice): + from _erg_range import Range + if isinstance(index_or_slice, slice): + return Bytes(bytes.__getitem__(self, index_or_slice)) + elif isinstance(index_or_slice, Range): + return Bytes(bytes.__getitem__(self, index_or_slice.into_slice())) + else: + return bytes.__getitem__(self, index_or_slice) diff --git a/crates/erg_compiler/lib/std/_erg_range.py b/crates/erg_compiler/lib/std/_erg_range.py index 51cfa52cc..56014a4e0 100644 --- a/crates/erg_compiler/lib/std/_erg_range.py +++ b/crates/erg_compiler/lib/std/_erg_range.py @@ -12,6 +12,13 @@ def __init__(self, start, end): def __contains__(self, item): pass + @staticmethod + def from_slice(slice): + pass + + def into_slice(self): + pass + def __getitem__(self, item): res = self.start + item if res in self: @@ -56,6 +63,13 @@ class RightOpenRange(Range): def __contains__(self, item): return self.start <= item < self.end + @staticmethod + def from_slice(slice): + return Range(slice.start, slice.stop) + + def into_slice(self): + return slice(self.start, self.end) + # represents `start<.., } +impl fmt::Display for ValueArgs { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut args = Vec::new(); + for arg in &self.pos_args { + args.push(arg.to_string()); + } + for (key, arg) in self.kw_args.iter() { + args.push(format!("{key} := {arg}")); + } + write!(f, "({})", args.join(", ")) + } +} + impl From for Vec { fn from(args: ValueArgs) -> Self { // TODO: kw_args diff --git a/crates/erg_compiler/ty/mod.rs b/crates/erg_compiler/ty/mod.rs index f325614d7..f6fde7a53 100644 --- a/crates/erg_compiler/ty/mod.rs +++ b/crates/erg_compiler/ty/mod.rs @@ -1717,6 +1717,7 @@ impl Type { Self::Quantified(t) => t.is_procedure(), Self::Subr(subr) if subr.kind == SubrKind::Proc => true, Self::Refinement(refine) => refine.t.is_procedure(), + Self::And(lhs, rhs) => lhs.is_procedure() && rhs.is_procedure(), _ => false, } } @@ -1819,6 +1820,7 @@ impl Type { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_refinement(), Self::Refinement(_) => true, + Self::And(l, r) => l.is_refinement() && r.is_refinement(), _ => false, } } @@ -1860,6 +1862,7 @@ impl Type { Self::Refinement(refine) => refine.t.is_method(), Self::Subr(subr) => subr.is_method(), Self::Quantified(quant) => quant.is_method(), + Self::And(l, r) => l.is_method() && r.is_method(), _ => false, } } @@ -2271,6 +2274,11 @@ impl Type { match self { Type::FreeVar(fv) if fv.is_linked() => fv.crack().intersection_types(), Type::Refinement(refine) => refine.t.intersection_types(), + Type::Quantified(tys) => tys + .intersection_types() + .into_iter() + .map(|t| t.quantify()) + .collect(), Type::And(t1, t2) => { let mut types = t1.intersection_types(); types.extend(t2.intersection_types()); diff --git a/tests/should_ok/index.er b/tests/should_ok/index.er new file mode 100644 index 000000000..2e2adcb27 --- /dev/null +++ b/tests/should_ok/index.er @@ -0,0 +1,11 @@ +a = [1, 2, 3] +assert a[0] == 1 +assert a[1..2] == [2, 3] + +s = "abcd" +assert s[0] == "a" +assert s[1..2] == "bc" + +b = bytes("abcd", "utf-8") +assert b[0] == 97 +assert b[1..2] == bytes("bc", "utf-8") diff --git a/tests/test.rs b/tests/test.rs index 09a2b308e..a28582270 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -123,6 +123,11 @@ fn exec_import_cyclic() -> Result<(), ()> { expect_success("tests/should_ok/cyclic/import.er", 0) } +#[test] +fn exec_index() -> Result<(), ()> { + expect_success("tests/should_ok/index.er", 0) +} + #[test] fn exec_inherit() -> Result<(), ()> { expect_success("tests/should_ok/inherit.er", 0)