Skip to content

Commit

Permalink
feat: support Array, Str, Bytes slice
Browse files Browse the repository at this point in the history
  • Loading branch information
mtshiba committed Jul 15, 2023
1 parent 579615d commit 0152e36
Show file tree
Hide file tree
Showing 14 changed files with 170 additions and 19 deletions.
4 changes: 4 additions & 0 deletions crates/erg_compiler/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2717,6 +2717,10 @@ impl PyCodeGenerator {
self.emit_load_name_instr(Identifier::public("Str"));
}
other => match &other.qual_name()[..] {
"Bytes" => {
self.emit_push_null();
self.emit_load_name_instr(Identifier::public("Bytes"));
}
"Array" => {
self.emit_push_null();
self.emit_load_name_instr(Identifier::public("Array"));
Expand Down
24 changes: 14 additions & 10 deletions crates/erg_compiler/context/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1508,24 +1508,29 @@ impl Context {
}
}

fn convert_type_to_array(&self, ty: Type) -> Result<Vec<ValueObj>, ()> {
fn convert_type_to_array(&self, ty: Type) -> Result<Vec<ValueObj>, Type> {
match ty {
Type::Poly { name, params } if &name[..] == "Array" || &name[..] == "Array!" => {
let t = self
.convert_tp_into_type(params[0].clone())
.map_err(|_| ())?;
let Ok(t) = self.convert_tp_into_type(params[0].clone()) else {
return Err(poly(name, params));
};
let TyParam::Value(ValueObj::Nat(len)) = params[1] else { unreachable!() };
Ok(vec![ValueObj::builtin_type(t); len as usize])
}
_ => Err(()),
_ => Err(ty),
}
}

pub(crate) fn convert_value_into_array(&self, val: ValueObj) -> Result<Vec<ValueObj>, ()> {
pub(crate) fn convert_value_into_array(
&self,
val: ValueObj,
) -> Result<Vec<ValueObj>, ValueObj> {
match val {
ValueObj::Array(arr) => Ok(arr.to_vec()),
ValueObj::Type(t) => self.convert_type_to_array(t.into_typ()),
_ => Err(()),
ValueObj::Type(t) => self
.convert_type_to_array(t.into_typ())
.map_err(ValueObj::builtin_type),
_ => Err(val),
}
}

Expand Down Expand Up @@ -1814,9 +1819,8 @@ impl Context {
})? {
if let Ok(obj) = ty_ctx.get_const_local(&Token::symbol(&attr_name), &self.name) {
if let ValueObj::Subr(subr) = obj {
let is_method = subr.sig_t().self_t().is_some();
let mut pos_args = vec![];
if is_method {
if subr.sig_t().is_method() {
match ValueObj::try_from(lhs) {
Ok(value) => {
pos_args.push(value);
Expand Down
34 changes: 31 additions & 3 deletions crates/erg_compiler/context/initialize/classes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ impl Context {
Immutable,
Visibility::BUILTIN_PUBLIC,
);
let str_getitem_t = fn1_kw_met(Str, kw(KW_IDX, Nat), Str);
let str_getitem_t = fn1_kw_met(Str, kw(KW_IDX, Nat | poly(RANGE, vec![ty_tp(Int)])), Str);
str_.register_builtin_erg_impl(
FUNDAMENTAL_GETITEM,
str_getitem_t,
Expand Down Expand Up @@ -1240,11 +1240,16 @@ impl Context {
Predicate::le(var, N.clone() - value(1usize)),
);
// __getitem__: |T, N|(self: [T; N], _: {I: Nat | I <= N}) -> T
let array_getitem_t = fn1_kw_met(
// and (self: [T; N], _: Range(Int)) -> [T; _]
let array_getitem_t = (fn1_kw_met(
array_t(T.clone(), N.clone()),
anon(input.clone()),
T.clone(),
)
) & fn1_kw_met(
array_t(T.clone(), N.clone()),
anon(poly(RANGE, vec![ty_tp(Int)])),
unknown_len_array_t(T.clone()),
))
.quantify();
let get_item = ValueObj::Subr(ConstSubr::Builtin(BuiltinConstSubr::new(
FUNDAMENTAL_GETITEM,
Expand Down Expand Up @@ -1540,6 +1545,29 @@ impl Context {
Str,
);
bytes.register_py_builtin(FUNC_DECODE, decode_t, Some(FUNC_DECODE), 6);
let bytes_getitem_t = fn1_kw_met(mono(BYTES), kw(KW_IDX, Nat), Int)
& fn1_kw_met(
mono(BYTES),
kw(KW_IDX, poly(RANGE, vec![ty_tp(Int)])),
mono(BYTES),
);
bytes.register_builtin_erg_impl(
FUNDAMENTAL_GETITEM,
bytes_getitem_t,
Immutable,
Visibility::BUILTIN_PUBLIC,
);
bytes
.register_marker_trait(self, poly(INDEXABLE, vec![ty_tp(Nat), ty_tp(Int)]))
.unwrap();
let mut bytes_eq = Self::builtin_methods(Some(mono(EQ)), 2);
bytes_eq.register_builtin_erg_impl(
OP_EQ,
fn1_met(mono(BYTES), mono(BYTES), Bool),
Const,
Visibility::BUILTIN_PUBLIC,
);
bytes.register_trait(mono(BYTES), bytes_eq);
/* GenericTuple */
let mut generic_tuple = Self::builtin_mono_class(GENERIC_TUPLE, 1);
generic_tuple.register_superclass(Obj, &obj);
Expand Down
2 changes: 1 addition & 1 deletion crates/erg_compiler/context/initialize/const_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ pub(crate) fn structural_func(mut args: ValueArgs, ctx: &Context) -> EvalValueRe
pub(crate) fn __array_getitem__(mut args: ValueArgs, ctx: &Context) -> EvalValueResult<ValueObj> {
let slf = ctx
.convert_value_into_array(args.remove_left_or_key("Self").unwrap())
.unwrap();
.unwrap_or_else(|err| panic!("{err}, {args}"));
let index = enum_unwrap!(args.remove_left_or_key("Index").unwrap(), ValueObj::Nat);
if let Some(v) = slf.get(index as usize) {
Ok(v.clone())
Expand Down
19 changes: 16 additions & 3 deletions crates/erg_compiler/context/inquire.rs
Original file line number Diff line number Diff line change
Expand Up @@ -831,10 +831,10 @@ impl Context {
if let Some(attr_name) = attr_name.as_ref() {
let mut vi =
self.search_method_info(obj, attr_name, pos_args, kw_args, input, namespace)?;
vi.t = self.resolve_overload(vi.t, pos_args, kw_args, attr_name)?;
vi.t = self.resolve_overload(obj, vi.t, pos_args, kw_args, attr_name)?;
Ok(vi)
} else {
let t = self.resolve_overload(obj.t(), pos_args, kw_args, obj)?;
let t = self.resolve_overload(obj, obj.t(), pos_args, kw_args, obj)?;
Ok(VarInfo {
t,
..VarInfo::default()
Expand All @@ -844,6 +844,7 @@ impl Context {

fn resolve_overload(
&self,
obj: &hir::Expr,
instance: Type,
pos_args: &[hir::PosArg],
kw_args: &[hir::KwArg],
Expand All @@ -853,7 +854,7 @@ impl Context {
if intersecs.len() == 1 {
Ok(instance)
} else {
let input_t = subr_t(
let mut input_t = subr_t(
SubrKind::Proc,
pos_args
.iter()
Expand All @@ -867,6 +868,18 @@ impl Context {
Obj,
);
for ty in intersecs.iter() {
match (ty.is_method(), input_t.is_method()) {
(true, false) => {
let Type::Subr(sub) = &mut input_t else { unreachable!() };
sub.non_default_params
.insert(0, ParamTy::kw(Str::ever("self"), obj.t()));
}
(false, true) => {
let Type::Subr(sub) = &mut input_t else { unreachable!() };
sub.non_default_params.remove(0);
}
_ => {}
}
if self.subtype_of(ty, &input_t) {
return Ok(ty.clone());
}
Expand Down
20 changes: 19 additions & 1 deletion crates/erg_compiler/context/instantiate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,25 @@ impl Context {
)?;
}
}
_ => unreachable!(),
Type::And(l, r) => {
if let Some(self_t) = l.self_t() {
self.sub_unify(
callee.ref_t(),
self_t,
callee,
Some(&Str::ever("self")),
)?;
}
if let Some(self_t) = r.self_t() {
self.sub_unify(
callee.ref_t(),
self_t,
callee,
Some(&Str::ever("self")),
)?;
}
}
other => unreachable!("{other}"),
}
Ok(t)
}
Expand Down
9 changes: 9 additions & 0 deletions crates/erg_compiler/lib/std/_erg_array.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from _erg_control import then__
from _erg_range import Range

class Array(list):
def dedup(self, same_bucket=None):
Expand All @@ -24,3 +25,11 @@ def partition(self, f):

def __mul__(self, n):
return then__(list.__mul__(self, n), Array)

def __getitem__(self, index_or_slice):
if isinstance(index_or_slice, slice):
return Array(list.__getitem__(self, index_or_slice))
elif isinstance(index_or_slice, Range):
return Array(list.__getitem__(self, index_or_slice.into_slice()))
else:
return list.__getitem__(self, index_or_slice)
9 changes: 9 additions & 0 deletions crates/erg_compiler/lib/std/_erg_bytes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
class Bytes(bytes):
def try_new(*b): # -> Result[Nat]
return Bytes(bytes(*b))

def __getitem__(self, index_or_slice):
from _erg_range import Range
if isinstance(index_or_slice, slice):
return Bytes(bytes.__getitem__(self, index_or_slice))
elif isinstance(index_or_slice, Range):
return Bytes(bytes.__getitem__(self, index_or_slice.into_slice()))
else:
return bytes.__getitem__(self, index_or_slice)
21 changes: 21 additions & 0 deletions crates/erg_compiler/lib/std/_erg_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ def __init__(self, start, end):
def __contains__(self, item):
pass

@staticmethod
def from_slice(slice):
pass

def into_slice(self):
pass

def __getitem__(self, item):
res = self.start + item
if res in self:
Expand Down Expand Up @@ -56,6 +63,13 @@ class RightOpenRange(Range):
def __contains__(self, item):
return self.start <= item < self.end

@staticmethod
def from_slice(slice):
return Range(slice.start, slice.stop)

def into_slice(self):
return slice(self.start, self.end)


# represents `start<..<end`
class OpenRange(Range):
Expand All @@ -68,6 +82,13 @@ class ClosedRange(Range):
def __contains__(self, item):
return self.start <= item <= self.end

@staticmethod
def from_slice(slice):
return Range(slice.start, slice.stop - 1)

def into_slice(self):
return slice(self.start, self.end + 1)


class RangeIterator:
def __init__(self, rng):
Expand Down
10 changes: 9 additions & 1 deletion crates/erg_compiler/lib/std/_erg_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from _erg_int import Int
from _erg_control import then__


class Str(str):
def __instancecheck__(cls, obj):
return isinstance(obj, str)
Expand Down Expand Up @@ -40,6 +39,15 @@ def __mul__(self, other):
def __mod__(self, other):
return then__(str.__mod__(other, self), Str)

def __getitem__(self, index_or_slice):
from _erg_range import Range
if isinstance(index_or_slice, slice):
return Str(str.__getitem__(self, index_or_slice))
elif isinstance(index_or_slice, Range):
return Str(str.__getitem__(self, index_or_slice.into_slice()))
else:
return str.__getitem__(self, index_or_slice)


class StrMut: # Inherits Str
value: Str
Expand Down
13 changes: 13 additions & 0 deletions crates/erg_compiler/ty/const_subr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,19 @@ pub struct ValueArgs {
pub kw_args: Dict<Str, ValueObj>,
}

impl fmt::Display for ValueArgs {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut args = Vec::new();
for arg in &self.pos_args {
args.push(arg.to_string());
}
for (key, arg) in self.kw_args.iter() {
args.push(format!("{key} := {arg}"));
}
write!(f, "({})", args.join(", "))
}
}

impl From<ValueArgs> for Vec<TyParam> {
fn from(args: ValueArgs) -> Self {
// TODO: kw_args
Expand Down
8 changes: 8 additions & 0 deletions crates/erg_compiler/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1717,6 +1717,7 @@ impl Type {
Self::Quantified(t) => t.is_procedure(),
Self::Subr(subr) if subr.kind == SubrKind::Proc => true,
Self::Refinement(refine) => refine.t.is_procedure(),
Self::And(lhs, rhs) => lhs.is_procedure() && rhs.is_procedure(),
_ => false,
}
}
Expand Down Expand Up @@ -1819,6 +1820,7 @@ impl Type {
match self {
Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_refinement(),
Self::Refinement(_) => true,
Self::And(l, r) => l.is_refinement() && r.is_refinement(),
_ => false,
}
}
Expand Down Expand Up @@ -1860,6 +1862,7 @@ impl Type {
Self::Refinement(refine) => refine.t.is_method(),
Self::Subr(subr) => subr.is_method(),
Self::Quantified(quant) => quant.is_method(),
Self::And(l, r) => l.is_method() && r.is_method(),
_ => false,
}
}
Expand Down Expand Up @@ -2271,6 +2274,11 @@ impl Type {
match self {
Type::FreeVar(fv) if fv.is_linked() => fv.crack().intersection_types(),
Type::Refinement(refine) => refine.t.intersection_types(),
Type::Quantified(tys) => tys
.intersection_types()
.into_iter()
.map(|t| t.quantify())
.collect(),
Type::And(t1, t2) => {
let mut types = t1.intersection_types();
types.extend(t2.intersection_types());
Expand Down
11 changes: 11 additions & 0 deletions tests/should_ok/index.er
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
a = [1, 2, 3]
assert a[0] == 1
assert a[1..2] == [2, 3]

s = "abcd"
assert s[0] == "a"
assert s[1..2] == "bc"

b = bytes("abcd", "utf-8")
assert b[0] == 97
assert b[1..2] == bytes("bc", "utf-8")
5 changes: 5 additions & 0 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ fn exec_import_cyclic() -> Result<(), ()> {
expect_success("tests/should_ok/cyclic/import.er", 0)
}

#[test]
fn exec_index() -> Result<(), ()> {
expect_success("tests/should_ok/index.er", 0)
}

#[test]
fn exec_inherit() -> Result<(), ()> {
expect_success("tests/should_ok/inherit.er", 0)
Expand Down

0 comments on commit 0152e36

Please sign in to comment.