Skip to content

Commit

Permalink
Add the shape inference function in the builder. (PaddlePaddle#457)
Browse files Browse the repository at this point in the history
* Add InferShape for Builer.

* move x86 tune-key collection out of infershape

* Replace std::unordered_map with absl::flat_hash_map in AttrMapType.

* Update some errors about the type of InferShape & InferDtype.

Co-authored-by: wangone <[email protected]>
  • Loading branch information
wzzju and wenming2014 authored Sep 29, 2021
1 parent 291de9a commit 4b443c2
Show file tree
Hide file tree
Showing 18 changed files with 273 additions and 262 deletions.
46 changes: 42 additions & 4 deletions cinn/frontend/base_builder.cc
Original file line number Diff line number Diff line change
@@ -1,27 +1,40 @@
#include "cinn/frontend/base_builder.h"

#include <algorithm>
#include <functional>
#include <string>
#include <utility>
#include <vector>

#include "cinn/common/common.h"
#include "cinn/common/context.h"
#include "cinn/common/type.h"
#include "cinn/frontend/syntax.h"
#include "cinn/hlir/framework/node.h"
#include "cinn/hlir/framework/op.h"

namespace cinn {
namespace frontend {

using common::Context;
using common::Type;
using hlir::framework::AttrMapType;
using hlir::framework::Operator;
using hlir::framework::shape_t;

BaseBuilder::BaseBuilder(const std::string& name) : name_(name) {}

Program BaseBuilder::Build() {
Program program{std::move(instrs_), std::move(inputs_)};
program.Validate();
return program;
}

Placeholder BaseBuilder::CreateInput(const common::Type& type,
const std::vector<int>& shape,
const std::string& id_hint) {
Placeholder BaseBuilder::CreateInput(const Type& type, const std::vector<int>& shape, const std::string& id_hint) {
if (!id_hint.empty()) {
CheckVarNameValid(id_hint);
}
std::string id = id_hint.empty() ? common::Context::Global().NewName("placeholder") : id_hint;
std::string id = id_hint.empty() ? Context::Global().NewName("placeholder") : id_hint;

inputs_.emplace_back(id);
auto& var = inputs_.back();
Expand All @@ -30,5 +43,30 @@ Placeholder BaseBuilder::CreateInput(const common::Type& type,
return Placeholder(var);
}

void BaseBuilder::InferShape(Instruction instr) const {
using shape_func_t = std::function<std::vector<shape_t>(const std::vector<shape_t>&, const AttrMapType&)>;
using type_func_t = std::function<std::vector<Type>(const std::vector<Type>&, const AttrMapType&)>;
const auto& op_infershape = Operator::GetAttrs<shape_func_t>("infershape");
const auto& op_inferdtype = Operator::GetAttrs<type_func_t>("inferdtype");

size_t size = instr->inputs.size();
std::vector<shape_t> in_shapes(size);
std::vector<Type> in_types(size);
std::transform(
instr->inputs.begin(), instr->inputs.end(), in_shapes.begin(), [](const Variable& var) { return var->shape; });
std::transform(
instr->inputs.begin(), instr->inputs.end(), in_types.begin(), [](const Variable& var) { return var->type; });

auto key = Operator::Get(instr->op_type);
auto out_shapes = op_infershape[key](in_shapes, instr->attrs);
auto out_types = op_inferdtype[key](in_types, instr->attrs);

auto& outs = instr->outputs;
for (size_t i = 0; i < outs.size(); i++) {
outs[i]->shape = out_shapes[i];
outs[i]->type = out_types[i];
}
}

} // namespace frontend
} // namespace cinn
6 changes: 4 additions & 2 deletions cinn/frontend/base_builder.h
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#pragma once

#include <string>
#include <utility>
#include <vector>

#include "cinn/common/type.h"
#include "cinn/frontend/syntax.h"
#include "cinn/hlir/framework/op.h"

namespace cinn {
namespace frontend {

class BaseBuilder {
public:
explicit BaseBuilder(const std::string& name) : name_(name) {}
explicit BaseBuilder(const std::string& name);

Program Build();

Expand All @@ -26,6 +26,8 @@ class BaseBuilder {
protected:
void AppendInstruction(const Instruction& instr) { instrs_.push_back(instr); }

void InferShape(Instruction instr) const;

std::string name_;
std::vector<Instruction> instrs_;
std::vector<Variable> inputs_;
Expand Down
18 changes: 17 additions & 1 deletion cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include "cinn/frontend/net_builder.h"

#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include "cinn/frontend/syntax.h"

Expand All @@ -11,6 +11,7 @@ namespace frontend {

Variable NetBuilder::add(const Variable& a, const Variable& b) {
Instruction instr("elementwise_add", {a, b});
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
}
Expand All @@ -19,6 +20,7 @@ Variable NetBuilder::mul(const Variable& a, const Variable& b, int x_num_col_dim
Instruction instr("mul", {a, b});
instr.SetAttr("x_num_col_dims", x_num_col_dims);
instr.SetAttr("y_num_col_dims", y_num_col_dims);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
}
Expand All @@ -28,33 +30,38 @@ Variable NetBuilder::mulbias(
Instruction instr("mulbias", {a, b, c});
instr.SetAttr("x_num_col_dims", x_num_col_dims);
instr.SetAttr("y_num_col_dims", y_num_col_dims);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(1);
}

Variable NetBuilder::elementwise_add(const Variable& a, const Variable& b, int axis) {
Instruction instr("elementwise_add", {a, b});
instr.SetAttr("axis", axis);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::elementwise_mul(const Variable& a, const Variable& b, int axis) {
Instruction instr("elementwise_mul", {a, b});
instr.SetAttr("axis", axis);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::relu(const Variable& a) {
Instruction instr("relu", {a});
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::relu6(const Variable& a, float threshold) {
Instruction instr("relu6", {a});
instr.SetAttr("threshold", threshold);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
}
Expand All @@ -75,6 +82,7 @@ Variable NetBuilder::conv2d(const Variable& a,
instr.SetAttr("groups", groups);
instr.SetAttr("data_format", data_format);
instr.SetAttr("padding_algorithm", padding_algorithm);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
}
Expand All @@ -95,6 +103,7 @@ Variable NetBuilder::depthwise_conv2d(const Variable& a,
instr.SetAttr("groups", groups);
instr.SetAttr("data_format", data_format);
instr.SetAttr("padding_algorithm", padding_algorithm);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
}
Expand Down Expand Up @@ -122,6 +131,7 @@ Variable NetBuilder::pool2d(const Variable& a,
instr.SetAttr("data_format", data_format);
instr.SetAttr("adaptive", adaptive);
instr.SetAttr("padding_algorithm", padding_algorithm);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
}
Expand All @@ -139,6 +149,7 @@ Variable NetBuilder::batchnorm(const Variable& a,
instr.SetAttr("epsilon", epsilon);
instr.SetAttr("momentum", momentum);
instr.SetAttr("data_layout", data_layout);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
}
Expand All @@ -148,6 +159,7 @@ Variable NetBuilder::scale(const Variable& a, float scale, float bias, bool bias
instr.SetAttr("scale", scale);
instr.SetAttr("bias", bias);
instr.SetAttr("bias_after_scale", bias_after_scale);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
}
Expand All @@ -156,12 +168,14 @@ Variable NetBuilder::softmax(const Variable& a, int axis, const std::string& dat
Instruction instr("softmax", {a});
instr.SetAttr("axis", axis);
instr.SetAttr("data_format", data_format);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::sigmoid(const Variable& a) {
Instruction instr("sigmoid", {a});
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
}
Expand All @@ -178,6 +192,7 @@ Variable NetBuilder::slice(const Variable& a,
instr.SetAttr("ends", ends);
instr.SetAttr("infer_flags", infer_flags);
instr.SetAttr("decrease_axis", decrease_axis);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
}
Expand All @@ -186,6 +201,7 @@ Variable NetBuilder::dropout_infer(const Variable& a, float dropout_prob, const
Instruction instr("dropout_infer", {a});
instr.SetAttr("dropout_prob", dropout_prob);
instr.SetAttr("dropout_implementation", dropout_implementation);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
}
Expand Down
7 changes: 2 additions & 5 deletions cinn/frontend/net_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,11 @@ TEST(net_build, program_execute_multi_elementwise_add) {
#else
Target target = common::DefaultHostTarget();
#endif

auto graph = std::make_shared<hlir::framework::Graph>(program, target);
LOG(INFO) << "graph:\n" << graph->Visualize();

hlir::framework::ApplyPass(graph.get(), "InferShape");
auto scope = BuildScope(target, graph);

hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

Expand Down Expand Up @@ -109,11 +108,9 @@ TEST(net_build, program_execute_fc) {
#else
Target target = common::DefaultHostTarget();
#endif
auto graph = std::make_shared<hlir::framework::Graph>(program, target);

hlir::framework::ApplyPass(graph.get(), "InferShape");
auto graph = std::make_shared<hlir::framework::Graph>(program, target);
auto scope = BuildScope(target, graph);

hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

Expand Down
4 changes: 3 additions & 1 deletion cinn/hlir/framework/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ Graph::Graph(const frontend::Program& prog, const Target& target) {
}
int out_idx = 0;
for (auto& output_v : temp->outputs) {
auto* output_data = new NodeData(node_ptr, out_idx++, 0, output_v->id);
dtype_dict[output_v->id] = output_v->type;
shape_dict[output_v->id] = output_v->shape;
auto* output_data = new NodeData(node_ptr, out_idx++, 0, output_v->id);
node_tmp->LinkTo(output_data);
this->RegisterNode(output_v->id, output_data);
}
Expand Down
30 changes: 15 additions & 15 deletions cinn/hlir/framework/node.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
#pragma once
#include <absl/container/flat_hash_map.h>
#include <absl/types/variant.h>

#include <memory>
#include <string>
#include <tuple>
#include <absl/container/flat_hash_map.h>
#include <utility>
#include <vector>

#include <absl/types/variant.h>
#include <absl/container/flat_hash_map.h>

#include "cinn/common/graph_utils.h"
#include "cinn/common/shared.h"
#include "cinn/hlir/framework/op.h"
Expand All @@ -19,15 +18,16 @@ namespace framework {
class Node;
class NodeData;

using NodePtr = std::shared_ptr<Node>;
using AttrType = absl::variant<bool,
float,
int,
std::string,
std::vector<bool>,
std::vector<int>,
std::vector<float>,
std::vector<std::string>>;
using NodePtr = std::shared_ptr<Node>;
using AttrType = absl::variant<bool,
float,
int,
std::string,
std::vector<bool>,
std::vector<int>,
std::vector<float>,
std::vector<std::string>>;
using AttrMapType = absl::flat_hash_map<std::string, AttrType>;

/**
* \brief Attributes of each node in graph.
Expand Down Expand Up @@ -93,7 +93,7 @@ class Node : public common::GraphNode {
inline uint32_t num_inputs() { return is_variable() ? 1 : this->op()->num_inputs; }

template <class... Args>
static NodePtr Create(Args &&... args) {
static NodePtr Create(Args &&...args) {
return std::make_shared<Node>(std::forward<Args>(args)...);
}

Expand Down Expand Up @@ -125,7 +125,7 @@ class NodeData : public common::GraphNode {
const char *op_name,
std::string node_name,
std::vector<NodeData> inputs,
std::string id = nullptr,
std::string id = nullptr,
absl::flat_hash_map<std::string, attr_t> attrs = absl::flat_hash_map<std::string, attr_t>()) {
auto res = std::make_shared<NodeData>();
res->id_ = std::move(id);
Expand Down
Loading

0 comments on commit 4b443c2

Please sign in to comment.