Skip to content

Commit

Permalink
rename Mask -> CausalMask
Browse files Browse the repository at this point in the history
  • Loading branch information
yirongjie committed Nov 8, 2023
1 parent c6acd4b commit 7fc795d
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 15 deletions.
4 changes: 2 additions & 2 deletions include/OpDefined.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ enum OpType {
SCALE,
ROPE,
RMSNORM,
MASK,
CAUSALMASK,
LINEAR,
ATTENTION,
EMBEDDING,
Expand All @@ -34,7 +34,7 @@ static const vector<string> OpNames = {
"Scale",
"RoPE",
"RMSNorm",
"Mask",
"CAUSALMASK",
"Linear",
"Attention",
"Embedding",
Expand Down
5 changes: 3 additions & 2 deletions src/backends/cpu/CPUBackend.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "CPUBackend.hpp"
#include "CPUView.hpp"
#include "CPUAdd.hpp"
#include "CPUMask.hpp"
#include "CPUCausalMask.hpp"
#include "CPUMatmul.hpp"
#include "CPURMSNorm.hpp"
#include "CPURoPE.hpp"
Expand Down Expand Up @@ -53,7 +53,7 @@ void CPUBackend::registerOps() {
// addCreator(MATMUL, &_temp);

addCreator(ADD, (CPUBackend::Creator *)(new CPUAddCreator()));
addCreator(MASK, (CPUBackend::Creator *)(new CPUMaskCreator()));
addCreator(CAUSALMASK, (CPUBackend::Creator *)(new CPUCausalMaskCreator()));
addCreator(MATMUL, (CPUBackend::Creator *)(new CPUMatmulCreator()));
addCreator(RMSNORM, (CPUBackend::Creator *)(new CPURMSNormCreator()));
addCreator(ROPE, (CPUBackend::Creator *)(new CPURoPECreator()));
Expand All @@ -65,6 +65,7 @@ void CPUBackend::registerOps() {
addCreator(EMBEDDING, (CPUBackend::Creator *)(new CPUEmbeddingCreator()));
addCreator(MUL, (CPUBackend::Creator *)(new CPUMulCreator()));
addCreator(VIEW, (CPUBackend::Creator *)(new CPUViewCreator()));
addCreator(CAUSALMASK, (CPUBackend::Creator *)(new CPUCausalMaskCreator()));
}

} // namespace mllm
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@

#include "CPUMask.hpp"
#include "CPUCausalMask.hpp"
#include <cmath>

namespace mllm {

// template class CPUMask;
// template class CPUMask;

CPUMask::CPUMask(Backend *bn, string opName, bool multiThread) :
CPUCausalMask::CPUCausalMask(Backend *bn, string opName, bool multiThread) :
Op(bn, opName) {
}

ErrorCode CPUMask::reshape(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) {
ErrorCode CPUCausalMask::reshape(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) {
std::cout << "CPUMask reshape" << std::endl;
CHECK_EQ(inputs.size(), 1);
CHECK_EQ(outputs.size(), 1);
outputs[0]->reshape(inputs[0]->batch(), inputs[0]->head(), inputs[0]->sequence(), inputs[0]->dimension());
return NO_ERROR;
}

ErrorCode CPUMask::load(ParamLoader &loader) {
ErrorCode CPUCausalMask::load(ParamLoader &loader) {
std::cout << "CPUMask load" << std::endl;
return NO_ERROR;
}

ErrorCode CPUMask::execute(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) {
ErrorCode CPUCausalMask::execute(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) {
std::cout << "CPUMask()" << std::endl;
int batch_size = inputs[0]->batch();
int head_num = inputs[0]->head();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

namespace mllm {

class CPUMask final : public Op {
class CPUCausalMask final : public Op {
public:
CPUMask(Backend *bn, string opName, bool multiThread);
virtual ~CPUMask() = default;
CPUCausalMask(Backend *bn, string opName, bool multiThread);
virtual ~CPUCausalMask() = default;
virtual ErrorCode reshape(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) override;
virtual ErrorCode load(ParamLoader &loader) override;
virtual ErrorCode execute(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) override;
Expand All @@ -18,10 +18,10 @@ class CPUMask final : public Op {
bool support_multi_thread_ = false;
};

class CPUMaskCreator : public CPUBackend::Creator {
class CPUCausalMaskCreator : public CPUBackend::Creator {
public:
virtual Op *create(OpParam op_param, Backend *bn, string name) const {
return new CPUMask(bn, name, false);
return new CPUCausalMask(bn, name, false);
}
};
} // namespace mllm
Expand Down
2 changes: 1 addition & 1 deletion src/express/Express.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ NetTensor *_Causalmask(Context *ctx, std::vector<NetTensor *> inputs, string nam
out_tensor->type = inputs[0]->type;
ctx->idx++;
_STORE_OUT_TENSOR
_NEW_OP(mllm::MASK)
_NEW_OP(mllm::CAUSALMASK)
_UPDATE_INPUT_TENSORS
out_tensor->in = net_op_;
return out_tensor;
Expand Down

0 comments on commit 7fc795d

Please sign in to comment.