Skip to content

Commit

Permalink
Factor out AssignIdsToCustomAggregatorOps.
Browse files Browse the repository at this point in the history
Implements `AssignIdsToCustomAggregatorOps` in cpp, replicating the python implementation of `assign_ids_to_custom_aggregator_ops`.
Accordingly, this change removes `assign_ids_to_custom_aggregator_ops` from `py_function_lib`.

PiperOrigin-RevId: 589349801
  • Loading branch information
dansuh17 authored and tensorflower-gardener committed Dec 9, 2023
1 parent 6812199 commit 4a617a0
Show file tree
Hide file tree
Showing 17 changed files with 212 additions and 114 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable")

package(
Expand Down Expand Up @@ -33,3 +34,28 @@ cc_library(
"@com_google_absl//absl/strings:str_format",
],
)

cc_library(
name = "assign_ids",
srcs = ["assign_ids.cc"],
hdrs = ["assign_ids.h"],
compatible_with = get_compatible_with_portable(),
deps = [
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:graph_def",
"//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
],
)

tf_cc_test(
name = "assign_ids_test",
srcs = ["assign_ids_test.cc"],
deps = [
":assign_ids",
"//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton_impl",
"//tensorflow/core:protos_all_cc",
"@com_google_googletest//:gtest_main",
"@local_tsl//tsl/platform:protobuf",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/assign_ids.h"

#include <cstdint>

#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h"
#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/graph.pb.h"

namespace stablehlo::quantization {
namespace {

using ::tensorflow::GraphDef;
using ::tensorflow::NodeDef;
using ::tensorflow::calibrator::CalibratorSingleton;

} // namespace

void AssignIdsToCustomAggregatorOps(GraphDef& graph_def) {
MutateNodeDefs(graph_def, [](NodeDef& node_def) {
if (node_def.op() == "CustomAggregator") {
const int64_t new_id = CalibratorSingleton::IssueNewId();
(*node_def.mutable_attr())["id"].set_s(absl::StrCat(new_id));
}
});
}

} // namespace stablehlo::quantization
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_ASSIGN_IDS_H_
#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_ASSIGN_IDS_H_

#include "tensorflow/core/framework/graph.pb.h"

namespace stablehlo::quantization {

// Assigns unique ids to each CustomAggregator op found in `graph_def`. The
// ids are set to the `id` attribute. The ids are used during the calibration
// step to identify the collected quantization statistics for each
// CustsomAggregator op.
void AssignIdsToCustomAggregatorOps(tensorflow::GraphDef& graph_def);

} // namespace stablehlo::quantization

#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_ASSIGN_IDS_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/assign_ids.h"

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tsl/platform/protobuf.h" // IWYU pragma: keep for tsl::protobuf

namespace stablehlo::quantization {
namespace {

using ::tensorflow::GraphDef;
using ::testing::IsEmpty;
using ::testing::Not;
using ::testing::SizeIs;
using ::tsl::protobuf::TextFormat;

TEST(AssignIdsTest, IdsAddedToCustomAggregatorOps) {
GraphDef graph_def;
ASSERT_TRUE(TextFormat::ParseFromString(
R"pb(
node { op: "CustomAggregator" name: "foo" }
)pb",
&graph_def));

AssignIdsToCustomAggregatorOps(graph_def);

ASSERT_THAT(graph_def.node(), SizeIs(1));
EXPECT_TRUE(graph_def.node()[0].attr().contains("id"));
EXPECT_THAT(graph_def.node()[0].attr().at("id").s(), Not(IsEmpty()));
}

TEST(AssignIdsTest, IdsNotAddedForNonCustomAggregatorOps) {
GraphDef graph_def;
ASSERT_TRUE(TextFormat::ParseFromString(
R"pb(
node { op: "NotCustomAggregator" name: "bar" }
)pb",
&graph_def));

AssignIdsToCustomAggregatorOps(graph_def);

ASSERT_THAT(graph_def.node(), SizeIs(1));
EXPECT_FALSE(graph_def.node()[0].attr().contains("id"));
}

} // namespace
} // namespace stablehlo::quantization
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ tf_python_pybind_extension(
deps = [
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:debugger",
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:io",
"//tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration:assign_ids",
"//tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration:statistics",
"//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc",
"//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil // IWYU pragma: keep
#include "pybind11_abseil/import_status_module.h" // from @pybind11_abseil
#include "pybind11_abseil/status_casters.h" // from @pybind11_abseil // IWYU pragma: keep
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/assign_ids.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h"
Expand All @@ -47,6 +48,7 @@ namespace py = pybind11;
namespace {

using ::stablehlo::quantization::AddCalibrationStatistics;
using ::stablehlo::quantization::AssignIdsToCustomAggregatorOps;
using ::stablehlo::quantization::EnableDebugging;
using ::stablehlo::quantization::io::CreateTmpDir;
using ::tensorflow::SignatureDef;
Expand Down Expand Up @@ -81,14 +83,13 @@ PYBIND11_MODULE(pywrap_quantization, m) {
tags.insert(quantization_options.tags().begin(),
quantization_options.tags().end());

const absl::StatusOr<ExportedModel> exported_model =
absl::StatusOr<ExportedModel> exported_model =
QuantizePtqModelPreCalibration(src_saved_model_path, signature_keys,
tags, quantization_options,
function_aliases);
if (!exported_model.ok()) return exported_model.status();

ExportedModel exported_model_for_calibration =
py_function_library.AssignIdsToCustomAggregatorOps(*exported_model);
AssignIdsToCustomAggregatorOps(*exported_model->mutable_graph_def());

const absl::StatusOr<std::string> precalibrated_saved_model_dir =
CreateTmpDir();
Expand All @@ -99,7 +100,7 @@ PYBIND11_MODULE(pywrap_quantization, m) {
}

py_function_library.SaveExportedModel(
*precalibrated_saved_model_dir, exported_model_for_calibration,
*precalibrated_saved_model_dir, *exported_model,
src_saved_model_path, tags, signature_def_map);

py_function_library.RunCalibration(
Expand All @@ -109,7 +110,7 @@ PYBIND11_MODULE(pywrap_quantization, m) {
representative_dataset);

if (absl::Status status = AddCalibrationStatistics(
*exported_model_for_calibration.mutable_graph_def(),
*exported_model->mutable_graph_def(),
quantization_options.calibration_options(),
py_function_library);
!status.ok()) {
Expand All @@ -119,7 +120,7 @@ PYBIND11_MODULE(pywrap_quantization, m) {
}

if (quantization_options.has_debugger_options()) {
EnableDebugging(exported_model_for_calibration,
EnableDebugging(*exported_model,
quantization_options.debugger_options(),
py_function_library, src_saved_model_path, tags,
signature_def_map);
Expand All @@ -134,13 +135,13 @@ PYBIND11_MODULE(pywrap_quantization, m) {
}

py_function_library.SaveExportedModel(
*calibrated_saved_model_path, exported_model_for_calibration,
src_saved_model_path, tags, signature_def_map);
*calibrated_saved_model_path, *exported_model, src_saved_model_path,
tags, signature_def_map);

const absl::flat_hash_map<std::string, std::string>
function_aliases_after_calibration(
exported_model_for_calibration.function_aliases().begin(),
exported_model_for_calibration.function_aliases().end());
exported_model->function_aliases().begin(),
exported_model->function_aliases().end());

const absl::StatusOr<ExportedModel> post_calibrated_exported_model =
QuantizePtqModelPostCalibration(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ cc_library(
srcs = ["calibrator_singleton.cc"],
hdrs = ["calibrator_singleton.h"],
compatible_with = get_compatible_with_portable(),
visibility = ["//visibility:private"],
deps = [
":calibration_statistics_collector_average_min_max",
":calibration_statistics_collector_base",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h"

#include <cstdint>
#include <memory>
#include <optional>
#include <string>
Expand Down Expand Up @@ -107,6 +108,11 @@ std::optional<CalibrationStatistics> CalibratorSingleton::GetStatistics(
return instance.id_to_collector_[id_str]->GetStatistics();
}

int64_t CalibratorSingleton::IssueNewId() {
CalibratorSingleton& instance = GetInstance();
return instance.next_id_++;
}

void CalibratorSingleton::AssignIfNotExists(
std::string id_str, const CalibrationOptions& calib_opts) {
CalibratorSingleton& instance = GetInstance();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATOR_SINGLETON_H_
#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATOR_SINGLETON_H_

#include <atomic>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
Expand Down Expand Up @@ -66,12 +68,20 @@ class CalibratorSingleton {
static std::optional<CalibrationStatistics> GetStatistics(
absl::string_view id);

// Issues a new node ID that uniquely identifies a set of calibration
// statistics.
static int64_t IssueNewId();

private:
static CalibratorSingleton& GetInstance();
static absl::Mutex lock_;
static void AssignIfNotExists(std::string id_str,
const CalibrationOptions& calib_opts);

// Indicates the next id for a set of calibration statistics. For every new ID
// issued this will be incremented atomically.
std::atomic<int64_t> next_id_{0};

absl::flat_hash_map<std::string,
std::unique_ptr<CalibrationStatisticsCollectorBase>>
id_to_collector_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h"

#include <cstdint>
#include <optional>
#include <vector>

Expand Down Expand Up @@ -201,6 +202,12 @@ TEST(CalibratorSingletonTest, SimpleAverageMinMax) {
EXPECT_EQ(statistics.value().average_min_max_statistics().num_samples(), 3);
}

TEST(CalibratorSingletonTest, IssueNewIdGeneratesNewId) {
const int64_t id = CalibratorSingleton::IssueNewId();
const int64_t next_id = CalibratorSingleton::IssueNewId();
EXPECT_NE(id, next_id);
}

} // namespace
} // namespace calibrator
} // namespace tensorflow
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ tf_python_pybind_extension(
":type_casters",
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:debugger",
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:io",
"//tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration:assign_ids",
"//tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration:statistics",
"//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc",
"//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,6 @@ class PyFunctionLibrary {
public:
virtual ~PyFunctionLibrary() = default;

// Assigns UUIDs to each CustomAggregator op found in each GraphDef in
// `exported_model`. The UUIDs are set to the `id` attributes. The UUIDs will
// be used during calibration step to identify the collected quantization
// statistics for each CustsomAggregator op.
//
// If the function signature changes, likely its corresponding .pyi type
// hinting and definition should also change.
// LINT.IfChange
virtual ExportedModel AssignIdsToCustomAggregatorOps(
const ExportedModel& exported_model) const = 0;
// LINT.ThenChange(
// pywrap_function_lib.pyi:assign_ids_to_custom_aggregator_ops,
// py_function_lib.py:assign_ids_to_custom_aggregator_ops,
// )

// Saves `exported_model` to `dst_saved_model_path` as SavedModel.
// `src_saved_model_path` is the path to the source SavedModel from which the
// exported model is produced. It is used to copy the asset files to
Expand Down
Loading

0 comments on commit 4a617a0

Please sign in to comment.