1 | #include <gmock/gmock.h> |
2 | #include <gtest/gtest.h> |
3 | |
4 | #include <ATen/ATen.h> |
5 | #include <ATen/Parallel.h> |
6 | #include <ATen/core/interned_strings.h> |
7 | #include <ATen/core/ivalue.h> |
8 | #include <ATen/core/jit_type_base.h> |
9 | #include <test/cpp/jit/test_utils.h> |
10 | #include <torch/csrc/jit/passes/remove_mutation.h> |
11 | #include <torch/csrc/jit/passes/tensorexpr_fuser.h> |
12 | #include <torch/csrc/jit/tensorexpr/kernel.h> |
13 | |
14 | #include <torch/csrc/autograd/engine.h> |
15 | #include <torch/csrc/autograd/generated/variable_factories.h> |
16 | #include <torch/csrc/autograd/profiler.h> |
17 | #include <torch/csrc/autograd/variable.h> |
18 | #include <torch/csrc/jit/api/function_impl.h> |
19 | #include <torch/csrc/jit/api/module.h> |
20 | #include <torch/csrc/jit/codegen/fuser/interface.h> |
21 | #include <torch/csrc/jit/frontend/ir_emitter.h> |
22 | #include <torch/csrc/jit/frontend/tracer.h> |
23 | #include <torch/csrc/jit/ir/alias_analysis.h> |
24 | #include <torch/csrc/jit/ir/attributes.h> |
25 | #include <torch/csrc/jit/ir/irparser.h> |
26 | #include <torch/csrc/jit/ir/scope.h> |
27 | #include <torch/csrc/jit/ir/type_hashing.h> |
28 | #include <torch/csrc/jit/jit_log.h> |
29 | #include <torch/csrc/jit/passes/bailout_graph.h> |
30 | #include <torch/csrc/jit/passes/canonicalize.h> |
31 | #include <torch/csrc/jit/passes/common_subexpression_elimination.h> |
32 | #include <torch/csrc/jit/passes/constant_propagation.h> |
33 | #include <torch/csrc/jit/passes/create_autodiff_subgraphs.h> |
34 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
35 | #include <torch/csrc/jit/passes/graph_fuser.h> |
36 | #include <torch/csrc/jit/passes/guard_elimination.h> |
37 | #include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h> |
38 | #include <torch/csrc/jit/passes/insert_guards.h> |
39 | #include <torch/csrc/jit/passes/liveness.h> |
40 | #include <torch/csrc/jit/passes/loop_unrolling.h> |
41 | #include <torch/csrc/jit/passes/lower_grad_of.h> |
42 | #include <torch/csrc/jit/passes/lower_tuples.h> |
43 | #include <torch/csrc/jit/passes/pass_manager.h> |
44 | #include <torch/csrc/jit/passes/requires_grad_analysis.h> |
45 | #include <torch/csrc/jit/passes/restore_mutation.h> |
46 | #include <torch/csrc/jit/passes/shape_analysis.h> |
47 | #include <torch/csrc/jit/passes/symbolic_shape_analysis.h> |
48 | #include <torch/csrc/jit/passes/utils/subgraph_utils.h> |
49 | #include <torch/csrc/jit/runtime/argument_spec.h> |
50 | #include <torch/csrc/jit/runtime/autodiff.h> |
51 | #include <torch/csrc/jit/runtime/custom_operator.h> |
52 | #include <torch/csrc/jit/runtime/decomposition_registry.h> |
53 | #include <torch/csrc/jit/runtime/graph_executor.h> |
54 | #include <torch/csrc/jit/runtime/interpreter.h> |
55 | #include <torch/csrc/jit/runtime/jit_trace.h> |
56 | #include <torch/csrc/jit/runtime/profiling_record.h> |
57 | #include <torch/csrc/jit/runtime/symbolic_script.h> |
58 | #include <torch/csrc/jit/runtime/symbolic_shape_registry.h> |
59 | #include <torch/csrc/jit/serialization/import.h> |
60 | #include <torch/csrc/jit/testing/file_check.h> |
61 | #include <torch/jit.h> |
62 | #include <torch/script.h> |
63 | |
64 | #include <onnx/onnx_pb.h> |
65 | |
66 | #include <c10/util/Exception.h> |
67 | #include <c10/util/ThreadLocalDebugInfo.h> |
68 | |
69 | #include <torch/csrc/jit/passes/freeze_module.h> |
70 | #include <torch/csrc/jit/passes/frozen_graph_optimizations.h> |
71 | #include <algorithm> |
72 | #include <cstddef> |
73 | #include <functional> |
74 | #include <iostream> |
75 | #include <memory> |
76 | #include <set> |
77 | #include <stdexcept> |
78 | #include <string> |
79 | #include <tuple> |
80 | #include <unordered_map> |
81 | #include <unordered_set> |
82 | #include <utility> |
83 | #include <vector> |
84 | |
85 | namespace torch { |
86 | namespace jit { |
87 | inline c10::AliasAnalysisKind aliasAnalysisFromSchema() { |
88 | return c10::AliasAnalysisKind::FROM_SCHEMA; |
89 | } |
90 | |
91 | template <typename T> |
92 | std::ostream& operator<<(std::ostream& out, const std::vector<T>& list) { |
93 | size_t i = 0; |
94 | out << "{" ; |
95 | for (auto&& e : list) { |
96 | if (i++ > 0) |
97 | out << ", " ; |
98 | out << e; |
99 | } |
100 | out << "}" ; |
101 | return out; |
102 | } |
103 | |
104 | TEST(InternedStringsTest, Basic) { |
105 | ASSERT_EQ(prim::Param, Symbol::prim("Param" )); |
106 | ASSERT_EQ(prim::Return, Symbol::prim("Return" )); |
107 | ASSERT_EQ(prim::Return.toUnqualString(), std::string("Return" )); |
108 | ASSERT_EQ(prim::Return.toQualString(), std::string("prim::Return" )); |
109 | Symbol newsym = Symbol::aten("__NEW_SYMBOL" ); |
110 | size_t symstart = newsym; |
111 | ASSERT_EQ(newsym.toQualString(), std::string("aten::__NEW_SYMBOL" )); |
112 | // TODO: This test is a bit too close to the implementation details. |
113 | ASSERT_EQ(Symbol::aten("What" ), symstart + 1); |
114 | ASSERT_EQ(Symbol::aten("What2" ), symstart + 2); |
115 | ASSERT_EQ(Symbol::aten("What" ), symstart + 1); |
116 | ASSERT_EQ(Symbol::aten("What2" ), symstart + 2); |
117 | ASSERT_EQ(Symbol(symstart + 2).toUnqualString(), std::string("What2" )); |
118 | } |
119 | |
120 | TEST(FromQualStringTest, Basic) { |
121 | ASSERT_EQ(Symbol::fromQualString("prim::Param" ), Symbol::prim("Param" )); |
122 | ASSERT_EQ(Symbol::fromQualString("aten::mm" ), Symbol::aten("mm" )); |
123 | ASSERT_EQ(Symbol::fromQualString("onnx::LSTM" ), Symbol::onnx("LSTM" )); |
124 | ASSERT_EQ(Symbol::fromQualString("attr::value" ), Symbol::attr("value" )); |
125 | ASSERT_EQ(Symbol::fromQualString("scope::" ), Symbol::scope("" )); |
126 | ASSERT_EQ(Symbol::fromQualString("::" ).toUnqualString(), std::string("" )); |
127 | ASSERT_EQ( |
128 | Symbol::fromQualString("::" ).ns().toQualString(), |
129 | std::string("namespaces::" )); |
130 | ASSERT_EQ( |
131 | Symbol::fromQualString("new_ns::param" ).toUnqualString(), |
132 | std::string("param" )); |
133 | ASSERT_EQ( |
134 | Symbol::fromQualString("new_ns::param" ).ns().toUnqualString(), |
135 | std::string("new_ns" )); |
136 | ASSERT_EQ( |
137 | Symbol::fromQualString("new_ns::param" ).ns(), |
138 | Symbol::fromQualString("namespaces::new_ns" )); |
139 | |
140 | auto bad_inputs = {"scope" , ":" , "" }; |
141 | for (auto input : bad_inputs) { |
142 | try { |
143 | Symbol::fromQualString(input); |
144 | ASSERT_TRUE(0); |
145 | } catch (const std::exception& c) { |
146 | } |
147 | } |
148 | } |
149 | |
150 | TEST(THNNConvTest, Basic) { |
151 | std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W |
152 | std::vector<int64_t> kernel_size = {3, 5}; |
153 | std::vector<int64_t> stride = {1, 2}; |
154 | std::vector<int64_t> padding = {2, 1}; |
155 | constexpr int out_channels = 5; |
156 | |
157 | // make inputs |
158 | at::Tensor input = torch::randn(input_size); |
159 | at::Tensor weight = torch::randn( |
160 | {out_channels, input_size[1], kernel_size[0], kernel_size[1]}); |
161 | at::Tensor bias = torch::randn({out_channels}); |
162 | |
163 | // run forward eagerly |
164 | at::Tensor output = at::_slow_conv2d_forward( |
165 | input, weight, kernel_size, bias, stride, padding); |
166 | |
167 | // make grad_outputs |
168 | at::Tensor grad_output = |
169 | torch::randn_like(output, at::MemoryFormat::Preserve); |
170 | |
171 | // run backward eagerly |
172 | at::Tensor grad_input, grad_weight, grad_bias; |
173 | std::tie(grad_input, grad_weight, grad_bias) = at::_slow_conv2d_backward( |
174 | grad_output, |
175 | input, |
176 | weight, |
177 | kernel_size, |
178 | stride, |
179 | padding, |
180 | {true, true, true}); |
181 | |
182 | // make JIT graph |
183 | auto graph = std::make_shared<Graph>(); |
184 | auto ksz_val = graph->insertConstant(kernel_size); |
185 | auto kst_val = graph->insertConstant(stride); |
186 | auto pad_val = graph->insertConstant(padding); |
187 | |
188 | auto inputg = graph->addInput("self" ); |
189 | auto weightg = graph->addInput("weight" ); |
190 | auto biasg = graph->addInput("bias" ); |
191 | |
192 | Value* conv = graph->insert( |
193 | aten::_slow_conv2d_forward, |
194 | {inputg, weightg, ksz_val, biasg, kst_val, pad_val}); |
195 | auto outputs = conv->node()->outputs(); |
196 | for (auto output : outputs) { |
197 | graph->registerOutput(output); |
198 | } |
199 | LowerAllTuples(graph); |
200 | graph->lint(); |
201 | |
202 | // differentiate JIT graph |
203 | EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick |
204 | ConstantPropagation(graph); |
205 | auto grad_spec = differentiate(graph); |
206 | LowerGradOf(*grad_spec.df); |
207 | |
208 | // prepare JIT inputs / gradients |
209 | tensor_list tensors_in; |
210 | tensors_in.push_back(input); |
211 | tensors_in.push_back(weight); |
212 | tensors_in.push_back(bias); |
213 | |
214 | tensor_list tensor_grads_in; |
215 | tensor_grads_in.push_back(grad_output); |
216 | |
217 | // Get outputs from the interpreter |
218 | tensor_list tensors_out, tensor_grads_out; |
219 | std::tie(tensors_out, tensor_grads_out) = |
220 | runGradient(grad_spec, tensors_in, tensor_grads_in); |
221 | |
222 | // prepare expected structs |
223 | tensor_list expected_tensors_out, expected_tensor_grads_out; |
224 | expected_tensors_out.push_back(output); |
225 | expected_tensor_grads_out.push_back(grad_input); |
226 | expected_tensor_grads_out.push_back(grad_weight); |
227 | expected_tensor_grads_out.push_back(grad_bias); |
228 | |
229 | // Compare results |
230 | assertAllClose(tensors_out, expected_tensors_out); |
231 | assertAllClose(tensor_grads_out, expected_tensor_grads_out); |
232 | } |
233 | |
234 | TEST(ATenNativeBatchNormTest, Basic) { |
235 | // aten::native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor |
236 | // running_mean, Tensor running_var, bool training, float momentum, float eps) |
237 | // -> (Tensor, Tensor, Tensor) |
238 | std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W |
239 | bool training = true; |
240 | float momentum = 0.9; |
241 | float eps = 1e-5; |
242 | |
243 | // make inputs |
244 | at::Tensor input = torch::randn(input_size); |
245 | at::Tensor weight = torch::randn({input_size[1]}); |
246 | at::Tensor bias = torch::randn({input_size[1]}); |
247 | at::Tensor running_mean = torch::randn({input_size[1]}); |
248 | at::Tensor running_var = torch::randn({input_size[1]}); |
249 | |
250 | // running_mean and running_var are changed in-place, so clone and send them |
251 | at::Tensor running_mean_eager = running_mean.clone(); |
252 | at::Tensor running_var_eager = running_var.clone(); |
253 | at::Tensor running_mean_jit = running_mean.clone(); |
254 | at::Tensor running_var_jit = running_var.clone(); |
255 | |
256 | // run forward eagerly |
257 | at::Tensor output, savemean, saveinvstd; |
258 | std::tie(output, savemean, saveinvstd) = at::native_batch_norm( |
259 | input, |
260 | weight, |
261 | bias, |
262 | running_mean_eager, |
263 | running_var_eager, |
264 | training, |
265 | momentum, |
266 | eps); |
267 | |
268 | // make grad_outputs |
269 | at::Tensor grad_output = |
270 | torch::randn_like(output, at::MemoryFormat::Preserve); |
271 | at::Tensor grad_savemean = |
272 | torch::zeros_like(savemean, at::MemoryFormat::Preserve); |
273 | at::Tensor grad_saveinvstd = |
274 | torch::zeros_like(saveinvstd, at::MemoryFormat::Preserve); |
275 | |
276 | // run backward eagerly |
277 | at::Tensor grad_input, grad_weight, grad_bias; |
278 | // aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor |
279 | // weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor |
280 | // save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, |
281 | // Tensor, Tensor) |
282 | std::tie(grad_input, grad_weight, grad_bias) = at::native_batch_norm_backward( |
283 | grad_output, |
284 | input, |
285 | weight, |
286 | running_mean_eager, |
287 | running_var_eager, |
288 | savemean, |
289 | saveinvstd, |
290 | training, |
291 | eps, |
292 | {true, true, true}); |
293 | |
294 | // make JIT graph |
295 | auto graph = std::make_shared<Graph>(); |
296 | auto training_val = graph->insertConstant(IValue(training)); |
297 | auto momentum_val = graph->insertConstant(IValue(momentum)); |
298 | auto eps_val = graph->insertConstant(IValue(eps)); |
299 | |
300 | auto inputg = graph->addInput("self" ); |
301 | auto weightg = graph->addInput("weight" ); |
302 | auto biasg = graph->addInput("bias" ); |
303 | auto running_meang = graph->addInput("running_mean" ); |
304 | auto running_varg = graph->addInput("running_var" ); |
305 | |
306 | Value* bn = graph->insert( |
307 | aten::native_batch_norm, |
308 | {inputg, |
309 | weightg, |
310 | biasg, |
311 | running_meang, |
312 | running_varg, |
313 | training_val, |
314 | momentum_val, |
315 | eps_val}); |
316 | auto outputs = bn->node()->outputs(); |
317 | for (auto output : outputs) { |
318 | graph->registerOutput(output); |
319 | } |
320 | LowerAllTuples(graph); |
321 | graph->lint(); |
322 | |
323 | // differentiate JIT graph |
324 | EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick |
325 | ConstantPropagation(graph); |
326 | auto grad_spec = differentiate(graph); |
327 | LowerGradOf(*grad_spec.df); |
328 | |
329 | // prepare JIT inputs / gradients |
330 | tensor_list tensors_in; |
331 | tensors_in.push_back(input); |
332 | tensors_in.push_back(weight); |
333 | tensors_in.push_back(bias); |
334 | tensors_in.push_back(running_mean_jit); |
335 | tensors_in.push_back(running_var_jit); |
336 | |
337 | tensor_list tensor_grads_in; |
338 | tensor_grads_in.push_back(grad_output); |
339 | tensor_grads_in.push_back(grad_savemean); |
340 | tensor_grads_in.push_back(grad_saveinvstd); |
341 | |
342 | // Get outputs from the interpreter |
343 | tensor_list tensors_out, tensor_grads_out; |
344 | std::tie(tensors_out, tensor_grads_out) = |
345 | runGradient(grad_spec, tensors_in, tensor_grads_in); |
346 | |
347 | // prepare expected structs |
348 | tensor_list expected_tensors_out, expected_tensor_grads_out; |
349 | expected_tensors_out.push_back(output); |
350 | expected_tensors_out.push_back(savemean); |
351 | expected_tensors_out.push_back(saveinvstd); |
352 | expected_tensors_out.push_back(running_mean_eager); |
353 | expected_tensors_out.push_back(running_var_eager); |
354 | expected_tensor_grads_out.push_back(grad_input); |
355 | expected_tensor_grads_out.push_back(grad_weight); |
356 | expected_tensor_grads_out.push_back(grad_bias); |
357 | |
358 | tensors_out.push_back(running_mean_jit); |
359 | tensors_out.push_back(running_var_jit); |
360 | |
361 | // Compare results |
362 | assertAllClose(tensors_out, expected_tensors_out); |
363 | assertAllClose(tensor_grads_out, expected_tensor_grads_out); |
364 | } |
365 | |
366 | TEST(CustomFusionTest, Basic) { |
367 | #if defined(FBCODE_CAFFE2) |
368 | return; |
369 | #endif |
370 | |
371 | auto graph_string = R"IR( |
372 | graph(%0 : Float(2, 3, 4), |
373 | %1 : Float(2, 3, 4)): |
374 | %2 : Tensor = aten::mul(%0, %1) |
375 | %3 : Tensor = aten::mul(%2, %0) |
376 | return (%3))IR" ; |
377 | auto g = std::make_shared<Graph>(); |
378 | torch::jit::parseIR(graph_string, g.get()); |
379 | |
380 | torch::jit::overrideCanFuseOnCPU(true); |
381 | CustomFuseGraph( |
382 | g, |
383 | [](Node* n) { return n->kind() != prim::Param; }, |
384 | Symbol::fromQualString("prim::FusionGroup" )); |
385 | torch::jit::overrideCanFuseOnCPU(false); |
386 | |
387 | const auto& nodes = g->nodes(); |
388 | auto fusion_group = |
389 | std::find_if(nodes.begin(), nodes.end(), [](const Node* node) { |
390 | return node->kind() == Symbol::fromQualString("prim::FusionGroup" ); |
391 | }); |
392 | AT_ASSERT(fusion_group != nodes.end()); |
393 | |
394 | auto subgraph = fusion_group->g(attr::Subgraph); |
395 | auto hits = 0; |
396 | // two multiplications |
397 | for (const auto& n : subgraph->nodes()) { |
398 | (void)n; |
399 | hits++; |
400 | } |
401 | AT_ASSERT(hits == 2); |
402 | } |
403 | |
404 | TEST(CustomFusionTest, NestedBlocks) { |
405 | #if defined(FBCODE_CAFFE2) |
406 | return; |
407 | #endif |
408 | |
409 | auto graph_string = R"IR( |
410 | graph(%0 : Float(2, 3, 4), |
411 | %1 : Float(2, 3, 4), |
412 | %2 : Float(2, 3, 4)): |
413 | %3 : int = prim::Constant[value=1]() |
414 | %4 : Tensor = prim::If(%2) |
415 | block0(): |
416 | %5 : Tensor = aten::mul(%0, %2) |
417 | %6 : Tensor = aten::mul(%5, %1) |
418 | -> (%6) |
419 | block1(): |
420 | %7 : Tensor = aten::add(%0, %2, %3) |
421 | %8 : Tensor = aten::add(%7, %1, %3) |
422 | -> (%8) |
423 | %9 : Tensor = aten::add(%4, %2, %3) |
424 | return (%4))IR" ; |
425 | auto g = std::make_shared<Graph>(); |
426 | torch::jit::parseIR(graph_string, g.get()); |
427 | |
428 | CustomFuseGraph( |
429 | g, |
430 | [](Node* n) { return n->kind() == aten::mul; }, |
431 | Symbol::fromQualString("prim::FusionGroup" )); |
432 | |
433 | // Could be done in more efficient ways, but this is only a test. |
434 | std::function<bool(const Block*, Symbol)> dfs = [&](const Block* b, |
435 | Symbol s) { |
436 | for (auto node : b->nodes()) { |
437 | if (node->kind() == s) |
438 | return true; |
439 | for (auto nested_b : node->blocks()) |
440 | if (dfs(nested_b, s)) |
441 | return true; |
442 | } |
443 | return false; |
444 | }; |
445 | |
446 | AT_ASSERT(dfs(g->block(), Symbol::fromQualString("prim::FusionGroup" ))); |
447 | } |
448 | |
449 | static const auto cf_examples = R"JIT( |
450 | def if_test(a, b): |
451 | # FIXME: use 0 instead of a. |
452 | # c = 0 |
453 | c = a |
454 | if bool(a < b): |
455 | c = b |
456 | else: |
457 | c = a |
458 | return c |
459 | def if_one(a, b): |
460 | c = b |
461 | if bool(a < b): |
462 | c = a |
463 | return c |
464 | def while_test(a, i): |
465 | while bool(i < 3): |
466 | a *= a |
467 | i += 1 |
468 | return a |
469 | )JIT" ; |
470 | |
471 | TEST(ControlFlowTest, Basic) { |
472 | auto cu = compile(cf_examples); |
473 | |
474 | auto run = [&](const std::string& name, std::vector<IValue> stack) { |
475 | auto graph = toGraphFunction(cu->get_function(name)).graph(); |
476 | Code code(graph, "" ); |
477 | InterpreterState interp(code); |
478 | interp.run(stack); |
479 | return stack; |
480 | }; |
481 | |
482 | auto L = [](int64_t l) { return IValue(scalar_to_tensor(at::Scalar(l))); }; |
483 | auto V = [](IValue t) { return std::move(t).toTensor().item<int64_t>(); }; |
484 | auto run_binary = [&](const std::string& name, int64_t a, int64_t b) { |
485 | return V(run(name, {L(a), L(b)})[0]); |
486 | }; |
487 | ASSERT_EQ(2, run_binary("if_test" , 1, 2)); |
488 | ASSERT_EQ(3, run_binary("if_test" , 3, 2)); |
489 | ASSERT_EQ(2, run_binary("if_one" , 2, 3)); |
490 | ASSERT_EQ(2, run_binary("if_one" , 3, 2)); |
491 | ASSERT_EQ(256, run_binary("while_test" , 2, 0)); |
492 | } |
493 | |
494 | #if defined(__has_feature) |
495 | #if __has_feature(address_sanitizer) |
496 | #define HAS_ASANUBSAN 1 |
497 | #endif |
498 | #endif |
499 | |
500 | #ifndef HAS_ASANUBSAN |
501 | // This test fails vptr UBSAN checks |
502 | |
503 | TEST(ProtoTest, Basic) { |
504 | ::ONNX_NAMESPACE::ModelProto proto; |
505 | proto.set_producer_name("foo" ); |
506 | } |
507 | #endif |
508 | |
509 | // test a few features that are not directly used in schemas yet |
510 | TEST(SchemaParserTest, NestedArrays) { |
511 | // nested arrays |
512 | auto s = parseSchema("at::what(int[][4] foo) -> ()" ); |
513 | ASSERT_TRUE(s.arguments().at(0).N() == 4); |
514 | ASSERT_TRUE(IntType::get()->isSubtypeOf(*s.arguments() |
515 | .at(0) |
516 | .type() |
517 | ->expectRef<ListType>() |
518 | .getElementType() |
519 | ->expectRef<ListType>() |
520 | .getElementType())); |
521 | auto s2 = parseSchema("at::what(int[][] foo) -> ()" ); |
522 | ASSERT_TRUE(IntType::get()->isSubtypeOf(*s2.arguments() |
523 | .at(0) |
524 | .type() |
525 | ->expectRef<ListType>() |
526 | .getElementType() |
527 | ->expectRef<ListType>() |
528 | .getElementType())); |
529 | } |
530 | |
531 | TEST(SchemaParserTest, OutVariant) { |
532 | auto schema_with_out = parseSchema( |
533 | "at::foo(Tensor self, *, Tensor(a!) f, Tensor(b!) l) -> (Tensor(a!) f, Tensor(b!) l)" ); |
534 | ASSERT_TRUE(schema_with_out.arguments().at(1).is_out()); |
535 | ASSERT_TRUE(schema_with_out.arguments().at(2).is_out()); |
536 | |
537 | auto schema_without_out = |
538 | parseSchema("at::foo(Tensor self, *, int scalar) -> (int)" ); |
539 | |
540 | for (const auto& arg : schema_without_out.arguments()) { |
541 | ASSERT_TRUE(!arg.is_out()); |
542 | } |
543 | |
544 | auto schema_with_is_write = parseSchema( |
545 | "aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> (Tensor(a!))" ); |
546 | |
547 | for (const auto& arg : schema_with_is_write.arguments()) { |
548 | ASSERT_TRUE(!arg.is_out()); |
549 | } |
550 | } |
551 | |
552 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
553 | TEST(SchemaParserTest, NamedReturns) { |
554 | // named returns |
555 | parseSchema("at::what(Tensor! i_will_be_written_to) -> ()" ); |
556 | auto s3 = |
557 | parseSchema("at::what() -> (Tensor the_return, Tensor the_return2)" ); |
558 | ASSERT_TRUE(s3.returns().at(0).name() == "the_return" ); |
559 | ASSERT_TRUE(s3.returns().at(1).name() == "the_return2" ); |
560 | } |
561 | |
562 | TEST(SchemaParserTest, Futures) { |
563 | // futures |
564 | auto s4 = parseSchema("at::what(Future(int) foo) -> ()" ); |
565 | ASSERT_TRUE(IntType::get()->isSubtypeOf( |
566 | *s4.arguments().at(0).type()->expectRef<FutureType>().getElementType())); |
567 | } |
568 | |
569 | TEST(SchemaParserTest, AnnotatedAliasSets) { |
570 | // test tensor with annotated alias sets |
571 | parseSchema("at::what(Tensor(a) foo) -> (Tensor(a))" ); |
572 | } |
573 | |
574 | TEST(SchemaParserTest, TensorListAnnotatedAliasSets) { |
575 | const auto s = parseSchema( |
576 | "at::foo(Tensor(a!) self, Tensor(b!)[] out)" |
577 | " -> ()" ); |
578 | const AliasInfo* selfAliasInfo = s.arguments().at(0).alias_info(); |
579 | const AliasInfo* outAliasInfo = s.arguments().at(1).alias_info(); |
580 | ASSERT_TRUE( |
581 | selfAliasInfo->beforeSets() == |
582 | std::unordered_set<Symbol>{Symbol::fromQualString("alias::a" )}); |
583 | ASSERT_TRUE(selfAliasInfo->isWrite()); |
584 | |
585 | ASSERT_TRUE(outAliasInfo->isWrite()); |
586 | ASSERT_TRUE(outAliasInfo->beforeSets().empty()); |
587 | ASSERT_EQ(outAliasInfo->containedTypes().size(), 1); |
588 | |
589 | auto containedType = outAliasInfo->containedTypes()[0]; |
590 | |
591 | ASSERT_TRUE(containedType.isWrite()); |
592 | ASSERT_TRUE( |
593 | containedType.beforeSets() == |
594 | std::unordered_set<Symbol>{Symbol::fromQualString("alias::b" )}); |
595 | } |
596 | |
597 | TEST(SchemaParserTest, AnnotatedAliasWithoutBeforeSet) { |
598 | EXPECT_THAT( |
599 | []() { parseSchema("at::foo(Tensor(!) self) -> Tensor" ); }, |
600 | ::testing::Throws<std::runtime_error>(::testing::Property( |
601 | &std::runtime_error::what, |
602 | ::testing::HasSubstr("expected ident but found '!' here" )))); |
603 | } |
604 | |
605 | TEST(SchemaParserTest, BeforeAfterSets) { |
606 | const auto s = parseSchema( |
607 | "at::what(Tensor(b|c)[](a!) list, Tensor(c) element)" |
608 | " -> (Tensor(b|c)[](a!))" ); |
609 | |
610 | // The list itself is annotated with `a` |
611 | const AliasInfo* aliasInfo = s.arguments().at(0).alias_info(); |
612 | ASSERT_NE(aliasInfo, nullptr); |
613 | ASSERT_TRUE( |
614 | aliasInfo->beforeSets() == |
615 | std::unordered_set<Symbol>{Symbol::fromQualString("alias::a" )}); |
616 | ASSERT_TRUE(aliasInfo->isWrite()); |
617 | |
618 | // Check the contained types |
619 | ASSERT_TRUE(!aliasInfo->containedTypes().empty()); |
620 | const auto& containedAliasInfo = aliasInfo->containedTypes()[0]; |
621 | const auto expected = std::unordered_set<Symbol>{ |
622 | Symbol::fromQualString("alias::b" ), |
623 | Symbol::fromQualString("alias::c" ), |
624 | }; |
625 | ASSERT_TRUE(containedAliasInfo.beforeSets() == expected); |
626 | ASSERT_TRUE(containedAliasInfo.afterSets() == expected); |
627 | ASSERT_FALSE(containedAliasInfo.isWrite()); |
628 | } |
629 | |
630 | TEST(SchemaParserTest, BeforeAfterSets2) { |
631 | const auto s = parseSchema( |
632 | "at::what(Tensor(b -> b|c)[](a!) list, Tensor(c) element)" |
633 | " -> (Tensor(b|c)[](a!))" ); |
634 | |
635 | // The list itself is annotated with `a` |
636 | const AliasInfo* aliasInfo = s.arguments().at(0).alias_info(); |
637 | ASSERT_NE(aliasInfo, nullptr); |
638 | ASSERT_EQ( |
639 | aliasInfo->beforeSets(), |
640 | std::unordered_set<Symbol>{Symbol::fromQualString("alias::a" )}); |
641 | ASSERT_EQ( |
642 | aliasInfo->afterSets(), |
643 | std::unordered_set<Symbol>{Symbol::fromQualString("alias::a" )}); |
644 | ASSERT_TRUE(aliasInfo->isWrite()); |
645 | ASSERT_EQ(aliasInfo->containedTypes().size(), 1); |
646 | |
647 | // Check the contained types |
648 | ASSERT_TRUE(!aliasInfo->containedTypes().empty()); |
649 | const auto& containedAliasInfo = aliasInfo->containedTypes()[0]; |
650 | const auto expectedBefore = std::unordered_set<Symbol>{ |
651 | Symbol::fromQualString("alias::b" ), |
652 | }; |
653 | const auto expectedAfter = std::unordered_set<Symbol>{ |
654 | Symbol::fromQualString("alias::b" ), Symbol::fromQualString("alias::c" )}; |
655 | ASSERT_TRUE(containedAliasInfo.beforeSets() == expectedBefore); |
656 | ASSERT_TRUE(containedAliasInfo.afterSets() == expectedAfter); |
657 | ASSERT_FALSE(containedAliasInfo.isWrite()); |
658 | } |
659 | |
660 | TEST(TopologicalIndexTest, Basic) { |
661 | Graph graph; |
662 | auto node1 = graph.create(prim::AutogradZero); |
663 | auto node2 = graph.create(prim::AutogradZero); |
664 | auto node3 = graph.create(prim::AutogradZero); |
665 | auto node4 = graph.create(prim::AutogradZero); |
666 | |
667 | graph.appendNode(node4); |
668 | graph.prependNode(node1); |
669 | node2->insertAfter(node1); |
670 | node3->insertBefore(node4); |
671 | |
672 | // nodes should be in numerical order |
673 | ASSERT_TRUE(node1->isBefore(node2)); |
674 | ASSERT_TRUE(node1->isBefore(node3)); |
675 | ASSERT_TRUE(node1->isBefore(node4)); |
676 | ASSERT_TRUE(node2->isAfter(node1)); |
677 | ASSERT_TRUE(node2->isBefore(node3)); |
678 | ASSERT_TRUE(node2->isBefore(node4)); |
679 | ASSERT_FALSE(node3->isBefore(node1)); |
680 | ASSERT_FALSE(node3->isBefore(node2)); |
681 | ASSERT_FALSE(node3->isAfter(node4)); |
682 | |
683 | // Built up a block structure |
684 | // node3 |
685 | // /\ ... |
686 | // A B block1 |
687 | // \ ... |
688 | // C block2 |
689 | auto block1 = node3->addBlock(); |
690 | auto A = graph.create(prim::AutogradZero); |
691 | block1->appendNode(A); |
692 | auto B = graph.create(prim::AutogradZero); |
693 | block1->appendNode(B); |
694 | auto block2 = B->addBlock(); |
695 | auto C = graph.create(prim::AutogradZero); |
696 | block2->appendNode(C); |
697 | |
698 | // Check isAfter on different block levels |
699 | ASSERT_TRUE(node1->isBefore(A)); |
700 | ASSERT_TRUE(A->isBefore(B)); |
701 | ASSERT_TRUE(A->isBefore(C)); |
702 | |
703 | // make sure things don't blow up on deletions |
704 | node2->destroy(); |
705 | auto node2p = graph.create(prim::AutogradZero); |
706 | node2p->insertAfter(node1); |
707 | ASSERT_TRUE(node1->isBefore(node2p)); |
708 | ASSERT_TRUE(node2p->isBefore(node3)); |
709 | } |
710 | |
711 | TEST(TopologicalIndexTest, Reindex) { |
712 | // Induce reindexing to test that path |
713 | Graph graph; |
714 | std::map<size_t, Node*> nodes; |
715 | |
716 | auto anchor = graph.create(prim::AutogradZero); |
717 | graph.appendNode(anchor); |
718 | // Inserting to the same place a lot will trigger reindexing |
719 | for (auto i = 0; i < 100; ++i) { |
720 | auto n = graph.create(prim::AutogradZero); |
721 | n->insertAfter(anchor); |
722 | nodes[i] = n; |
723 | } |
724 | |
725 | // Nodes should be in reverse order |
726 | for (auto i = 0; i < 100; ++i) { |
727 | for (auto j = i + 1; j < 100; ++j) { |
728 | ASSERT_TRUE(nodes[i]->isAfter(nodes[j])); |
729 | } |
730 | } |
731 | } |
732 | |
733 | at::Tensor invokeTestRecordFunction(at::Tensor& t) { |
734 | RECORD_FUNCTION("test" , std::vector<c10::IValue>({t})); |
735 | |
736 | auto t2 = t.pow(2); |
737 | return t2; |
738 | } |
739 | |
740 | static const auto invokeTestRecordFunction_JIT = R"JIT( |
741 | def foo(self, t): |
742 | t2 = t.pow(2) |
743 | return t2 |
744 | |
745 | def forward(self, t): |
746 | return self.foo(t) |
747 | )JIT" ; |
748 | |
749 | at::Tensor invokeTestRecordFunctionJIT(at::Tensor& t) { |
750 | RECORD_FUNCTION("test" , std::vector<c10::IValue>({t})); |
751 | |
752 | auto module = std::make_shared<script::Module>( |
753 | "RecordFunctionTestModule" , std::make_shared<script::CompilationUnit>()); |
754 | module->define(invokeTestRecordFunction_JIT); |
755 | return module->forward({t}).toTensor(); |
756 | } |
757 | |
758 | using TracedTestValues = |
759 | std::vector<std::tuple<std::string, std::vector<std::vector<int64_t>>>>; |
760 | |
761 | void checkTracedInputs(const TracedTestValues& inputs) { |
762 | bool found_test = false; |
763 | bool found_pow = false; |
764 | bool found_mul = false; |
765 | for (const auto& input : inputs) { |
766 | const auto& fn = std::get<0>(input); |
767 | const auto& sizes = std::get<1>(input); |
768 | |
769 | if (fn == "test" ) { |
770 | found_test = true; |
771 | TORCH_CHECK(sizes.size() == 1); |
772 | TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3})); |
773 | } else if (fn == "aten::pow" ) { |
774 | found_pow = true; |
775 | TORCH_CHECK(sizes.size() == 2); |
776 | TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3})); |
777 | TORCH_CHECK(sizes[1].empty()); |
778 | } else if (fn == "aten::mul" ) { |
779 | found_mul = true; |
780 | TORCH_CHECK(sizes.size() > 1); |
781 | TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3})); |
782 | } |
783 | } |
784 | TORCH_CHECK(found_test); |
785 | TORCH_CHECK(found_pow); |
786 | TORCH_CHECK(found_mul); |
787 | } |
788 | |
789 | void checkTracedOutputs(const TracedTestValues& outputs) { |
790 | bool found_test = false; |
791 | bool found_pow = false; |
792 | bool found_mul = false; |
793 | for (const auto& output : outputs) { |
794 | const auto& fn = std::get<0>(output); |
795 | const auto& sizes = std::get<1>(output); |
796 | |
797 | if (fn == "test" ) { |
798 | found_test = true; |
799 | TORCH_CHECK(sizes.empty()); |
800 | } else if (fn == "aten::pow" ) { |
801 | found_pow = true; |
802 | TORCH_CHECK(sizes.size() == 1); |
803 | TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3})); |
804 | } else if (fn == "aten::mul" ) { |
805 | found_mul = true; |
806 | TORCH_CHECK(sizes.size() == 1); |
807 | TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3})); |
808 | } |
809 | } |
810 | TORCH_CHECK(found_test); |
811 | TORCH_CHECK(found_pow); |
812 | TORCH_CHECK(found_mul); |
813 | } |
814 | |
815 | static bool bad_scope = false; |
816 | template <RecordScope scope, size_t* cnt> |
817 | std::unique_ptr<at::ObserverContext> checkScopeCallback( |
818 | const at::RecordFunction& fn) { |
819 | if (fn.scope() == scope) { |
820 | ++(*cnt); |
821 | } else { |
822 | bad_scope = true; |
823 | } |
824 | return nullptr; |
825 | } |
826 | |
827 | template <RecordScope scope, size_t* cnt> |
828 | void pushScopedCallback() { |
829 | at::addGlobalCallback( |
830 | at::RecordFunctionCallback(checkScopeCallback<scope, cnt>) |
831 | .scopes({scope})); |
832 | } |
833 | |
834 | // These cannot be function-local because that would prohibit them |
835 | // from being used as template arguments prior to C++17. |
836 | static size_t fun_cnt; |
837 | static size_t ts_fun_cnt; |
838 | static size_t user_scope_cnt; |
839 | |
840 | void checkScopeCallbacks() { |
841 | static bool found_function_scope; |
842 | static bool found_method_scope; |
843 | static bool found_user_scope; |
844 | found_function_scope = false; |
845 | found_method_scope = false; |
846 | found_user_scope = false; |
847 | at::addGlobalCallback(at::RecordFunctionCallback( |
848 | [](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> { |
849 | if (fn.scope() == at::RecordScope::FUNCTION && |
850 | std::string(fn.name()) == "test_function" ) { |
851 | found_function_scope = true; |
852 | } |
853 | if (fn.scope() == at::RecordScope::TORCHSCRIPT_FUNCTION && |
854 | std::string(fn.name()) == "test_method" ) { |
855 | found_method_scope = true; |
856 | } |
857 | if (fn.scope() == at::RecordScope::USER_SCOPE && |
858 | std::string(fn.name()) == "test_user_scope" ) { |
859 | found_user_scope = true; |
860 | } |
861 | return nullptr; |
862 | })); |
863 | |
864 | bad_scope = false; |
865 | fun_cnt = 0; |
866 | pushScopedCallback<at::RecordScope::FUNCTION, &fun_cnt>(); |
867 | ts_fun_cnt = 0; |
868 | pushScopedCallback<at::RecordScope::TORCHSCRIPT_FUNCTION, &ts_fun_cnt>(); |
869 | user_scope_cnt = 0; |
870 | pushScopedCallback<at::RecordScope::USER_SCOPE, &user_scope_cnt>(); |
871 | |
872 | TORCH_CHECK(at::hasCallbacks()); |
873 | |
874 | { |
875 | RECORD_TORCHSCRIPT_FUNCTION("test_method" , {}); |
876 | { RECORD_FUNCTION("test_function" , {}); } |
877 | { RECORD_USER_SCOPE("test_user_scope" ); } |
878 | } |
879 | |
880 | TORCH_CHECK(!bad_scope); |
881 | TORCH_CHECK(fun_cnt == 1); |
882 | TORCH_CHECK(ts_fun_cnt == 1); |
883 | TORCH_CHECK(user_scope_cnt == 1); |
884 | |
885 | TORCH_CHECK(found_function_scope); |
886 | TORCH_CHECK(found_method_scope); |
887 | TORCH_CHECK(found_user_scope); |
888 | } |
889 | |
890 | static TracedTestValues traced_inputs; |
891 | static TracedTestValues traced_outputs; |
892 | static std::unordered_set<std::string> ts_input_names; |
893 | static std::unordered_set<std::string> ts_output_names; |
894 | |
895 | std::unique_ptr<at::ObserverContext> tracedInputsCallback( |
896 | const RecordFunction& fn) { |
897 | if (fn.scope() == RecordScope::FUNCTION) { |
898 | auto inputs = fn.inputs(); |
899 | std::vector<std::vector<int64_t>> sizes; |
900 | for (const auto& input : inputs) { |
901 | if (input.isTensor()) { |
902 | sizes.push_back(input.toTensor().sizes().vec()); |
903 | } else if (input.isScalar()) { |
904 | // NOLINTNEXTLINE(modernize-use-emplace) |
905 | sizes.push_back(std::vector<int64_t>()); |
906 | } |
907 | } |
908 | traced_inputs.push_back(std::make_tuple(fn.name(), sizes)); |
909 | } else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) { |
910 | ts_input_names.insert(fn.name()); |
911 | } |
912 | return nullptr; |
913 | } |
914 | |
915 | void tracedOutputsCallback(const RecordFunction& fn, ObserverContext* ctx_ptr) { |
916 | if (fn.scope() == RecordScope::FUNCTION) { |
917 | auto outputs = fn.outputs(); |
918 | std::vector<std::vector<int64_t>> sizes; |
919 | for (const auto& output : outputs) { |
920 | if (output.isTensor()) { |
921 | sizes.push_back(output.toTensor().sizes().vec()); |
922 | } else if (output.isScalar()) { |
923 | sizes.emplace_back(); |
924 | } |
925 | } |
926 | traced_outputs.push_back(std::make_tuple(fn.name(), sizes)); |
927 | } else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) { |
928 | ts_output_names.insert(fn.name()); |
929 | } |
930 | } |
931 | |
932 | TEST(RecordFunctionTest, TracedTestInputsOutputs) { |
933 | // disabling the inlining of method calls |
934 | GraphOptimizerEnabledGuard opt_guard(false); |
935 | |
936 | // [(fn, [[sizes], [sizes], ...]), ...] |
937 | addGlobalCallback( |
938 | RecordFunctionCallback(tracedInputsCallback, tracedOutputsCallback) |
939 | .needsInputs(true) |
940 | .needsOutputs(true)); |
941 | |
942 | TracedTestValues eager_inputs, eager_outputs, jit_inputs, jit_outputs; |
943 | { |
944 | auto t = torch::randn({1, 2, 3}, at::kCPU); |
945 | t.set_requires_grad(true); |
946 | auto t2 = invokeTestRecordFunction(t); |
947 | t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve)); |
948 | eager_inputs = traced_inputs; |
949 | eager_outputs = traced_outputs; |
950 | traced_inputs.clear(); |
951 | traced_outputs.clear(); |
952 | |
953 | TORCH_CHECK(ts_input_names.empty()); |
954 | TORCH_CHECK(ts_output_names.empty()); |
955 | |
956 | t = torch::randn({1, 2, 3}, at::kCPU); |
957 | t.set_requires_grad(true); |
958 | t2 = invokeTestRecordFunctionJIT(t); |
959 | t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve)); |
960 | jit_inputs = traced_inputs; |
961 | jit_outputs = traced_outputs; |
962 | traced_inputs.clear(); |
963 | traced_outputs.clear(); |
964 | } |
965 | |
966 | TORCH_CHECK(ts_input_names.find("forward" ) != ts_input_names.end()); |
967 | TORCH_CHECK(ts_input_names.find("foo" ) != ts_input_names.end()); |
968 | TORCH_CHECK(ts_output_names.find("forward" ) != ts_output_names.end()); |
969 | TORCH_CHECK(ts_output_names.find("foo" ) != ts_output_names.end()); |
970 | |
971 | checkTracedInputs(eager_inputs); |
972 | checkTracedOutputs(eager_outputs); |
973 | checkTracedInputs(jit_inputs); |
974 | checkTracedOutputs(jit_outputs); |
975 | at::clearCallbacks(); |
976 | } |
977 | |
978 | static int sampled_cb_ctr = 0; |
979 | std::unique_ptr<ObserverContext> sampledCallback(const RecordFunction& fn) { |
980 | if (std::string(fn.name()) == "test" ) { |
981 | ++sampled_cb_ctr; |
982 | } |
983 | return nullptr; |
984 | } |
985 | |
986 | static int non_sampled_cb_ctr = 0; |
987 | std::unique_ptr<ObserverContext> nonSampledCallback(const RecordFunction& fn) { |
988 | if (std::string(fn.name()) == "test" ) { |
989 | ++non_sampled_cb_ctr; |
990 | } |
991 | return nullptr; |
992 | } |
993 | |
994 | TEST(RecordFunctionTest, SampledCallbacks) { |
995 | // disabling the inlining of method calls |
996 | GraphOptimizerEnabledGuard opt_guard(false); |
997 | |
998 | // test sampled callbacks |
999 | sampled_cb_ctr = 0; |
1000 | auto setup_sampled_callback = [](double sampling_prob) { |
1001 | return addGlobalCallback( |
1002 | RecordFunctionCallback(sampledCallback).samplingProb(sampling_prob)); |
1003 | }; |
1004 | |
1005 | addGlobalCallback(RecordFunctionCallback(nonSampledCallback)); |
1006 | |
1007 | auto handle = setup_sampled_callback(0.5); |
1008 | |
1009 | auto run_test_function = []() { |
1010 | auto t = torch::randn({1, 2, 3}, at::kCPU); |
1011 | for (auto k = 0; k < 1000; k++) { |
1012 | invokeTestRecordFunction(t); |
1013 | } |
1014 | }; |
1015 | |
1016 | run_test_function(); |
1017 | TORCH_CHECK(non_sampled_cb_ctr == 1000); |
1018 | TORCH_CHECK(sampled_cb_ctr > 0 && sampled_cb_ctr < 1000); |
1019 | |
1020 | sampled_cb_ctr = 0; |
1021 | removeCallback(handle); |
1022 | handle = setup_sampled_callback(0.0); |
1023 | run_test_function(); |
1024 | |
1025 | TORCH_CHECK(non_sampled_cb_ctr == 2000); |
1026 | TORCH_CHECK(sampled_cb_ctr == 0); |
1027 | |
1028 | sampled_cb_ctr = 0; |
1029 | removeCallback(handle); |
1030 | // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
1031 | handle = setup_sampled_callback(1.0); |
1032 | run_test_function(); |
1033 | |
1034 | TORCH_CHECK(non_sampled_cb_ctr == 3000); |
1035 | TORCH_CHECK(sampled_cb_ctr == 1000); |
1036 | clearCallbacks(); |
1037 | |
1038 | // test the scope of the callbacks |
1039 | checkScopeCallbacks(); |
1040 | clearCallbacks(); |
1041 | } |
1042 | |
1043 | TEST(RecordFunctionTest, RecordFunctionGuard) { |
1044 | // disabling the inlining of method calls |
1045 | GraphOptimizerEnabledGuard opt_guard(false); |
1046 | |
1047 | static std::vector<std::string> fn_names; |
1048 | static std::mutex guard_mtx; |
1049 | |
1050 | // check record function guard |
1051 | addGlobalCallback(RecordFunctionCallback( |
1052 | [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> { |
1053 | std::lock_guard<std::mutex> lock(guard_mtx); |
1054 | // NOLINTNEXTLINE(modernize-use-emplace) |
1055 | fn_names.push_back(fn.name()); |
1056 | return nullptr; |
1057 | })); |
1058 | { |
1059 | RecordFunctionGuard g1(false); |
1060 | { |
1061 | RECORD_USER_SCOPE("A" ); |
1062 | { |
1063 | RecordFunctionGuard g2(true); |
1064 | RECORD_USER_SCOPE("B" ); |
1065 | { |
1066 | DisableRecordFunctionGuard g3; |
1067 | RECORD_USER_SCOPE("C" ); |
1068 | } |
1069 | } |
1070 | { RECORD_USER_SCOPE("D" ); } |
1071 | } |
1072 | } |
1073 | TORCH_CHECK(fn_names.size() == 1); |
1074 | TORCH_CHECK(fn_names[0] == "B" ); |
1075 | clearCallbacks(); |
1076 | } |
1077 | |
1078 | static std::vector<size_t> ids; |
1079 | |
1080 | template <size_t id> |
1081 | auto add_remove_test_add_cb() { |
1082 | return addGlobalCallback(RecordFunctionCallback( |
1083 | [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> { |
1084 | ids.push_back(id); |
1085 | return nullptr; |
1086 | })); |
1087 | } |
1088 | |
1089 | TEST(RecordFunctionTest, Callbacks) { |
1090 | // disabling the inlining of method calls |
1091 | GraphOptimizerEnabledGuard opt_guard(false); |
1092 | |
1093 | auto h1 = add_remove_test_add_cb<1>(); |
1094 | add_remove_test_add_cb<2>(); |
1095 | auto h3 = add_remove_test_add_cb<3>(); |
1096 | |
1097 | { RECORD_USER_SCOPE("test" ); } |
1098 | |
1099 | TORCH_CHECK(ids.size() == 3); |
1100 | TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end()); |
1101 | TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end()); |
1102 | TORCH_CHECK(std::find(ids.begin(), ids.end(), 3) != ids.end()); |
1103 | |
1104 | ids.clear(); |
1105 | removeCallback(h1); |
1106 | |
1107 | { RECORD_USER_SCOPE("test" ); } |
1108 | |
1109 | TORCH_CHECK(ids.size() == 2); |
1110 | TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end()); |
1111 | TORCH_CHECK(std::find(ids.begin(), ids.end(), 3) != ids.end()); |
1112 | |
1113 | ids.clear(); |
1114 | removeCallback(h3); |
1115 | |
1116 | { RECORD_USER_SCOPE("test" ); } |
1117 | |
1118 | TORCH_CHECK(ids.size() == 1); |
1119 | TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end()); |
1120 | |
1121 | clearCallbacks(); |
1122 | |
1123 | // thread local / global callbacks |
1124 | |
1125 | ids.clear(); |
1126 | add_remove_test_add_cb<1>(); |
1127 | |
1128 | { RECORD_USER_SCOPE("test" ); } |
1129 | |
1130 | TORCH_CHECK(ids.size() == 1); |
1131 | TORCH_CHECK(ids[0] == 1); |
1132 | ids.clear(); |
1133 | |
1134 | auto th = std::thread([]() { |
1135 | addThreadLocalCallback(RecordFunctionCallback( |
1136 | [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> { |
1137 | ids.push_back(2); |
1138 | return nullptr; |
1139 | })); |
1140 | |
1141 | { RECORD_USER_SCOPE("test_thread" ); } |
1142 | }); |
1143 | th.join(); |
1144 | TORCH_CHECK(ids.size() == 2); |
1145 | TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end()); |
1146 | TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end()); |
1147 | ids.clear(); |
1148 | |
1149 | { RECORD_USER_SCOPE("test" ); } |
1150 | |
1151 | TORCH_CHECK(ids.size() == 1); |
1152 | TORCH_CHECK(ids[0] == 1); |
1153 | ids.clear(); |
1154 | |
1155 | clearCallbacks(); |
1156 | |
1157 | // START: thread local / global context check callbacks |
1158 | struct TestContext : public ObserverContext { |
1159 | int a{0}; |
1160 | std::string b; |
1161 | }; |
1162 | ids.clear(); |
1163 | { // START: global test |
1164 | addGlobalCallback(RecordFunctionCallback( |
1165 | [](const RecordFunction& |
1166 | /* unused */) -> std::unique_ptr<at::ObserverContext> { |
1167 | auto ctx = std::make_unique<TestContext>(); |
1168 | ctx->a = 123; |
1169 | ctx->b = "test_str" ; |
1170 | ids.push_back(1); |
1171 | return ctx; |
1172 | }, |
1173 | [](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) { |
1174 | auto ctx = dynamic_cast<TestContext*>(ctx_ptr); |
1175 | TORCH_CHECK(ctx_ptr != nullptr); |
1176 | TORCH_CHECK(ctx->a == 123); |
1177 | TORCH_CHECK(ctx->b == "test_str" ); |
1178 | })); |
1179 | |
1180 | { RECORD_USER_SCOPE("test" ); } |
1181 | |
1182 | TORCH_CHECK(ids.size() == 1); |
1183 | TORCH_CHECK(ids[0] == 1); |
1184 | ids.clear(); |
1185 | } // END: global test |
1186 | { // START: thread local test |
1187 | auto ctx_th = std::thread([]() { |
1188 | const std::string test_str = "test thread str" ; |
1189 | addThreadLocalCallback(RecordFunctionCallback( |
1190 | [](const RecordFunction& |
1191 | /* unused */) -> std::unique_ptr<at::ObserverContext> { |
1192 | auto ctx = std::make_unique<TestContext>(); |
1193 | ctx->a = 234; |
1194 | ctx->b = "test_thread_str" ; |
1195 | ids.push_back(2); |
1196 | return ctx; |
1197 | }, |
1198 | [](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) { |
1199 | auto ctx = dynamic_cast<TestContext*>(ctx_ptr); |
1200 | TORCH_CHECK(ctx_ptr != nullptr); |
1201 | TORCH_CHECK(ctx->a == 234); |
1202 | TORCH_CHECK(ctx->b == "test_thread_str" ); |
1203 | })); |
1204 | |
1205 | // Will call both global and thread local callbacks. |
1206 | { RECORD_USER_SCOPE("test_thread" ); } |
1207 | }); |
1208 | ctx_th.join(); |
1209 | TORCH_CHECK(ids.size() == 2); |
1210 | TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end()); |
1211 | TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end()); |
1212 | ids.clear(); |
1213 | } // END: thread local test |
1214 | |
1215 | clearCallbacks(); |
1216 | } |
1217 | |
1218 | TEST(RecordFunctionTest, ShouldRun) { |
1219 | // disabling the inlining of method calls |
1220 | GraphOptimizerEnabledGuard opt_guard(false); |
1221 | |
1222 | static bool ran = false; |
1223 | auto handle = addGlobalCallback(RecordFunctionCallback( |
1224 | [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> { |
1225 | ran = true; |
1226 | return nullptr; |
1227 | })); |
1228 | |
1229 | { RECORD_USER_SCOPE("test" ); } |
1230 | |
1231 | EXPECT_TRUE(ran) << "first run didn't happen" ; |
1232 | ran = false; |
1233 | |
1234 | disableCallback(handle); |
1235 | |
1236 | { RECORD_USER_SCOPE("test" ); } |
1237 | |
1238 | EXPECT_FALSE(ran) << "second run happened but shouldn't have" ; |
1239 | ran = false; |
1240 | |
1241 | reenableCallback(handle); |
1242 | |
1243 | { RECORD_USER_SCOPE("test" ); } |
1244 | |
1245 | EXPECT_TRUE(ran) << "run after re-enable didn't happen" ; |
1246 | ran = false; |
1247 | |
1248 | clearCallbacks(); |
1249 | } |
1250 | |
1251 | TEST(RecordFunctionTest, Basic) { |
1252 | // disabling the inlining of method calls |
1253 | GraphOptimizerEnabledGuard opt_guard(false); |
1254 | |
1255 | static std::string recorded_op; |
1256 | static bool has_ids = false; |
1257 | |
1258 | // test propagation of TLS callbacks |
1259 | std::thread t([]() { |
1260 | RecordFunctionGuard enable_rec_fn; |
1261 | auto handle = addThreadLocalCallback(RecordFunctionCallback( |
1262 | [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> { |
1263 | recorded_op = fn.name(); |
1264 | return nullptr; |
1265 | })); |
1266 | ThreadLocalState state; |
1267 | std::thread t_child([state]() { |
1268 | ThreadLocalStateGuard g_tls(state); |
1269 | RECORD_USER_SCOPE("test_in_thread" ); |
1270 | }); |
1271 | t_child.join(); |
1272 | EXPECT_EQ(recorded_op, "test_in_thread" ); |
1273 | removeCallback(handle); |
1274 | }); |
1275 | t.join(); |
1276 | clearCallbacks(); |
1277 | |
1278 | // test set ids |
1279 | addGlobalCallback( |
1280 | RecordFunctionCallback( |
1281 | [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> { |
1282 | has_ids = fn.handle() > 0; |
1283 | return nullptr; |
1284 | }) |
1285 | .needsIds(true)); |
1286 | { RECORD_USER_SCOPE("test" ); } |
1287 | TORCH_CHECK(has_ids); |
1288 | clearCallbacks(); |
1289 | has_ids = false; |
1290 | addGlobalCallback(RecordFunctionCallback( |
1291 | [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> { |
1292 | has_ids = fn.handle() > 0; |
1293 | return nullptr; |
1294 | })); |
1295 | { RECORD_USER_SCOPE("test" ); } |
1296 | TORCH_CHECK(!has_ids); |
1297 | clearCallbacks(); |
1298 | } |
1299 | |
1300 | TEST(RecordFunctionTest, OperatorNameOverload) { |
1301 | static std::set<std::string> operator_names; |
1302 | at::addGlobalCallback(at::RecordFunctionCallback( |
1303 | [](const at::RecordFunction& fn) |
1304 | -> std::unique_ptr<at::ObserverContext> { |
1305 | c10::optional<c10::OperatorName> op_name = |
1306 | fn.operator_name(); |
1307 | if (op_name.has_value()) { |
1308 | operator_names.insert(c10::toString(*op_name)); |
1309 | } else { |
1310 | operator_names.insert("No Operator Name" ); |
1311 | } |
1312 | return nullptr; |
1313 | }) |
1314 | .scopes({at::RecordScope::FUNCTION})); |
1315 | auto t = torch::randn({1, 2, 3}, at::kCPU); |
1316 | t.set_requires_grad(false); |
1317 | auto t2 = t.pow(2); |
1318 | |
1319 | at::clearCallbacks(); |
1320 | EXPECT_TRUE(operator_names.count("No Operator Name" ) == 0) |
1321 | << "Expected that all traced operators had an associated OperatorName object" ; |
1322 | EXPECT_TRUE(operator_names.count("aten::randn" ) == 1) |
1323 | << "Expected aten::randn to have been called and recorded, but it was not" ; |
1324 | EXPECT_TRUE(operator_names.count("aten::pow.Tensor_Scalar" ) == 1) |
1325 | << "Expected aten::pow.Tensor_Scalar to have been called and recorded, but it was not" ; |
1326 | } |
1327 | |
1328 | class TestThreadLocalDebugInfo : public c10::DebugInfoBase { |
1329 | public: |
1330 | int getModelId() const { |
1331 | return model_id_; |
1332 | } |
1333 | |
1334 | void setModelId(int model_id) { |
1335 | model_id_ = model_id; |
1336 | } |
1337 | |
1338 | // NOLINTNEXTLINE(modernize-use-override,modernize-use-equals-default) |
1339 | virtual ~TestThreadLocalDebugInfo() {} |
1340 | |
1341 | private: |
1342 | int model_id_ = 0; |
1343 | }; |
1344 | |
1345 | void checkDebugInfo(c10::DebugInfoKind kind, int model_id) { |
1346 | auto* debug_info = c10::ThreadLocalDebugInfo::get(kind); |
1347 | TORCH_CHECK(debug_info != nullptr); |
1348 | auto* test_debug_info = dynamic_cast<TestThreadLocalDebugInfo*>(debug_info); |
1349 | TORCH_CHECK(test_debug_info != nullptr); |
1350 | TORCH_CHECK(test_debug_info->getModelId() == model_id); |
1351 | } |
1352 | |
1353 | TEST(ThreadLocalDebugInfoTest, Basic) { |
1354 | static std::atomic<bool> done{false}; |
1355 | |
1356 | TORCH_CHECK( |
1357 | c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr); |
1358 | auto debug_info = std::make_shared<TestThreadLocalDebugInfo>(); |
1359 | debug_info->setModelId(42); |
1360 | { |
1361 | c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info); |
1362 | checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); |
1363 | } |
1364 | |
1365 | // check that thread local debug info is propagated through fork calls |
1366 | TORCH_CHECK( |
1367 | c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr); |
1368 | { |
1369 | c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info); |
1370 | at::launch([]() { |
1371 | checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); |
1372 | done = true; |
1373 | }); |
1374 | } |
1375 | while (!done) { |
1376 | } |
1377 | |
1378 | // check that thread local debug info is propagated through backward pass |
1379 | TORCH_CHECK( |
1380 | c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr); |
1381 | done = false; |
1382 | auto handle = addGlobalCallback(RecordFunctionCallback( |
1383 | [](const RecordFunction&) -> std::unique_ptr<at::ObserverContext> { |
1384 | checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); |
1385 | done = true; |
1386 | return nullptr; |
1387 | })); |
1388 | { |
1389 | c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info); |
1390 | auto t = torch::randn({1, 2, 3}, at::kCPU); |
1391 | t.set_requires_grad(true); |
1392 | auto t2 = t.pow(2); |
1393 | t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve)); |
1394 | } |
1395 | removeCallback(handle); |
1396 | TORCH_CHECK(done); |
1397 | |
1398 | // check nested debug info |
1399 | TORCH_CHECK( |
1400 | c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr); |
1401 | { |
1402 | c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info); |
1403 | { |
1404 | checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); |
1405 | { |
1406 | auto debug_info = std::make_shared<TestThreadLocalDebugInfo>(); |
1407 | debug_info->setModelId(314); |
1408 | c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO_2, debug_info); |
1409 | { |
1410 | checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); |
1411 | checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314); |
1412 | done = false; |
1413 | at::launch([]() { |
1414 | checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); |
1415 | checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314); |
1416 | done = true; |
1417 | }); |
1418 | while (!done) { |
1419 | } |
1420 | } |
1421 | } |
1422 | } |
1423 | } |
1424 | } |
1425 | |
1426 | TEST(TestSymIntArrayRef, BasicConversion) { |
1427 | const size_t X = 2, Y = 4, Z = 5; |
1428 | std::vector<int64_t> tgt_size_v{2, 4, 5}; |
1429 | std::vector<c10::SymInt> tgt_size({SymInt(X), SymInt(Y), SymInt(Z)}); |
1430 | auto a = at::randn({1, 4, 1}, at::kCPU); |
1431 | auto b = a.expand_symint(tgt_size); |
1432 | auto c = a.expand(tgt_size_v); |
1433 | ASSERT_TRUE(torch::allclose(b, c)); |
1434 | } |
1435 | |
1436 | TEST(TestSymInt, NarrowCopyWithSymbolicInt) { |
1437 | static const size_t LENGTH = 5; |
1438 | auto a = at::randn({10}, at::kCPU); |
1439 | c10::SymInt si(LENGTH); |
1440 | auto b = a.narrow_copy_symint(0, 0, si); |
1441 | auto c = a.narrow(0, 0, LENGTH); |
1442 | ASSERT_TRUE(torch::allclose(b, c)); |
1443 | } |
1444 | |
1445 | TEST(TestSymInt, NarrowCopy) { |
1446 | static const size_t LENGTH = 5; |
1447 | auto a = at::randn({10}, at::kCPU); |
1448 | auto b = a.narrow_copy(0, 0, LENGTH); |
1449 | auto c = a.narrow(0, 0, LENGTH); |
1450 | ASSERT_TRUE(torch::allclose(b, c)); |
1451 | } |
1452 | |
1453 | TEST(TestSymInt, AddSymbolicInt) { |
1454 | c10::SymInt a(5); |
1455 | c10::SymInt b(3); |
1456 | ASSERT_TRUE((a + b).expect_int() == 8); |
1457 | } |
1458 | |
1459 | TEST(FallbackGraphsTest, Basic) { |
1460 | auto x = at::randn({1}, at::kCPU); |
1461 | auto y = at::randn({1}, at::kCPU); |
1462 | auto stack = createStack({x.clone(), y.clone()}); |
1463 | |
1464 | auto graph_string = R"IR( |
1465 | graph(%0 : Float(1), |
1466 | %1 : Float(1)): |
1467 | %2 : Tensor = aten::mul(%0, %1) |
1468 | %3 : Tensor = aten::mul(%2, %0) |
1469 | return (%3))IR" ; |
1470 | auto graph = std::make_shared<Graph>(); |
1471 | torch::jit::parseIR(graph_string, graph.get()); |
1472 | |
1473 | { |
1474 | Code code(graph, "" ); |
1475 | InterpreterState interpreter{code}; |
1476 | interpreter.run(stack); |
1477 | } |
1478 | at::Tensor et; |
1479 | pop(stack, et); |
1480 | float ef = et.item<float>(); |
1481 | { |
1482 | EnableProfilingGuard epg; |
1483 | GraphFunction f("fallbackGraphs" , graph, nullptr); |
1484 | for (size_t i = 0; i < getNumProfiledRuns() + 1; i++) { |
1485 | stack.emplace_back(x.clone()); |
1486 | stack.emplace_back(y.clone()); |
1487 | if (i == getNumProfiledRuns()) { |
1488 | // we will be modifying a profiled graph |
1489 | // before ProfilingGraphExecutor |
1490 | // will optimize it in the next iteration |
1491 | auto opt_graph = lastExecutedOptimizedGraph(); |
1492 | // this is safe to do since we are done profiling |
1493 | ProfilingRecord::removeProfileCounter(opt_graph->block()); |
1494 | replaceBlockWithFallbackGraph(opt_graph->block(), opt_graph->inputs()); |
1495 | auto it = opt_graph->block()->nodes().begin(); |
1496 | ASSERT_EQ(it->kind(), prim::FallbackGraph); |
1497 | auto fallback = *it++; |
1498 | ASSERT_EQ(it, opt_graph->block()->nodes().end()); |
1499 | ASSERT_TRUE(fallback->hasAttribute(attr::Subgraph)); |
1500 | testing::FileCheck() |
1501 | .check("Tensor = aten::mul" ) |
1502 | ->check("Tensor = aten::mul" ) |
1503 | ->run(*fallback->g(attr::Subgraph)); |
1504 | } |
1505 | f.run(stack); |
1506 | at::Tensor at; |
1507 | pop(stack, at); |
1508 | float af = at.item<float>(); |
1509 | ASSERT_EQ(af, ef); |
1510 | } |
1511 | |
1512 | auto opt_graph = lastExecutedOptimizedGraph(); |
1513 | testing::FileCheck() |
1514 | .check("(Tensor) = prim::CallFunction" ) |
1515 | ->run(*opt_graph); |
1516 | } |
1517 | } |
1518 | |
1519 | // TODO this test wasn't running and is broken. |
1520 | // TEST(AutogradProfilerTest, Basic) { |
1521 | // constexpr int batch_size = 4; |
1522 | // constexpr int input_size = 256; |
1523 | // constexpr int seq_len = 32; |
1524 | |
1525 | // int hidden_size = 2 * input_size; |
1526 | // auto input = torch::randn({seq_len, batch_size, input_size}, at::kCPU); |
1527 | // auto hx = torch::randn({batch_size, hidden_size}, at::kCPU); |
1528 | // auto cx = torch::randn({batch_size, hidden_size}, at::kCPU); |
1529 | // auto w_ih = t_def(torch::randn({4 * hidden_size, input_size}, at::kCPU)); |
1530 | // auto w_hh = t_def(torch::randn({4 * hidden_size, hidden_size}, at::kCPU)); |
1531 | |
1532 | // std::stringstream ss; |
1533 | // { |
1534 | // RecordProfile guard(ss); |
1535 | // for (size_t i = 0; i < 100; ++i) { |
1536 | // std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh); |
1537 | // } |
1538 | // } |
1539 | |
1540 | // std::string result = ss.str(); |
1541 | // size_t count = 0; |
1542 | // for (size_t pos = 0; (pos = result.find("tanh", pos)) != std::string::npos; |
1543 | // count++, pos++) { |
1544 | // } |
1545 | // ASSERT_EQ((count, 200); |
1546 | // } |
1547 | |
1548 | TEST(NoneSchemaMatchTest, Basic) { |
1549 | RegisterOperators reg({ |
1550 | Operator( |
1551 | "prim::test_none() -> int?" , |
1552 | [](Stack& stack) { push(stack, IValue()); }, |
1553 | aliasAnalysisFromSchema()), |
1554 | Operator( |
1555 | "prim::is_none(int? a) -> bool" , |
1556 | [](Stack& stack) { |
1557 | IValue a = pop(stack); |
1558 | if (a.isNone()) { |
1559 | push(stack, true); |
1560 | } else { |
1561 | push(stack, false); |
1562 | } |
1563 | }, |
1564 | aliasAnalysisFromSchema()), |
1565 | }); |
1566 | |
1567 | // Constant propagation will run test_none and produce a None, |
1568 | // testing that its type is set appropriately and schema matching doesn't |
1569 | // fail when running is_none |
1570 | |
1571 | auto r = std::make_shared<Graph>(); |
1572 | auto& g = *r; |
1573 | auto opt_int = g.insert(Symbol::fromQualString("prim::test_none" ), {}); |
1574 | auto out_bool = g.insert(Symbol::fromQualString("prim::is_none" ), {opt_int}); |
1575 | g.registerOutput(out_bool); |
1576 | ConstantPropagation(r); |
1577 | |
1578 | auto nodes = r->block()->nodes(); |
1579 | // checking that constant propagation ran wo/failure |
1580 | AT_ASSERT(std::distance(nodes.begin(), nodes.end()) == 1); |
1581 | } |
1582 | |
1583 | static int testPassValue = 0; |
1584 | void fakePass(std::shared_ptr<Graph>& g) { |
1585 | testPassValue++; |
1586 | return; |
1587 | } |
1588 | |
1589 | RegisterPass p(fakePass); |
1590 | |
1591 | TEST(PassManagementTest, Basic) { |
1592 | std::shared_ptr<Graph> graph = std::make_shared<Graph>(); |
1593 | parseIR( |
1594 | R"IR( |
1595 | graph(%a): |
1596 | return (%a))IR" , |
1597 | &*graph); |
1598 | |
1599 | std::vector<IValue> stack = {IValue(torch::randn({22}, at::kCPU))}; |
1600 | auto run = [&](std::shared_ptr<Graph>& graph, std::vector<IValue> stack) { |
1601 | GraphExecutor executor(graph, "" ); |
1602 | executor.run(stack); |
1603 | return stack; |
1604 | }; |
1605 | run(graph, stack); |
1606 | // we will not run fusion in simple mode |
1607 | if (!getExecutorMode()) { |
1608 | AT_ASSERT(testPassValue); |
1609 | } |
1610 | } |
1611 | |
1612 | static void checkShape(TypePtr typ, std::vector<int64_t> expected) { |
1613 | auto ptp = typ->expect<TensorType>(); |
1614 | ASSERT_EQ(ptp->sizes().concrete_sizes().value(), expected); |
1615 | } |
1616 | |
1617 | static void checkShape( |
1618 | Node* n, |
1619 | std::vector<int64_t> expected, |
1620 | bool prev = true) { |
1621 | auto profile = (prev) ? n->inputs().at(0)->node() : n; |
1622 | checkShape(profile->output()->type(), expected); |
1623 | } |
1624 | |
1625 | void count_( |
1626 | Block* block, |
1627 | const std::function<bool(Node* n)>& pred, |
1628 | size_t& count) { |
1629 | for (Node* n : block->nodes()) { |
1630 | if (pred(n)) { |
1631 | count++; |
1632 | } |
1633 | |
1634 | for (Block* ib : n->blocks()) { |
1635 | count_(ib, pred, count); |
1636 | } |
1637 | } |
1638 | } |
1639 | |
1640 | size_t countNodes( |
1641 | const std::shared_ptr<Graph>& graph, |
1642 | const std::function<bool(Node* n)>& pred) { |
1643 | size_t count = 0; |
1644 | count_(graph->block(), pred, count); |
1645 | return count; |
1646 | } |
1647 | |
1648 | bool true_pred(Node* n) { |
1649 | return true; |
1650 | }; |
1651 | |
1652 | bool is_loop(Node* n) { |
1653 | return n->kind() == prim::Loop; |
1654 | }; |
1655 | |
1656 | TEST(LoopPeelerTest, NoInductionVariableUse) { |
1657 | // do not use an induction variable explicitly |
1658 | static const auto str_func_def = R"JIT( |
1659 | def test_peel_n_times(): |
1660 | sum = 0 |
1661 | for i in range(10): |
1662 | sum += 2 |
1663 | return sum |
1664 | )JIT" ; |
1665 | |
1666 | auto cu = compile(str_func_def); |
1667 | auto& f = toGraphFunction(cu->get_function("test_peel_n_times" )); |
1668 | auto stack = createStack({}); |
1669 | // peeling loop once |
1670 | { |
1671 | LoopsPeeler peeler(true_pred, 1); |
1672 | auto copy = f.graph()->copy(); |
1673 | peeler.run(copy); |
1674 | int num_loops = |
1675 | std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop); |
1676 | ASSERT_EQ(num_loops, 2); |
1677 | Code code(copy, "" ); |
1678 | InterpreterState interpreter{code}; |
1679 | interpreter.run(stack); |
1680 | ASSERT_EQ(stack.back().toInt(), 20); |
1681 | } |
1682 | |
1683 | // test peeling more than one iteration |
1684 | { |
1685 | LoopsPeeler peeler(true_pred, 3); |
1686 | auto copy = f.graph()->copy(); |
1687 | peeler.run(copy); |
1688 | int num_loops = |
1689 | std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop); |
1690 | ASSERT_EQ(num_loops, 2); |
1691 | Code code(copy, "" ); |
1692 | InterpreterState interpreter{code}; |
1693 | interpreter.run(stack); |
1694 | ASSERT_EQ(stack.back().toInt(), 20); |
1695 | } |
1696 | } |
1697 | |
1698 | TEST(LoopPeelerTest, YesInductionVariableUse) { |
1699 | // uses the induction variable |
1700 | static const auto str_func_def = R"JIT( |
1701 | def test_peel_n_times(): |
1702 | sum = 0 |
1703 | for i in range(10): |
1704 | sum += i |
1705 | return sum |
1706 | )JIT" ; |
1707 | |
1708 | auto cu = compile(str_func_def); |
1709 | auto& f = toGraphFunction(cu->get_function("test_peel_n_times" )); |
1710 | auto stack = createStack({}); |
1711 | // peeling loop once |
1712 | { |
1713 | LoopsPeeler peeler(true_pred, 1); |
1714 | auto copy = f.graph()->copy(); |
1715 | peeler.run(copy); |
1716 | int num_loops = |
1717 | std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop); |
1718 | ASSERT_EQ(num_loops, 2); |
1719 | Code code(copy, "" ); |
1720 | InterpreterState interpreter{code}; |
1721 | interpreter.run(stack); |
1722 | ASSERT_EQ(stack.back().toInt(), 45); |
1723 | } |
1724 | |
1725 | // test peeling more than one iteration |
1726 | { |
1727 | LoopsPeeler peeler(true_pred, 3); |
1728 | auto copy = f.graph()->copy(); |
1729 | peeler.run(copy); |
1730 | int num_loops = |
1731 | std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop); |
1732 | ASSERT_EQ(num_loops, 2); |
1733 | Code code(copy, "" ); |
1734 | InterpreterState interpreter{code}; |
1735 | interpreter.run(stack); |
1736 | ASSERT_EQ(stack.back().toInt(), 45); |
1737 | } |
1738 | } |
1739 | |
1740 | TEST(LoopPeelerTest, LoopWithTerminationCondition) { |
1741 | // tests with explicit termination conditions |
1742 | static const auto str_func_def = R"JIT( |
1743 | def test_with_cond_times(): |
1744 | sum = 0 |
1745 | i = 0 |
1746 | while (sum < 2): |
1747 | sum += i |
1748 | i += 1 |
1749 | return sum |
1750 | )JIT" ; |
1751 | |
1752 | // the peel changes the termination condition to false |
1753 | // so the original loop doesn't run |
1754 | auto cu = compile(str_func_def); |
1755 | auto& f = toGraphFunction(cu->get_function("test_with_cond_times" )); |
1756 | auto stack = createStack({}); |
1757 | // peeling 5 iterations should update the termination |
1758 | // condition to false |
1759 | { |
1760 | LoopsPeeler peeler(true_pred, 5); |
1761 | auto copy = f.graph()->copy(); |
1762 | peeler.run(copy); |
1763 | int num_loops = |
1764 | std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop); |
1765 | ASSERT_EQ(num_loops, 2); |
1766 | Code code(copy, "" ); |
1767 | InterpreterState interpreter{code}; |
1768 | interpreter.run(stack); |
1769 | ASSERT_EQ(stack.back().toInt(), 3); |
1770 | } |
1771 | |
1772 | // the termination condition remains true |
1773 | { |
1774 | LoopsPeeler peeler(true_pred, 1); |
1775 | auto copy = f.graph()->copy(); |
1776 | peeler.run(copy); |
1777 | int num_loops = |
1778 | std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop); |
1779 | ASSERT_EQ(num_loops, 2); |
1780 | Code code(copy, "" ); |
1781 | InterpreterState interpreter{code}; |
1782 | interpreter.run(stack); |
1783 | ASSERT_EQ(stack.back().toInt(), 3); |
1784 | } |
1785 | } |
1786 | |
1787 | // tests simple nested loops |
1788 | TEST(LoopPeelerTest, SimpleNestedLoops) { |
1789 | static const auto str_func_def = R"JIT( |
1790 | def test_nested_loops(): |
1791 | sum = 0 |
1792 | i = 0 |
1793 | for i in range(10): |
1794 | for j in range(10): |
1795 | sum += i + j |
1796 | return sum |
1797 | )JIT" ; |
1798 | |
1799 | auto cu = compile(str_func_def); |
1800 | auto& f = toGraphFunction(cu->get_function("test_nested_loops" )); |
1801 | auto stack = createStack({}); |
1802 | |
1803 | { |
1804 | LoopsPeeler peeler(true_pred, 1); |
1805 | auto copy = f.graph()->copy(); |
1806 | peeler.run(copy); |
1807 | ASSERT_EQ(countNodes(copy, is_loop), 5); |
1808 | Code code(copy, "" ); |
1809 | InterpreterState interpreter{code}; |
1810 | interpreter.run(stack); |
1811 | ASSERT_EQ(stack.back().toInt(), 900); |
1812 | } |
1813 | |
1814 | { |
1815 | LoopsPeeler peeler(true_pred, 5); |
1816 | auto copy = f.graph()->copy(); |
1817 | peeler.run(copy); |
1818 | ASSERT_EQ(countNodes(copy, is_loop), 5); |
1819 | Code code(copy, "" ); |
1820 | InterpreterState interpreter{code}; |
1821 | interpreter.run(stack); |
1822 | ASSERT_EQ(stack.back().toInt(), 900); |
1823 | } |
1824 | } |
1825 | |
1826 | TEST(LoopPeelerTest, SimpleNestedLoops2) { |
1827 | static const auto str_func_def = R"JIT( |
1828 | def test_nested_loops(): |
1829 | sum = 0 |
1830 | i = 0 |
1831 | for i in range(10): |
1832 | j = 0 |
1833 | while sum < 2: |
1834 | sum += i + j |
1835 | j += 1 |
1836 | return sum |
1837 | )JIT" ; |
1838 | |
1839 | auto cu = compile(str_func_def); |
1840 | auto& f = toGraphFunction(cu->get_function("test_nested_loops" )); |
1841 | auto stack = createStack({}); |
1842 | { |
1843 | LoopsPeeler peeler(true_pred, 1); |
1844 | auto copy = f.graph()->copy(); |
1845 | peeler.run(copy); |
1846 | ASSERT_EQ(countNodes(copy, is_loop), 5); |
1847 | Code code(copy, "" ); |
1848 | InterpreterState interpreter{code}; |
1849 | interpreter.run(stack); |
1850 | ASSERT_EQ(stack.back().toInt(), 3); |
1851 | } |
1852 | |
1853 | { |
1854 | LoopsPeeler peeler(true_pred, 5); |
1855 | auto copy = f.graph()->copy(); |
1856 | peeler.run(copy); |
1857 | ASSERT_EQ(countNodes(copy, is_loop), 5); |
1858 | Code code(copy, "" ); |
1859 | InterpreterState interpreter{code}; |
1860 | interpreter.run(stack); |
1861 | ASSERT_EQ(stack.back().toInt(), 3); |
1862 | } |
1863 | } |
1864 | |
1865 | TEST(JitTracing, Basic) { |
1866 | constexpr int batch_size = 4; |
1867 | constexpr int input_size = 256; |
1868 | |
1869 | int hidden_size = 2 * input_size; |
1870 | |
1871 | auto input = at::randn({batch_size, input_size}, at::kCPU); |
1872 | auto hx = at::randn({batch_size, hidden_size}, at::kCPU); |
1873 | auto cx = at::randn({batch_size, hidden_size}, at::kCPU); |
1874 | auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCPU)); |
1875 | auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCPU)); |
1876 | |
1877 | auto graph = build_lstm(); |
1878 | auto stack = createStack({input, hx, cx, w_ih, w_hh}); |
1879 | auto traced = TraceGraph(graph, stack); |
1880 | |
1881 | // Check that the inputs of traced graph have the same type as the inputs |
1882 | // specified here. |
1883 | ASSERT_EQ(*traced->inputs().at(0)->type(), *TensorType::create(input)); |
1884 | ASSERT_EQ(*traced->inputs().at(1)->type(), *TensorType::create(hx)); |
1885 | ASSERT_EQ(*traced->inputs().at(2)->type(), *TensorType::create(cx)); |
1886 | ASSERT_EQ(*traced->inputs().at(3)->type(), *TensorType::create(w_ih)); |
1887 | ASSERT_EQ(*traced->inputs().at(4)->type(), *TensorType::create(w_hh)); |
1888 | |
1889 | Tensor prof_out; |
1890 | pop(stack, prof_out); |
1891 | |
1892 | { |
1893 | stack = createStack({input, hx, cx, w_ih, w_hh}); |
1894 | Code cd(traced, "traced" ); |
1895 | InterpreterState is{cd}; |
1896 | is.run(stack); |
1897 | Tensor traced_out; |
1898 | pop(stack, traced_out); |
1899 | torch::allclose(prof_out, traced_out); |
1900 | } |
1901 | |
1902 | { |
1903 | stack = createStack({input, hx, cx, w_ih, w_hh}); |
1904 | Code cd(graph, "graph" ); |
1905 | InterpreterState is{cd}; |
1906 | is.run(stack); |
1907 | Tensor scripted_out; |
1908 | pop(stack, scripted_out); |
1909 | torch::allclose(prof_out, scripted_out); |
1910 | } |
1911 | } |
1912 | |
1913 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
1914 | TEST(InsertAndEliminateRedundantGuardsTest, Basic) { |
1915 | static const auto basic_example = R"JIT( |
1916 | def basic(x, y): |
1917 | a = x + y |
1918 | b = x * y |
1919 | c = x + 1 |
1920 | d = a - c |
1921 | e = b - c |
1922 | return d + e |
1923 | )JIT" ; |
1924 | |
1925 | auto cu = compile(basic_example); |
1926 | auto& fun = toGraphFunction(cu->get_function("basic" )); |
1927 | auto pr = ProfilingRecord::instrumentGraph(fun.graph()); |
1928 | auto x = at::randn({2, 3}, at::kCPU); |
1929 | auto y = at::randn({2, 3}, at::kCPU); |
1930 | auto stack = createStack({x, y}); |
1931 | // introduce some profiling information |
1932 | Code cd(pr->profiled_graph_, "" ); |
1933 | InterpreterState is{cd}; |
1934 | is.run(stack); |
1935 | auto copy = pr->profiled_graph_->copy(); |
1936 | ProfilingRecord::removeProfileCounter(copy->block()); |
1937 | InsertGuards(copy); |
1938 | auto nodes = copy->block()->nodes(); |
1939 | auto guard = std::find_if(nodes.begin(), nodes.end(), [](Node* n) { |
1940 | return n->kind() == prim::Guard; |
1941 | }); |
1942 | ASSERT_NE(guard, nodes.end()); |
1943 | ASSERT_EQ( |
1944 | guard->input()->type()->expectRef<TensorType>().sizes().size(), |
1945 | c10::nullopt); |
1946 | checkShape(*guard, {2, 3}, false); |
1947 | auto is_guard = [](Node* n) { return n->kind() == prim::Guard; }; |
1948 | int num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard); |
1949 | ASSERT_EQ(num_guards, 12); |
1950 | // now eliminate as many guards as possible |
1951 | // we should be left with two guards on x and y's defs |
1952 | EliminateRedundantGuards(copy); |
1953 | num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard); |
1954 | ASSERT_EQ(num_guards, 2); |
1955 | } |
1956 | |
1957 | TEST(InsertBailOutsTest, Basic) { |
1958 | static const auto basic_example = R"JIT( |
1959 | def basic_loop(x, y): |
1960 | |
1961 | a = x + 1 |
1962 | b = y + 2 |
1963 | c = x + y + 3 |
1964 | |
1965 | for i in range(10): |
1966 | a = a + b |
1967 | # invariant |
1968 | d = b * c |
1969 | # |
1970 | a = a - d |
1971 | |
1972 | e = a + 4 |
1973 | return e |
1974 | )JIT" ; |
1975 | |
1976 | auto cu = compile(basic_example); |
1977 | auto& fun = toGraphFunction(cu->get_function("basic_loop" )); |
1978 | auto pr = ProfilingRecord::instrumentGraph(fun.graph()); |
1979 | auto x = at::randn({2, 3}, at::kCPU); |
1980 | auto y = at::randn({2, 3}, at::kCPU); |
1981 | auto stack = createStack({x, y}); |
1982 | // introduce some profiling information |
1983 | Code cd(pr->profiled_graph_, "" ); |
1984 | InterpreterState is{cd}; |
1985 | is.run(stack); |
1986 | auto copy = pr->profiled_graph_->copy(); |
1987 | ProfilingRecord::removeProfileCounter(copy->block()); |
1988 | InsertGuards(copy); |
1989 | EliminateRedundantGuards(copy); |
1990 | auto nodes = copy->block()->nodes(); |
1991 | auto is_guard = [](Node* n) { return n->kind() == prim::Guard; }; |
1992 | auto num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard); |
1993 | ASSERT_EQ(num_guards, 3); |
1994 | InsertBailOuts(copy); |
1995 | auto is_bailout = [](Node* n) { return n->kind() == prim::BailOut; }; |
1996 | auto num_bailouts = std::count_if(nodes.begin(), nodes.end(), is_bailout); |
1997 | ASSERT_EQ(num_guards, num_bailouts); |
1998 | std::vector<Node*> bailouts(num_bailouts); |
1999 | std::copy_if(nodes.begin(), nodes.end(), bailouts.begin(), is_bailout); |
2000 | |
2001 | for (auto blo : bailouts) { |
2002 | ASSERT_EQ(blo->inputs().at(0)->node()->kind(), prim::BailoutTemplate); |
2003 | } |
2004 | } |
2005 | |
2006 | TEST(ProfilerTest, Basic) { |
2007 | constexpr int batch_size = 4; |
2008 | constexpr int input_size = 256; |
2009 | |
2010 | int hidden_size = 2 * input_size; |
2011 | |
2012 | auto input = at::randn({batch_size, input_size}, at::kCPU); |
2013 | auto hx = at::randn({batch_size, hidden_size}, at::kCPU); |
2014 | auto cx = at::randn({batch_size, hidden_size}, at::kCPU); |
2015 | auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCPU)); |
2016 | auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCPU)); |
2017 | |
2018 | auto g = build_lstm(); |
2019 | auto stack = createStack({input, hx, cx, w_ih, w_hh}); |
2020 | |
2021 | auto& opt_graph = *g.get(); |
2022 | ArgumentSpecCreator arg_spec_creator(opt_graph); |
2023 | ArgumentSpec spec = |
2024 | arg_spec_creator.create(autograd::GradMode::is_enabled(), stack); |
2025 | arg_spec_creator.specializeTypes(opt_graph, spec); |
2026 | auto pr = ProfilingRecord::instrumentGraph(g); |
2027 | Code cd(pr->profiled_graph_, "" ); |
2028 | InterpreterState is{cd}; |
2029 | is.run(stack); |
2030 | |
2031 | // profiled types are stored as attributes and show up in the dump, e.g. |
2032 | // Tensor = prim::profile[profiled_type=Double(4, 256, strides=[256, 1], |
2033 | // requires_grad=0, device=cpu) |
2034 | testing::FileCheck() |
2035 | .check("Tensor = prim::profile[profiled_type" ) |
2036 | ->check_same("256" ) |
2037 | ->run(*pr->profiled_graph_); |
2038 | |
2039 | auto begin = pr->profiled_graph_->block()->nodes().begin(); |
2040 | auto end = pr->profiled_graph_->block()->nodes().end(); |
2041 | auto mm = |
2042 | std::find_if(begin, end, [](Node* n) { return n->kind() == aten::add; }); |
2043 | ASSERT_NE(mm, end); |
2044 | std::vector<int64_t> mm_expected{4, 2048}; |
2045 | std::vector<int64_t> eltwise{4, 512}; |
2046 | checkShape(mm->inputs().at(0)->node()->ty(attr::profiled_type), mm_expected); |
2047 | auto mul_n = |
2048 | std::find_if(begin, end, [](Node* n) { return n->kind() == aten::mul; }); |
2049 | ASSERT_NE(mul_n, end); |
2050 | checkShape(mul_n->inputs().at(0)->node()->ty(attr::profiled_type), eltwise); |
2051 | auto tanh_n = |
2052 | std::find_if(begin, end, [](Node* n) { return n->kind() == aten::tanh; }); |
2053 | checkShape(tanh_n->inputs().at(0)->node()->ty(attr::profiled_type), eltwise); |
2054 | } |
2055 | |
2056 | TEST(ProfilerTest, OptionalProfiling) { |
2057 | auto graph = std::make_shared<Graph>(); |
2058 | std::unordered_map<std::string, Value*> vmap; |
2059 | parseIR( |
2060 | R"IR( |
2061 | graph(%inp : Tensor, |
2062 | %weight : Tensor, |
2063 | %bias : Tensor?): |
2064 | %1 : Tensor = aten::linear(%inp, %weight, %bias) |
2065 | return (%1))IR" , |
2066 | &*graph, |
2067 | vmap); |
2068 | |
2069 | auto pr = ProfilingRecord::instrumentGraph(graph); |
2070 | pr->profiling_count_ = 2; |
2071 | |
2072 | auto input = torch::randn({1, 2}); |
2073 | auto weight = torch::randn({2, 2}); |
2074 | auto bias = torch::randn({1, 2}); |
2075 | |
2076 | auto stack = createStack({input, weight, bias}); |
2077 | Code cd(pr->profiled_graph_, "" ); |
2078 | InterpreterState is{cd}; |
2079 | is.run(stack); |
2080 | |
2081 | testing::FileCheck() |
2082 | .check_count("Tensor? = prim::profile[profiled_type" , 1, true) |
2083 | ->run(*pr->profiled_graph_); |
2084 | |
2085 | // make sure we recorded the shape |
2086 | auto begin = pr->profiled_graph_->block()->nodes().begin(); |
2087 | auto end = pr->profiled_graph_->block()->nodes().end(); |
2088 | auto linear = std::find_if( |
2089 | begin, end, [](Node* n) { return n->kind() == aten::linear; }); |
2090 | ASSERT_NE(linear, end); |
2091 | std::vector<int64_t> bias_expected_shape = {1, 2}; |
2092 | auto profiled_bias = linear->namedInput("bias" )->node(); |
2093 | checkShape(profiled_bias->ty(attr::profiled_type), bias_expected_shape); |
2094 | ASSERT_EQ(0, profiled_bias->i(attr::seen_none)); |
2095 | |
2096 | auto none_bias = c10::IValue(); |
2097 | |
2098 | stack.clear(); |
2099 | stack.emplace_back(input); |
2100 | stack.emplace_back(weight); |
2101 | stack.emplace_back(none_bias); |
2102 | is = InterpreterState{cd}; |
2103 | is.run(stack); |
2104 | |
2105 | // make sure we recorded that "None" was seen. |
2106 | begin = pr->profiled_graph_->block()->nodes().begin(); |
2107 | end = pr->profiled_graph_->block()->nodes().end(); |
2108 | linear = std::find_if( |
2109 | begin, end, [](Node* n) { return n->kind() == aten::linear; }); |
2110 | ASSERT_NE(linear, end); |
2111 | profiled_bias = linear->namedInput("bias" )->node(); |
2112 | checkShape(profiled_bias->ty(attr::profiled_type), bias_expected_shape); |
2113 | ASSERT_EQ(1, profiled_bias->i(attr::seen_none)); |
2114 | } |
2115 | |
2116 | TEST(CallStackTest, Basic) { |
2117 | const auto text = R"( |
2118 | def ham(x): |
2119 | return x/7 |
2120 | |
2121 | def bar(x): |
2122 | return x*3 |
2123 | |
2124 | def baz(x): |
2125 | return ham(x)*x |
2126 | |
2127 | def foo(x): |
2128 | return bar(x)*baz(x)*11 |
2129 | )" ; |
2130 | auto cu = compile(text); |
2131 | const auto& foo = toGraphFunction(cu->get_function("foo" )); |
2132 | for (Node* n : foo.optimized_graph()->nodes()) { |
2133 | if (n->kind() == prim::Constant) { |
2134 | if (!n->hasAttribute(attr::value) || |
2135 | n->kindOf(attr::value) != AttributeKind::i) { |
2136 | continue; |
2137 | } |
2138 | int v = n->i(attr::value); |
2139 | switch (v) { |
2140 | case 3: { |
2141 | // Const 3 comes from function 'bar', which gets inlined to 'foo'. |
2142 | // The callstack for the corresponding node should contain only the |
2143 | // function 'bar'. |
2144 | ASSERT_TRUE(n->callstack()); |
2145 | auto callstack_vector = (*n->callstack())->vec(); |
2146 | ASSERT_EQ(callstack_vector.size(), 1); |
2147 | ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("bar" )); |
2148 | break; |
2149 | } |
2150 | case 7: { |
2151 | // Const 7 comes from function 'ham', which gets inlined to 'baz', |
2152 | // which is then inlined to 'foo'. The callstack for the corresponding |
2153 | // node should contain these two functions. |
2154 | ASSERT_TRUE(n->callstack()); |
2155 | auto callstack_vector = (*n->callstack())->vec(); |
2156 | ASSERT_EQ(callstack_vector.size(), 2); |
2157 | ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("baz" )); |
2158 | ASSERT_EQ(std::get<0>(callstack_vector[1]), &cu->get_function("ham" )); |
2159 | break; |
2160 | } |
2161 | case 11: { |
2162 | // Const 11 comes from function 'foo', which is not inlined anywhere |
2163 | // and thus it should not have a callstack. |
2164 | ASSERT_FALSE(n->callstack()); |
2165 | break; |
2166 | } |
2167 | } |
2168 | } |
2169 | } |
2170 | |
2171 | // Check that inlining doesn't corrupt callstack of the callee's nodes. |
2172 | const auto& baz = toGraphFunction(cu->get_function("baz" )); |
2173 | for (Node* n : baz.optimized_graph()->nodes()) { |
2174 | if (n->kind() == prim::Constant) { |
2175 | if (!n->hasAttribute(attr::value) || |
2176 | n->kindOf(attr::value) != AttributeKind::i) { |
2177 | continue; |
2178 | } |
2179 | int v = n->i(attr::value); |
2180 | ASSERT_TRUE(v == 7); |
2181 | // Const 7 comes from function 'ham', which gets inlined to 'baz'. 'baz' |
2182 | // was also inlined into 'foo', but when looking at the graph of 'baz' we |
2183 | // should only see a callstack of depth 1 (containing only 'ham'). |
2184 | ASSERT_TRUE(n->callstack()); |
2185 | auto callstack_vector = (*n->callstack())->vec(); |
2186 | ASSERT_EQ(callstack_vector.size(), 1); |
2187 | ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("ham" )); |
2188 | } |
2189 | } |
2190 | } |
2191 | |
2192 | TEST(CallStackTest, Caching) { |
2193 | const auto text = R"( |
2194 | |
2195 | def a(x): |
2196 | print("a1") |
2197 | print("a2") |
2198 | return x |
2199 | |
2200 | def b(x): |
2201 | print("b1") |
2202 | print("b2") |
2203 | a(x) |
2204 | return x |
2205 | |
2206 | def c(x): |
2207 | print("c1") |
2208 | print("c2") |
2209 | b(x) |
2210 | return x |
2211 | )" ; |
2212 | auto cu = compile(text); |
2213 | const auto& baz = toGraphFunction(cu->get_function("c" )); |
2214 | std::unordered_map<std::string, InlinedCallStack*> callstack_objects; |
2215 | for (Node* n : baz.optimized_graph()->nodes()) { |
2216 | if (n->kind() == prim::Constant) { |
2217 | if (!n->hasAttribute(attr::value) || |
2218 | n->kindOf(attr::value) != AttributeKind::s) { |
2219 | continue; |
2220 | } |
2221 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
2222 | std::string v = n->s(attr::value); |
2223 | if (n->callstack()) { |
2224 | callstack_objects[v] = n->callstack()->get(); |
2225 | } |
2226 | } |
2227 | } |
2228 | // We expect to see nodes prim::Constant[value="a1"] and |
2229 | // prim::Constant[value="a2"] inlined to function 'c'. Their callstacks are |
2230 | // the same (a->b->c), so we want to make sure we're not creating different |
2231 | // callstack entries for them. |
2232 | ASSERT_TRUE(callstack_objects.count("a1" ) && callstack_objects.count("a2" )); |
2233 | ASSERT_TRUE(callstack_objects.at("a1" ) == callstack_objects.at("a2" )); |
2234 | } |
2235 | |
2236 | TEST(InlinedCallStackTest, BlockAnnotation) { |
2237 | Module a("A" ); |
2238 | a.define(R"( |
2239 | def forward(self, x, y, z: int): |
2240 | if (z == 1): |
2241 | return x + y |
2242 | else: |
2243 | return x * y |
2244 | )" ); |
2245 | Module b("B" ); |
2246 | b.define(R"( |
2247 | def forward(self, x): |
2248 | return x + 2 |
2249 | )" ); |
2250 | Module c("C" ); |
2251 | c.register_module("A0" , a); |
2252 | c.register_module("B0" , b); |
2253 | c.define(R"( |
2254 | def forward(self, x, y, z: int): |
2255 | return self.A0.forward(x, y, z) + self.B0.forward(x) |
2256 | )" ); |
2257 | |
2258 | auto graph = |
2259 | toGraphFunction(c.get_method("forward" ).function()).optimized_graph(); |
2260 | std::stringstream add_ss, mul_ss; |
2261 | for (Node* n : graph->nodes()) { |
2262 | if (n->kind() == prim::If) { |
2263 | for (Block* block : n->blocks()) { |
2264 | for (Node* if_node : block->nodes()) { |
2265 | if (if_node->kind() == aten::add) { |
2266 | for (const auto& e : if_node->callstack().value()->vec()) { |
2267 | add_ss << std::get<1>(e); |
2268 | } |
2269 | add_ss << if_node->sourceRange(); |
2270 | } |
2271 | if (if_node->kind() == aten::mul) { |
2272 | for (const auto& e : if_node->callstack().value()->vec()) { |
2273 | mul_ss << std::get<1>(e); |
2274 | } |
2275 | mul_ss << if_node->sourceRange(); |
2276 | } |
2277 | } |
2278 | } |
2279 | } |
2280 | } |
2281 | ASSERT_NE(add_ss.str().find("line 3" ), std::string::npos); |
2282 | ASSERT_NE(add_ss.str().find("line 4" ), std::string::npos); |
2283 | ASSERT_NE( |
2284 | add_ss.str().find("return self.A0.forward(x, y, z)" ), std::string::npos); |
2285 | ASSERT_NE(add_ss.str().find("return x + y" ), std::string::npos); |
2286 | ASSERT_NE(mul_ss.str().find("line 3" ), std::string::npos); |
2287 | ASSERT_NE(mul_ss.str().find("line 6" ), std::string::npos); |
2288 | ASSERT_NE( |
2289 | mul_ss.str().find("return self.A0.forward(x, y, z)" ), std::string::npos); |
2290 | ASSERT_NE(mul_ss.str().find("return x * y" ), std::string::npos); |
2291 | } |
2292 | |
2293 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
2294 | TEST(InlinedCallStackTest, SelfCallMethods) { |
2295 | Module a("A" ); |
2296 | a.define(R"( |
2297 | def my_new_method(self, x): |
2298 | return x * 3 |
2299 | def forward_impl_(self, x, y): |
2300 | return self.my_new_method(x) + y |
2301 | def forward(self, x, y): |
2302 | y = y + 2 |
2303 | return self.forward_impl_(x, y) |
2304 | )" ); |
2305 | Module b("B" ); |
2306 | b.define(R"( |
2307 | def forward(self, x): |
2308 | return x + 2 |
2309 | )" ); |
2310 | Module c("C" ); |
2311 | c.register_module("A0" , a); |
2312 | c.register_module("B0" , b); |
2313 | c.define(R"( |
2314 | def call_b(self, x): |
2315 | return self.B0.forward(x) |
2316 | def forward(self, x, y): |
2317 | return self.A0.forward(x, y) + self.call_b(x) |
2318 | )" ); |
2319 | |
2320 | auto graph = |
2321 | toGraphFunction(c.get_method("forward" ).function()).optimized_graph(); |
2322 | std::unordered_map<std::string, size_t> module_hierarchies; |
2323 | for (Node* n : graph->nodes()) { |
2324 | auto hierarchy = torch::jit::utils::getNodesModuleHierarchy(*n); |
2325 | if (module_hierarchies.count(hierarchy) == 0) { |
2326 | module_hierarchies[hierarchy] = 0; |
2327 | } |
2328 | module_hierarchies[hierarchy] += 1; |
2329 | } |
2330 | ASSERT_EQ(module_hierarchies["A0(A)" ], 2); |
2331 | ASSERT_EQ(module_hierarchies["A0(A).SELF(A).SELF(A)" ], 2); |
2332 | ASSERT_EQ(module_hierarchies["A0(A).SELF(A)" ], 1); |
2333 | ASSERT_EQ(module_hierarchies["SELF(C)" ], 1); |
2334 | ASSERT_EQ(module_hierarchies["SELF(C).B0(B)" ], 1); |
2335 | } |
2336 | |
2337 | TEST(AutogradSymbolsTest, Basic) { |
2338 | Symbol sym = Symbol::fromQualString("aten::test_symbol" ); |
2339 | Graph graph; |
2340 | auto node = graph.create(sym); |
2341 | TORCH_CHECK(canRunWithAutograd(node)); |
2342 | |
2343 | sym = Symbol::fromQualString("prim::test_symbol" ); |
2344 | node = graph.create(sym); |
2345 | TORCH_CHECK(canRunWithAutograd(node)); |
2346 | |
2347 | sym = Symbol::fromQualString("prim::FusionGroup" ); |
2348 | node = graph.create(sym); |
2349 | TORCH_CHECK(!canRunWithAutograd(node)); |
2350 | |
2351 | sym = Symbol::fromQualString("custom::test_symbol" ); |
2352 | node = graph.create(sym); |
2353 | TORCH_CHECK(!canRunWithAutograd(node)); |
2354 | } |
2355 | |
2356 | TEST(DefaultArgTypeHintingTest, Basic) { |
2357 | const auto text_non_hinted = R"( |
2358 | |
2359 | def a(x, y=1): |
2360 | print("a1") |
2361 | print("a2") |
2362 | return x |
2363 | )" ; |
2364 | |
2365 | const auto text_hinted = R"( |
2366 | |
2367 | def a(x, y:int=1): |
2368 | print("a1") |
2369 | print("a2") |
2370 | return x |
2371 | )" ; |
2372 | |
2373 | try { |
2374 | compile(text_non_hinted); |
2375 | ASSERT_TRUE(0); |
2376 | } catch (const std::exception& c) { |
2377 | } |
2378 | |
2379 | auto cu = compile(text_hinted); |
2380 | } |
2381 | |
2382 | // Basic set case. |
2383 | TEST(FuturesTest, Basic) { |
2384 | auto f1 = c10::make_intrusive<Future>(IntType::get()); |
2385 | ASSERT_FALSE(f1->completed()); |
2386 | ASSERT_FALSE(f1->hasValue()); |
2387 | int32_t sat1 = 0; |
2388 | int32_t sat2 = 0; |
2389 | f1->addCallback([&](Future& /* unused */) { ++sat1; }); |
2390 | f1->markCompleted(43); |
2391 | ASSERT_TRUE(f1->completed()); |
2392 | ASSERT_TRUE(f1->hasValue()); |
2393 | ASSERT_FALSE(f1->hasError()); |
2394 | ASSERT_EQ(sat1, 1); |
2395 | ASSERT_EQ(f1->constValue().toInt(), 43); |
2396 | ASSERT_EQ(f1->value().toInt(), 43); |
2397 | f1->addCallback([&](Future& /* unused */) { ++sat2; }); |
2398 | ASSERT_EQ(sat1, 1); |
2399 | ASSERT_EQ(sat2, 1); |
2400 | } |
2401 | |
2402 | // Basic error cases. |
2403 | TEST(FuturesTest, Error) { |
2404 | auto f1 = c10::make_intrusive<Future>(IntType::get()); |
2405 | int sat1 = 0; |
2406 | int sat2 = 0; |
2407 | f1->addCallback([&](Future& /* unused */) { ++sat1; }); |
2408 | f1->setError( |
2409 | std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed" ))); |
2410 | ASSERT_EQ(sat1, 1); |
2411 | ASSERT_TRUE(f1->completed()); |
2412 | ASSERT_TRUE(f1->hasError()); |
2413 | ASSERT_FALSE(f1->hasValue()); |
2414 | try { |
2415 | (void)f1->value(); |
2416 | ASSERT_TRUE(false); // Supposed to throw. |
2417 | } catch (const std::exception& e) { |
2418 | ASSERT_TRUE(strcmp(e.what(), "Failed" ) == 0); |
2419 | } |
2420 | f1->addCallback([&](Future& /* unused */) { ++sat2; }); |
2421 | ASSERT_EQ(sat1, 1); |
2422 | ASSERT_EQ(sat2, 1); |
2423 | f1->setErrorIfNeeded( |
2424 | std::make_exception_ptr(c10::ivalue::Future::FutureError("Dup" ))); |
2425 | ASSERT_TRUE(strcmp(f1->tryRetrieveErrorMessage().c_str(), "Failed" ) == 0); |
2426 | ASSERT_EQ(sat1, 1); |
2427 | ASSERT_EQ(sat2, 1); |
2428 | try { |
2429 | (void)f1->constValue(); |
2430 | ASSERT_TRUE(false); // Supposed to throw. |
2431 | } catch (const std::exception& e) { |
2432 | // Original error should be logged. |
2433 | ASSERT_TRUE(std::string(e.what()).find("Failed" ) != std::string::npos); |
2434 | } |
2435 | } |
2436 | |
2437 | // then |
2438 | TEST(FuturesTest, Then) { |
2439 | auto f1 = c10::make_intrusive<Future>(IntType::get()); |
2440 | auto f2 = f1->then( |
2441 | [](Future& f1) -> IValue { return f1.constValue().toInt() + 1; }, |
2442 | IntType::get()); |
2443 | auto f3 = f2->then( |
2444 | [](Future& f2) -> IValue { return f2.constValue().toInt() * 3; }, |
2445 | IntType::get()); |
2446 | bool done = false; |
2447 | f3->addCallback([&done](Future& f3) { |
2448 | ASSERT_EQ(f3.constValue().toInt(), (42 + 1) * 3); |
2449 | done = true; |
2450 | }); |
2451 | ASSERT_FALSE(done); |
2452 | f1->markCompleted(42); |
2453 | ASSERT_TRUE(done); |
2454 | } |
2455 | |
2456 | // collectAll() |
2457 | TEST(FuturesTest, CollectAll) { |
2458 | auto s1 = c10::make_intrusive<Future>(IntType::get()); |
2459 | auto s2 = c10::make_intrusive<Future>(IntType::get()); |
2460 | auto s3 = c10::make_intrusive<Future>(IntType::get()); |
2461 | |
2462 | // Empty case |
2463 | c10::List<intrusive_ptr<ivalue::Future>> futures( |
2464 | FutureType::create(IntType::get())); |
2465 | auto c1 = collectAll(futures); |
2466 | ASSERT_TRUE(c1->completed()); |
2467 | ASSERT_EQ(c1->value().toList().size(), 0); |
2468 | ASSERT_TRUE( |
2469 | *(c1->value().toList().elementType()) == |
2470 | *FutureType::create(IntType::get())); |
2471 | |
2472 | // 1-element, initially not completed. |
2473 | futures.push_back(s1); |
2474 | auto c2 = collectAll(futures); |
2475 | ASSERT_FALSE(c2->completed()); |
2476 | s1->markCompleted(5); |
2477 | ASSERT_TRUE(c2->completed()); |
2478 | ASSERT_EQ(c2->value().toList().size(), 1); |
2479 | ASSERT_TRUE( |
2480 | *(c2->value().toList().elementType()) == |
2481 | *FutureType::create(IntType::get())); |
2482 | ASSERT_EQ(c2->value().toList().get(0).toFuture()->value().toInt(), 5); |
2483 | |
2484 | // 1-element, already completed |
2485 | auto c3 = collectAll(futures); |
2486 | ASSERT_TRUE(c3->completed()); |
2487 | ASSERT_EQ(c3->value().toList().size(), 1); |
2488 | ASSERT_EQ(c3->value().toList().get(0).toFuture()->value().toInt(), 5); |
2489 | |
2490 | // 3 elements. |
2491 | futures.push_back(s2); |
2492 | futures.push_back(s3); |
2493 | auto c4 = collectAll(futures); |
2494 | ASSERT_FALSE(c4->completed()); |
2495 | s3->markCompleted(7); |
2496 | ASSERT_FALSE(c4->completed()); |
2497 | s2->markCompleted(6); |
2498 | ASSERT_TRUE(c4->completed()); |
2499 | ASSERT_EQ(c4->value().toList().size(), 3); |
2500 | ASSERT_EQ(c4->value().toList().get(0).toFuture()->value().toInt(), 5); |
2501 | ASSERT_EQ(c4->value().toList().get(1).toFuture()->value().toInt(), 6); |
2502 | ASSERT_EQ(c4->value().toList().get(2).toFuture()->value().toInt(), 7); |
2503 | ASSERT_TRUE( |
2504 | *(c4->value().toList().elementType()) == |
2505 | *FutureType::create(IntType::get())); |
2506 | |
2507 | // Handle exception in the list. |
2508 | auto s4 = c10::make_intrusive<Future>(IntType::get()); |
2509 | futures.push_back(s4); |
2510 | auto c5 = collectAll(futures); |
2511 | ASSERT_FALSE(c5->completed()); |
2512 | s4->setError( |
2513 | std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed" ))); |
2514 | ASSERT_TRUE(c5->completed()); |
2515 | try { |
2516 | c5->value(); |
2517 | ASSERT_TRUE(false); // supposed to throw |
2518 | } catch (const std::exception& e) { |
2519 | ASSERT_EQ(std::string(e.what()), "Failed" ); |
2520 | } |
2521 | } |
2522 | |
2523 | // collectAny() |
2524 | TEST(FuturesTest, CollectAny) { |
2525 | auto s1 = c10::make_intrusive<Future>(IntType::get()); |
2526 | |
2527 | // Empty case |
2528 | c10::List<intrusive_ptr<ivalue::Future>> futures( |
2529 | FutureType::create(IntType::get())); |
2530 | auto c1 = collectAny(futures); |
2531 | ASSERT_TRUE(c1->completed()); |
2532 | |
2533 | // 1 element, not yet satisfied |
2534 | futures.push_back(s1); |
2535 | auto c2 = collectAny(futures); |
2536 | ASSERT_FALSE(c2->completed()); |
2537 | s1->markCompleted(5); |
2538 | ASSERT_TRUE(c2->completed()); |
2539 | ASSERT_TRUE(c2->value().isInt()); |
2540 | ASSERT_EQ(c2->value().toInt(), 5); |
2541 | |
2542 | // 1 element already satisfied. |
2543 | auto c3 = collectAny(futures); |
2544 | ASSERT_TRUE(c3->completed()); |
2545 | ASSERT_TRUE(c3->value().isInt()); |
2546 | ASSERT_EQ(c3->value().toInt(), 5); |
2547 | |
2548 | // 2 elements |
2549 | futures.clear(); |
2550 | auto s2 = c10::make_intrusive<Future>(IntType::get()); |
2551 | auto s3 = c10::make_intrusive<Future>(IntType::get()); |
2552 | futures.push_back(s2); |
2553 | futures.push_back(s3); |
2554 | auto c4 = collectAny(futures); |
2555 | ASSERT_FALSE(c4->completed()); |
2556 | s3->markCompleted(7); |
2557 | ASSERT_TRUE(c4->completed()); |
2558 | ASSERT_EQ(c4->value().toInt(), 7); |
2559 | s2->markCompleted(1); |
2560 | ASSERT_EQ(c4->value().toInt(), 7); |
2561 | } |
2562 | |
2563 | TEST(TLSFutureCallbacksTest, Basic) { |
2564 | // cb that verifies the profiler is enabled |
2565 | auto profilerEnabledCb = [](Future& /* unused */) { |
2566 | ASSERT_TRUE(torch::autograd::profiler::profilerEnabled()); |
2567 | }; |
2568 | // test running callbacks with propagation of TLS state. |
2569 | { |
2570 | // Enable the profiler in this thread |
2571 | torch::autograd::profiler::enableProfilerLegacy( |
2572 | torch::autograd::profiler::ProfilerConfig( |
2573 | torch::autograd::profiler::ProfilerState::CPU, false, false)); |
2574 | auto s1 = c10::make_intrusive<Future>(IntType::get()); |
2575 | s1->addCallback(wrapPropagateTLSState(profilerEnabledCb)); |
2576 | std::thread t([s1 = std::move(s1)]() { s1->markCompleted(); }); |
2577 | // Since we join here, we can ensure that all callbacks corresponding to |
2578 | // markCompleted() have finished. |
2579 | t.join(); |
2580 | torch::autograd::profiler::disableProfilerLegacy(); |
2581 | } |
2582 | // then() with TLS State |
2583 | { |
2584 | // Enable the profiler in this thread |
2585 | torch::autograd::profiler::enableProfilerLegacy( |
2586 | torch::autograd::profiler::ProfilerConfig( |
2587 | torch::autograd::profiler::ProfilerState::CPU, false, false)); |
2588 | auto s1 = c10::make_intrusive<Future>(IntType::get()); |
2589 | auto s2 = s1->then( |
2590 | wrapPropagateTLSState([&profilerEnabledCb](Future& s1) { |
2591 | profilerEnabledCb(s1); |
2592 | return at::IValue(1); |
2593 | }), |
2594 | IntType::get()); |
2595 | std::thread t([s1 = std::move(s1)]() { s1->markCompleted(); }); |
2596 | t.join(); |
2597 | s2->wait(); |
2598 | torch::autograd::profiler::disableProfilerLegacy(); |
2599 | } |
2600 | } |
2601 | |
2602 | TEST(ProfilerDisableInCallbackTest, Basic) { |
2603 | // cb that verifies the profiler is enabled |
2604 | auto profilerEnabledCb = []() { |
2605 | ASSERT_TRUE(torch::autograd::profiler::profilerEnabled()); |
2606 | }; |
2607 | torch::autograd::profiler::enableProfilerLegacy( |
2608 | torch::autograd::profiler::ProfilerConfig( |
2609 | torch::autograd::profiler::ProfilerState::CPU, false, false)); |
2610 | auto s1 = c10::make_intrusive<Future>(IntType::get()); |
2611 | auto verifyProfilerCb = |
2612 | wrapPropagateTLSState([&profilerEnabledCb](Future& /* unused */) { |
2613 | // Ensure the profiler is still enabled in this thread. |
2614 | profilerEnabledCb(); |
2615 | auto t1 = torch::ones({2, 2}); |
2616 | auto t2 = torch::ones({2, 2}); |
2617 | torch::add(t1, t2); |
2618 | // Don't cleanup TLSState, and just consolidate. |
2619 | auto opts = |
2620 | torch::autograd::profiler::ProfilerDisableOptions(false, true); |
2621 | auto thread_event_lists = |
2622 | // NOLINTNEXTLINE(performance-move-const-arg) |
2623 | torch::autograd::profiler::disableProfilerLegacy(std::move(opts)); |
2624 | // Ensure that the events from this thread are still profiled and we |
2625 | // obtain the expected in events in our consolidated list when calling |
2626 | // disableProfilerLegacy(). |
2627 | bool found_ones = false; |
2628 | bool found_add = false; |
2629 | for (const auto& li : thread_event_lists) { |
2630 | for (const auto& evt : li) { |
2631 | if (strcmp(evt.name(), "aten::add" ) == 0) { |
2632 | found_add = true; |
2633 | } else if (strcmp(evt.name(), "aten::ones" ) == 0) { |
2634 | found_ones = true; |
2635 | } |
2636 | } |
2637 | if (found_add && found_ones) { |
2638 | break; |
2639 | } |
2640 | } |
2641 | ASSERT_TRUE(found_ones); |
2642 | ASSERT_TRUE(found_add); |
2643 | }); |
2644 | |
2645 | s1->addCallback(verifyProfilerCb); |
2646 | // Disable the profiler, but do not consolidate results in the main thread. |
2647 | auto opts = torch::autograd::profiler::ProfilerDisableOptions(true, false); |
2648 | // NOLINTNEXTLINE(performance-move-const-arg) |
2649 | torch::autograd::profiler::disableProfilerLegacy(std::move(opts)); |
2650 | std::thread t([s1 = std::move(s1)]() { s1->markCompleted(at::IValue(1)); }); |
2651 | t.join(); |
2652 | |
2653 | // Similar to above test, but verifies correctness in the case where |
2654 | // continuation runs on the main thread. |
2655 | torch::autograd::profiler::enableProfilerLegacy( |
2656 | torch::autograd::profiler::ProfilerConfig( |
2657 | torch::autograd::profiler::ProfilerState::CPU, false, false)); |
2658 | s1 = c10::make_intrusive<Future>(IntType::get()); |
2659 | s1->addCallback(verifyProfilerCb); |
2660 | // Runs callback inline |
2661 | s1->markCompleted(at::IValue(1)); |
2662 | opts = torch::autograd::profiler::ProfilerDisableOptions(true, false); |
2663 | // NOLINTNEXTLINE(performance-move-const-arg) |
2664 | torch::autograd::profiler::disableProfilerLegacy(std::move(opts)); |
2665 | } |
2666 | |
2667 | TEST(RecordDebugHandles, Basic) { |
2668 | // Enable the profiler in this thread |
2669 | const std::set<torch::autograd::profiler::ActivityType> activities( |
2670 | {torch::autograd::profiler::ActivityType::CPU}); |
2671 | torch::autograd::profiler::prepareProfiler( |
2672 | torch::autograd::profiler::ProfilerConfig( |
2673 | torch::autograd::profiler::ProfilerState::KINETO, false, false), |
2674 | activities); |
2675 | torch::autograd::profiler::enableProfiler( |
2676 | torch::autograd::profiler::ProfilerConfig( |
2677 | torch::autograd::profiler::ProfilerState::KINETO, false, false), |
2678 | activities); |
2679 | { |
2680 | RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS("my_function" , 42, {}); |
2681 | float x{5.9999}, y{2.1212}; |
2682 | float z = x / y; |
2683 | (void)z; |
2684 | } |
2685 | { |
2686 | RECORD_USER_SCOPE_WITH_INPUTS("not_my_function" , {}); |
2687 | float x{5.9999}, y{2.1212}; |
2688 | float z = x / y; |
2689 | (void)z; |
2690 | } |
2691 | auto profiler_results_ptr = torch::autograd::profiler::disableProfiler(); |
2692 | const auto& kineto_events = profiler_results_ptr->events(); |
2693 | size_t my_events{0}; |
2694 | for (const auto& e : kineto_events) { |
2695 | if (e.name() == "my_function" ) { |
2696 | ASSERT_EQ(e.debugHandle(), 42); |
2697 | my_events++; |
2698 | } else if (e.name() == "not_my_function" ) { |
2699 | ASSERT_EQ(e.debugHandle(), -1); |
2700 | my_events++; |
2701 | } |
2702 | } |
2703 | ASSERT_EQ(my_events, 2); |
2704 | } |
2705 | |
2706 | TEST(RecordDebugHandles, ScopedCallbacks) { |
2707 | // Enable the profiler in this thread |
2708 | torch::autograd::profiler::prepareProfiler( |
2709 | torch::autograd::profiler::ProfilerConfig( |
2710 | torch::autograd::profiler::ProfilerState::KINETO, false, false), |
2711 | {torch::autograd::profiler::ActivityType::CPU}); |
2712 | torch::autograd::profiler::enableProfiler( |
2713 | torch::autograd::profiler::ProfilerConfig( |
2714 | torch::autograd::profiler::ProfilerState::KINETO, false, false), |
2715 | {torch::autograd::profiler::ActivityType::CPU}); |
2716 | |
2717 | { |
2718 | auto a = torch::rand({128, 128}); |
2719 | auto b = torch::rand({128, 128}); |
2720 | auto c = a + b; |
2721 | } |
2722 | auto profiler_results_ptr = torch::autograd::profiler::disableProfiler(); |
2723 | ASSERT_TRUE(profiler_results_ptr->events().size() > 0); |
2724 | |
2725 | // Enable the profiler in this thread |
2726 | torch::autograd::profiler::prepareProfiler( |
2727 | torch::autograd::profiler::ProfilerConfig( |
2728 | torch::autograd::profiler::ProfilerState::KINETO, false, false), |
2729 | {torch::autograd::profiler::ActivityType::CPU}); |
2730 | torch::autograd::profiler::enableProfiler( |
2731 | torch::autograd::profiler::ProfilerConfig( |
2732 | torch::autograd::profiler::ProfilerState::KINETO, false, false), |
2733 | {torch::autograd::profiler::ActivityType::CPU}, |
2734 | {at::RecordScope::LITE_INTERPRETER}); |
2735 | { |
2736 | auto a = torch::rand({128, 128}); |
2737 | auto b = torch::rand({128, 128}); |
2738 | auto c = a + b; |
2739 | } |
2740 | profiler_results_ptr = torch::autograd::profiler::disableProfiler(); |
2741 | ASSERT_TRUE(profiler_results_ptr->events().size() == 0); |
2742 | |
2743 | torch::autograd::profiler::prepareProfiler( |
2744 | torch::autograd::profiler::ProfilerConfig( |
2745 | torch::autograd::profiler::ProfilerState::KINETO, false, false), |
2746 | {torch::autograd::profiler::ActivityType::CPU}); |
2747 | torch::autograd::profiler::enableProfiler( |
2748 | torch::autograd::profiler::ProfilerConfig( |
2749 | torch::autograd::profiler::ProfilerState::KINETO, false, false), |
2750 | {torch::autograd::profiler::ActivityType::CPU}, |
2751 | {at::RecordScope::LITE_INTERPRETER}); |
2752 | { |
2753 | RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS("my_function" , 42, {}); |
2754 | auto a = torch::rand({128, 128}); |
2755 | auto b = torch::rand({128, 128}); |
2756 | auto c = a + b; |
2757 | } |
2758 | { |
2759 | RECORD_USER_SCOPE_WITH_INPUTS("not_my_function" , {}); |
2760 | auto a = torch::rand({128, 128}); |
2761 | auto b = torch::rand({128, 128}); |
2762 | auto c = a + b; |
2763 | } |
2764 | profiler_results_ptr = torch::autograd::profiler::disableProfiler(); |
2765 | const auto& kineto_events = profiler_results_ptr->events(); |
2766 | for (const auto& e : kineto_events) { |
2767 | if (e.name() == "my_function" ) { |
2768 | ASSERT_EQ(e.debugHandle(), 42); |
2769 | } |
2770 | } |
2771 | ASSERT_TRUE(profiler_results_ptr->events().size() == 1); |
2772 | } |
2773 | |
2774 | TEST(IValueKWargsTest, Basic) { |
2775 | const auto text = R"( |
2776 | def foo(a : int, b : int, c : int = 4): |
2777 | return a + 2*b + 3*c |
2778 | )" ; |
2779 | auto cu = compile(text); |
2780 | auto result = cu->get_function("foo" )({1}, {{"b" , 3}}); |
2781 | ASSERT_EQ(result.toInt(), 19); |
2782 | } |
2783 | |
2784 | TEST(ComputeFlopsTest, Basic) { |
2785 | uint64_t flops = 0; |
2786 | |
2787 | // Test unknown operator |
2788 | std::unordered_map<std::string, c10::IValue> ; |
2789 | flops = torch::profiler::impl::computeFlops( |
2790 | std::string("aten::unknown" ), extra_args); |
2791 | ASSERT_EQ(flops, 0); |
2792 | |
2793 | // Test aten::conv2d |
2794 | extra_args.clear(); |
2795 | std::vector<int64_t> input_size = {4, 5, 6, 7}; |
2796 | std::vector<int64_t> weight_size = {3, 5, 2, 1}; |
2797 | std::vector<int64_t> padding = {1, 0}; |
2798 | std::vector<int64_t> stride = {1, 1}; |
2799 | std::vector<int64_t> dilation = {0, 0}; |
2800 | extra_args["input_size" ] = at::IValue(at::IntArrayRef(input_size)); |
2801 | extra_args["weight_size" ] = at::IValue(at::IntArrayRef(weight_size)); |
2802 | extra_args["groups" ] = 1; |
2803 | extra_args["padding" ] = at::IValue(at::IntArrayRef(padding)); |
2804 | extra_args["stride" ] = at::IValue(at::IntArrayRef(stride)); |
2805 | extra_args["dilation" ] = at::IValue(at::IntArrayRef(dilation)); |
2806 | flops = torch::profiler::impl::computeFlops( |
2807 | std::string("aten::conv2d" ), extra_args); |
2808 | ASSERT_EQ(flops, 13440); |
2809 | |
2810 | // Test aten::conv2d fail |
2811 | input_size = {4, 5, 6, 7}; |
2812 | weight_size = {4, 5, 6}; |
2813 | extra_args["input_size" ] = at::IValue(at::IntArrayRef(input_size)); |
2814 | extra_args["weight_size" ] = at::IValue(at::IntArrayRef(weight_size)); |
2815 | flops = torch::profiler::impl::computeFlops( |
2816 | std::string("aten::conv2d" ), extra_args); |
2817 | ASSERT_EQ(flops, 0); |
2818 | |
2819 | // Test aten::conv2d fail 2 |
2820 | weight_size = {3, 5, 2, 1}; |
2821 | stride = {0, 0}; |
2822 | extra_args["weight_size" ] = at::IValue(at::IntArrayRef(input_size)); |
2823 | extra_args["stride" ] = at::IValue(at::IntArrayRef(stride)); |
2824 | flops = torch::profiler::impl::computeFlops( |
2825 | std::string("aten::conv2d" ), extra_args); |
2826 | ASSERT_EQ(flops, 0); |
2827 | |
2828 | // Test aten::conv2d fail 3 |
2829 | extra_args.clear(); |
2830 | input_size = {4, 5, 6, 7}; |
2831 | extra_args["input_size" ] = at::IValue(at::IntArrayRef(input_size)); |
2832 | flops = torch::profiler::impl::computeFlops( |
2833 | std::string("aten::conv2d" ), extra_args); |
2834 | ASSERT_EQ(flops, 0); |
2835 | |
2836 | // Test aten::mm |
2837 | extra_args.clear(); |
2838 | std::vector<int64_t> mat1_sizes = {3, 4, 5, 6}; |
2839 | std::vector<int64_t> mat2_sizes = {6, 5, 4, 3}; |
2840 | extra_args["mat1_size" ] = at::IValue(at::IntArrayRef(mat1_sizes)); |
2841 | extra_args["mat2_size" ] = at::IValue(at::IntArrayRef(mat2_sizes)); |
2842 | flops = |
2843 | torch::profiler::impl::computeFlops(std::string("aten::mm" ), extra_args); |
2844 | ASSERT_EQ(flops, 43200); |
2845 | |
2846 | // Test aten::addmm |
2847 | flops = torch::profiler::impl::computeFlops( |
2848 | std::string("aten::addmm" ), extra_args); |
2849 | ASSERT_EQ(flops, 43200); |
2850 | |
2851 | // Test aten::bmm |
2852 | extra_args.clear(); |
2853 | mat1_sizes = {7, 5, 6}; |
2854 | mat2_sizes = {7, 6, 3}; |
2855 | extra_args["mat1_size" ] = at::IValue(at::IntArrayRef(mat1_sizes)); |
2856 | extra_args["mat2_size" ] = at::IValue(at::IntArrayRef(mat2_sizes)); |
2857 | flops = |
2858 | torch::profiler::impl::computeFlops(std::string("aten::bmm" ), extra_args); |
2859 | ASSERT_EQ(flops, 1260); |
2860 | |
2861 | // Test aten::baddbmm |
2862 | flops = torch::profiler::impl::computeFlops( |
2863 | std::string("aten::baddbmm" ), extra_args); |
2864 | ASSERT_EQ(flops, 1260); |
2865 | |
2866 | // Test mm out of range |
2867 | extra_args.clear(); |
2868 | flops = |
2869 | torch::profiler::impl::computeFlops(std::string("aten::mm" ), extra_args); |
2870 | ASSERT_EQ(flops, 0); |
2871 | |
2872 | // Test aten::add.Tensor |
2873 | extra_args.clear(); |
2874 | std::vector<int64_t> mat_sizes = {3, 4, 5, 6}; |
2875 | extra_args["mat_size" ] = at::IValue(at::IntArrayRef(mat_sizes)); |
2876 | flops = |
2877 | torch::profiler::impl::computeFlops(std::string("aten::add" ), extra_args); |
2878 | ASSERT_EQ(flops, 360); |
2879 | |
2880 | // Test aten::mul.Tensor |
2881 | extra_args.clear(); |
2882 | mat_sizes = {3, 4, 5, 6}; |
2883 | extra_args["mat_size" ] = at::IValue(at::IntArrayRef(mat_sizes)); |
2884 | flops = |
2885 | torch::profiler::impl::computeFlops(std::string("aten::mul" ), extra_args); |
2886 | ASSERT_EQ(flops, 360); |
2887 | } |
2888 | |
2889 | TEST(TestConstant, TensorGrad) { |
2890 | auto graph = std::make_shared<Graph>(); |
2891 | IValue ten = torch::randn({3, 5}).requires_grad_(true); |
2892 | auto con = tryInsertConstant(*graph, ten); |
2893 | ASSERT_TRUE(con == c10::nullopt); |
2894 | } |
2895 | |
2896 | TEST(TestMutation, Basic) { |
2897 | auto graph = std::make_shared<Graph>(); |
2898 | std::unordered_map<std::string, Value*> vmap; |
2899 | parseIR( |
2900 | R"IR( |
2901 | graph(%x.1 : Tensor): |
2902 | %2 : int = prim::Constant[value=1]() |
2903 | %9 : int = prim::Constant[value=4]() |
2904 | %x.3 : Tensor = aten::add(%x.1, %2, %2) |
2905 | %7 : Tensor = aten::add_(%x.3, %2, %2) |
2906 | %y.1 : Tensor = aten::add(%x.3, %9, %2) |
2907 | return (%y.1))IR" , |
2908 | &*graph, |
2909 | vmap); |
2910 | RemoveTensorMutation(graph, [](Node*) { return false; }); |
2911 | testing::FileCheck().check("aten::add_" )->run(*graph); |
2912 | RemoveTensorMutation(graph, [](Node*) { return true; }); |
2913 | testing::FileCheck().check_not("aten::add_" )->run(*graph); |
2914 | } |
2915 | |
2916 | TEST(TestInplaceToFunctionalActivation, Basic) { |
2917 | auto graph = std::make_shared<Graph>(); |
2918 | std::unordered_map<std::string, Value*> vmap; |
2919 | parseIR( |
2920 | R"IR( |
2921 | graph(%x.1 : Tensor): |
2922 | %2 : int = prim::Constant[value=1]() |
2923 | %x.3 : Tensor = aten::add(%x.1, %2, %2) |
2924 | %y : Tensor = aten::relu_(%x.3) |
2925 | return (%y))IR" , |
2926 | &*graph, |
2927 | vmap); |
2928 | InplaceToFunctionalActivation(graph); |
2929 | testing::FileCheck().check("aten::relu" )->run(*graph); |
2930 | testing::FileCheck().check_not("aten::relu_" )->run(*graph); |
2931 | } |
2932 | |
2933 | TEST(TestRegisterShapeOp, Basic) { |
2934 | auto graph = std::make_shared<Graph>(); |
2935 | std::unordered_map<std::string, Value*> vmap; |
2936 | parseIR( |
2937 | R"IR( |
2938 | graph(): |
2939 | %2 : int = prim::Constant[value=5]() |
2940 | %3: int[] = prim::ListConstruct(%2, %2) |
2941 | return (%3))IR" , |
2942 | &*graph, |
2943 | vmap); |
2944 | |
2945 | auto g2 = std::make_shared<Graph>(); |
2946 | parseIR( |
2947 | R"IR( |
2948 | graph(): |
2949 | %2 : Tensor = prim::MakeTestTensor() |
2950 | return (%2))IR" , |
2951 | &*g2, |
2952 | vmap); |
2953 | |
2954 | const FunctionSchema& schema = g2->nodes().begin()->schema(); |
2955 | torch::jit::RegisterShapeComputeGraphForSchema(schema, graph); |
2956 | PropagateShapesOnGraph(g2); |
2957 | testing::FileCheck().check("5, 5" )->run(*g2); |
2958 | } |
2959 | |
2960 | TEST(TestFunctionalToInplaceActivation, Basic) { |
2961 | auto graph = std::make_shared<Graph>(); |
2962 | std::unordered_map<std::string, Value*> vmap; |
2963 | parseIR( |
2964 | R"IR( |
2965 | graph(%x.1 : Tensor): |
2966 | %2 : int = prim::Constant[value=1]() |
2967 | %x.3 : Tensor = aten::add(%x.1, %2, %2) |
2968 | %y : Tensor = aten::relu(%x.3) |
2969 | return (%y))IR" , |
2970 | &*graph, |
2971 | vmap); |
2972 | FunctionalToInplaceActivation(graph); |
2973 | testing::FileCheck().check("aten::relu_" )->run(*graph); |
2974 | testing::FileCheck().check_not("aten::relu(" )->run(*graph); |
2975 | } |
2976 | |
2977 | TEST(TestFunctionExecutor, SimpleExecutorTest) { |
2978 | auto graph = std::make_shared<Graph>(); |
2979 | parseIR( |
2980 | R"IR( |
2981 | graph(%x.1 : Tensor): |
2982 | %2 : int = prim::Constant[value=1]() |
2983 | %x.3 : Tensor = aten::add(%x.1, %2, %2) |
2984 | %y : Tensor = aten::relu(%x.3) |
2985 | return (%y))IR" , |
2986 | &*graph); |
2987 | { |
2988 | auto func = torch::make_unique<GraphFunction>( |
2989 | "name" , graph, [](GraphFunction&) {}, ExecutorExecutionMode::PROFILING); |
2990 | auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); |
2991 | Stack stack = {a}; |
2992 | func->run(stack); |
2993 | auto g = lastExecutedOptimizedGraph(); |
2994 | testing::FileCheck() |
2995 | .check("prim::profile" ) |
2996 | ->check("aten::add" ) |
2997 | ->check("aten::relu" ) |
2998 | ->run(*g); |
2999 | } |
3000 | { |
3001 | auto func = torch::make_unique<GraphFunction>( |
3002 | "name" , graph, [](GraphFunction&) {}, ExecutorExecutionMode::SIMPLE); |
3003 | auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); |
3004 | Stack stack = {a}; |
3005 | func->run(stack); |
3006 | auto g = func->getDebugState().graph; |
3007 | testing::FileCheck() |
3008 | .check_not("prim::profile" ) |
3009 | ->check("aten::add" ) |
3010 | ->check("aten::relu" ) |
3011 | ->run(*g); |
3012 | } |
3013 | } |
3014 | |
3015 | TEST(TestFunctionExecutor, RunDecompositionTest) { |
3016 | static auto* func = torch::jit::GetDecompositionExecutor( |
3017 | "aten::var(Tensor self, bool unbiased=True) -> Tensor" ); |
3018 | for (bool unbiased : {true, false}) { |
3019 | auto input = at::rand({4, 4}); |
3020 | Stack stack = {input, unbiased}; |
3021 | func->run(stack); |
3022 | at::Tensor out = pop(stack).toTensor(); |
3023 | ASSERT_TRUE(at::allclose(out, input.var(unbiased))); |
3024 | } |
3025 | } |
3026 | |
3027 | TEST(TestShapeGraphLinting, Basic) { |
3028 | auto schemas = RegisteredShapeComputeSchemas(); |
3029 | for (const auto& schema : schemas) { |
3030 | // arange does not acually support complex, leave as |
3031 | // union[int, float] for now |
3032 | if (schema->name() == "aten::arange" ) { |
3033 | continue; |
3034 | } |
3035 | auto g = shapeComputeGraphForSchema(*schema); |
3036 | TORCH_INTERNAL_ASSERT(g); |
3037 | LintShapeComputeGraph(schema, *g); |
3038 | } |
3039 | } |
3040 | |
3041 | // TODO: move to test_kernel when global settings are explicit |
3042 | // fusion parameters |
3043 | class Composed : public ::testing::Test { |
3044 | public: |
3045 | void SetUp() override { |
3046 | torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = false; |
3047 | } |
3048 | }; |
3049 | |
3050 | TEST_F(Composed, ComposedOp) { |
3051 | struct WithCPUFuser { |
3052 | WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) { |
3053 | overrideCanFuseOnCPU(val); |
3054 | } |
3055 | |
3056 | ~WithCPUFuser() { |
3057 | overrideCanFuseOnCPU(cpuFuserEnabled); |
3058 | } |
3059 | |
3060 | bool cpuFuserEnabled; |
3061 | }; |
3062 | |
3063 | #ifdef TORCH_ENABLE_LLVM |
3064 | const auto graph_string = R"IR( |
3065 | graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), |
3066 | %1 : Float(5, 3, strides=[1, 5], device=cpu)): |
3067 | %2 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %1) |
3068 | %3 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %2) |
3069 | %4 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %3) |
3070 | return (%3, %4))IR" ; |
3071 | auto graph = std::make_shared<Graph>(); |
3072 | parseIR(graph_string, &*graph); |
3073 | |
3074 | // wrong input sizes so we hit the fallback path |
3075 | auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); |
3076 | auto b = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)) |
3077 | .transpose(0, 1); |
3078 | auto ref1 = a * (a * b); |
3079 | auto ref2 = a * ref1; |
3080 | WithCPUFuser g(true); |
3081 | bool fusable_on_device = torch::jit::tensorexpr::getTEMustUseLLVMOnCPU(); |
3082 | torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = false; |
3083 | FuseTensorExprs( |
3084 | graph, |
3085 | /*min_group_size*/ 2, |
3086 | /*add_composed_op*/ true, |
3087 | /*fuse_to_dynamic_shapes*/ true); |
3088 | Code code(graph, "" ); |
3089 | InterpreterState interpreter{code}; |
3090 | std::vector<IValue> stack = {a, b}; |
3091 | interpreter.run(stack); |
3092 | at::Tensor out2 = pop(stack).toTensor(); |
3093 | at::Tensor out1 = pop(stack).toTensor(); |
3094 | ASSERT_TRUE(at::allclose(ref1, out1)); |
3095 | ASSERT_TRUE(at::allclose(ref2, out2)); |
3096 | |
3097 | auto inp_1 = at::ones({4, 4}, TensorOptions(kCPU).dtype(at::kFloat)); |
3098 | auto inp_2 = at::ones({4, 4}, TensorOptions(kCPU).dtype(at::kFloat)); |
3099 | stack = {inp_1, inp_2, a, b}; |
3100 | InterpreterState interpreter2{code}; |
3101 | interpreter2.run(stack); |
3102 | out2 = pop(stack).toTensor(); |
3103 | out1 = pop(stack).toTensor(); |
3104 | ASSERT_TRUE(at::allclose(ref1, out1)); |
3105 | ASSERT_TRUE(at::allclose(ref2, out2)); |
3106 | // inp_1 is on the bottom of the stack, and corresponds |
3107 | // to the second output. inp_2 is on the top corresponds to first output |
3108 | ASSERT_TRUE(at::allclose(inp_1, ref2)); |
3109 | ASSERT_TRUE(at::allclose(inp_2, ref1)); |
3110 | torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = fusable_on_device; |
3111 | #endif |
3112 | } |
3113 | |
3114 | TEST(ConstantPropagation, CustomClassesCanBePropagated) { |
3115 | const auto src = R"IR( |
3116 | graph(): |
3117 | %none: NoneType = prim::Constant() |
3118 | %dim: int = prim::Constant[value=3]() |
3119 | %shape: int[] = prim::ListConstruct(%dim, %dim) |
3120 | %weight: Tensor = aten::ones(%shape, %none, %none, %none, %none) |
3121 | %scale: float = prim::Constant[value=1.]() |
3122 | %zero_point: int = prim::Constant[value=0]() |
3123 | %dtype: int = prim::Constant[value=12]() |
3124 | %weight_q: Tensor = aten::quantize_per_tensor(%weight, %scale, %zero_point, %dtype) |
3125 | %params: __torch__.torch.classes.quantized.LinearPackedParamsBase = quantized::linear_prepack(%weight_q, %none) |
3126 | return (%params) |
3127 | )IR" ; |
3128 | auto graph = std::make_shared<Graph>(); |
3129 | std::unordered_map<std::string, Value*> vmap; |
3130 | parseIR(src, graph.get(), vmap); |
3131 | |
3132 | ConstantPropagation(graph); |
3133 | |
3134 | testing::FileCheck().check_not("quantized::linear_prepack" )->run(*graph); |
3135 | } |
3136 | |
3137 | } // namespace jit |
3138 | } // namespace torch |
3139 | |