From 3e54bb249b05cb0024d656f34d145148f33c5199 Mon Sep 17 00:00:00 2001 From: mxi-box Date: Tue, 29 Mar 2022 23:42:33 +0800 Subject: [PATCH] support TF version >= 2.8 --- tensorflow_binding/src/ctc_op_kernel.cc | 2 +- tensorflow_binding/src/warpctc_op.cc | 12 ++++++++++-- tensorflow_binding/warpctc_tensorflow/__init__.py | 7 ------- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/tensorflow_binding/src/ctc_op_kernel.cc b/tensorflow_binding/src/ctc_op_kernel.cc index 9918da8..ed3000b 100644 --- a/tensorflow_binding/src/ctc_op_kernel.cc +++ b/tensorflow_binding/src/ctc_op_kernel.cc @@ -4,7 +4,7 @@ #endif #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/util/sparse/sparse_tensor.h" #include "ctc.h" diff --git a/tensorflow_binding/src/warpctc_op.cc b/tensorflow_binding/src/warpctc_op.cc index 6e4088f..3f6e276 100644 --- a/tensorflow_binding/src/warpctc_op.cc +++ b/tensorflow_binding/src/warpctc_op.cc @@ -5,8 +5,9 @@ #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/shape_inference.h" #include "ctc.h" @@ -17,7 +18,14 @@ REGISTER_OP("WarpCTC") .Input("input_lengths: int32") .Attr("blank_label: int = 0") .Output("costs: float32") - .Output("gradients: float32"); + .Output("gradients: float32") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + ::tensorflow::shape_inference::ShapeHandle input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &input)); + c->set_output(0, c->Vector(c->Dim(input, 1))); + c->set_output(1, input); + return ::tensorflow::Status::OK(); + }); namespace tf = tensorflow; diff --git a/tensorflow_binding/warpctc_tensorflow/__init__.py b/tensorflow_binding/warpctc_tensorflow/__init__.py index c407f77..8531a94 100644 --- a/tensorflow_binding/warpctc_tensorflow/__init__.py +++ b/tensorflow_binding/warpctc_tensorflow/__init__.py @@ -49,10 +49,3 @@ def _CTCLossGrad(op, grad_loss, _): grad = op.outputs[1] return [_BroadcastMul(grad_loss, grad), None, None, None] - -@ops.RegisterShape("WarpCTC") -def _CTCLossShape(op): - inputs_shape = op.inputs[0].get_shape().with_rank(3) - batch_size = inputs_shape[1] - return [batch_size, inputs_shape] -