diff --git a/PhysicsTools/TensorFlow/interface/TensorFlow.h b/PhysicsTools/TensorFlow/interface/TensorFlow.h index a61539c00413d..451b854ed2196 100644 --- a/PhysicsTools/TensorFlow/interface/TensorFlow.h +++ b/PhysicsTools/TensorFlow/interface/TensorFlow.h @@ -106,6 +106,8 @@ namespace tensorflow { // version of the function above that accepts a const session bool closeSession(const Session*& session); + bool checkEmptyInputs(const NamedTensorList& inputs); + // run the session with inputs and outputNames, store output tensors, and control the underlying // thread pool using threadPoolOptions // used for thread scheduling with custom thread pool options diff --git a/PhysicsTools/TensorFlow/src/TensorFlow.cc b/PhysicsTools/TensorFlow/src/TensorFlow.cc index fcb09e2e9c449..3e5cc01ea3e3b 100644 --- a/PhysicsTools/TensorFlow/src/TensorFlow.cc +++ b/PhysicsTools/TensorFlow/src/TensorFlow.cc @@ -256,6 +256,19 @@ namespace tensorflow { return state; } + bool checkEmptyInputs(const NamedTensorList& inputs) { + // check for empty tensors in the inputs + bool isEmpty = false; + for (const auto& input : inputs) { + // Checking using the shape + if (input.second.shape().num_elements() == 0) { + isEmpty = true; + break; + } + } + return isEmpty; + } + void run(Session* session, const NamedTensorList& inputs, const std::vector& outputNames, @@ -268,6 +281,10 @@ namespace tensorflow { // create empty run options RunOptions runOptions; + // Check if the inputs are empty + if (checkEmptyInputs(inputs)) + return; + // run and check the status Status status = session->Run(runOptions, inputs, outputNames, {}, outputs, nullptr, threadPoolOptions); if (!status.ok()) { diff --git a/PhysicsTools/TensorFlow/test/BuildFile.xml b/PhysicsTools/TensorFlow/test/BuildFile.xml index 03ca557c61619..b2cfafd6ff027 100644 --- a/PhysicsTools/TensorFlow/test/BuildFile.xml +++ b/PhysicsTools/TensorFlow/test/BuildFile.xml @@ -144,6 +144,13 @@ + + + + + + + diff --git a/PhysicsTools/TensorFlow/test/testEmptyInputs.cc b/PhysicsTools/TensorFlow/test/testEmptyInputs.cc new file mode 100644 index 0000000000000..7272f53045fcc --- /dev/null +++ b/PhysicsTools/TensorFlow/test/testEmptyInputs.cc @@ -0,0 +1,57 @@ +/* + * Tests for working with empty inputs + * + */ + +#include +#include + +#include "PhysicsTools/TensorFlow/interface/TensorFlow.h" + +#include "testBase.h" + +class testEmptyInputs : public testBase { + CPPUNIT_TEST_SUITE(testEmptyInputs); + CPPUNIT_TEST(test); + CPPUNIT_TEST_SUITE_END(); + +public: + std::string pyScript() const override; + void test() override; +}; + +CPPUNIT_TEST_SUITE_REGISTRATION(testEmptyInputs); + +std::string testEmptyInputs::pyScript() const { return "createconstantgraph.py"; } + +void testEmptyInputs::test() { + std::string pbFile = dataPath_ + "/constantgraph.pb"; + + std::cout << "Testing CPU backend" << std::endl; + tensorflow::Backend backend = tensorflow::Backend::cpu; + + // load the graph + tensorflow::Options options{backend}; + tensorflow::GraphDef* graphDef = tensorflow::loadGraphDef(pbFile); + CPPUNIT_ASSERT(graphDef != nullptr); + + // create a new session and add the graphDef + const tensorflow::Session* session = tensorflow::createSession(graphDef, options); + CPPUNIT_ASSERT(session != nullptr); + + // example evaluation with empty tensor + tensorflow::Tensor input(tensorflow::DT_FLOAT, {1, 0}); + tensorflow::Tensor scale(tensorflow::DT_FLOAT, {}); + scale.scalar()() = 1.0; + std::vector outputs; + + // run using the convenience helper + outputs.clear(); + tensorflow::run(session, {{"input", input}, {"scale", scale}}, {"output"}, &outputs); + CPPUNIT_ASSERT(outputs.size() == 0); + + // cleanup + CPPUNIT_ASSERT(tensorflow::closeSession(session)); + CPPUNIT_ASSERT(session == nullptr); + delete graphDef; +}