Skip to content

Commit

Permalink
fixed bug when broadcast dimensions is negative
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Xin <[email protected]>
  • Loading branch information
Chen Xin authored and sunshinemyson committed Oct 8, 2022
1 parent a038df2 commit 3fed6d6
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
6 changes: 3 additions & 3 deletions include/tim/vx/ops/broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ namespace ops {
*
* Input:
* - input.
*
*
* Attribute:
* - shape: the shape which broadcast to.
* - dimensions(optional): Which dimension in the target shape each dimension
* - dimensions(optional): Which dimension in the target shape each dimension
* of the operand shape corresponds to. For BroadcastInDim.
*/

Expand All @@ -51,7 +51,7 @@ class Broadcast : public BuiltinOp {

protected:
const std::vector<int32_t> shape_;
const std::vector<int32_t> dimensions_;
std::vector<int32_t> dimensions_;
};

} // namespace ops
Expand Down
4 changes: 4 additions & 0 deletions src/tim/vx/ops/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ Broadcast::Broadcast(Graph* graph, const std::vector<int32_t>& shape,
this->impl()->node()->nn_param.expand_broadcast.dimensions_num = dimensions_.size();
if (dimensions.size() > 0)
{
int dim_num = shape.size();
for (uint32_t i = 0; i < dimensions.size(); ++i) {
dimensions_[i] += (dimensions[i] < 0 ? dim_num : 0U);
}
this->impl()->node()->nn_param.expand_broadcast.dimensions = (uint32_t*)dimensions_.data();
} else {
this->impl()->node()->nn_param.expand_broadcast.dimensions = nullptr;
Expand Down

0 comments on commit 3fed6d6

Please sign in to comment.