From 83ce37fc850d1eab25d21ab51636dc4695b59dde Mon Sep 17 00:00:00 2001 From: swamishiju Date: Wed, 26 Mar 2025 13:29:01 +0530 Subject: [PATCH 1/4] Added Bytes type --- src/libasr/ASR.asdl | 2 ++ src/libasr/asdl_cpp.py | 3 ++- src/libasr/asr_utils.h | 11 ++++++++- src/libasr/codegen/asr_to_llvm.cpp | 26 ++++++++++++++++++++- src/lpython/semantics/python_ast_to_asr.cpp | 13 ++++++++++- 5 files changed, 51 insertions(+), 4 deletions(-) diff --git a/src/libasr/ASR.asdl b/src/libasr/ASR.asdl index 5ea9a482b5..9b8da38482 100644 --- a/src/libasr/ASR.asdl +++ b/src/libasr/ASR.asdl @@ -140,6 +140,7 @@ expr | StringOrd(expr arg, ttype type, expr? value) | StringChr(expr arg, ttype type, expr? value) | StringFormat(expr fmt, expr* args, string_format_kind kind, ttype type, expr? value) + | BytesConstant(string s, ttype type) | CPtrCompare(expr left, cmpop op, expr right, ttype type, expr? value) | SymbolicCompare(expr left, cmpop op, expr right, ttype type, expr? value) | DictConstant(expr* keys, expr* values, ttype type) @@ -198,6 +199,7 @@ ttype | Real(int kind) | Complex(int kind) | Character(int kind, int len, expr? len_expr) + | Byte(int kind, int len, expr? len_expr) | Logical(int kind) | Set(ttype type) | List(ttype type) diff --git a/src/libasr/asdl_cpp.py b/src/libasr/asdl_cpp.py index 9463cb9d1f..05dcde949f 100644 --- a/src/libasr/asdl_cpp.py +++ b/src/libasr/asdl_cpp.py @@ -2,8 +2,9 @@ Generate C++ AST node definitions from an ASDL description. """ -import sys import os +import sys + import asdl diff --git a/src/libasr/asr_utils.h b/src/libasr/asr_utils.h index 42d64c4c1c..9035867293 100644 --- a/src/libasr/asr_utils.h +++ b/src/libasr/asr_utils.h @@ -990,7 +990,8 @@ static inline bool is_value_constant(ASR::expr_t *a_value) { case ASR::exprType::ImpliedDoLoop: case ASR::exprType::PointerNullConstant: case ASR::exprType::ArrayConstant: - case ASR::exprType::StringConstant: { + case ASR::exprType::StringConstant: + case ASR::exprType::BytesConstant: { return true; } case ASR::exprType::RealBinOp: @@ -1608,6 +1609,9 @@ static inline std::string type_to_str_python(const ASR::ttype_t *t, case ASR::ttypeType::Character: { return "str"; } + case ASR::ttypeType::Byte: { + return "bytes"; + } case ASR::ttypeType::Tuple: { ASR::Tuple_t *tup = ASR::down_cast(t); std::string result = "tuple["; @@ -3123,6 +3127,11 @@ inline bool types_equal(ASR::ttype_t *a, ASR::ttype_t *b, ASR::Character_t *b2 = ASR::down_cast(b); return (a2->m_kind == b2->m_kind); } + case (ASR::ttypeType::Byte) : { + ASR::Byte_t *a2 = ASR::down_cast(a); + ASR::Byte_t *b2 = ASR::down_cast(b); + return (a2->m_kind == b2->m_kind); + } case (ASR::ttypeType::List) : { ASR::List_t *a2 = ASR::down_cast(a); ASR::List_t *b2 = ASR::down_cast(b); diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index ec8a8b0205..4022527f36 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -143,7 +143,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor bool prototype_only; llvm::StructType *complex_type_4, *complex_type_8; llvm::StructType *complex_type_4_ptr, *complex_type_8_ptr; - llvm::PointerType *character_type; + llvm::PointerType *character_type, *byte_type; llvm::PointerType *list_type; std::vector struct_type_stack; @@ -910,6 +910,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor complex_type_4_ptr = llvm_utils->complex_type_4_ptr; complex_type_8_ptr = llvm_utils->complex_type_8_ptr; character_type = llvm_utils->character_type; + byte_type = llvm_utils->character_type; list_type = llvm::Type::getInt8PtrTy(context); llvm::Type* bound_arg = static_cast(arr_descr->get_dimension_descriptor_type(true)); @@ -2879,6 +2880,25 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } } llvm_symtab[h] = ptr; + } else if (x.m_type->type == ASR::ttypeType::Byte) { + llvm::Constant *ptr = module->getOrInsertGlobal(x.m_name, + character_type); + if (!external) { + if (init_value) { + module->getNamedGlobal(x.m_name)->setInitializer( + init_value); + } else { + module->getNamedGlobal(x.m_name)->setInitializer( + llvm::Constant::getNullValue(character_type) + ); + ASR::Byte_t *t = down_cast(x.m_type); + if( t->m_len >= 0 ) { + strings_to_be_allocated.insert(std::pair(ptr, llvm::ConstantInt::get( + context, llvm::APInt(32, t->m_len+1)))); + } + } + } + llvm_symtab[h] = ptr; } else if( x.m_type->type == ASR::ttypeType::CPtr ) { llvm::Type* void_ptr = llvm::Type::getVoidTy(context)->getPointerTo(); llvm::Constant *ptr = module->getOrInsertGlobal(x.m_name, @@ -7072,6 +7092,10 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor tmp = builder->CreateGlobalStringPtr(x.m_s); } + void visit_BytesConstant(const ASR::BytesConstant_t &x) { + tmp = builder->CreateGlobalStringPtr(x.m_s); + } + inline void fetch_ptr(ASR::Variable_t* x) { uint32_t x_h = get_hash((ASR::asr_t*)x); LCOMPILERS_ASSERT(llvm_symtab.find(x_h) != llvm_symtab.end()); diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index fac917eaf4..2ef2a14551 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -886,6 +886,9 @@ class CommonVisitor : public AST::BaseVisitor { } else if (var_annotation == "c64") { type = ASRUtils::TYPE(ASR::make_Complex_t(al, loc, 8)); type = ASRUtils::make_Array_t_util(al, loc, type, dims.p, dims.size(), abi, is_argument); + } else if (var_annotation == "bytes") { + type = ASRUtils::TYPE(ASR::make_Byte_t(al, loc, 1, -2, nullptr)); + type = ASRUtils::make_Array_t_util(al, loc, type, dims.p, dims.size(), abi, is_argument); } else if (var_annotation == "str") { type = ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, -2, nullptr)); type = ASRUtils::make_Array_t_util(al, loc, type, dims.p, dims.size(), abi, is_argument); @@ -2899,6 +2902,7 @@ class CommonVisitor : public AST::BaseVisitor { } else { type = ast_expr_to_asr_type(x.m_annotation->base.loc, *x.m_annotation, is_allocatable, is_const, true, abi); } + if (ASR::is_a(*type)) { ASR::FunctionType_t* fn_type = ASR::down_cast(type); handle_lambda_function_declaration(var_name, fn_type, x.m_value, x.base.base.loc); @@ -2956,6 +2960,7 @@ class CommonVisitor : public AST::BaseVisitor { } else { cast_helper(type, init_expr, init_expr->base.loc); } + if (!inside_struct || is_const) { process_variable_init_val(current_scope->get_symbol(var_name), x.base.base.loc, init_expr); @@ -3567,6 +3572,13 @@ class CommonVisitor : public AST::BaseVisitor { 1, s_size, nullptr)); tmp = ASR::make_StringConstant_t(al, x.base.base.loc, s, type); } + void visit_ConstantBytes(const AST::ConstantBytes_t &x) { + char *s = x.m_value; + size_t s_size = std::string(s).size(); + ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_Byte_t(al, x.base.base.loc, + 1, s_size, nullptr)); + tmp = ASR::make_BytesConstant_t(al, x.base.base.loc, s, type); + } void visit_ConstantBool(const AST::ConstantBool_t &x) { bool b = x.m_value; @@ -9121,7 +9133,6 @@ Result python_ast_to_asr(Allocator &al, LocationManager }; #endif } - return tu; } From c731a3e1ef1af0cf7c4a707c2c6143ca0e95db20 Mon Sep 17 00:00:00 2001 From: swamishiju Date: Wed, 26 Mar 2025 17:02:59 +0530 Subject: [PATCH 2/4] Byte variables in functions --- src/libasr/asr_utils.h | 27 +++++++++++++++++ src/libasr/codegen/asr_to_llvm.cpp | 47 ++++++++++++++++++++++++++++++ src/libasr/codegen/llvm_utils.cpp | 6 ++++ 3 files changed, 80 insertions(+) diff --git a/src/libasr/asr_utils.h b/src/libasr/asr_utils.h index 9035867293..ea83f7c9e9 100644 --- a/src/libasr/asr_utils.h +++ b/src/libasr/asr_utils.h @@ -207,6 +207,9 @@ static inline int extract_kind_from_ttype_t(const ASR::ttype_t* type) { case ASR::ttypeType::Character: { return ASR::down_cast(type)->m_kind; } + case ASR::ttypeType::Byte: { + return ASR::down_cast(type)->m_kind; + } case ASR::ttypeType::Logical: { return ASR::down_cast(type)->m_kind; } @@ -542,6 +545,9 @@ static inline std::string type_to_str(const ASR::ttype_t *t) case ASR::ttypeType::Character: { return "character"; } + case ASR::ttypeType::Byte: { + return "byte"; + } case ASR::ttypeType::Tuple: { return "tuple"; } @@ -1422,6 +1428,9 @@ static inline std::string get_type_code(const ASR::ttype_t *t, bool use_undersco case ASR::ttypeType::Character: { return "str"; } + case ASR::ttypeType::Byte: { + return "bytes"; + } case ASR::ttypeType::Tuple: { ASR::Tuple_t *tup = ASR::down_cast(t); std::string result = "tuple"; @@ -2152,6 +2161,7 @@ inline size_t extract_dimensions_from_ttype(ASR::ttype_t *x, case ASR::ttypeType::Real: case ASR::ttypeType::Complex: case ASR::ttypeType::Character: + case ASR::ttypeType::Byte: case ASR::ttypeType::Logical: case ASR::ttypeType::StructType: case ASR::ttypeType::Enum: @@ -2423,6 +2433,7 @@ inline bool ttype_set_dimensions(ASR::ttype_t** x, case ASR::ttypeType::Real: case ASR::ttypeType::Complex: case ASR::ttypeType::Character: + case ASR::ttypeType::Byte: case ASR::ttypeType::Logical: case ASR::ttypeType::StructType: case ASR::ttypeType::Enum: @@ -2544,6 +2555,12 @@ static inline ASR::ttype_t* duplicate_type(Allocator& al, const ASR::ttype_t* t, tnew->m_kind, tnew->m_len, tnew->m_len_expr)); break; } + case ASR::ttypeType::Byte: { + ASR::Byte_t* tnew = ASR::down_cast(t); + t_ = ASRUtils::TYPE(ASR::make_Byte_t(al, t->base.loc, + tnew->m_kind, tnew->m_len, tnew->m_len_expr)); + break; + } case ASR::ttypeType::StructType: { ASR::StructType_t* tnew = ASR::down_cast(t); t_ = ASRUtils::TYPE(ASR::make_StructType_t(al, t->base.loc, @@ -2700,6 +2717,11 @@ static inline ASR::ttype_t* duplicate_type_without_dims(Allocator& al, const ASR return ASRUtils::TYPE(ASR::make_Character_t(al, loc, tnew->m_kind, tnew->m_len, tnew->m_len_expr)); } + case ASR::ttypeType::Byte: { + ASR::Byte_t* tnew = ASR::down_cast(t); + return ASRUtils::TYPE(ASR::make_Byte_t(al, loc, + tnew->m_kind, tnew->m_len, tnew->m_len_expr)); + } case ASR::ttypeType::StructType: { ASR::StructType_t* tstruct = ASR::down_cast(t); return ASRUtils::TYPE(ASR::make_StructType_t(al, t->base.loc, @@ -3315,6 +3337,11 @@ inline bool types_equal_with_substitution(ASR::ttype_t *a, ASR::ttype_t *b, ASR::Character_t *b2 = ASR::down_cast(b); return (a2->m_kind == b2->m_kind); } + case (ASR::ttypeType::Byte) : { + ASR::Byte_t *a2 = ASR::down_cast(a); + ASR::Byte_t *b2 = ASR::down_cast(b); + return (a2->m_kind == b2->m_kind); + } case (ASR::ttypeType::List) : { ASR::List_t *a2 = ASR::down_cast(a); ASR::List_t *b2 = ASR::down_cast(b); diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 4022527f36..5b45424983 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -82,6 +82,23 @@ void string_init(llvm::LLVMContext &context, llvm::Module &module, builder.CreateCall(fn, args); } +void bytes_init(llvm::LLVMContext &context, llvm::Module &module, + llvm::IRBuilder<> &builder, llvm::Value* arg_size, llvm::Value* arg_bytes) { + std::string func_name = "_lfortran_bytes_init"; + llvm::Function *fn = module.getFunction(func_name); + if (!fn) { + llvm::FunctionType *function_type = llvm::FunctionType::get( + llvm::Type::getVoidTy(context), { + llvm::Type::getInt32Ty(context), + llvm::Type::getInt8PtrTy(context) + }, true); + fn = llvm::Function::Create(function_type, + llvm::Function::ExternalLinkage, func_name, module); + } + std::vector args = {arg_size, arg_bytes}; + builder.CreateCall(fn, args); +} + class ASRToLLVMVisitor : public ASR::BaseVisitor { private: @@ -3909,6 +3926,36 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } else { throw CodeGenError("Unsupported len value in ASR " + std::to_string(strlen)); } + } else if (is_a(*v->m_type) && !is_array_type && !is_list) { + ASR::Byte_t *t = down_cast(v->m_type); + target_var = ptr; + int byte_len = t->m_len; + if (byte_len >= 0 || byte_len == -3) { + llvm::Value *arg_size; + if (byte_len == -3) { + LCOMPILERS_ASSERT(t->m_len_expr) + this->visit_expr(*t->m_len_expr); + arg_size = builder->CreateAdd(builder->CreateSExtOrTrunc(tmp, + llvm::Type::getInt32Ty(context)), + llvm::ConstantInt::get(context, llvm::APInt(32, 1)) ); + } else { + // Compile time length + arg_size = llvm::ConstantInt::get(context, + llvm::APInt(32, byte_len+1)); + } + llvm::Value *init_value = LLVM::lfortran_malloc(context, *module, *builder, arg_size); + string_init(context, *module, *builder, arg_size, init_value); + builder->CreateStore(init_value, target_var); + if (v->m_intent == intent_local) { + strings_to_be_deallocated.push_back(al, CreateLoad(target_var)); + } + } else if (byte_len == -2) { + // Allocatable string. Initialize to `nullptr` (unallocated) + llvm::Value *init_value = llvm::Constant::getNullValue(type); + builder->CreateStore(init_value, target_var); + } else { + throw CodeGenError("Unsupported bytes len value in ASR " + std::to_string(byte_len)); + } } else if (is_list) { ASR::List_t* asr_list = ASR::down_cast(v->m_type); std::string type_code = ASRUtils::get_type_code(asr_list->m_type); diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index dde91aa6d9..220beb01c5 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -1422,6 +1422,12 @@ namespace LCompilers { llvm_type = character_type; break; } + case (ASR::ttypeType::Byte) : { + ASR::Byte_t* v_type = ASR::down_cast(asr_type); + a_kind = v_type->m_kind; + llvm_type = character_type; + break; + } case (ASR::ttypeType::Logical) : { ASR::Logical_t* v_type = ASR::down_cast(asr_type); a_kind = v_type->m_kind; From ed168a17e25427a159d2e729a5182c4f8918604e Mon Sep 17 00:00:00 2001 From: swamishiju Date: Wed, 26 Mar 2025 21:14:34 +0530 Subject: [PATCH 3/4] Implemented bytes variables with functions --- src/libasr/asr_utils.h | 4 + src/libasr/codegen/asr_to_llvm.cpp | 19 +- src/libasr/codegen/llvm_utils.cpp | 265 ++++++++++++++++++++ src/libasr/pass/global_stmts.cpp | 1 + src/lpython/semantics/python_ast_to_asr.cpp | 11 + 5 files changed, 299 insertions(+), 1 deletion(-) diff --git a/src/libasr/asr_utils.h b/src/libasr/asr_utils.h index ea83f7c9e9..dae63f6833 100644 --- a/src/libasr/asr_utils.h +++ b/src/libasr/asr_utils.h @@ -254,6 +254,10 @@ static inline void set_kind_to_ttype_t(ASR::ttype_t* type, int kind) { ASR::down_cast(type)->m_kind = kind; break; } + case ASR::ttypeType::Byte: { + ASR::down_cast(type)->m_kind = kind; + break; + } case ASR::ttypeType::Logical: { ASR::down_cast(type)->m_kind = kind; break; diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 5b45424983..cf139cf906 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -966,7 +966,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor prototype_only = false; // TODO: handle dependencies across modules and main program - +; // Then do all the modules in the right order std::vector build_order = determine_module_dependencies(x); @@ -974,6 +974,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor LCOMPILERS_ASSERT(x.m_symtab->get_symbol(item) != nullptr); ASR::symbol_t *mod = x.m_symtab->get_symbol(item); + std::cout << mod->type << "unit"; visit_symbol(*mod); } @@ -7199,6 +7200,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor case ASR::ttypeType::Complex: case ASR::ttypeType::StructType: case ASR::ttypeType::Character: + case ASR::ttypeType::Byte: case ASR::ttypeType::Logical: case ASR::ttypeType::Class: { if( t2->type == ASR::ttypeType::StructType ) { @@ -8919,6 +8921,21 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor target_type = character_type; break; } + case (ASR::ttypeType::Byte) : { + ASR::Variable_t *orig_arg = nullptr; + if( func_subrout->type == ASR::symbolType::Function ) { + ASR::Function_t* func = down_cast(func_subrout); + orig_arg = ASRUtils::EXPR2VAR(func->m_args[i]); + } else { + throw CodeGenError("ICE: expected func_subrout->type == ASR::symbolType::Function."); + } + if (orig_arg->m_abi == ASR::abiType::BindC) { + character_bindc = true; + } + + target_type = character_type; + break; + } case (ASR::ttypeType::Logical) : target_type = llvm::Type::getInt1Ty(context); break; diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 220beb01c5..f2813dc4e2 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -192,6 +192,10 @@ namespace LCompilers { llvm_mem_type = character_type; break; } + case ASR::ttypeType::Byte: { + llvm_mem_type = character_type; + break; + } case ASR::ttypeType::CPtr: { llvm_mem_type = llvm::Type::getVoidTy(context)->getPointerTo(); break; @@ -509,6 +513,10 @@ namespace LCompilers { el_type = character_type; break; } + case ASR::ttypeType::Byte: { + el_type = character_type; + break; + } default: LCOMPILERS_ASSERT(false); break; @@ -753,6 +761,16 @@ namespace LCompilers { } break; } + case (ASR::ttypeType::Byte) : { + ASR::Byte_t* v_type = ASR::down_cast(asr_type); + a_kind = v_type->m_kind; + if (arg_m_abi == ASR::abiType::BindC) { + type = character_type; + } else { + type = character_type->getPointerTo(); + } + break; + } case (ASR::ttypeType::Logical) : { ASR::Logical_t* v_type = ASR::down_cast(asr_type); a_kind = v_type->m_kind; @@ -1005,6 +1023,9 @@ namespace LCompilers { case (ASR::ttypeType::Character) : return_type = character_type; break; + case (ASR::ttypeType::Byte) : + return_type = character_type; + break; case (ASR::ttypeType::Logical) : return_type = llvm::Type::getInt1Ty(context); break; @@ -1203,6 +1224,9 @@ namespace LCompilers { case (ASR::ttypeType::Character) : return_type = character_type; break; + case (ASR::ttypeType::Byte) : + return_type = character_type; + break; case (ASR::ttypeType::Logical) : return_type = llvm::Type::getInt1Ty(context); break; @@ -1714,6 +1738,51 @@ namespace LCompilers { llvm::Value* r = LLVM::CreateLoad(*builder, create_ptr_gep(right, i)); return builder->CreateICmpEQ(l, r); } + case ASR::ttypeType::Byte: { + get_builder0() + str_cmp_itr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + llvm::Value* null_char = llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), + llvm::APInt(8, '\0')); + llvm::Value* idx = str_cmp_itr; + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)), + idx); + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + start_new_block(loophead); + { + llvm::Value* i = LLVM::CreateLoad(*builder, idx); + llvm::Value* l = LLVM::CreateLoad(*builder, create_ptr_gep(left, i)); + llvm::Value* r = LLVM::CreateLoad(*builder, create_ptr_gep(right, i)); + llvm::Value *cond = builder->CreateAnd( + builder->CreateICmpNE(l, null_char), + builder->CreateICmpNE(r, null_char) + ); + cond = builder->CreateAnd(cond, builder->CreateICmpEQ(l, r)); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + start_new_block(loopbody); + { + llvm::Value* i = LLVM::CreateLoad(*builder, idx); + i = builder->CreateAdd(i, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, i, idx); + } + + builder->CreateBr(loophead); + + // end + start_new_block(loopend); + llvm::Value* i = LLVM::CreateLoad(*builder, idx); + llvm::Value* l = LLVM::CreateLoad(*builder, create_ptr_gep(left, i)); + llvm::Value* r = LLVM::CreateLoad(*builder, create_ptr_gep(right, i)); + return builder->CreateICmpEQ(l, r); + } case ASR::ttypeType::Tuple: { ASR::Tuple_t* tuple_type = ASR::down_cast(asr_type); return tuple_api->check_tuple_equality(left, right, tuple_type, context, @@ -1863,6 +1932,72 @@ namespace LCompilers { llvm::Value* r = LLVM::CreateLoad(*builder, create_ptr_gep(right, i)); return builder->CreateICmpULT(l, r); } + case ASR::ttypeType::Byte: { + get_builder0() + str_cmp_itr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + llvm::Value* null_char = llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), + llvm::APInt(8, '\0')); + llvm::Value* idx = str_cmp_itr; + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)), + idx); + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + start_new_block(loophead); + { + llvm::Value* i = LLVM::CreateLoad(*builder, idx); + llvm::Value* l = LLVM::CreateLoad(*builder, create_ptr_gep(left, i)); + llvm::Value* r = LLVM::CreateLoad(*builder, create_ptr_gep(right, i)); + llvm::Value *cond = builder->CreateAnd( + builder->CreateICmpNE(l, null_char), + builder->CreateICmpNE(r, null_char) + ); + switch( overload_id ) { + case 0: { + pred = llvm::CmpInst::Predicate::ICMP_ULT; + break; + } + case 1: { + pred = llvm::CmpInst::Predicate::ICMP_ULE; + break; + } + case 2: { + pred = llvm::CmpInst::Predicate::ICMP_UGT; + break; + } + case 3: { + pred = llvm::CmpInst::Predicate::ICMP_UGE; + break; + } + default: { + throw CodeGenError("Un-recognized overload-id: " + std::to_string(overload_id)); + } + } + cond = builder->CreateAnd(cond, builder->CreateICmp(pred, l, r)); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + start_new_block(loopbody); + { + llvm::Value* i = LLVM::CreateLoad(*builder, idx); + i = builder->CreateAdd(i, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, i, idx); + } + + builder->CreateBr(loophead); + + // end + start_new_block(loopend); + llvm::Value* i = LLVM::CreateLoad(*builder, idx); + llvm::Value* l = LLVM::CreateLoad(*builder, create_ptr_gep(left, i)); + llvm::Value* r = LLVM::CreateLoad(*builder, create_ptr_gep(right, i)); + return builder->CreateICmpULT(l, r); + } case ASR::ttypeType::Tuple: { ASR::Tuple_t* tuple_type = ASR::down_cast(asr_type); return tuple_api->check_tuple_inequality(left, right, tuple_type, context, @@ -1923,6 +2058,7 @@ namespace LCompilers { break ; }; case ASR::ttypeType::Character: + case ASR::ttypeType::Byte: case ASR::ttypeType::FunctionType: case ASR::ttypeType::CPtr: { LLVM::CreateStore(*builder, src, dest); @@ -2008,6 +2144,7 @@ namespace LCompilers { case ASR::ttypeType::Logical: case ASR::ttypeType::Complex: case ASR::ttypeType::Character: + case ASR::ttypeType::Byte: case ASR::ttypeType::FunctionType: case ASR::ttypeType::CPtr: case ASR::ttypeType::Allocatable: { @@ -3744,6 +3881,70 @@ namespace LCompilers { hash = builder->CreateTrunc(hash, llvm::Type::getInt32Ty(context)); return builder->CreateSRem(hash, capacity); } + case ASR::ttypeType::Byte: { + // Polynomial rolling hash function for strings + llvm::Value* null_char = llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), + llvm::APInt(8, '\0')); + llvm::Value* p = llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 31)); + llvm::Value* m = llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 100000009)); + get_builder0() + hash_value = builder0.CreateAlloca(llvm::Type::getInt64Ty(context), nullptr, "hash_value"); + hash_iter = builder0.CreateAlloca(llvm::Type::getInt64Ty(context), nullptr, "hash_iter"); + polynomial_powers = builder0.CreateAlloca(llvm::Type::getInt64Ty(context), nullptr, "p_pow"); + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 0)), + hash_value); + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 1)), + polynomial_powers); + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 0)), + hash_iter); + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value* i = LLVM::CreateLoad(*builder, hash_iter); + llvm::Value* c = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(key, i)); + llvm::Value *cond = builder->CreateICmpNE(c, null_char); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + // for c in key: + // hash_value = (hash_value + (ord(c) + 1) * p_pow) % m + // p_pow = (p_pow * p) % m + llvm::Value* i = LLVM::CreateLoad(*builder, hash_iter); + llvm::Value* c = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(key, i)); + llvm::Value* p_pow = LLVM::CreateLoad(*builder, polynomial_powers); + llvm::Value* hash = LLVM::CreateLoad(*builder, hash_value); + c = builder->CreateZExt(c, llvm::Type::getInt64Ty(context)); + c = builder->CreateAdd(c, llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 1))); + c = builder->CreateMul(c, p_pow); + c = builder->CreateSRem(c, m); + hash = builder->CreateAdd(hash, c); + hash = builder->CreateSRem(hash, m); + LLVM::CreateStore(*builder, hash, hash_value); + p_pow = builder->CreateMul(p_pow, p); + p_pow = builder->CreateSRem(p_pow, m); + LLVM::CreateStore(*builder, p_pow, polynomial_powers); + i = builder->CreateAdd(i, llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 1))); + LLVM::CreateStore(*builder, i, hash_iter); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + llvm::Value* hash = LLVM::CreateLoad(*builder, hash_value); + hash = builder->CreateTrunc(hash, llvm::Type::getInt32Ty(context)); + return builder->CreateSRem(hash, capacity); + } case ASR::ttypeType::Tuple: { llvm::Value* tuple_hash = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)); ASR::Tuple_t* asr_tuple = ASR::down_cast(key_asr_type); @@ -5941,6 +6142,70 @@ namespace LCompilers { hash = builder->CreateTrunc(hash, llvm::Type::getInt32Ty(context)); return builder->CreateSRem(hash, capacity); } + case ASR::ttypeType::Byte: { + // Polynomial rolling hash function for bytes + llvm::Value* null_char = llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), + llvm::APInt(8, '\0')); + llvm::Value* p = llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 31)); + llvm::Value* m = llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 100000009)); + get_builder0() + hash_value = builder0.CreateAlloca(llvm::Type::getInt64Ty(context), nullptr, "hash_value"); + hash_iter = builder0.CreateAlloca(llvm::Type::getInt64Ty(context), nullptr, "hash_iter"); + polynomial_powers = builder0.CreateAlloca(llvm::Type::getInt64Ty(context), nullptr, "p_pow"); + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 0)), + hash_value); + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 1)), + polynomial_powers); + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 0)), + hash_iter); + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value* i = LLVM::CreateLoad(*builder, hash_iter); + llvm::Value* c = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(el, i)); + llvm::Value *cond = builder->CreateICmpNE(c, null_char); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + // for c in el: + // hash_value = (hash_value + (ord(c) + 1) * p_pow) % m + // p_pow = (p_pow * p) % m + llvm::Value* i = LLVM::CreateLoad(*builder, hash_iter); + llvm::Value* c = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(el, i)); + llvm::Value* p_pow = LLVM::CreateLoad(*builder, polynomial_powers); + llvm::Value* hash = LLVM::CreateLoad(*builder, hash_value); + c = builder->CreateZExt(c, llvm::Type::getInt64Ty(context)); + c = builder->CreateAdd(c, llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 1))); + c = builder->CreateMul(c, p_pow); + c = builder->CreateSRem(c, m); + hash = builder->CreateAdd(hash, c); + hash = builder->CreateSRem(hash, m); + LLVM::CreateStore(*builder, hash, hash_value); + p_pow = builder->CreateMul(p_pow, p); + p_pow = builder->CreateSRem(p_pow, m); + LLVM::CreateStore(*builder, p_pow, polynomial_powers); + i = builder->CreateAdd(i, llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 1))); + LLVM::CreateStore(*builder, i, hash_iter); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + llvm::Value* hash = LLVM::CreateLoad(*builder, hash_value); + hash = builder->CreateTrunc(hash, llvm::Type::getInt32Ty(context)); + return builder->CreateSRem(hash, capacity); + } case ASR::ttypeType::Tuple: { llvm::Value* tuple_hash = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)); ASR::Tuple_t* asr_tuple = ASR::down_cast(el_asr_type); diff --git a/src/libasr/pass/global_stmts.cpp b/src/libasr/pass/global_stmts.cpp index 7fc1e8e6c4..9517ab5e8d 100644 --- a/src/libasr/pass/global_stmts.cpp +++ b/src/libasr/pass/global_stmts.cpp @@ -51,6 +51,7 @@ void pass_wrap_global_stmts(Allocator &al, (ASRUtils::expr_type(value)->type == ASR::ttypeType::Real) || (ASRUtils::expr_type(value)->type == ASR::ttypeType::Complex) || (ASRUtils::expr_type(value)->type == ASR::ttypeType::Character) || + (ASRUtils::expr_type(value)->type == ASR::ttypeType::Byte) || (ASRUtils::expr_type(value)->type == ASR::ttypeType::List) || (ASRUtils::expr_type(value)->type == ASR::ttypeType::Tuple) || (ASRUtils::expr_type(value)->type == ASR::ttypeType::StructType)) { diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 2ef2a14551..18c2a3d88b 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -4310,6 +4310,7 @@ class SymbolTableVisitor : public CommonVisitor { } void visit_Module(const AST::Module_t &x) { + /*Shiju del*/ std::cout << "We here ya yellow bellied mongrels" << "----" << "\n"; ASR::asr_t *tmp0 = nullptr; if (current_scope) { LCOMPILERS_ASSERT(current_scope->asr_owner); @@ -4340,9 +4341,11 @@ class SymbolTableVisitor : public CommonVisitor { module_sym = ASR::down_cast(ASR::down_cast(tmp1)); parent_scope->add_symbol(module_name, ASR::down_cast(tmp1)); current_module_dependencies.reserve(al, 1); + /*Shiju del*/ std::cout << "Did way du goodie?" << "----" << "\n"; for (size_t i=0; im_dependencies = current_module_dependencies.p; @@ -5106,6 +5109,7 @@ class BodyVisitor : public CommonVisitor { } void visit_Module(const AST::Module_t &x) { + /*Shiju del*/ std::cout << "Inside the Module start" << "---" << "\n"; ASR::TranslationUnit_t *unit = ASR::down_cast2(asr); current_scope = unit->m_symtab; LCOMPILERS_ASSERT(current_scope != nullptr); @@ -5122,10 +5126,12 @@ class BodyVisitor : public CommonVisitor { Vec items; items.reserve(al, 4); current_module_dependencies.reserve(al, 1); + /*Shiju del*/ std::cout << "Inside the Module visitor" << "---" << "\n"; for (size_t i=0; i { // Wrap all the global statements into a Function LCompilers::PassOptions pass_options; pass_options.run_fun = func_name; + /*Shiju del*/ std::cout << "Random spot" << "----" << "\n"; pass_wrap_global_stmts(al, *unit, pass_options); + /*Shiju del*/ std::cout << "Another random spot" << "----" << "\n"; ASR::symbol_t *f_sym = unit->m_symtab->get_symbol(func_name); if (f_sym) { @@ -5188,6 +5196,7 @@ class BodyVisitor : public CommonVisitor { items.p = nullptr; items.n = 0; } + /*Shiju del*/ std::cout << "End of the line browski" << "----" << "\n"; tmp = asr; } @@ -9034,6 +9043,7 @@ Result python_ast_to_asr(Allocator &al, LocationManager } ASR::TranslationUnit_t *tu = ASR::down_cast2(unit); + /*shiju*/ std::cout << ";; ASR after SymbolTable Visitor\n" << pickle(*tu, false, true, compiler_options.po.with_intrinsic_mods) << "\n"; if (compiler_options.po.dump_all_passes) { std::ofstream outfile ("pass_00_initial_asr_01.clj"); outfile << ";; ASR after SymbolTable Visitor\n" << pickle(*tu, false, true, compiler_options.po.with_intrinsic_mods) << "\n"; @@ -9057,6 +9067,7 @@ Result python_ast_to_asr(Allocator &al, LocationManager #endif if (!compiler_options.symtab_only) { + /*Shiju del*/ std::cout << "Injection" << "---" << "\n"; auto res2 = body_visitor(al, lm, *ast_m, diagnostics, unit, main_module, module_name, ast_overload, allow_implicit_casting, eval_count); if (res2.ok) { From 712b97f8f108a19503db6dd62a1bacb6ab12280d Mon Sep 17 00:00:00 2001 From: swamishiju Date: Wed, 26 Mar 2025 22:45:54 +0530 Subject: [PATCH 4/4] Test: added test --- integration_tests/CMakeLists.txt | 1 + integration_tests/test_bytes_01.py | 20 ++++++++++++++++++++ src/libasr/codegen/asr_to_llvm.cpp | 1 - src/lpython/semantics/python_ast_to_asr.cpp | 11 ----------- 4 files changed, 21 insertions(+), 12 deletions(-) create mode 100644 integration_tests/test_bytes_01.py diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 2c360fd51d..31023e28a5 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -595,6 +595,7 @@ RUN(NAME test_set_discard LABELS cpython llvm llvm_jit) RUN(NAME test_set_from_list LABELS cpython llvm llvm_jit) RUN(NAME test_set_clear LABELS cpython llvm) RUN(NAME test_set_pop LABELS cpython llvm) +RUN(NAME test_bytes_01 LABELS cpython llvm llvm_jit) RUN(NAME test_global_set LABELS cpython llvm llvm_jit) RUN(NAME test_for_loop LABELS cpython llvm llvm_jit c) RUN(NAME modules_01 LABELS cpython llvm llvm_jit c wasm wasm_x86 wasm_x64) diff --git a/integration_tests/test_bytes_01.py b/integration_tests/test_bytes_01.py new file mode 100644 index 0000000000..6b5e2b9537 --- /dev/null +++ b/integration_tests/test_bytes_01.py @@ -0,0 +1,20 @@ +def f(): + a: bytes = b"This is a test string" + b: bytes = b"This is another test string" + c: bytes = b"""Bigger test string with docstrings + Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do + eiusmod tempor incididunt ut labore et dolore magna aliqua. """ + + +def g(a: bytes) -> bytes: + return a + + +def h() -> bytes: + bar: bytes + bar = g(b"fiwabcd") + return b"12jw19\\xq0" + + +f() +h() diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index cf139cf906..6129f3fedf 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -974,7 +974,6 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor LCOMPILERS_ASSERT(x.m_symtab->get_symbol(item) != nullptr); ASR::symbol_t *mod = x.m_symtab->get_symbol(item); - std::cout << mod->type << "unit"; visit_symbol(*mod); } diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 18c2a3d88b..2ef2a14551 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -4310,7 +4310,6 @@ class SymbolTableVisitor : public CommonVisitor { } void visit_Module(const AST::Module_t &x) { - /*Shiju del*/ std::cout << "We here ya yellow bellied mongrels" << "----" << "\n"; ASR::asr_t *tmp0 = nullptr; if (current_scope) { LCOMPILERS_ASSERT(current_scope->asr_owner); @@ -4341,11 +4340,9 @@ class SymbolTableVisitor : public CommonVisitor { module_sym = ASR::down_cast(ASR::down_cast(tmp1)); parent_scope->add_symbol(module_name, ASR::down_cast(tmp1)); current_module_dependencies.reserve(al, 1); - /*Shiju del*/ std::cout << "Did way du goodie?" << "----" << "\n"; for (size_t i=0; im_dependencies = current_module_dependencies.p; @@ -5109,7 +5106,6 @@ class BodyVisitor : public CommonVisitor { } void visit_Module(const AST::Module_t &x) { - /*Shiju del*/ std::cout << "Inside the Module start" << "---" << "\n"; ASR::TranslationUnit_t *unit = ASR::down_cast2(asr); current_scope = unit->m_symtab; LCOMPILERS_ASSERT(current_scope != nullptr); @@ -5126,12 +5122,10 @@ class BodyVisitor : public CommonVisitor { Vec items; items.reserve(al, 4); current_module_dependencies.reserve(al, 1); - /*Shiju del*/ std::cout << "Inside the Module visitor" << "---" << "\n"; for (size_t i=0; i { // Wrap all the global statements into a Function LCompilers::PassOptions pass_options; pass_options.run_fun = func_name; - /*Shiju del*/ std::cout << "Random spot" << "----" << "\n"; pass_wrap_global_stmts(al, *unit, pass_options); - /*Shiju del*/ std::cout << "Another random spot" << "----" << "\n"; ASR::symbol_t *f_sym = unit->m_symtab->get_symbol(func_name); if (f_sym) { @@ -5196,7 +5188,6 @@ class BodyVisitor : public CommonVisitor { items.p = nullptr; items.n = 0; } - /*Shiju del*/ std::cout << "End of the line browski" << "----" << "\n"; tmp = asr; } @@ -9043,7 +9034,6 @@ Result python_ast_to_asr(Allocator &al, LocationManager } ASR::TranslationUnit_t *tu = ASR::down_cast2(unit); - /*shiju*/ std::cout << ";; ASR after SymbolTable Visitor\n" << pickle(*tu, false, true, compiler_options.po.with_intrinsic_mods) << "\n"; if (compiler_options.po.dump_all_passes) { std::ofstream outfile ("pass_00_initial_asr_01.clj"); outfile << ";; ASR after SymbolTable Visitor\n" << pickle(*tu, false, true, compiler_options.po.with_intrinsic_mods) << "\n"; @@ -9067,7 +9057,6 @@ Result python_ast_to_asr(Allocator &al, LocationManager #endif if (!compiler_options.symtab_only) { - /*Shiju del*/ std::cout << "Injection" << "---" << "\n"; auto res2 = body_visitor(al, lm, *ast_m, diagnostics, unit, main_module, module_name, ast_overload, allow_implicit_casting, eval_count); if (res2.ok) {