Skip to content

Commit

Permalink
fix: fast-value codegen bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mtshiba committed Oct 15, 2023
1 parent 14729c9 commit 2d2c5dc
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 41 deletions.
78 changes: 48 additions & 30 deletions crates/erg_compiler/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use erg_parser::token::EQUAL;
use erg_parser::token::{Token, TokenKind};

use crate::compile::{AccessKind, Name, StoreLoadKind};
use crate::context::ControlKind;
use crate::error::CompileError;
use crate::hir::ArrayWithLength;
use crate::hir::{
Expand All @@ -47,16 +48,27 @@ use crate::varinfo::VarInfo;
use AccessKind::*;
use Type::*;

#[allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ControlKind {
If,
While,
For,
Match,
With,
Discard,
Assert,
#[derive(Debug)]
pub enum RegisterNameKind {
Import,
Fast,
NonFast,
}

use RegisterNameKind::*;

impl RegisterNameKind {
pub const fn is_fast(&self) -> bool {
matches!(self, Fast)
}

pub fn from_ident(ident: &Identifier) -> Self {
if ident.vi.is_fast_value() {
Fast
} else {
NonFast
}
}
}

/// patch method -> function
Expand Down Expand Up @@ -589,7 +601,6 @@ impl PyCodeGenerator {
}

fn local_search(&self, name: &str, _acc_kind: AccessKind) -> Option<Name> {
let current_is_toplevel = self.cur_block() == self.toplevel_block();
if let Some(idx) = self
.cur_block_codeobj()
.names
Expand All @@ -603,9 +614,7 @@ impl PyCodeGenerator {
.iter()
.position(|v| &**v == name)
{
if current_is_toplevel {
Some(Name::local(idx))
} else if self.cur_block().captured_vars.contains(&Str::rc(name)) {
if self.cur_block().captured_vars.contains(&Str::rc(name)) {
Some(Name::deref(idx))
} else {
Some(Name::fast(idx))
Expand Down Expand Up @@ -644,17 +653,22 @@ impl PyCodeGenerator {
Some(StoreLoadKind::Global)
}

fn register_name(&mut self, name: Str) -> Name {
fn register_name(&mut self, name: Str, kind: RegisterNameKind) -> Name {
let current_is_toplevel = self.cur_block() == self.toplevel_block();
match self.rec_search(&name) {
Some(st @ (StoreLoadKind::Local | StoreLoadKind::Global)) => {
let st = if current_is_toplevel {
StoreLoadKind::Local
if kind.is_fast() {
self.mut_cur_block_codeobj().varnames.push(name);
Name::fast(self.cur_block_codeobj().varnames.len() - 1)
} else {
st
};
self.mut_cur_block_codeobj().names.push(name);
Name::new(st, self.cur_block_codeobj().names.len() - 1)
let st = if current_is_toplevel {
StoreLoadKind::Local
} else {
st
};
self.mut_cur_block_codeobj().names.push(name);
Name::new(st, self.cur_block_codeobj().names.len() - 1)
}
}
Some(StoreLoadKind::Deref) => {
self.mut_cur_block_codeobj().freevars.push(name.clone());
Expand Down Expand Up @@ -743,6 +757,7 @@ impl PyCodeGenerator {
self.load_module_type();
self.module_type_loaded = true;
}
let kind = RegisterNameKind::from_ident(&ident);
let escaped = escape_ident(ident);
match &escaped[..] {
"if__" | "for__" | "while__" | "with__" | "discard__" => {
Expand All @@ -769,7 +784,7 @@ impl PyCodeGenerator {
}
let name = self
.local_search(&escaped, Name)
.unwrap_or_else(|| self.register_name(escaped));
.unwrap_or_else(|| self.register_name(escaped, kind));
let instr = self.select_load_instr(name.kind, Name);
self.write_instr(instr);
self.write_arg(name.idx);
Expand All @@ -785,7 +800,7 @@ impl PyCodeGenerator {
let escaped = escape_ident(ident);
let name = self
.local_search(&escaped, Name)
.unwrap_or_else(|| self.register_name(escaped));
.unwrap_or_else(|| self.register_name(escaped, NonFast));
let instr = LOAD_GLOBAL;
self.write_instr(instr);
self.write_arg(name.idx);
Expand All @@ -797,7 +812,7 @@ impl PyCodeGenerator {
let escaped = escape_ident(ident);
let name = self
.local_search(&escaped, Name)
.unwrap_or_else(|| self.register_name(escaped));
.unwrap_or_else(|| self.register_name(escaped, Import));
self.write_instr(IMPORT_NAME);
self.write_arg(name.idx);
self.stack_inc_n(items_len);
Expand All @@ -809,7 +824,7 @@ impl PyCodeGenerator {
let escaped = escape_ident(ident);
let name = self
.local_search(&escaped, Name)
.unwrap_or_else(|| self.register_name(escaped));
.unwrap_or_else(|| self.register_name(escaped, Import));
self.write_instr(IMPORT_FROM);
self.write_arg(name.idx);
// self.stack_inc(); (module object) -> attribute
Expand All @@ -822,7 +837,7 @@ impl PyCodeGenerator {
let escaped = escape_ident(ident);
let name = self
.local_search(&escaped, Name)
.unwrap_or_else(|| self.register_name(escaped));
.unwrap_or_else(|| self.register_name(escaped, Import));
self.write_instr(IMPORT_NAME);
self.write_arg(name.idx);
self.stack_inc();
Expand Down Expand Up @@ -891,10 +906,11 @@ impl PyCodeGenerator {

fn emit_store_instr(&mut self, ident: Identifier, acc_kind: AccessKind) {
log!(info "entered {} ({ident})", fn_name!());
let kind = RegisterNameKind::from_ident(&ident);
let escaped = escape_ident(ident);
let name = self.local_search(&escaped, acc_kind).unwrap_or_else(|| {
if acc_kind.is_local() {
self.register_name(escaped)
self.register_name(escaped, kind)
} else {
self.register_attr(escaped)
}
Expand All @@ -908,6 +924,8 @@ impl PyCodeGenerator {
self.write_bytes(&[0; 8]);
}
self.stack_dec();
} else if instr == STORE_FAST as u8 {
self.mut_cur_block_codeobj().nlocals += 1;
}
}

Expand All @@ -918,7 +936,7 @@ impl PyCodeGenerator {
let escaped = escape_ident(ident);
let name = self
.local_search(&escaped, Name)
.unwrap_or_else(|| self.register_name(escaped));
.unwrap_or_else(|| self.register_name(escaped, NonFast));
let instr = STORE_GLOBAL;
self.write_instr(instr);
self.write_arg(name.idx);
Expand Down Expand Up @@ -1905,7 +1923,7 @@ impl PyCodeGenerator {
let escaped = escape_ident(ident);
let name = self
.local_search(&escaped, Name)
.unwrap_or_else(|| self.register_name(escaped));
.unwrap_or_else(|| self.register_name(escaped, Fast));
self.write_instr(DELETE_NAME);
self.write_arg(name.idx);
self.emit_load_const(ValueObj::None);
Expand Down Expand Up @@ -2990,7 +3008,7 @@ impl PyCodeGenerator {
);
let name = self
.local_search(&full_name, Name)
.unwrap_or_else(|| self.register_name(full_name));
.unwrap_or_else(|| self.register_name(full_name, Import));
self.write_instr(IMPORT_NAME);
self.write_arg(name.idx);
let root = Self::get_root(&acc);
Expand Down
47 changes: 43 additions & 4 deletions crates/erg_compiler/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,25 @@ pub enum ControlKind {
Match,
Try,
With,
Discard,
Assert,
}

impl fmt::Display for ControlKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::If => write!(f, "if"),
Self::While => write!(f, "while"),
Self::For => write!(f, "for"),
Self::Match => write!(f, "match"),
Self::Try => write!(f, "try"),
Self::With => write!(f, "with"),
Self::Discard => write!(f, "discard"),
Self::Assert => write!(f, "assert"),
}
}
}

impl TryFrom<&str> for ControlKind {
type Error = ();
fn try_from(s: &str) -> Result<Self, Self::Error> {
Expand All @@ -80,6 +96,7 @@ impl TryFrom<&str> for ControlKind {
"match" | "match!" => Ok(ControlKind::Match),
"try" | "try!" => Ok(ControlKind::Try),
"with" | "with!" => Ok(ControlKind::With),
"discard" => Ok(ControlKind::Discard),
"assert" => Ok(ControlKind::Assert),
_ => Err(()),
}
Expand Down Expand Up @@ -311,6 +328,8 @@ impl From<&ParamSpec> for ParamTy {

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ContextKind {
LambdaFunc(Option<ControlKind>),
LambdaProc(Option<ControlKind>),
Func,
Proc,
Class,
Expand Down Expand Up @@ -354,6 +373,8 @@ impl From<&Def> for ContextKind {
impl fmt::Display for ContextKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::LambdaFunc(kind) => write!(f, "LambdaFunc({})", fmt_option!(kind)),
Self::LambdaProc(kind) => write!(f, "LambdaProc({})", fmt_option!(kind)),
Self::Func => write!(f, "Func"),
Self::Proc => write!(f, "Proc"),
Self::Class => write!(f, "Class"),
Expand Down Expand Up @@ -386,7 +407,10 @@ impl ContextKind {
}

pub const fn is_subr(&self) -> bool {
matches!(self, Self::Func | Self::Proc)
matches!(
self,
Self::Func | Self::Proc | Self::LambdaFunc(_) | Self::LambdaProc(_)
)
}

pub const fn is_class(&self) -> bool {
Expand All @@ -404,6 +428,17 @@ impl ContextKind {
pub const fn is_module(&self) -> bool {
matches!(self, Self::Module)
}

pub const fn is_instant(&self) -> bool {
matches!(self, Self::Instant)
}

pub const fn control_kind(&self) -> Option<ControlKind> {
match self {
Self::LambdaFunc(kind) | Self::LambdaProc(kind) => *kind,
_ => None,
}
}
}

/// Indicates the mode registered in the Context
Expand Down Expand Up @@ -1129,6 +1164,7 @@ impl Context {
};
self.cfg = self.get_outer().unwrap().cfg.clone();
self.shared = self.get_outer().unwrap().shared.clone();
self.higher_order_caller = self.get_outer().unwrap().higher_order_caller.clone();
self.tv_cache = tv_cache;
self.name = name.into();
self.kind = kind;
Expand Down Expand Up @@ -1266,9 +1302,12 @@ impl Context {
}

pub fn control_kind(&self) -> Option<ControlKind> {
self.higher_order_caller
.last()
.and_then(|caller| ControlKind::try_from(&caller[..]).ok())
for caller in self.higher_order_caller.iter().rev() {
if let Ok(control) = ControlKind::try_from(&caller[..]) {
return Some(control);
}
}
None
}

pub(crate) fn check_types(&self) {
Expand Down
23 changes: 16 additions & 7 deletions crates/erg_compiler/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1141,9 +1141,12 @@ impl ASTLowerer {
expect: Option<&Type>,
) -> LowerResult<hir::Call> {
log!(info "entered {}({}{}(...))", fn_name!(), call.obj, fmt_option!(call.attr_name));
if let (Some(name), None) = (call.obj.get_name(), &call.attr_name) {
let pushed = if let (Some(name), None) = (call.obj.get_name(), &call.attr_name) {
self.module.context.higher_order_caller.push(name.clone());
}
true
} else {
false
};
let mut errs = LowerErrors::empty();
let guard = if let (
ast::Expr::Accessor(ast::Accessor::Ident(ident)),
Expand All @@ -1163,7 +1166,9 @@ impl ASTLowerer {
let mut obj = match self.lower_expr(*call.obj, None) {
Ok(obj) => obj,
Err(es) => {
self.module.context.higher_order_caller.pop();
if pushed {
self.module.context.higher_order_caller.pop();
}
errs.extend(es);
return Err(errs);
}
Expand All @@ -1190,7 +1195,9 @@ impl ASTLowerer {
) {
Ok(vi) => vi,
Err((vi, es)) => {
self.module.context.higher_order_caller.pop();
if pushed {
self.module.context.higher_order_caller.pop();
}
errs.extend(es);
vi.unwrap_or(VarInfo::ILLEGAL)
}
Expand Down Expand Up @@ -1231,7 +1238,9 @@ impl ASTLowerer {
self.errs.extend(errs);
}
}*/
self.module.context.higher_order_caller.pop();
if pushed {
self.module.context.higher_order_caller.pop();
}
if errs.is_empty() {
self.exec_additional_op(&mut call)?;
}
Expand Down Expand Up @@ -1454,9 +1463,9 @@ impl ASTLowerer {
let id = lambda.id.0;
let name = format!("<lambda_{id}>");
let kind = if is_procedural {
ContextKind::Proc
ContextKind::LambdaProc(self.module.context.control_kind())
} else {
ContextKind::Func
ContextKind::LambdaFunc(self.module.context.control_kind())
};
let tv_cache = self
.module
Expand Down
7 changes: 7 additions & 0 deletions crates/erg_compiler/varinfo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,4 +448,11 @@ impl VarInfo {
pub fn is_toplevel(&self) -> bool {
self.vis.def_namespace.split_with(&[".", "::"]).len() == 1
}

pub fn is_fast_value(&self) -> bool {
!self.is_toplevel()
&& !self.is_parameter()
&& self.ctx.control_kind().is_none()
&& (self.ctx.is_subr() || self.ctx.is_instant())
}
}
16 changes: 16 additions & 0 deletions tests/should_ok/fast_value.er
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#[
Python's `STORE_NAME` was changed incompatibly in CPython 3.10.
`STORE_NAME` was originally an instruction to branch to `STORE_FAST` or `STORE_GLOBAL` depending on the context,
but as of 3.10, it now uses `STORE_GLOBAL` even when the value should clearly be fast.
Currently the Erg code generator has already addressed this issue,
but we added a test to ensure that future changes will not cause the same issue again.
]#
f(x: Int): Int =
y = x
if x == 0, do:
f::return 0
_ = f x - 1
y

# if `y` is global, here `f(1)` will return 0
assert f(1) == 1
Loading

0 comments on commit 2d2c5dc

Please sign in to comment.