Skip to content

Commit

Permalink
gtest-ify JIT tests, through the letter c (pytorch#45249)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#45249

Reland of pytorch#45055 and
pytorch#45020

See pytorch#45018 for context.

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D23892645

Pulled By: suo

fbshipit-source-id: e7fe58d5e1a5a0c44f4e2aec9694145afabde0fd
  • Loading branch information
suo authored and facebook-github-bot committed Sep 24, 2020
1 parent 29dc3c5 commit 6d21d5f
Show file tree
Hide file tree
Showing 16 changed files with 641 additions and 634 deletions.
7 changes: 6 additions & 1 deletion test/cpp/jit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ set(JIT_TEST_ROOT ${TORCH_ROOT}/test/cpp/jit)

# Build separate libraries the define custom classes/operators used from our Python tests.
# These are intended to be used with torch.ops.load_library() in our Python test suite.
add_library(torchbind_test SHARED ${JIT_TEST_ROOT}/test_custom_class.cpp)
add_library(torchbind_test SHARED
${JIT_TEST_ROOT}/test_custom_class_registrations.h
${JIT_TEST_ROOT}/test_custom_class_registrations.cpp
)
target_link_libraries(torchbind_test torch)

add_library(jitbackend_test SHARED ${JIT_TEST_ROOT}/test_backend.cpp)
Expand Down Expand Up @@ -30,6 +33,8 @@ set(JIT_TEST_SRCS
${JIT_TEST_ROOT}/test_cleanup_passes.cpp
${JIT_TEST_ROOT}/test_create_autodiff_subgraphs.cpp
${JIT_TEST_ROOT}/test_custom_class.cpp
${JIT_TEST_ROOT}/test_custom_class_registrations.h
${JIT_TEST_ROOT}/test_custom_class_registrations.cpp
${JIT_TEST_ROOT}/test_custom_operators.cpp
${JIT_TEST_ROOT}/test_dce.cpp
${JIT_TEST_ROOT}/test_fuser.cpp
Expand Down
9 changes: 5 additions & 4 deletions test/cpp/jit/test_autodiff.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "test/cpp/jit/test_base.h"
#include <gtest/gtest.h>

#include "test/cpp/jit/test_utils.h"
#include "torch/csrc/jit/frontend/tracer.h"
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
Expand Down Expand Up @@ -83,7 +84,7 @@ variable_list grad(
fmap(inputs, get_edge));
}

void testADFormulas() {
TEST(AutodiffTest, ADFormulas) {
const auto cast = [](const Variable& v) {
return static_cast<at::Tensor>(v);
};
Expand Down Expand Up @@ -174,7 +175,7 @@ void testADFormulas() {
}
}

void testDifferentiate() {
TEST(AutodiffTest, Differentiate) {
// Note: can't use IRParser for this test due to issue #23989
auto graph = std::make_shared<Graph>();
std::vector<int64_t> sizes{2, 3, 4};
Expand Down Expand Up @@ -229,7 +230,7 @@ void testDifferentiate() {
->run(*grad_spec.df);
}

void testDifferentiateWithRequiresGrad() {
TEST(AutodiffTest, DifferentiateWithRequiresGrad) {
const auto graph_string = R"IR(
graph(%0 : Tensor,
%1 : Tensor):
Expand Down
12 changes: 6 additions & 6 deletions test/cpp/jit/test_class_import.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <test/cpp/jit/test_base.h>
#include <test/cpp/jit/test_utils.h>
#include <gtest/gtest.h>

#include <ATen/core/qualified_name.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/frontend/resolver.h>
#include <torch/csrc/jit/serialization/import_source.h>
#include <torch/torch.h>
Expand Down Expand Up @@ -45,7 +45,7 @@ static void import_libs(
si.loadType(QualifiedName(class_name));
}

void testClassImport() {
TEST(ClassImportTest, Basic) {
auto cu1 = std::make_shared<CompilationUnit>();
auto cu2 = std::make_shared<CompilationUnit>();
std::vector<at::IValue> constantTable;
Expand Down Expand Up @@ -80,7 +80,7 @@ void testClassImport() {
ASSERT_FALSE(c);
}

void testScriptObject() {
TEST(ClassImportTest, ScriptObject) {
Module m1("m1");
Module m2("m2");
std::vector<at::IValue> constantTable;
Expand Down Expand Up @@ -114,7 +114,7 @@ def __init__(self, x):
return x
)JIT";

void testClassDerive() {
TEST(ClassImportTest, ClassDerive) {
auto cu = std::make_shared<CompilationUnit>();
auto cls = ClassType::create("foo.bar", cu);
const auto self = SimpleSelf(cls);
Expand Down Expand Up @@ -142,7 +142,7 @@ class FooBar1234(Module):
return (self.f).top()
)JIT";

void testSaveLoadTorchbind() {
TEST(ClassImportTest, CustomClass) {
auto cu1 = std::make_shared<CompilationUnit>();
std::vector<at::IValue> constantTable;
// Import different versions of FooTest into two namespaces.
Expand Down
4 changes: 3 additions & 1 deletion test/cpp/jit/test_class_parser.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <gtest/gtest.h>

#include <test/cpp/jit/test_base.h>
#include <torch/csrc/jit/frontend/parser.h>
#include <torch/csrc/jit/frontend/resolver.h>
Expand All @@ -15,7 +17,7 @@ const auto testSource = R"JIT(
an_attribute : Tensor
)JIT";

void testClassParser() {
TEST(ClassParserTest, Basic) {
Parser p(std::make_shared<Source>(testSource));
std::vector<Def> definitions;
std::vector<Resolver> resolvers;
Expand Down
37 changes: 18 additions & 19 deletions test/cpp/jit/test_cleanup_passes.cpp
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
#include <gtest/gtest.h>

#include <torch/csrc/jit/frontend/ir_emitter.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/testing/file_check.h>
#include "test/cpp/jit/test_base.h"

namespace torch {
namespace jit {

void testCleanUpPasses() {
TEST(CleanupPassTest, Basic) {
// Tests stability of clean up passes when dealing with constant pooling
// and constant propagation.
{
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%cond.1 : Tensor,
%suffix.1 : str):
%3 : bool = aten::Bool(%cond.1) # o.py:6:7
Expand All @@ -31,20 +31,19 @@ graph(%cond.1 : Tensor,
-> (%12)
return (%25)
)IR",
&*graph);
runCleanupPasses(graph);
testing::FileCheck()
.check_count(
"prim::Constant[value=\"same string with a twist\"]",
1,
/*exactly=*/true)
->run(*graph);
&*graph);
runCleanupPasses(graph);
testing::FileCheck()
.check_count(
"prim::Constant[value=\"same string with a twist\"]",
1,
/*exactly=*/true)
->run(*graph);

auto graph_after_pass_once = graph->toString();
runCleanupPasses(graph);
auto graph_after_pass_twice = graph->toString();
ASSERT_EQ(graph_after_pass_once, graph_after_pass_twice);
}
auto graph_after_pass_once = graph->toString();
runCleanupPasses(graph);
auto graph_after_pass_twice = graph->toString();
ASSERT_EQ(graph_after_pass_once, graph_after_pass_twice);
}
} // namespace jit
} // namespace torch
50 changes: 24 additions & 26 deletions test/cpp/jit/test_code_template.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "test/cpp/jit/test_base.h"
#include "test/cpp/jit/test_utils.h"
#include <gtest/gtest.h>

#include "test/cpp/jit/test_utils.h"
#include "torch/csrc/jit/frontend/code_template.h"

namespace torch {
Expand Down Expand Up @@ -33,31 +33,29 @@ static const auto ct_expect = R"(
int notest(int a)
)";

void testCodeTemplate() {
{
TemplateEnv e;
e.s("hi", "foo");
e.v("what", {"is", "this"});
TemplateEnv c(e);
c.s("hi", "foo2");
ASSERT_EQ(e.s("hi"), "foo");
ASSERT_EQ(c.s("hi"), "foo2");
ASSERT_EQ(e.v("what")[0], "is");
}
TEST(TestCodeTemplate, Copying) {
TemplateEnv e;
e.s("hi", "foo");
e.v("what", {"is", "this"});
TemplateEnv c(e);
c.s("hi", "foo2");
ASSERT_EQ(e.s("hi"), "foo");
ASSERT_EQ(c.s("hi"), "foo2");
ASSERT_EQ(e.v("what")[0], "is");
}

{
TemplateEnv e;
e.v("args", {"hi", "8"});
e.v("bar", {"what\non many\nlines...", "7"});
e.s("a", "3");
e.s("b", "4");
e.v("stuff", {"things...", "others"});
e.v("empty", {});
auto s = ct.format(e);
// std::cout << "'" << s << "'\n";
// std::cout << "'" << ct_expect << "'\n";
ASSERT_EQ(s, ct_expect);
}
TEST(TestCodeTemplate, Formatting) {
TemplateEnv e;
e.v("args", {"hi", "8"});
e.v("bar", {"what\non many\nlines...", "7"});
e.s("a", "3");
e.s("b", "4");
e.v("stuff", {"things...", "others"});
e.v("empty", {});
auto s = ct.format(e);
// std::cout << "'" << s << "'\n";
// std::cout << "'" << ct_expect << "'\n";
ASSERT_EQ(s, ct_expect);
}

} // namespace jit
Expand Down
87 changes: 44 additions & 43 deletions test/cpp/jit/test_constant_pooling.cpp
Original file line number Diff line number Diff line change
@@ -1,36 +1,37 @@
#include <gtest/gtest.h>

#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/testing/file_check.h>
#include "test/cpp/jit/test_base.h"

#include <sstream>
#include <string>

namespace torch {
namespace jit {

void testConstantPooling() {
{
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
TEST(ConstantPoolingTest, Int) {
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph():
%8 : int = prim::Constant[value=1]()
%10 : int = prim::Constant[value=1]()
return (%8, %10)
)IR",
&*graph);
ConstantPooling(graph);
testing::FileCheck()
.check_count("prim::Constant", 1, /*exactly*/ true)
->run(*graph);
}
{
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
&*graph);
ConstantPooling(graph);
testing::FileCheck()
.check_count("prim::Constant", 1, /*exactly*/ true)
->run(*graph);
}

TEST(ConstantPoolingTest, PoolingAcrossBlocks) {
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%cond : Tensor):
%a : str = prim::Constant[value="bcd"]()
%3 : bool = aten::Bool(%cond)
Expand All @@ -44,17 +45,18 @@ graph(%cond : Tensor):
%7 : (str, str) = prim::TupleConstruct(%a, %b)
return (%7)
)IR",
&*graph);
ConstantPooling(graph);
testing::FileCheck()
.check_count("prim::Constant[value=\"abc\"]", 1, /*exactly*/ true)
->check_count("prim::Constant[value=\"bcd\"]", 1, /*exactly*/ true)
->run(*graph);
}
{
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
&*graph);
ConstantPooling(graph);
testing::FileCheck()
.check_count("prim::Constant[value=\"abc\"]", 1, /*exactly*/ true)
->check_count("prim::Constant[value=\"bcd\"]", 1, /*exactly*/ true)
->run(*graph);
}

TEST(ConstantPoolingTest, PoolingDifferentDevices) {
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph():
%2 : int = prim::Constant[value=2]()
%1 : int = prim::Constant[value=1]()
Expand All @@ -70,22 +72,21 @@ graph():
prim::Print(%x, %y, %z)
return (%1)
)IR",
&*graph);
// three tensors created - two different devices among the three
// don't have good support for parsing tensor constants
ConstantPropagation(graph);
ConstantPooling(graph);
testing::FileCheck()
.check_count(
"Float(2:1, requires_grad=0, device=cpu) = prim::Constant",
1,
/*exactly*/ true)
->check_count(
"Long(2:1, requires_grad=0, device=cpu) = prim::Constant",
1,
/*exactly*/ true)
->run(*graph);
}
&*graph);
// three tensors created - two different devices among the three
// don't have good support for parsing tensor constants
ConstantPropagation(graph);
ConstantPooling(graph);
testing::FileCheck()
.check_count(
"Float(2:1, requires_grad=0, device=cpu) = prim::Constant",
1,
/*exactly*/ true)
->check_count(
"Long(2:1, requires_grad=0, device=cpu) = prim::Constant",
1,
/*exactly*/ true)
->run(*graph);
}
} // namespace jit
} // namespace torch
5 changes: 3 additions & 2 deletions test/cpp/jit/test_create_autodiff_subgraphs.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#include "test/cpp/jit/test_base.h"
#include <gtest/gtest.h>

#include "test/cpp/jit/test_utils.h"

#include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"

namespace torch {
namespace jit {

void testCreateAutodiffSubgraphs() {
TEST(CreateAutodiffSubgraphsTest, Basic) {
auto graph = build_lstm();
CreateAutodiffSubgraphs(graph, /*threshold=*/2);
// all of the ops are within the DifferentiableGraph
Expand Down
Loading

0 comments on commit 6d21d5f

Please sign in to comment.