Skip to content

Commit fa026e1

Browse files
authored
Synchronizing libasr's template instantiation with LFortran (#1474)
* Modified generic tests to not use built-in functions * Removed TemplateBinOp and built in functions for restrictions * Fixed generic tests * Removed comments * Updated reference tests for generic tests * Renamed instantiated generic functions so it synchronized with LFortran instantiation * Updated reference tests again * Modified generic instantiation's interface to match LFortran
1 parent 2c5034a commit fa026e1

22 files changed

+77
-93
lines changed

integration_tests/generics_01.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
1-
from ltypes import TypeVar, restriction
1+
from ltypes import TypeVar, restriction, i32
22

33
T = TypeVar('T')
44

55
@restriction
66
def add(x: T, y: T) -> T:
77
pass
88

9+
def add_integer(x: i32, y: i32) -> i32:
10+
return x + y
11+
912
def add_string(x: str, y: str) -> str:
1013
return x + y
1114

1215
def f(x: T, y: T, **kwargs) -> T:
1316
return add(x,y)
1417

15-
print(f(1,2))
18+
print(f(1, 2, add=add_integer))
1619
print(f("a","b",add=add_string))
1720
print(f("c","d",add=add_string))

integration_tests/generics_array_01.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ltypes import TypeVar
1+
from ltypes import TypeVar, i32
22
from numpy import empty
33

44
T = TypeVar('T')

integration_tests/generics_array_02.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99
def add(x: T, y: T) -> T:
1010
pass
1111

12-
def g(n: i32, a: T[n], b: T[n]):
12+
def add_integer(x: i32, y: i32) -> i32:
13+
return x + y
14+
15+
def add_float(x: f32, y: f32) -> f32:
16+
return x + y
17+
18+
def g(n: i32, a: T[n], b: T[n], **kwargs):
1319
r: T[n]
1420
r = empty(n)
1521
i: i32
@@ -22,11 +28,11 @@ def main():
2228
a_int[0] = 400
2329
b_int: i32[1] = empty(1)
2430
b_int[0] = 20
25-
g(1, a_int, b_int)
31+
g(1, a_int, b_int, add=add_integer)
2632
a_float: f32[1] = empty(1)
2733
a_float[0] = f32(400.0)
2834
b_float: f32[1] = empty(1)
2935
b_float[0] = f32(20.0)
30-
g(1, a_float ,b_float)
36+
g(1, a_float, b_float, add=add_float)
3137

3238
main()

integration_tests/generics_array_03.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,13 @@
1111
def add(x: T, y: T) -> T:
1212
pass
1313

14-
def g(n: i32, m: i32, a: T[n,m], b: T[n,m]) -> T[n,m]:
14+
def add_integer(x: i32, y: i32) -> i32:
15+
return x + y
16+
17+
def add_float(x: f32, y: f32) -> f32:
18+
return x + y
19+
20+
def g(n: i32, m: i32, a: T[n,m], b: T[n,m], **kwargs) -> T[n,m]:
1521
r: T[n,m]
1622
r = empty([n,m])
1723
i: i32
@@ -26,11 +32,11 @@ def main():
2632
a_int[0,0] = 400
2733
b_int: i32[1,1] = empty([1,1])
2834
b_int[0,0] = 20
29-
g(1, 1, a_int, b_int)
30-
a_float: i32[1,1] = empty([1,1])
31-
a_float[0,0] = 400
32-
b_float: i32[1,1] = empty([1,1])
33-
b_float[0,0] = 20
34-
g(1, 1, a_float, b_float)
35+
g(1, 1, a_int, b_int, add=add_integer)
36+
a_float: f32[1,1] = empty([1,1])
37+
a_float[0,0] = f32(400)
38+
b_float: f32[1,1] = empty([1,1])
39+
b_float[0,0] = f32(20)
40+
g(1, 1, a_float, b_float, add=add_float)
3541

3642
main()

integration_tests/generics_list_01.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,24 @@ def add(x: T, y: T) -> T:
1414
def div(x: T, k: i32) -> f64:
1515
pass
1616

17+
def empty_integer(x: i32) -> i32:
18+
return 0
19+
20+
def add_integer(x: i32, y: i32) -> i32:
21+
return x + y
22+
23+
def div_integer(x: i32, k: i32) -> f64:
24+
return x / k
25+
26+
def empty_float(x: f64) -> f64:
27+
return 0.0
28+
29+
def add_float(x: f64, y: f64) -> f64:
30+
return x + y
31+
32+
def div_float(x: f64, k: i32) -> f64:
33+
return x / k
34+
1735
def empty_string(x: str) -> str:
1836
return ""
1937

@@ -34,6 +52,6 @@ def mean(x: list[T], **kwargs) -> f64:
3452
res = add(res, x[i])
3553
return div(res, k)
3654

37-
print(mean([1,2,3]))
38-
print(mean([1.0,2.0,3.0]))
55+
print(mean([1,2,3], zero=empty_integer, add=add_integer, div=div_integer))
56+
print(mean([1.0,2.0,3.0], zero=empty_float, add=add_float, div=div_float))
3957
print(mean(["a","b","c"], zero=empty_string, add=add_string, div=div_string))

src/libasr/ASR.asdl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,6 @@ expr
249249
| LogicalNot(expr arg, ttype type, expr? value)
250250
| LogicalCompare(expr left, cmpop op, expr right, ttype type, expr? value)
251251
| LogicalBinOp(expr left, logicalbinop op, expr right, ttype type, expr? value)
252-
| TemplateBinOp(expr left, binop op, expr right, ttype type, expr? value)
253252

254253
| ListConstant(expr* args, ttype type)
255254
| ListLen(expr arg, ttype type, expr? value)

src/libasr/pass/instantiate_template.cpp

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,6 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator<FunctionInstanti
204204
return ASR::make_Assignment_t(al, x->base.base.loc, target, value, overloaded);
205205
}
206206

207-
ASR::asr_t* duplicate_TemplateBinOp(ASR::TemplateBinOp_t *x) {
208-
ASR::expr_t *left = duplicate_expr(x->m_left);
209-
ASR::expr_t *right = duplicate_expr(x->m_right);
210-
return make_BinOp_helper(left, right, x->m_op, x->base.base.loc);
211-
}
212-
213207
ASR::asr_t* duplicate_DoLoop(ASR::DoLoop_t *x) {
214208
Vec<ASR::stmt_t*> m_body;
215209
m_body.reserve(al, x->n_body);
@@ -249,39 +243,10 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator<FunctionInstanti
249243
ASR::expr_t* value = duplicate_expr(x->m_value);
250244
ASR::expr_t* dt = duplicate_expr(x->m_dt);
251245
std::string call_name = ASRUtils::symbol_name(x->m_name);
252-
if ((name && ASRUtils::is_restriction_function(name) && rt_subs.find(call_name) == rt_subs.end()) ||
253-
!name) {
254-
if (call_name.compare("add") == 0) {
255-
ASR::expr_t* left_arg = duplicate_expr(x->m_args[0].m_value);
256-
ASR::expr_t* right_arg = duplicate_expr(x->m_args[1].m_value);
257-
ASR::ttype_t* left_type = substitute_type(ASRUtils::expr_type(left_arg));
258-
ASR::ttype_t* right_type = substitute_type(ASRUtils::expr_type(right_arg));
259-
if ((ASRUtils::is_integer(*left_type) && ASRUtils::is_integer(*right_type)) ||
260-
(ASRUtils::is_real(*left_type) && ASRUtils::is_real(*right_type))) {
261-
return make_BinOp_helper(left_arg, right_arg, ASR::binopType::Add, x->base.base.loc);
262-
} else {
263-
throw SemanticError("Intrinsic plus not yet supported for this type", x->base.base.loc);
264-
}
265-
} else if (call_name.compare("zero") == 0) {
266-
ASR::expr_t* arg = duplicate_expr(x->m_args[0].m_value);
267-
ASR::ttype_t* arg_type = substitute_type(ASRUtils::expr_type(arg));
268-
if (ASRUtils::is_integer(*arg_type)) {
269-
return ASR::make_IntegerConstant_t(al, x->base.base.loc, 0, arg_type);
270-
} else if (ASRUtils::is_real(*arg_type)) {
271-
return ASR::make_RealConstant_t(al, x->base.base.loc, 0, arg_type);
272-
}
273-
} else if (call_name.compare("div") == 0) {
274-
ASR::expr_t* left_arg = duplicate_expr(x->m_args[0].m_value);
275-
ASR::expr_t* right_arg = duplicate_expr(x->m_args[1].m_value);
276-
return make_BinOp_helper(left_arg, right_arg, ASR::binopType::Div, x->base.base.loc);
277-
}
278-
LCOMPILERS_ASSERT(false); // should never happen
279-
name = rt_subs[call_name];
280-
}
281246
if (ASRUtils::is_restriction_function(name)) {
282247
name = rt_subs[call_name];
283248
} else if (ASRUtils::is_generic_function(name)) {
284-
std::string nested_func_name = "__lfortran_generic_" + sym_name;
249+
std::string nested_func_name = "__asr_generic_" + sym_name;
285250
ASR::symbol_t* name2 = ASRUtils::symbol_get_past_external(name);
286251
ASR::Function_t* func = ASR::down_cast<ASR::Function_t>(name2);
287252
FunctionInstantiator nested_tf(al, subs, rt_subs, func_scope, nested_func_name);
@@ -425,7 +390,7 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator<FunctionInstanti
425390
};
426391

427392
ASR::symbol_t* pass_instantiate_generic_function(Allocator &al, std::map<std::string, ASR::ttype_t*> subs,
428-
std::map<std::string, ASR::symbol_t*>& rt_subs, SymbolTable *current_scope,
393+
std::map<std::string, ASR::symbol_t*> rt_subs, SymbolTable *current_scope,
429394
std::string new_func_name, ASR::symbol_t *sym) {
430395
ASR::symbol_t* sym2 = ASRUtils::symbol_get_past_external(sym);
431396
ASR::Function_t* func = ASR::down_cast<ASR::Function_t>(sym2);
@@ -434,12 +399,4 @@ ASR::symbol_t* pass_instantiate_generic_function(Allocator &al, std::map<std::st
434399
return ASR::down_cast<ASR::symbol_t>(new_function);
435400
}
436401

437-
ASR::symbol_t* pass_instantiate_generic_function(Allocator &al, std::map<std::string, ASR::ttype_t*> subs,
438-
std::map<std::string, ASR::symbol_t*>& rt_subs, SymbolTable *current_scope,
439-
std::string new_func_name, ASR::Function_t *func) {
440-
FunctionInstantiator tf(al, subs, rt_subs, current_scope, new_func_name);
441-
ASR::asr_t *new_function = tf.instantiate_Function(func);
442-
return ASR::down_cast<ASR::symbol_t>(new_function);
443-
}
444-
445402
} // namespace LCompilers

src/libasr/pass/instantiate_template.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,8 @@ namespace LCompilers {
1111
* is executed here
1212
*/
1313
ASR::symbol_t* pass_instantiate_generic_function(Allocator &al,
14-
std::map<std::string, ASR::ttype_t*> subs, std::map<std::string, ASR::symbol_t*>& rt_subs,
14+
std::map<std::string, ASR::ttype_t*> subs, std::map<std::string, ASR::symbol_t*> rt_subs,
1515
SymbolTable *current_scope, std::string new_func_name, ASR::symbol_t *sym);
16-
17-
ASR::symbol_t* pass_instantiate_generic_function(Allocator &al,
18-
std::map<std::string, ASR::ttype_t*> subs, std::map<std::string, ASR::symbol_t*>& rt_subs,
19-
SymbolTable *current_scope, std::string new_func_name, ASR::Function_t *sym);
2016
}
2117

2218
#endif // LFORTRAN_PASS_TEMPLATE_VISITOR_H

src/libasr/runtime/lfortran_intrinsics.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1304,7 +1304,7 @@ char *get_base_name(char *filename) {
13041304

13051305
#ifdef HAVE_LFORTRAN_LINK
13061306
int shared_lib_callback(struct dl_phdr_info *info,
1307-
size_t /* size */, void *_data) {
1307+
size_t size, void *_data) {
13081308
struct Stacktrace *d = (struct Stacktrace *) _data;
13091309
for (int i = 0; i < info->dlpi_phnum; i++) {
13101310
if (info->dlpi_phdr[i].p_type == PT_LOAD) {

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,8 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
10621062
check_type_restriction(subs, rt_subs, rt, loc);
10631063
}
10641064

1065-
ASR::symbol_t *t = get_generic_function(subs, rt_subs, func);
1065+
//ASR::symbol_t *t = get_generic_function(subs, rt_subs, func);
1066+
ASR::symbol_t *t = get_generic_function(subs, rt_subs, s);
10661067
std::string new_call_name = (ASR::down_cast<ASR::Function_t>(t))->m_name;
10671068

10681069
// Currently ignoring keyword arguments for generic function calls
@@ -1336,14 +1337,14 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
13361337
* arguments. If not, then instantiate a new function.
13371338
*/
13381339
ASR::symbol_t* get_generic_function(std::map<std::string, ASR::ttype_t*> subs,
1339-
std::map<std::string, ASR::symbol_t*>& rt_subs, ASR::Function_t *func) {
1340+
std::map<std::string, ASR::symbol_t*>& rt_subs, ASR::symbol_t *func) {
13401341
int new_function_num;
13411342
ASR::symbol_t *t;
1342-
std::string func_name = func->m_name;
1343+
std::string func_name = ASRUtils::symbol_name(func);
13431344
if (generic_func_nums.find(func_name) != generic_func_nums.end()) {
13441345
new_function_num = generic_func_nums[func_name];
13451346
for (int i=0; i<generic_func_nums[func_name]; i++) {
1346-
std::string generic_func_name = "__lpython_generic_" + func_name + "_" + std::to_string(i);
1347+
std::string generic_func_name = "__asr_generic_" + func_name + "_" + std::to_string(i);
13471348
if (generic_func_subs.find(generic_func_name) != generic_func_subs.end()) {
13481349
std::map<std::string, ASR::ttype_t*> subs_check = generic_func_subs[generic_func_name];
13491350
if (subs_check.size() != subs.size()) { continue; }
@@ -1368,11 +1369,11 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
13681369
new_function_num = 0;
13691370
}
13701371
generic_func_nums[func_name] = new_function_num + 1;
1371-
std::string new_func_name = "__lpython_generic_" + func_name + "_"
1372+
std::string new_func_name = "__asr_generic_" + func_name + "_"
13721373
+ std::to_string(new_function_num);
13731374
generic_func_subs[new_func_name] = subs;
1374-
t = pass_instantiate_generic_function(al, subs, rt_subs, func->m_symtab->parent,
1375-
new_func_name, func);
1375+
t = pass_instantiate_generic_function(al, subs, rt_subs,
1376+
ASRUtils::symbol_parent_symtab(func), new_func_name, func);
13761377
dependencies.erase(func_name);
13771378
dependencies.insert(new_func_name);
13781379
return t;
@@ -2105,8 +2106,6 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
21052106

21062107
tmp = ASR::make_ComplexBinOp_t(al, loc, left, op, right, dest_type, value);
21072108

2108-
} else if (ASRUtils::is_generic(*dest_type)) {
2109-
tmp = ASR::make_TemplateBinOp_t(al, loc, left, op, right, dest_type, value);
21102109
}
21112110

21122111

tests/reference/asr-generics_01-d616074.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
"basename": "asr-generics_01-d616074",
33
"cmd": "lpython --show-asr --no-color {infile} -o {outfile}",
44
"infile": "tests/../integration_tests/generics_01.py",
5-
"infile_hash": "457bb6166206ba1f310ff7879cf0f56a00c50162f6f4fe3f376af6ee",
5+
"infile_hash": "cf91399f0cae8a099eec7af56ef12bc0f52e60599f0971baf5e001b5",
66
"outfile": null,
77
"outfile_hash": null,
88
"stdout": "asr-generics_01-d616074.stdout",
9-
"stdout_hash": "3639281a76fe0e25dc0c22bde4e8393f4c84e82d78ca370da6b50ffe",
9+
"stdout_hash": "83f0370d5dd0478c47e8c976c294dcde303ee66be912941e4f69c41e",
1010
"stderr": null,
1111
"stderr_hash": null,
1212
"returncode": 0

0 commit comments

Comments
 (0)