1#include <gtest/gtest.h>
2
3#include <ATen/ATen.h>
4#include <ATen/core/interned_strings.h>
5#include <ATen/core/ivalue.h>
6#include <c10/util/irange.h>
7
8#include <torch/csrc/autograd/engine.h>
9#include <torch/csrc/autograd/generated/variable_factories.h>
10#include <torch/csrc/autograd/variable.h>
11#include <torch/csrc/jit/api/module.h>
12#include <torch/csrc/jit/codegen/cuda/interface.h>
13#include <torch/csrc/jit/codegen/fuser/interface.h>
14#include <torch/csrc/jit/frontend/ir_emitter.h>
15#include <torch/csrc/jit/frontend/tracer.h>
16#include <torch/csrc/jit/ir/alias_analysis.h>
17#include <torch/csrc/jit/ir/attributes.h>
18#include <torch/csrc/jit/ir/irparser.h>
19#include <torch/csrc/jit/passes/canonicalize.h>
20#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
21#include <torch/csrc/jit/passes/constant_propagation.h>
22#include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
23#include <torch/csrc/jit/passes/dead_code_elimination.h>
24#include <torch/csrc/jit/passes/graph_fuser.h>
25#include <torch/csrc/jit/passes/lower_grad_of.h>
26#include <torch/csrc/jit/passes/lower_tuples.h>
27#include <torch/csrc/jit/passes/requires_grad_analysis.h>
28#include <torch/csrc/jit/passes/shape_analysis.h>
29#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
30#include <torch/csrc/jit/runtime/argument_spec.h>
31#include <torch/csrc/jit/runtime/autodiff.h>
32#include <torch/csrc/jit/runtime/custom_operator.h>
33#include <torch/csrc/jit/runtime/graph_executor.h>
34#include <torch/csrc/jit/runtime/interpreter.h>
35#include <torch/csrc/jit/runtime/symbolic_script.h>
36#include <torch/csrc/jit/serialization/import.h>
37#include <torch/csrc/jit/testing/file_check.h>
38
39#include <onnx/onnx_pb.h>
40
41#include <c10/util/Exception.h>
42
43#include <algorithm>
44#include <cstddef>
45#include <functional>
46#include <iostream>
47#include <memory>
48#include <stdexcept>
49#include <string>
50#include <tuple>
51#include <unordered_set>
52#include <utility>
53#include <vector>
54
55namespace torch {
56namespace jit {
57
58class FuserTest : public ::testing::Test {
59 void SetUp() override {
60 old_nvfuser_value_ = fuser::cuda::setEnabled(false);
61 }
62 void TearDown() override {
63 fuser::cuda::setEnabled(old_nvfuser_value_);
64 }
65
66 private:
67 bool old_nvfuser_value_;
68};
69
70TEST_F(FuserTest, TestSimple_CUDA) {
71#if defined(FBCODE_CAFFE2)
72 return;
73#endif
74 const auto graph_string = R"IR(
75 graph(%0 : Tensor,
76 %1 : Tensor):
77 %2 : Tensor = aten::mul(%0, %1)
78 return (%2))IR";
79 Graph graph;
80 torch::jit::parseIR(graph_string, &graph);
81
82 auto a = at::rand({3, 4}, at::kCUDA);
83 auto b = at::rand({4, 3}, at::kCUDA).transpose(0, 1);
84 auto o = at::zeros({3, 4}, at::kCUDA);
85 auto outputs = debugLaunchGraph(graph, {a, b});
86 ASSERT_EQ(outputs.size(), 1);
87 auto o2 = a * b;
88 float max_diff = (o2 - outputs[0]).abs().max().item<double>();
89 // std::cout << "max diff: " << max_diff << "\n";
90 ASSERT_EQ(max_diff, 0);
91}
92
93TEST_F(FuserTest, TestOne_CUDA) {
94#if defined(FBCODE_CAFFE2)
95 return;
96#endif
97 auto testOne = [&](int ti, int tj) {
98 const auto graph_string = R"IR(
99 graph(%0 : Tensor,
100 %1 : Tensor,
101 %2 : Tensor,
102 %3 : Tensor,
103 %4 : Tensor):
104 %5 : Tensor = aten::sigmoid(%4)
105 %6 : Tensor = aten::sigmoid(%3)
106 %7 : Tensor = aten::tanh(%2)
107 %8 : Tensor = aten::sigmoid(%1)
108 %9 : Tensor = aten::mul(%6, %0)
109 %10 : Tensor = aten::mul(%5, %7)
110 %11 : int = prim::Constant[value=1]()
111 %12 : Tensor = aten::add(%9, %10, %11)
112 %13 : Tensor = aten::tanh(%12)
113 %14 : Tensor = aten::mul(%8, %13)
114 return (%14, %12))IR";
115 Graph graph;
116 torch::jit::parseIR(graph_string, &graph);
117
118 graph.lint();
119
120 std::vector<at::Tensor> inputs;
121 // We want to generate input/output tensors with dimension 128x128x32, but
122 // with different internal strides. To do this, we generate a tensor
123 // with the "wrong" dimensions, and then use transpose to get an
124 // appropriately sized view.
125 std::generate_n(
126 std::back_inserter(inputs), graph.inputs().size(), [ti, tj] {
127 std::array<int64_t, 3> dims = {128, 128, 32};
128 std::swap(dims[ti], dims[tj]);
129 return at::rand(dims, at::kCUDA).transpose(ti, tj);
130 });
131
132 auto t22 = inputs[4].sigmoid();
133 auto t20 = inputs[3].sigmoid();
134 auto t18 = inputs[2].tanh();
135 auto t16 = inputs[1].sigmoid();
136 auto t14 = t20 * inputs[0];
137 auto t11 = t22 * t18;
138 auto out1 = t14 + t11;
139 auto t5 = out1.tanh();
140 auto out0 = t16 * t5;
141
142 auto outputs = debugLaunchGraph(graph, inputs);
143 ASSERT_EQ(outputs.size(), graph.outputs().size());
144 ASSERT_TRUE(out0.is_same_size(outputs.front()));
145 float max_diff = (outputs.front() - out0).abs().max().item<double>();
146 ASSERT_TRUE(max_diff < 1e-6);
147 };
148 testOne(0, 0);
149 testOne(0, 1);
150 testOne(1, 2);
151 testOne(0, 2);
152}
153
154TEST_F(FuserTest, FusedConcat_CUDA) {
155#if defined(FBCODE_CAFFE2)
156 return;
157#endif
158 const auto graph_string0 = R"IR(
159 graph(%0 : Tensor,
160 %1 : Tensor):
161 %2 : Tensor = aten::mul(%0, %1)
162 %3 : Tensor = prim::FusedConcat[dim=0](%0, %2)
163 return (%2, %3))IR";
164 const auto graph_string1 = R"IR(
165 graph(%0 : Tensor,
166 %1 : Tensor):
167 %2 : Tensor = aten::mul(%0, %1)
168 %3 : Tensor = prim::FusedConcat[dim=1](%0, %2)
169 return (%2, %3))IR";
170 const auto graph_string2 = R"IR(
171 graph(%0 : Tensor,
172 %1 : Tensor):
173 %2 : Tensor = aten::mul(%0, %1)
174 %3 : Tensor = prim::FusedConcat[dim=2](%0, %2)
175 return (%2, %3))IR";
176
177 auto a = at::rand({3, 4, 5}, at::kCUDA);
178 auto b = at::rand({4, 3, 5}, at::kCUDA).transpose(0, 1);
179 const auto o_r = a * b;
180
181 std::vector<std::string> graph_strings{
182 graph_string0, graph_string1, graph_string2};
183 for (const auto i : c10::irange(graph_strings.size())) {
184 Graph g;
185 torch::jit::parseIR(graph_strings[i], &g);
186
187 auto outputs = debugLaunchGraph(g, {a, b});
188 ASSERT_EQ(outputs.size(), 2);
189
190 float max_diff = (o_r - outputs[0]).abs().max().item<double>();
191 ASSERT_EQ(max_diff, 0);
192
193 const auto o2_r = at::cat({a, o_r}, i);
194 float max_diff2 = (o2_r - outputs[1]).abs().max().item<double>();
195 ASSERT_EQ(max_diff2, 0);
196 };
197}
198
199TEST_F(FuserTest, FusionAliasing) {
200#if defined(FBCODE_CAFFE2)
201 return;
202#endif
203 const auto graph_string = R"IR(
204 graph(%0 : Tensor,
205 %1 : Tensor):
206 %12 : int = prim::Constant[value=1]()
207 %2.1 : Tensor = aten::mul(%0, %1)
208 %2 : Tensor = aten::mul(%2.1, %1)
209 %3 : Tensor = aten::add_(%2, %1, %12)
210 %4 : Tensor = aten::mul(%2, %1)
211 %5 : Tensor = aten::add(%2, %4, %12)
212 return (%5))IR";
213 auto g = std::make_shared<Graph>();
214 torch::jit::parseIR(graph_string, g.get());
215
216 g->lint();
217 FuseGraph(g);
218
219 // We should not be able to fuse across the in-place operation here.
220 testing::FileCheck()
221 .check("prim::FusionGroup_0")
222 ->check("aten::add_")
223 ->check("prim::FusionGroup_1")
224 ->run(*g);
225}
226
227TEST_F(FuserTest, KernelCaching) {
228#if defined(FBCODE_CAFFE2)
229 return;
230#endif
231
232 // Constructs two functionally equivalent graphs
233 const auto graph0_string = R"IR(
234 graph(%0 : Float(2, 3, 4),
235 %1 : Float(2, 3, 4)):
236 %c0 : Float(2, 3, 4) = aten::mul(%0, %1)
237 %d0 : Float(2, 3, 4) = aten::mul(%c0, %0)
238 return (%d0))IR";
239 auto g0 = std::make_shared<Graph>();
240 torch::jit::parseIR(graph0_string, g0.get());
241
242 const auto graph1_string = R"IR(
243 graph(%0 : Float(2, 3, 4),
244 %1 : Float(2, 3, 4)):
245 %c1 : Float(2, 3, 4) = aten::mul(%0, %1)
246 %d1 : Float(2, 3, 4) = aten::mul(%c1, %0)
247 return (%d1))IR";
248 auto g1 = std::make_shared<Graph>();
249 torch::jit::parseIR(graph1_string, g1.get());
250
251 auto getFusionGroup = [](const std::shared_ptr<Graph>& graph) {
252 const auto& nodes = graph->nodes();
253 auto maybe_fusion_group =
254 std::find_if(nodes.begin(), nodes.end(), [](const Node* node) {
255 return node->kind() == prim::FusionGroup;
256 });
257 TORCH_CHECK(
258 maybe_fusion_group != nodes.end(),
259 "testRegisterFusionCachesKernel: could not create FusionGroup");
260 return *maybe_fusion_group;
261 };
262
263 // Creates two alpha-equivalent fusion groups
264 torch::jit::overrideCanFuseOnCPU(true);
265 FuseGraph(g0);
266 FuseGraph(g1);
267 torch::jit::overrideCanFuseOnCPU(false);
268 auto fg0 = getFusionGroup(g0);
269 auto fg1 = getFusionGroup(g1);
270
271 // Registers both with the fusion compiler.
272 auto expected_key = registerFusion(fg0);
273 auto second_key = registerFusion(fg1);
274
275 // Because the graphs are alpha-equivalent, they should return the same key
276 // and therefore share a KernelSpec to share kernels for specializations
277 ASSERT_EQ(second_key, expected_key);
278}
279} // namespace jit
280} // namespace torch
281