|
16 | 16 |
|
17 | 17 | #include "ngraph_builder.h"
|
18 | 18 | #include "ngraph/op/util/logical_reduction.hpp"
|
| 19 | +#include "ngraph_backend_manager.h" |
19 | 20 | #include "ngraph_conversions.h"
|
20 | 21 | #include "ngraph_log.h"
|
| 22 | +#include "ngraph_mark_for_clustering.h" |
21 | 23 | #include "ngraph_utils.h"
|
22 | 24 |
|
23 | 25 | #include "ngraph/builder/autobroadcast.hpp"
|
@@ -1914,6 +1916,90 @@ static Status TranslateFusedBatchNormGradOp(
|
1914 | 1916 | return Status::OK();
|
1915 | 1917 | }
|
1916 | 1918 |
|
| 1919 | +static Status TranslateGatherV2Op( |
| 1920 | + const Node* op, const std::vector<const Tensor*>& static_input_map, |
| 1921 | + Builder::OpMap& ng_op_map) { |
| 1922 | + shared_ptr<ng::Node> ng_input; |
| 1923 | + TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 0, &ng_input)); |
| 1924 | + |
| 1925 | + std::vector<int64> tf_indices; |
| 1926 | + TF_RETURN_IF_ERROR( |
| 1927 | + GetStaticInputVector(op, 1, static_input_map, &tf_indices)); |
| 1928 | + // It seems indices cannot be negative, so no need to handle that |
| 1929 | + std::vector<size_t> indices(tf_indices.size()); |
| 1930 | + std::transform(tf_indices.begin(), tf_indices.end(), indices.begin(), |
| 1931 | + [](int64 x) { return (size_t)(x); }); |
| 1932 | + |
| 1933 | + std::vector<int64> tf_axis; |
| 1934 | + TF_RETURN_IF_ERROR(GetStaticInputVector(op, 2, static_input_map, &tf_axis)); |
| 1935 | + |
| 1936 | + if (tf_axis.size() > 1) { |
| 1937 | + return errors::Internal("Found axis in GatherV2 op (", op->name(), |
| 1938 | + ") translation to be non scalar, of size ", |
| 1939 | + tf_axis.size()); |
| 1940 | + } |
| 1941 | + |
| 1942 | + std::string backend_name; |
| 1943 | + TF_RETURN_IF_ERROR(ngraph_bridge::GetNodeBackend(op, &backend_name)); |
| 1944 | + |
| 1945 | + if (backend_name != "NNPI") { |
| 1946 | + return errors::Internal("In translating GatherV2 op ", op->name(), |
| 1947 | + " found requested backend ", backend_name, |
| 1948 | + " which is unsupported"); |
| 1949 | + } |
| 1950 | + |
| 1951 | + ng::runtime::Backend* backend = BackendManager::GetBackend(backend_name); |
| 1952 | + auto coords = ng::Coordinate(indices); |
| 1953 | + // Negative axis is supported. Accounting for that |
| 1954 | + auto ng_input_shape = ng_input->get_shape(); |
| 1955 | + size_t ng_input_rank = ng_input_shape.size(); |
| 1956 | + size_t axis; |
| 1957 | + if (tf_axis[0] >= 0) { |
| 1958 | + axis = tf_axis[0]; |
| 1959 | + } else { |
| 1960 | + axis = tf_axis[0] + ng_input_rank; |
| 1961 | + } |
| 1962 | + if (axis < 0 || axis >= ng_input_rank) { |
| 1963 | + return errors::InvalidArgument("Expected axis in the range [-", |
| 1964 | + ng_input_rank, ", ", ng_input_rank, |
| 1965 | + "), but got ", tf_axis[0]); |
| 1966 | + } |
| 1967 | + |
| 1968 | + for (size_t indices_idx = 0; indices_idx < indices.size(); indices_idx++) { |
| 1969 | + if (indices[indices_idx] >= ng_input_shape[axis]) { |
| 1970 | + // TODO: this error returnign must be generalized when indices = vector of |
| 1971 | + // vectors is supported |
| 1972 | + return errors::InvalidArgument("indices[0,", indices_idx, "] = ", |
| 1973 | + indices[indices_idx], " is not in [0, ", |
| 1974 | + ng_input_shape[axis], ")"); |
| 1975 | + } |
| 1976 | + } |
| 1977 | + |
| 1978 | + vector<size_t> possibly_empty_node_size(ng_input_shape); |
| 1979 | + possibly_empty_node_size[axis] = indices.size(); |
| 1980 | + |
| 1981 | + if (std::any_of(possibly_empty_node_size.begin(), |
| 1982 | + possibly_empty_node_size.end(), |
| 1983 | + [](size_t x) { return x == 0; })) { |
| 1984 | + std::vector<std::string> const_values( |
| 1985 | + ng::shape_size(possibly_empty_node_size), "0"); |
| 1986 | + auto ng_empty = ConstructNgNode<ng::op::Constant>( |
| 1987 | + op->name(), ng_input->get_element_type(), |
| 1988 | + ng::Shape(possibly_empty_node_size), const_values); |
| 1989 | + SaveNgOp(ng_op_map, op->name(), ng_empty); |
| 1990 | + } else { |
| 1991 | + shared_ptr<ng::Node> ng_gather = |
| 1992 | + backend->get_backend_op("Gather", &ng_input, &coords, &axis); |
| 1993 | + if (ng_gather == nullptr) { |
| 1994 | + return errors::Internal("In translating GatherV2 op ", op->name(), |
| 1995 | + " backend could not return valid ngraph node"); |
| 1996 | + } |
| 1997 | + SaveNgOp(ng_op_map, op->name(), ng_gather); |
| 1998 | + } |
| 1999 | + |
| 2000 | + return Status::OK(); |
| 2001 | +} |
| 2002 | + |
1917 | 2003 | static Status TranslateFusedConv2DOp(
|
1918 | 2004 | const Node* op, const std::vector<const Tensor*>& static_input_map,
|
1919 | 2005 | Builder::OpMap& ng_op_map) {
|
@@ -4252,6 +4338,7 @@ const static std::map<
|
4252 | 4338 | {"FusedBatchNorm", TranslateFusedBatchNormOp},
|
4253 | 4339 | {"FusedBatchNormV2", TranslateFusedBatchNormOp},
|
4254 | 4340 | {"FusedBatchNormGrad", TranslateFusedBatchNormGradOp},
|
| 4341 | + {"GatherV2", TranslateGatherV2Op}, |
4255 | 4342 | {"_FusedConv2D", TranslateFusedConv2DOp},
|
4256 | 4343 | {"Greater", TranslateBinaryOp<ngraph::op::Greater>},
|
4257 | 4344 | {"GreaterEqual", TranslateBinaryOp<ngraph::op::GreaterEq>},
|
|
0 commit comments