From 7fc795dc72101a78809605fe1826e107c7b2cf1b Mon Sep 17 00:00:00 2001 From: Yeeethan00 Date: Wed, 8 Nov 2023 08:47:41 +0800 Subject: [PATCH] rename Mask -> CausalMask --- include/OpDefined.hpp | 4 ++-- src/backends/cpu/CPUBackend.cpp | 5 +++-- src/backends/cpu/{CPUMask.cpp => CPUCausalMask.cpp} | 10 +++++----- src/backends/cpu/{CPUMask.hpp => CPUCausalMask.hpp} | 10 +++++----- src/express/Express.cpp | 2 +- 5 files changed, 16 insertions(+), 15 deletions(-) rename src/backends/cpu/{CPUMask.cpp => CPUCausalMask.cpp} (76%) rename src/backends/cpu/{CPUMask.hpp => CPUCausalMask.hpp} (70%) diff --git a/include/OpDefined.hpp b/include/OpDefined.hpp index 2ba3fad4..da1eb699 100644 --- a/include/OpDefined.hpp +++ b/include/OpDefined.hpp @@ -16,7 +16,7 @@ enum OpType { SCALE, ROPE, RMSNORM, - MASK, + CAUSALMASK, LINEAR, ATTENTION, EMBEDDING, @@ -34,7 +34,7 @@ static const vector OpNames = { "Scale", "RoPE", "RMSNorm", - "Mask", + "CAUSALMASK", "Linear", "Attention", "Embedding", diff --git a/src/backends/cpu/CPUBackend.cpp b/src/backends/cpu/CPUBackend.cpp index 0205305b..54b90f38 100644 --- a/src/backends/cpu/CPUBackend.cpp +++ b/src/backends/cpu/CPUBackend.cpp @@ -1,7 +1,7 @@ #include "CPUBackend.hpp" #include "CPUView.hpp" #include "CPUAdd.hpp" -#include "CPUMask.hpp" +#include "CPUCausalMask.hpp" #include "CPUMatmul.hpp" #include "CPURMSNorm.hpp" #include "CPURoPE.hpp" @@ -53,7 +53,7 @@ void CPUBackend::registerOps() { // addCreator(MATMUL, &_temp); addCreator(ADD, (CPUBackend::Creator *)(new CPUAddCreator())); - addCreator(MASK, (CPUBackend::Creator *)(new CPUMaskCreator())); + addCreator(CAUSALMASK, (CPUBackend::Creator *)(new CPUCausalMaskCreator())); addCreator(MATMUL, (CPUBackend::Creator *)(new CPUMatmulCreator())); addCreator(RMSNORM, (CPUBackend::Creator *)(new CPURMSNormCreator())); addCreator(ROPE, (CPUBackend::Creator *)(new CPURoPECreator())); @@ -65,6 +65,7 @@ void CPUBackend::registerOps() { addCreator(EMBEDDING, (CPUBackend::Creator *)(new CPUEmbeddingCreator())); addCreator(MUL, (CPUBackend::Creator *)(new CPUMulCreator())); addCreator(VIEW, (CPUBackend::Creator *)(new CPUViewCreator())); + addCreator(CAUSALMASK, (CPUBackend::Creator *)(new CPUCausalMaskCreator())); } } // namespace mllm diff --git a/src/backends/cpu/CPUMask.cpp b/src/backends/cpu/CPUCausalMask.cpp similarity index 76% rename from src/backends/cpu/CPUMask.cpp rename to src/backends/cpu/CPUCausalMask.cpp index 6a7ae1cf..6716fcd0 100644 --- a/src/backends/cpu/CPUMask.cpp +++ b/src/backends/cpu/CPUCausalMask.cpp @@ -1,5 +1,5 @@ -#include "CPUMask.hpp" +#include "CPUCausalMask.hpp" #include namespace mllm { @@ -7,11 +7,11 @@ namespace mllm { // template class CPUMask; // template class CPUMask; -CPUMask::CPUMask(Backend *bn, string opName, bool multiThread) : +CPUCausalMask::CPUCausalMask(Backend *bn, string opName, bool multiThread) : Op(bn, opName) { } -ErrorCode CPUMask::reshape(vector> inputs, vector> outputs) { +ErrorCode CPUCausalMask::reshape(vector> inputs, vector> outputs) { std::cout << "CPUMask reshape" << std::endl; CHECK_EQ(inputs.size(), 1); CHECK_EQ(outputs.size(), 1); @@ -19,12 +19,12 @@ ErrorCode CPUMask::reshape(vector> inputs, vector> inputs, vector> outputs) { +ErrorCode CPUCausalMask::execute(vector> inputs, vector> outputs) { std::cout << "CPUMask()" << std::endl; int batch_size = inputs[0]->batch(); int head_num = inputs[0]->head(); diff --git a/src/backends/cpu/CPUMask.hpp b/src/backends/cpu/CPUCausalMask.hpp similarity index 70% rename from src/backends/cpu/CPUMask.hpp rename to src/backends/cpu/CPUCausalMask.hpp index 561dd0ef..46405704 100644 --- a/src/backends/cpu/CPUMask.hpp +++ b/src/backends/cpu/CPUCausalMask.hpp @@ -6,10 +6,10 @@ namespace mllm { -class CPUMask final : public Op { +class CPUCausalMask final : public Op { public: - CPUMask(Backend *bn, string opName, bool multiThread); - virtual ~CPUMask() = default; + CPUCausalMask(Backend *bn, string opName, bool multiThread); + virtual ~CPUCausalMask() = default; virtual ErrorCode reshape(vector> inputs, vector> outputs) override; virtual ErrorCode load(ParamLoader &loader) override; virtual ErrorCode execute(vector> inputs, vector> outputs) override; @@ -18,10 +18,10 @@ class CPUMask final : public Op { bool support_multi_thread_ = false; }; -class CPUMaskCreator : public CPUBackend::Creator { +class CPUCausalMaskCreator : public CPUBackend::Creator { public: virtual Op *create(OpParam op_param, Backend *bn, string name) const { - return new CPUMask(bn, name, false); + return new CPUCausalMask(bn, name, false); } }; } // namespace mllm diff --git a/src/express/Express.cpp b/src/express/Express.cpp index 15adaa03..fc7cc090 100644 --- a/src/express/Express.cpp +++ b/src/express/Express.cpp @@ -123,7 +123,7 @@ NetTensor *_Causalmask(Context *ctx, std::vector inputs, string nam out_tensor->type = inputs[0]->type; ctx->idx++; _STORE_OUT_TENSOR - _NEW_OP(mllm::MASK) + _NEW_OP(mllm::CAUSALMASK) _UPDATE_INPUT_TENSORS out_tensor->in = net_op_; return out_tensor;