Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fused get target_offsets #9536

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ struct FusedGetBounddingBoxesCoordGradCaptureState : public AutoGradCaptureState
std::vector<bool> requires_grad;
};

class FusedGetBounddingBoxesCoordGrad
class FusedYolov5GetBounddingBoxesCoordGrad
: public OpExprGradFunction<FusedGetBounddingBoxesCoordGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Expand Down Expand Up @@ -54,7 +54,7 @@ class FusedGetBounddingBoxesCoordGrad
const auto& b2_y2_diff = out_grads.at(7);

in_grads->resize(8);
auto result = JUST(functional::FusedGetBounddingBoxesCoordGrad(
auto result = JUST(functional::FusedYolov5GetBounddingBoxesCoordGrad(
b1_x1_diff, b1_x2_diff, b1_y1_diff, b1_y2_diff, b2_x1_diff, b2_x2_diff, b2_y1_diff,
b2_y2_diff));
CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN);
Expand All @@ -65,7 +65,8 @@ class FusedGetBounddingBoxesCoordGrad
}
};

REGISTER_OP_EXPR_GRAD_FUNCTION("fused_get_boundding_boxes_coord", FusedGetBounddingBoxesCoordGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_yolov5_get_boundding_boxes_coord",
FusedYolov5GetBounddingBoxesCoordGrad);

} // namespace one
} // namespace oneflow
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct FusedCenterCaptureState : public AutoGradCaptureState {
std::vector<bool> requires_grad;
};

class FusedCenterGrad : public OpExprGradFunction<FusedCenterCaptureState> {
class FusedYolov5GetCenterDistGrad : public OpExprGradFunction<FusedCenterCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }

Expand Down Expand Up @@ -55,8 +55,8 @@ class FusedCenterGrad : public OpExprGradFunction<FusedCenterCaptureState> {
const auto& b2_y2 = ctx->SavedTensors().at(7);

in_grads->resize(INPUT_LEN);
auto result = JUST(functional::FusedCenterGrad(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1,
b2_y2, rho2_diff));
auto result = JUST(functional::FusedYolov5GetCenterDistGrad(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1,
b1_y2, b2_y1, b2_y2, rho2_diff));

CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN);
for (int i = 0; i < INPUT_LEN; i++) {
Expand All @@ -66,7 +66,7 @@ class FusedCenterGrad : public OpExprGradFunction<FusedCenterCaptureState> {
}
};

REGISTER_OP_EXPR_GRAD_FUNCTION("fused_get_center_dist", FusedCenterGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_yolov5_get_center_dist", FusedYolov5GetCenterDistGrad);

} // namespace one
} // namespace oneflow
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ struct FusedCiouAngleCaptureState : public AutoGradCaptureState {
float eps = 1e-8;
};

class FusedGetCiouDiagonalAngleGrad : public OpExprGradFunction<FusedCiouAngleCaptureState> {
class FusedYolov5GetCiouDiagonalAngleGrad : public OpExprGradFunction<FusedCiouAngleCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }

Expand Down Expand Up @@ -56,7 +56,8 @@ class FusedGetCiouDiagonalAngleGrad : public OpExprGradFunction<FusedCiouAngleCa
const auto& w2 = ctx->SavedTensors().at(2);
const auto& h2 = ctx->SavedTensors().at(3);

auto result = JUST(functional::FusedGetCiouDiagonalAngleGrad(w1, h1, w2, h2, v_diff, ctx->eps));
auto result =
JUST(functional::FusedYolov5GetCiouDiagonalAngleGrad(w1, h1, w2, h2, v_diff, ctx->eps));
CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN);

in_grads->resize(INPUT_LEN);
Expand All @@ -67,7 +68,8 @@ class FusedGetCiouDiagonalAngleGrad : public OpExprGradFunction<FusedCiouAngleCa
}
};

REGISTER_OP_EXPR_GRAD_FUNCTION("fused_get_ciou_diagonal_angle", FusedGetCiouDiagonalAngleGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_yolov5_get_ciou_diagonal_angle",
FusedYolov5GetCiouDiagonalAngleGrad);

} // namespace one
} // namespace oneflow
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ struct FusedGetCiouResultGradCaptureState : public AutoGradCaptureState {
bool c2_requires_grad = false;
};

class FusedGetCiouResultGrad : public OpExprGradFunction<FusedGetCiouResultGradCaptureState> {
class FusedYolov5GetCiouResultGrad : public OpExprGradFunction<FusedGetCiouResultGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }

Expand Down Expand Up @@ -60,7 +60,7 @@ class FusedGetCiouResultGrad : public OpExprGradFunction<FusedGetCiouResultGradC
const auto& c2 = saved_tensors.at(2);

in_grads->resize(4);
auto result = JUST(functional::FusedGetCiouResultGrad(dy, alpha, rho2, c2));
auto result = JUST(functional::FusedYolov5GetCiouResultGrad(dy, alpha, rho2, c2));
CHECK_EQ_OR_RETURN(result->size(), 4);
if (ctx->v_requires_grad && ctx->iou_requires_grad && ctx->rho2_requires_grad
&& ctx->c2_requires_grad) {
Expand All @@ -76,7 +76,7 @@ class FusedGetCiouResultGrad : public OpExprGradFunction<FusedGetCiouResultGradC
AttrMap base_attrs_;
};

REGISTER_OP_EXPR_GRAD_FUNCTION("fused_get_ciou_result", FusedGetCiouResultGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_yolov5_get_ciou_result", FusedYolov5GetCiouResultGrad);

} // namespace one
} // namespace oneflow
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ struct FusedGetConvexDiagonalSquaredCaptureState : public AutoGradCaptureState {
float eps = 1e-8;
};

class FusedGetConvexDiagonalSquaredGrad
class FusedYolov5GetConvexDiagonalSquaredGrad
: public OpExprGradFunction<FusedGetConvexDiagonalSquaredCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
Expand Down Expand Up @@ -64,7 +64,7 @@ class FusedGetConvexDiagonalSquaredGrad
const auto& b2_y2 = ctx->SavedTensors().at(7);

in_grads->resize(INPUT_LEN);
auto result = JUST(functional::FusedGetConvexDiagonalSquaredGrad(
auto result = JUST(functional::FusedYolov5GetConvexDiagonalSquaredGrad(
c2_diff, b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2, ctx->eps));

CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN);
Expand All @@ -78,8 +78,8 @@ class FusedGetConvexDiagonalSquaredGrad
AttrMap base_attrs_;
};

REGISTER_OP_EXPR_GRAD_FUNCTION("fused_get_convex_diagonal_squared",
FusedGetConvexDiagonalSquaredGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_yolov5_get_convex_diagonal_squared",
FusedYolov5GetConvexDiagonalSquaredGrad);

} // namespace one
} // namespace oneflow
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct FusedGetIntersectionAreaCaptureState : public AutoGradCaptureState {
std::vector<bool> requires_grad;
};

class FusedGetIntersectionAreaGrad
class FusedYolov5GetIntersectionAreaGrad
: public OpExprGradFunction<FusedGetIntersectionAreaCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Expand Down Expand Up @@ -56,8 +56,8 @@ class FusedGetIntersectionAreaGrad
const auto& b2_y2 = ctx->SavedTensors().at(7);

in_grads->resize(INPUT_LEN);
auto result = JUST(functional::FusedGetIntersectionAreaGrad(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1,
b1_y2, b2_y1, b2_y2, rho2_diff));
auto result = JUST(functional::FusedYolov5GetIntersectionAreaGrad(
b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2, rho2_diff));

CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN);
for (int i = 0; i < INPUT_LEN; i++) {
Expand All @@ -67,7 +67,8 @@ class FusedGetIntersectionAreaGrad
}
};

REGISTER_OP_EXPR_GRAD_FUNCTION("fused_get_intersection_area", FusedGetIntersectionAreaGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_yolov5_get_intersection_area",
FusedYolov5GetIntersectionAreaGrad);

} // namespace one
} // namespace oneflow
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ struct FusedGetIouGradCaptureState : public AutoGradCaptureState {
float eps = 1e-8;
};

class FusedGetIouGrad : public OpExprGradFunction<FusedGetIouGradCaptureState> {
class FusedYolov5GetIouGrad : public OpExprGradFunction<FusedGetIouGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
Expand Down Expand Up @@ -67,7 +67,7 @@ class FusedGetIouGrad : public OpExprGradFunction<FusedGetIouGradCaptureState> {
const auto& inter = saved_tensors.at(4);

in_grads->resize(5);
auto result = JUST(functional::FusedGetIouGrad(diou, w1, h1, w2, h2, inter, ctx->eps));
auto result = JUST(functional::FusedYolov5GetIouGrad(diou, w1, h1, w2, h2, inter, ctx->eps));
CHECK_EQ_OR_RETURN(result->size(), 3);
if (ctx->requires_grad) {
in_grads->at(0) = result->at(0);
Expand All @@ -81,7 +81,7 @@ class FusedGetIouGrad : public OpExprGradFunction<FusedGetIouGradCaptureState> {
AttrMap base_attrs_;
};

REGISTER_OP_EXPR_GRAD_FUNCTION("fused_get_iou", FusedGetIouGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_yolov5_get_iou", FusedYolov5GetIouGrad);

} // namespace one
} // namespace oneflow
84 changes: 44 additions & 40 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1670,30 +1670,6 @@
Bool align_corners=False, Int64List[3] output_size=None, String data_format="channels_first") => UpsampleTrilinear3DGrad'
bind_python: False

- name: "fused_get_boundding_boxes_coord"
signature: "TensorTuple (Tensor x1, Tensor y1, Tensor w1, Tensor h1, Tensor x2, Tensor y2, Tensor w2, Tensor h2) => FusedGetBounddingBoxesCoord"
bind_python: True

- name: "fused_get_boundding_boxes_coord_grad"
signature: "TensorTuple (Tensor b1_x1_diff, Tensor b1_x2_diff, Tensor b1_y1_diff, Tensor b1_y2_diff, Tensor b2_x1_diff, Tensor b2_x2_diff, Tensor b2_y1_diff, Tensor b2_y2_diff) => FusedGetBounddingBoxesCoordGrad"
bind_python: False

- name: "fused_get_ciou_result"
signature: "TensorTuple (Tensor v, Tensor iou, Tensor rho2, Tensor c2, Float eps) => FusedGetCiouResult"
bind_python: True

- name: "fused_get_ciou_result_grad"
signature: "TensorTuple (Tensor dy ,Tensor alpha, Tensor rho2, Tensor c2) => FusedGetCiouResultGrad"
bind_python: False

- name: "fused_get_iou"
signature: "Tensor (Tensor w1, Tensor h1, Tensor w2, Tensor h2, Tensor inter, Float eps) => FusedGetIou"
bind_python: True

- name: "fused_get_iou_grad"
signature: "TensorTuple (Tensor diou, Tensor w1, Tensor h1, Tensor w2, Tensor h2, Tensor inter, Float eps) => FusedGetIouGrad"
bind_python: False

- name: "abs"
signature: "Tensor (Tensor x) => Abs"
bind_python: True
Expand Down Expand Up @@ -2482,38 +2458,66 @@
signature: "Tensor (Tensor query, Tensor key, Tensor value, Int64 num_heads, Bool causal=False, Int64 query_hidden_slice_start=0, Int64 query_hidden_slice_end=-1, Int64 key_hidden_slice_start=0, Int64 key_hidden_slice_end=-1, Int64 value_hidden_slice_start=0, Int64 value_hidden_slice_end=-1) => FusedMultiHeadAttentionInference"
bind_python: True

- name: "fused_get_center_dist"
signature: "Tensor (Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2) => FusedCenter"
- name: "fused_yolov5_get_boundding_boxes_coord"
signature: "TensorTuple (Tensor x1, Tensor y1, Tensor w1, Tensor h1, Tensor x2, Tensor y2, Tensor w2, Tensor h2) => FusedYolov5GetBounddingBoxesCoord"
bind_python: True

- name: "fused_yolov5_get_boundding_boxes_coord_grad"
signature: "TensorTuple (Tensor b1_x1_diff, Tensor b1_x2_diff, Tensor b1_y1_diff, Tensor b1_y2_diff, Tensor b2_x1_diff, Tensor b2_x2_diff, Tensor b2_y1_diff, Tensor b2_y2_diff) => FusedYolov5GetBounddingBoxesCoordGrad"
bind_python: False

- name: "fused_yolov5_get_ciou_result"
signature: "TensorTuple (Tensor v, Tensor iou, Tensor rho2, Tensor c2, Float eps) => FusedYolov5GetCiouResult"
bind_python: True

- name: "fused_yolov5_get_ciou_result_grad"
signature: "TensorTuple (Tensor dy ,Tensor alpha, Tensor rho2, Tensor c2) => FusedYolov5GetCiouResultGrad"
bind_python: False

- name: "fused_yolov5_get_iou"
signature: "Tensor (Tensor w1, Tensor h1, Tensor w2, Tensor h2, Tensor inter, Float eps) => FusedYolov5GetIou"
bind_python: True

- name: "fused_get_center_dist_grad"
signature: "TensorTuple (Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2, Tensor rho2_diff) => FusedCenterGrad"
- name: "fused_yolov5_get_iou_grad"
signature: "TensorTuple (Tensor diou, Tensor w1, Tensor h1, Tensor w2, Tensor h2, Tensor inter, Float eps) => FusedYolov5GetIouGrad"
bind_python: False

- name: "fused_get_intersection_area"
signature: "Tensor (Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2) => FusedGetIntersectionArea"
- name: "fused_yolov5_get_center_dist"
signature: "Tensor (Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2) => FusedYolov5GetCenterDist"
bind_python: True

- name: "fused_get_intersection_area_grad"
signature: "TensorTuple (Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2, Tensor inter_diff) => FusedGetIntersectionAreaGrad"
- name: "fused_yolov5_get_center_dist_grad"
signature: "TensorTuple (Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2, Tensor rho2_diff) => FusedYolov5GetCenterDistGrad"
bind_python: False

- name: "fused_get_ciou_diagonal_angle"
signature: "Tensor (Tensor w1, Tensor h1, Tensor w2, Tensor h2, Float eps) => FusedGetCiouDiagonalAngle"
- name: "fused_yolov5_get_intersection_area"
signature: "Tensor (Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2) => FusedYolov5GetIntersectionArea"
bind_python: True

- name: "fused_get_ciou_diagonal_angle_grad"
signature: "TensorTuple (Tensor w1, Tensor h1, Tensor w2, Tensor h2, Tensor v_diff, Float eps) => FusedGetCiouDiagonalAngleGrad"
- name: "fused_yolov5_get_intersection_area_grad"
signature: "TensorTuple (Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2, Tensor inter_diff) => FusedYolov5GetIntersectionAreaGrad"
bind_python: False

- name: "fused_get_convex_diagonal_squared"
signature: "Tensor (Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2, Float eps) => FusedGetConvexDiagonalSquared"
- name: "fused_yolov5_get_ciou_diagonal_angle"
signature: "Tensor (Tensor w1, Tensor h1, Tensor w2, Tensor h2, Float eps) => FusedYolov5GetCiouDiagonalAngle"
bind_python: True

- name: "fused_get_convex_diagonal_squared_grad"
signature: "TensorTuple (Tensor c2_diff, Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2, Float eps) => FusedGetConvexDiagonalSquaredGrad"
- name: "fused_yolov5_get_ciou_diagonal_angle_grad"
signature: "TensorTuple (Tensor w1, Tensor h1, Tensor w2, Tensor h2, Tensor v_diff, Float eps) => FusedYolov5GetCiouDiagonalAngleGrad"
bind_python: False

- name: "fused_yolov5_get_convex_diagonal_squared"
signature: "Tensor (Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2, Float eps) => FusedYolov5GetConvexDiagonalSquared"
bind_python: True

- name: "fused_yolov5_get_convex_diagonal_squared_grad"
signature: "TensorTuple (Tensor c2_diff, Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2, Float eps) => FusedYolov5GetConvexDiagonalSquaredGrad"
bind_python: False

- name: "fused_yolov5_get_target_offsets"
signature: "Tensor (Tensor gxy, Tensor gxi, Float g) => FusedYolov5GetTargetOffsets"
bind_python: True

- name: "grouped_matmul_bias"
signature: "TensorTuple (TensorTuple xs, TensorTuple weights, TensorTuple biases) => GroupedMatmulBias"
bind_python: True
Expand Down
Loading