Skip to content

Commit

Permalink
Added implementation and test cases for MIRMetadata
Browse files Browse the repository at this point in the history
  • Loading branch information
AjayBrahmakshatriya committed May 19, 2020
1 parent 08c3ae8 commit 8a2e1e0
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 1 deletion.
29 changes: 29 additions & 0 deletions include/graphit/midend/mir.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <iostream>
#include <unordered_set>
#include <graphit/midend/mir_visitor.h>
#include <graphit/midend/mir_metadata.h>
#include <graphit/midend/var.h>
#include <assert.h>
#include <graphit/midend/field_vector_property.h>
Expand Down Expand Up @@ -54,6 +55,8 @@ namespace graphit {
return to<T>(cloneNode());
}

// We use a single map to hold all metadata on the MIR Node
std::unordered_map<std::string, std::shared_ptr<MIRMetadata>> metadata_map;
protected:
template<typename T = MIRNode>
std::shared_ptr<T> self() {
Expand All @@ -68,6 +71,32 @@ namespace graphit {
// as I slowly add in support for copy functionalities
return nullptr;
};
public:
// Functions to set and retrieve metadata of different types
template<typename T>
void setMetadata(std::string mdname, T val) {
typename MIRMetadataImpl<T>::Ptr mdnode = std::make_shared<MIRMetadataImpl<T>>(val);
metadata_map[mdname] = mdnode;
}
// This function is safe to be called even if the metadata with
// the specified name doesn't exist
template<typename T>
bool hasMetadata(std::string mdname) {
if (metadata_map.find(mdname) == metadata_map.end())
return false;
typename MIRMetadata::Ptr mdnode = metadata_map[mdname];
if (!mdnode->isa<T>())
return false;
return true;
}
// This function should be called only after confirming that the
// metadata with the given name exists
template <typename T>
T getMetadata(std::string mdname) {
assert(hasMetadata<T>(mdname));
typename MIRMetadata::Ptr mdnode = metadata_map[mdname];
return mdnode->to<T>()->val;
}
};

struct Expr : public MIRNode {
Expand Down
46 changes: 46 additions & 0 deletions include/graphit/midend/mir_metadata.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#ifndef MIR_METADATA_H
#define MIR_METADATA_H

#include <memory>
#include <cassert>
namespace graphit {
namespace mir {

template<typename T>
class MIRMetadataImpl;

// The abstract class for the mir metadata
// Different templated metadata types inherit from this type
class MIRMetadata: public std::enable_shared_from_this<MIRMetadata> {
public:
typedef std::shared_ptr<MIRMetadata> Ptr;
virtual ~MIRMetadata() = default;


template <typename T>
bool isa (void) {
if(std::dynamic_pointer_cast<MIRMetadataImpl<T>>(shared_from_this()))
return true;
return false;
}
template <typename T>
std::shared_ptr<MIRMetadataImpl<T>> to(void) {

This comment has been minimized.

Copy link
@yunmingzhang17

yunmingzhang17 May 19, 2020

Collaborator

Could we use " typename MIRMetadataImp::Ptr" instead of "std::shared_ptr<MIRMetadataImpl>"? I think " typename MIRMetadataImp::Ptr" is more consistent with our other type declarations?

std::shared_ptr<MIRMetadataImpl<T>> ret = std::dynamic_pointer_cast<MIRMetadataImpl<T>>(shared_from_this());
assert(ret != nullptr);
return ret;
}
};

// Templated metadata class for each type
template<typename T>
class MIRMetadataImpl: public MIRMetadata {
public:
typedef std::shared_ptr<MIRMetadataImpl<T>> Ptr;
T val;
MIRMetadataImpl(T _val): val(_val) {
}
};

}
}
#endif
74 changes: 73 additions & 1 deletion test/c++/midend_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,76 @@ TEST_F(MidendTest, SimpleVertexSetDeclAllocWithMain) {
"const vertices : vertexset{Vertex} = new vertexset{Vertex}(5);\n"
"func main() print 4; end");
EXPECT_EQ (0, basicTest(is));
}
}

// Test cases for the MIRMetadata API
TEST_F(MidendTest, SimpleMetadataTest) {
istringstream is("func main() print 4; end");
EXPECT_EQ(0, basicTest(is));
EXPECT_EQ(true, mir_context_->isFunction("main"));

mir::FuncDecl::Ptr main_func = mir_context_->getFunction("main");

main_func->setMetadata<bool>("basic_boolean_md", true);
main_func->setMetadata<int>("basic_int_md", 42);
EXPECT_EQ(true, main_func->hasMetadata<bool>("basic_boolean_md"));
EXPECT_EQ(true, main_func->getMetadata<bool>("basic_boolean_md"));

EXPECT_EQ(true, main_func->hasMetadata<int>("basic_int_md"));
EXPECT_EQ(42, main_func->getMetadata<int>("basic_int_md"));

}
TEST_F(MidendTest, SimpleMetadataTestNoExist) {
istringstream is("func main() print 4; end");
EXPECT_EQ(0, basicTest(is));
EXPECT_EQ(true, mir_context_->isFunction("main"));

mir::FuncDecl::Ptr main_func = mir_context_->getFunction("main");

main_func->setMetadata<int>("basic_int_md", 42);
EXPECT_EQ(false, main_func->hasMetadata<int>("other_int_md"));
EXPECT_EQ(false, main_func->hasMetadata<bool>("basic_int_md"));
}

TEST_F(MidendTest, SimpleMetadataTestString) {
istringstream is("func main() print 4; end");
EXPECT_EQ(0, basicTest(is));
EXPECT_EQ(true, mir_context_->isFunction("main"));

mir::FuncDecl::Ptr main_func = mir_context_->getFunction("main");

main_func->setMetadata<std::string>("basic_str_md", "md value");
EXPECT_EQ(true, main_func->hasMetadata<std::string>("basic_str_md"));
EXPECT_EQ("md value", main_func->getMetadata<std::string>("basic_str_md"));
}

TEST_F(MidendTest, SimpleMetadataTestMIRNodeAsMD) {
istringstream is("const val:int = 42;\nfunc main() print val; end");
EXPECT_EQ(0, basicTest(is));
EXPECT_EQ(true, mir_context_->isFunction("main"));
EXPECT_EQ(1, mir_context_->getConstants().size());

mir::FuncDecl::Ptr main_func = mir_context_->getFunction("main");
mir::VarDecl::Ptr decl = mir_context_->getConstants()[0];

main_func->setMetadata<mir::MIRNode::Ptr>("used_var_md", decl);

EXPECT_EQ(true, main_func->hasMetadata<mir::MIRNode::Ptr>("used_var_md"));
mir::MIRNode::Ptr mdnode = main_func->getMetadata<mir::MIRNode::Ptr>("used_var_md");
EXPECT_EQ(true, mir::isa<mir::VarDecl>(mdnode));
}

TEST_F(MidendTest, SimpleMetadataTestMIRNodeVectorAsMD) {
istringstream is("const val:int = 42;\nconst val2: int = 55;\nfunc main() print val + val2; end");
EXPECT_EQ(0, basicTest(is));
EXPECT_EQ(true, mir_context_->isFunction("main"));
EXPECT_EQ(2, mir_context_->getConstants().size());

mir::FuncDecl::Ptr main_func = mir_context_->getFunction("main");
std::vector<mir::VarDecl::Ptr> decls = mir_context_->getConstants();

main_func->setMetadata<std::vector<mir::VarDecl::Ptr>>("used_vars_md", decls);

EXPECT_EQ(true, main_func->hasMetadata<std::vector<mir::VarDecl::Ptr>>("used_vars_md"));
EXPECT_EQ(2, main_func->getMetadata<std::vector<mir::VarDecl::Ptr>>("used_vars_md").size());
}

This comment has been minimized.

Copy link
@yunmingzhang17

yunmingzhang17 May 19, 2020

Collaborator

We should probably add some test cases for flexInt once it is integrated?

0 comments on commit 8a2e1e0

Please sign in to comment.