1#include <gtest/gtest.h>
2
3#include <torch/csrc/jit/api/compilation_unit.h>
4#include <torch/csrc/jit/api/module.h>
5#include <torch/csrc/jit/passes/inliner.h>
6#include <torch/csrc/jit/testing/file_check.h>
7
8const auto testSource = R"JIT(
9def foo1(x):
10 print("one")
11 return x
12
13def foo2(x):
14 print("two")
15 return foo1(x)
16
17def foo3(x):
18 print("three")
19 return foo2(x)
20)JIT";
21
22namespace torch {
23namespace jit {
24using namespace testing;
25
26struct InlinerGuard {
27 explicit InlinerGuard(bool shouldInline)
28 : oldState_(getInlineEverythingMode()) {
29 getInlineEverythingMode() = shouldInline;
30 }
31
32 ~InlinerGuard() {
33 getInlineEverythingMode() = oldState_;
34 }
35
36 bool oldState_;
37};
38
39TEST(InlinerTest, Basic) {
40 // disable automatic inlining so we can test it manually
41 InlinerGuard guard(/*shouldInline=*/false);
42
43 CompilationUnit cu(testSource);
44 auto& fn = cu.get_function("foo3");
45
46 auto g = toGraphFunction(fn).graph();
47 Inline(*g);
48 FileCheck().check_count("prim::Print", 3)->run(*g);
49}
50} // namespace jit
51} // namespace torch
52