Skip to content

Commit

Permalink
Map OneHot & unit test (#258)
Browse files Browse the repository at this point in the history
Signed-off-by: yuenan.li <[email protected]>

Co-authored-by: yuenan.li <[email protected]>
  • Loading branch information
liyuenan2333 and yuenan.li authored Jan 5, 2022
1 parent 8e4ab68 commit 7c63ba6
Show file tree
Hide file tree
Showing 5 changed files with 439 additions and 1 deletion.
1 change: 1 addition & 0 deletions include/tim/vx/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
#include "tim/vx/ops/maxunpool2d.h"
#include "tim/vx/ops/moments.h"
#include "tim/vx/ops/nbg.h"
#include "tim/vx/ops/onehot.h"
#include "tim/vx/ops/pad.h"
#include "tim/vx/ops/pool2d.h"
#include "tim/vx/ops/reduce.h"
Expand Down
59 changes: 59 additions & 0 deletions include/tim/vx/ops/onehot.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_VX_OPERATION_ONE_HOT_H_
#define TIM_VX_OPERATION_ONE_HOT_H_
#include "tim/vx/direct_map_op.h"

namespace tim {
namespace vx {
namespace ops {

/**
* ## OneHot
*
* Create a one-hot tensor.
*
* - depth : A scalar defining the depth of the one hot dimension.
* - on_value : A scalar defining the value to fill in output.
* - off_value : A scalar defining the value to fill in output.
* - axis : The axis to fill.
*/

class OneHot : public DirectMapOp {
public:
OneHot(Graph* graph, int32_t depth, float on_value = 1, float off_value = 0,
int32_t axis = 0);

std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;

protected:
int32_t depth_;
float on_value_;
float off_value_;
int32_t axis_;
};
} // namespace ops
} // namespace vx
} // namespace tim
#endif
2 changes: 1 addition & 1 deletion src/tim/vx/ops/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ GroupedConv1d|GROUPED_CONV1D|Mapped|[tf.keras.layers.Conv1D](https://tensorflow.
||CEIL|Planned 21Q4|[tf.math.ceil](https://tensorflow.google.cn/api_docs/python/tf/math/ceil)
||SEQUENCE_MASK|Planned 21Q4|[tf.math.ceil](https://tensorflow.google.cn/api_docs/python/tf/sequence_mask)
||REPEAT|Planned 21Q4|[tf.repeat](https://tensorflow.google.cn/api_docs/python/tf/repeat)
||ONE_HOT|Planned 21Q4|[tf.one_hot](https://tensorflow.google.cn/api_docs/python/tf/one_hot)
OneHot|ONE_HOT|Mapped|[tf.one_hot](https://tensorflow.google.cn/api_docs/python/tf/one_hot)
||NMS|Planned 21Q4|[tf.image.non_max_suppression](https://tensorflow.google.cn/api_docs/python/tf/image/non_max_suppression)
||SCATTER_ND_UPDATE|Planned 21Q4|[tf.compat.v1.scatter_nd_update](https://tensorflow.google.cn/api_docs/python/tf/compat/v1/scatter_nd_update)
||GELU|Planned 21Q4|[tf.nn.gelu](https://tensorflow.google.cn/api_docs/python/tf/nn/gelu)
Expand Down
52 changes: 52 additions & 0 deletions src/tim/vx/ops/onehot.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#include "tim/vx/ops/onehot.h"

#include "direct_map_op_impl.h"
#include "vsi_nn_pub.h"

namespace tim {
namespace vx {
namespace ops {
OneHot::OneHot(Graph* graph, int32_t depth, float on_value, float off_value,
int32_t axis)
: DirectMapOp(graph, VSI_NN_OP_ONE_HOT),
depth_(depth),
on_value_(on_value),
off_value_(off_value),
axis_(axis) {
this->impl()->node()->nn_param.one_hot.depth = depth_;
this->impl()->node()->nn_param.one_hot.on_value = on_value_;
this->impl()->node()->nn_param.one_hot.off_value = off_value_;
this->impl()->node()->nn_param.one_hot.axis = axis_;
}

std::shared_ptr<Operation> OneHot::Clone(std::shared_ptr<Graph>& graph) const {
return graph->CreateOperation<OneHot>(this->depth_, this->on_value_,
this->off_value_, this->axis_);
}

} // namespace ops
} // namespace vx
} // namespace tim
Loading

0 comments on commit 7c63ba6

Please sign in to comment.