1 | #include <gtest/gtest.h> |
2 | |
3 | #include <ATen/ATen.h> |
4 | #include <ATen/core/interned_strings.h> |
5 | #include <ATen/core/ivalue.h> |
6 | #include <c10/util/irange.h> |
7 | |
8 | #include <torch/csrc/autograd/engine.h> |
9 | #include <torch/csrc/autograd/generated/variable_factories.h> |
10 | #include <torch/csrc/autograd/variable.h> |
11 | #include <torch/csrc/jit/api/module.h> |
12 | #include <torch/csrc/jit/codegen/cuda/interface.h> |
13 | #include <torch/csrc/jit/codegen/fuser/interface.h> |
14 | #include <torch/csrc/jit/frontend/ir_emitter.h> |
15 | #include <torch/csrc/jit/frontend/tracer.h> |
16 | #include <torch/csrc/jit/ir/alias_analysis.h> |
17 | #include <torch/csrc/jit/ir/attributes.h> |
18 | #include <torch/csrc/jit/ir/irparser.h> |
19 | #include <torch/csrc/jit/passes/canonicalize.h> |
20 | #include <torch/csrc/jit/passes/common_subexpression_elimination.h> |
21 | #include <torch/csrc/jit/passes/constant_propagation.h> |
22 | #include <torch/csrc/jit/passes/create_autodiff_subgraphs.h> |
23 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
24 | #include <torch/csrc/jit/passes/graph_fuser.h> |
25 | #include <torch/csrc/jit/passes/lower_grad_of.h> |
26 | #include <torch/csrc/jit/passes/lower_tuples.h> |
27 | #include <torch/csrc/jit/passes/requires_grad_analysis.h> |
28 | #include <torch/csrc/jit/passes/shape_analysis.h> |
29 | #include <torch/csrc/jit/passes/utils/subgraph_utils.h> |
30 | #include <torch/csrc/jit/runtime/argument_spec.h> |
31 | #include <torch/csrc/jit/runtime/autodiff.h> |
32 | #include <torch/csrc/jit/runtime/custom_operator.h> |
33 | #include <torch/csrc/jit/runtime/graph_executor.h> |
34 | #include <torch/csrc/jit/runtime/interpreter.h> |
35 | #include <torch/csrc/jit/runtime/symbolic_script.h> |
36 | #include <torch/csrc/jit/serialization/import.h> |
37 | #include <torch/csrc/jit/testing/file_check.h> |
38 | |
39 | #include <onnx/onnx_pb.h> |
40 | |
41 | #include <c10/util/Exception.h> |
42 | |
43 | #include <algorithm> |
44 | #include <cstddef> |
45 | #include <functional> |
46 | #include <iostream> |
47 | #include <memory> |
48 | #include <stdexcept> |
49 | #include <string> |
50 | #include <tuple> |
51 | #include <unordered_set> |
52 | #include <utility> |
53 | #include <vector> |
54 | |
55 | namespace torch { |
56 | namespace jit { |
57 | |
58 | class FuserTest : public ::testing::Test { |
59 | void SetUp() override { |
60 | old_nvfuser_value_ = fuser::cuda::setEnabled(false); |
61 | } |
62 | void TearDown() override { |
63 | fuser::cuda::setEnabled(old_nvfuser_value_); |
64 | } |
65 | |
66 | private: |
67 | bool old_nvfuser_value_; |
68 | }; |
69 | |
70 | TEST_F(FuserTest, TestSimple_CUDA) { |
71 | #if defined(FBCODE_CAFFE2) |
72 | return; |
73 | #endif |
74 | const auto graph_string = R"IR( |
75 | graph(%0 : Tensor, |
76 | %1 : Tensor): |
77 | %2 : Tensor = aten::mul(%0, %1) |
78 | return (%2))IR" ; |
79 | Graph graph; |
80 | torch::jit::parseIR(graph_string, &graph); |
81 | |
82 | auto a = at::rand({3, 4}, at::kCUDA); |
83 | auto b = at::rand({4, 3}, at::kCUDA).transpose(0, 1); |
84 | auto o = at::zeros({3, 4}, at::kCUDA); |
85 | auto outputs = debugLaunchGraph(graph, {a, b}); |
86 | ASSERT_EQ(outputs.size(), 1); |
87 | auto o2 = a * b; |
88 | float max_diff = (o2 - outputs[0]).abs().max().item<double>(); |
89 | // std::cout << "max diff: " << max_diff << "\n"; |
90 | ASSERT_EQ(max_diff, 0); |
91 | } |
92 | |
93 | TEST_F(FuserTest, TestOne_CUDA) { |
94 | #if defined(FBCODE_CAFFE2) |
95 | return; |
96 | #endif |
97 | auto testOne = [&](int ti, int tj) { |
98 | const auto graph_string = R"IR( |
99 | graph(%0 : Tensor, |
100 | %1 : Tensor, |
101 | %2 : Tensor, |
102 | %3 : Tensor, |
103 | %4 : Tensor): |
104 | %5 : Tensor = aten::sigmoid(%4) |
105 | %6 : Tensor = aten::sigmoid(%3) |
106 | %7 : Tensor = aten::tanh(%2) |
107 | %8 : Tensor = aten::sigmoid(%1) |
108 | %9 : Tensor = aten::mul(%6, %0) |
109 | %10 : Tensor = aten::mul(%5, %7) |
110 | %11 : int = prim::Constant[value=1]() |
111 | %12 : Tensor = aten::add(%9, %10, %11) |
112 | %13 : Tensor = aten::tanh(%12) |
113 | %14 : Tensor = aten::mul(%8, %13) |
114 | return (%14, %12))IR" ; |
115 | Graph graph; |
116 | torch::jit::parseIR(graph_string, &graph); |
117 | |
118 | graph.lint(); |
119 | |
120 | std::vector<at::Tensor> inputs; |
121 | // We want to generate input/output tensors with dimension 128x128x32, but |
122 | // with different internal strides. To do this, we generate a tensor |
123 | // with the "wrong" dimensions, and then use transpose to get an |
124 | // appropriately sized view. |
125 | std::generate_n( |
126 | std::back_inserter(inputs), graph.inputs().size(), [ti, tj] { |
127 | std::array<int64_t, 3> dims = {128, 128, 32}; |
128 | std::swap(dims[ti], dims[tj]); |
129 | return at::rand(dims, at::kCUDA).transpose(ti, tj); |
130 | }); |
131 | |
132 | auto t22 = inputs[4].sigmoid(); |
133 | auto t20 = inputs[3].sigmoid(); |
134 | auto t18 = inputs[2].tanh(); |
135 | auto t16 = inputs[1].sigmoid(); |
136 | auto t14 = t20 * inputs[0]; |
137 | auto t11 = t22 * t18; |
138 | auto out1 = t14 + t11; |
139 | auto t5 = out1.tanh(); |
140 | auto out0 = t16 * t5; |
141 | |
142 | auto outputs = debugLaunchGraph(graph, inputs); |
143 | ASSERT_EQ(outputs.size(), graph.outputs().size()); |
144 | ASSERT_TRUE(out0.is_same_size(outputs.front())); |
145 | float max_diff = (outputs.front() - out0).abs().max().item<double>(); |
146 | ASSERT_TRUE(max_diff < 1e-6); |
147 | }; |
148 | testOne(0, 0); |
149 | testOne(0, 1); |
150 | testOne(1, 2); |
151 | testOne(0, 2); |
152 | } |
153 | |
154 | TEST_F(FuserTest, FusedConcat_CUDA) { |
155 | #if defined(FBCODE_CAFFE2) |
156 | return; |
157 | #endif |
158 | const auto graph_string0 = R"IR( |
159 | graph(%0 : Tensor, |
160 | %1 : Tensor): |
161 | %2 : Tensor = aten::mul(%0, %1) |
162 | %3 : Tensor = prim::FusedConcat[dim=0](%0, %2) |
163 | return (%2, %3))IR" ; |
164 | const auto graph_string1 = R"IR( |
165 | graph(%0 : Tensor, |
166 | %1 : Tensor): |
167 | %2 : Tensor = aten::mul(%0, %1) |
168 | %3 : Tensor = prim::FusedConcat[dim=1](%0, %2) |
169 | return (%2, %3))IR" ; |
170 | const auto graph_string2 = R"IR( |
171 | graph(%0 : Tensor, |
172 | %1 : Tensor): |
173 | %2 : Tensor = aten::mul(%0, %1) |
174 | %3 : Tensor = prim::FusedConcat[dim=2](%0, %2) |
175 | return (%2, %3))IR" ; |
176 | |
177 | auto a = at::rand({3, 4, 5}, at::kCUDA); |
178 | auto b = at::rand({4, 3, 5}, at::kCUDA).transpose(0, 1); |
179 | const auto o_r = a * b; |
180 | |
181 | std::vector<std::string> graph_strings{ |
182 | graph_string0, graph_string1, graph_string2}; |
183 | for (const auto i : c10::irange(graph_strings.size())) { |
184 | Graph g; |
185 | torch::jit::parseIR(graph_strings[i], &g); |
186 | |
187 | auto outputs = debugLaunchGraph(g, {a, b}); |
188 | ASSERT_EQ(outputs.size(), 2); |
189 | |
190 | float max_diff = (o_r - outputs[0]).abs().max().item<double>(); |
191 | ASSERT_EQ(max_diff, 0); |
192 | |
193 | const auto o2_r = at::cat({a, o_r}, i); |
194 | float max_diff2 = (o2_r - outputs[1]).abs().max().item<double>(); |
195 | ASSERT_EQ(max_diff2, 0); |
196 | }; |
197 | } |
198 | |
199 | TEST_F(FuserTest, FusionAliasing) { |
200 | #if defined(FBCODE_CAFFE2) |
201 | return; |
202 | #endif |
203 | const auto graph_string = R"IR( |
204 | graph(%0 : Tensor, |
205 | %1 : Tensor): |
206 | %12 : int = prim::Constant[value=1]() |
207 | %2.1 : Tensor = aten::mul(%0, %1) |
208 | %2 : Tensor = aten::mul(%2.1, %1) |
209 | %3 : Tensor = aten::add_(%2, %1, %12) |
210 | %4 : Tensor = aten::mul(%2, %1) |
211 | %5 : Tensor = aten::add(%2, %4, %12) |
212 | return (%5))IR" ; |
213 | auto g = std::make_shared<Graph>(); |
214 | torch::jit::parseIR(graph_string, g.get()); |
215 | |
216 | g->lint(); |
217 | FuseGraph(g); |
218 | |
219 | // We should not be able to fuse across the in-place operation here. |
220 | testing::FileCheck() |
221 | .check("prim::FusionGroup_0" ) |
222 | ->check("aten::add_" ) |
223 | ->check("prim::FusionGroup_1" ) |
224 | ->run(*g); |
225 | } |
226 | |
227 | TEST_F(FuserTest, KernelCaching) { |
228 | #if defined(FBCODE_CAFFE2) |
229 | return; |
230 | #endif |
231 | |
232 | // Constructs two functionally equivalent graphs |
233 | const auto graph0_string = R"IR( |
234 | graph(%0 : Float(2, 3, 4), |
235 | %1 : Float(2, 3, 4)): |
236 | %c0 : Float(2, 3, 4) = aten::mul(%0, %1) |
237 | %d0 : Float(2, 3, 4) = aten::mul(%c0, %0) |
238 | return (%d0))IR" ; |
239 | auto g0 = std::make_shared<Graph>(); |
240 | torch::jit::parseIR(graph0_string, g0.get()); |
241 | |
242 | const auto graph1_string = R"IR( |
243 | graph(%0 : Float(2, 3, 4), |
244 | %1 : Float(2, 3, 4)): |
245 | %c1 : Float(2, 3, 4) = aten::mul(%0, %1) |
246 | %d1 : Float(2, 3, 4) = aten::mul(%c1, %0) |
247 | return (%d1))IR" ; |
248 | auto g1 = std::make_shared<Graph>(); |
249 | torch::jit::parseIR(graph1_string, g1.get()); |
250 | |
251 | auto getFusionGroup = [](const std::shared_ptr<Graph>& graph) { |
252 | const auto& nodes = graph->nodes(); |
253 | auto maybe_fusion_group = |
254 | std::find_if(nodes.begin(), nodes.end(), [](const Node* node) { |
255 | return node->kind() == prim::FusionGroup; |
256 | }); |
257 | TORCH_CHECK( |
258 | maybe_fusion_group != nodes.end(), |
259 | "testRegisterFusionCachesKernel: could not create FusionGroup" ); |
260 | return *maybe_fusion_group; |
261 | }; |
262 | |
263 | // Creates two alpha-equivalent fusion groups |
264 | torch::jit::overrideCanFuseOnCPU(true); |
265 | FuseGraph(g0); |
266 | FuseGraph(g1); |
267 | torch::jit::overrideCanFuseOnCPU(false); |
268 | auto fg0 = getFusionGroup(g0); |
269 | auto fg1 = getFusionGroup(g1); |
270 | |
271 | // Registers both with the fusion compiler. |
272 | auto expected_key = registerFusion(fg0); |
273 | auto second_key = registerFusion(fg1); |
274 | |
275 | // Because the graphs are alpha-equivalent, they should return the same key |
276 | // and therefore share a KernelSpec to share kernels for specializations |
277 | ASSERT_EQ(second_key, expected_key); |
278 | } |
279 | } // namespace jit |
280 | } // namespace torch |
281 | |