1 | #include <gtest/gtest.h> |
---|---|
2 | |
3 | #include <torch/csrc/jit/api/compilation_unit.h> |
4 | #include <torch/csrc/jit/api/module.h> |
5 | #include <torch/csrc/jit/passes/inliner.h> |
6 | #include <torch/csrc/jit/testing/file_check.h> |
7 | |
8 | const auto testSource = R"JIT( |
9 | def foo1(x): |
10 | print("one") |
11 | return x |
12 | |
13 | def foo2(x): |
14 | print("two") |
15 | return foo1(x) |
16 | |
17 | def foo3(x): |
18 | print("three") |
19 | return foo2(x) |
20 | )JIT"; |
21 | |
22 | namespace torch { |
23 | namespace jit { |
24 | using namespace testing; |
25 | |
26 | struct InlinerGuard { |
27 | explicit InlinerGuard(bool shouldInline) |
28 | : oldState_(getInlineEverythingMode()) { |
29 | getInlineEverythingMode() = shouldInline; |
30 | } |
31 | |
32 | ~InlinerGuard() { |
33 | getInlineEverythingMode() = oldState_; |
34 | } |
35 | |
36 | bool oldState_; |
37 | }; |
38 | |
39 | TEST(InlinerTest, Basic) { |
40 | // disable automatic inlining so we can test it manually |
41 | InlinerGuard guard(/*shouldInline=*/false); |
42 | |
43 | CompilationUnit cu(testSource); |
44 | auto& fn = cu.get_function("foo3"); |
45 | |
46 | auto g = toGraphFunction(fn).graph(); |
47 | Inline(*g); |
48 | FileCheck().check_count("prim::Print", 3)->run(*g); |
49 | } |
50 | } // namespace jit |
51 | } // namespace torch |
52 |