1#include <gmock/gmock.h>
2#include <gtest/gtest.h>
3
4#include <ATen/ATen.h>
5#include <ATen/Parallel.h>
6#include <ATen/core/interned_strings.h>
7#include <ATen/core/ivalue.h>
8#include <ATen/core/jit_type_base.h>
9#include <test/cpp/jit/test_utils.h>
10#include <torch/csrc/jit/passes/remove_mutation.h>
11#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
12#include <torch/csrc/jit/tensorexpr/kernel.h>
13
14#include <torch/csrc/autograd/engine.h>
15#include <torch/csrc/autograd/generated/variable_factories.h>
16#include <torch/csrc/autograd/profiler.h>
17#include <torch/csrc/autograd/variable.h>
18#include <torch/csrc/jit/api/function_impl.h>
19#include <torch/csrc/jit/api/module.h>
20#include <torch/csrc/jit/codegen/fuser/interface.h>
21#include <torch/csrc/jit/frontend/ir_emitter.h>
22#include <torch/csrc/jit/frontend/tracer.h>
23#include <torch/csrc/jit/ir/alias_analysis.h>
24#include <torch/csrc/jit/ir/attributes.h>
25#include <torch/csrc/jit/ir/irparser.h>
26#include <torch/csrc/jit/ir/scope.h>
27#include <torch/csrc/jit/ir/type_hashing.h>
28#include <torch/csrc/jit/jit_log.h>
29#include <torch/csrc/jit/passes/bailout_graph.h>
30#include <torch/csrc/jit/passes/canonicalize.h>
31#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
32#include <torch/csrc/jit/passes/constant_propagation.h>
33#include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
34#include <torch/csrc/jit/passes/dead_code_elimination.h>
35#include <torch/csrc/jit/passes/graph_fuser.h>
36#include <torch/csrc/jit/passes/guard_elimination.h>
37#include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h>
38#include <torch/csrc/jit/passes/insert_guards.h>
39#include <torch/csrc/jit/passes/liveness.h>
40#include <torch/csrc/jit/passes/loop_unrolling.h>
41#include <torch/csrc/jit/passes/lower_grad_of.h>
42#include <torch/csrc/jit/passes/lower_tuples.h>
43#include <torch/csrc/jit/passes/pass_manager.h>
44#include <torch/csrc/jit/passes/requires_grad_analysis.h>
45#include <torch/csrc/jit/passes/restore_mutation.h>
46#include <torch/csrc/jit/passes/shape_analysis.h>
47#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
48#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
49#include <torch/csrc/jit/runtime/argument_spec.h>
50#include <torch/csrc/jit/runtime/autodiff.h>
51#include <torch/csrc/jit/runtime/custom_operator.h>
52#include <torch/csrc/jit/runtime/decomposition_registry.h>
53#include <torch/csrc/jit/runtime/graph_executor.h>
54#include <torch/csrc/jit/runtime/interpreter.h>
55#include <torch/csrc/jit/runtime/jit_trace.h>
56#include <torch/csrc/jit/runtime/profiling_record.h>
57#include <torch/csrc/jit/runtime/symbolic_script.h>
58#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
59#include <torch/csrc/jit/serialization/import.h>
60#include <torch/csrc/jit/testing/file_check.h>
61#include <torch/jit.h>
62#include <torch/script.h>
63
64#include <onnx/onnx_pb.h>
65
66#include <c10/util/Exception.h>
67#include <c10/util/ThreadLocalDebugInfo.h>
68
69#include <torch/csrc/jit/passes/freeze_module.h>
70#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
71#include <algorithm>
72#include <cstddef>
73#include <functional>
74#include <iostream>
75#include <memory>
76#include <set>
77#include <stdexcept>
78#include <string>
79#include <tuple>
80#include <unordered_map>
81#include <unordered_set>
82#include <utility>
83#include <vector>
84
85namespace torch {
86namespace jit {
87inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
88 return c10::AliasAnalysisKind::FROM_SCHEMA;
89}
90
91template <typename T>
92std::ostream& operator<<(std::ostream& out, const std::vector<T>& list) {
93 size_t i = 0;
94 out << "{";
95 for (auto&& e : list) {
96 if (i++ > 0)
97 out << ", ";
98 out << e;
99 }
100 out << "}";
101 return out;
102}
103
104TEST(InternedStringsTest, Basic) {
105 ASSERT_EQ(prim::Param, Symbol::prim("Param"));
106 ASSERT_EQ(prim::Return, Symbol::prim("Return"));
107 ASSERT_EQ(prim::Return.toUnqualString(), std::string("Return"));
108 ASSERT_EQ(prim::Return.toQualString(), std::string("prim::Return"));
109 Symbol newsym = Symbol::aten("__NEW_SYMBOL");
110 size_t symstart = newsym;
111 ASSERT_EQ(newsym.toQualString(), std::string("aten::__NEW_SYMBOL"));
112 // TODO: This test is a bit too close to the implementation details.
113 ASSERT_EQ(Symbol::aten("What"), symstart + 1);
114 ASSERT_EQ(Symbol::aten("What2"), symstart + 2);
115 ASSERT_EQ(Symbol::aten("What"), symstart + 1);
116 ASSERT_EQ(Symbol::aten("What2"), symstart + 2);
117 ASSERT_EQ(Symbol(symstart + 2).toUnqualString(), std::string("What2"));
118}
119
120TEST(FromQualStringTest, Basic) {
121 ASSERT_EQ(Symbol::fromQualString("prim::Param"), Symbol::prim("Param"));
122 ASSERT_EQ(Symbol::fromQualString("aten::mm"), Symbol::aten("mm"));
123 ASSERT_EQ(Symbol::fromQualString("onnx::LSTM"), Symbol::onnx("LSTM"));
124 ASSERT_EQ(Symbol::fromQualString("attr::value"), Symbol::attr("value"));
125 ASSERT_EQ(Symbol::fromQualString("scope::"), Symbol::scope(""));
126 ASSERT_EQ(Symbol::fromQualString("::").toUnqualString(), std::string(""));
127 ASSERT_EQ(
128 Symbol::fromQualString("::").ns().toQualString(),
129 std::string("namespaces::"));
130 ASSERT_EQ(
131 Symbol::fromQualString("new_ns::param").toUnqualString(),
132 std::string("param"));
133 ASSERT_EQ(
134 Symbol::fromQualString("new_ns::param").ns().toUnqualString(),
135 std::string("new_ns"));
136 ASSERT_EQ(
137 Symbol::fromQualString("new_ns::param").ns(),
138 Symbol::fromQualString("namespaces::new_ns"));
139
140 auto bad_inputs = {"scope", ":", ""};
141 for (auto input : bad_inputs) {
142 try {
143 Symbol::fromQualString(input);
144 ASSERT_TRUE(0);
145 } catch (const std::exception& c) {
146 }
147 }
148}
149
150TEST(THNNConvTest, Basic) {
151 std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W
152 std::vector<int64_t> kernel_size = {3, 5};
153 std::vector<int64_t> stride = {1, 2};
154 std::vector<int64_t> padding = {2, 1};
155 constexpr int out_channels = 5;
156
157 // make inputs
158 at::Tensor input = torch::randn(input_size);
159 at::Tensor weight = torch::randn(
160 {out_channels, input_size[1], kernel_size[0], kernel_size[1]});
161 at::Tensor bias = torch::randn({out_channels});
162
163 // run forward eagerly
164 at::Tensor output = at::_slow_conv2d_forward(
165 input, weight, kernel_size, bias, stride, padding);
166
167 // make grad_outputs
168 at::Tensor grad_output =
169 torch::randn_like(output, at::MemoryFormat::Preserve);
170
171 // run backward eagerly
172 at::Tensor grad_input, grad_weight, grad_bias;
173 std::tie(grad_input, grad_weight, grad_bias) = at::_slow_conv2d_backward(
174 grad_output,
175 input,
176 weight,
177 kernel_size,
178 stride,
179 padding,
180 {true, true, true});
181
182 // make JIT graph
183 auto graph = std::make_shared<Graph>();
184 auto ksz_val = graph->insertConstant(kernel_size);
185 auto kst_val = graph->insertConstant(stride);
186 auto pad_val = graph->insertConstant(padding);
187
188 auto inputg = graph->addInput("self");
189 auto weightg = graph->addInput("weight");
190 auto biasg = graph->addInput("bias");
191
192 Value* conv = graph->insert(
193 aten::_slow_conv2d_forward,
194 {inputg, weightg, ksz_val, biasg, kst_val, pad_val});
195 auto outputs = conv->node()->outputs();
196 for (auto output : outputs) {
197 graph->registerOutput(output);
198 }
199 LowerAllTuples(graph);
200 graph->lint();
201
202 // differentiate JIT graph
203 EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
204 ConstantPropagation(graph);
205 auto grad_spec = differentiate(graph);
206 LowerGradOf(*grad_spec.df);
207
208 // prepare JIT inputs / gradients
209 tensor_list tensors_in;
210 tensors_in.push_back(input);
211 tensors_in.push_back(weight);
212 tensors_in.push_back(bias);
213
214 tensor_list tensor_grads_in;
215 tensor_grads_in.push_back(grad_output);
216
217 // Get outputs from the interpreter
218 tensor_list tensors_out, tensor_grads_out;
219 std::tie(tensors_out, tensor_grads_out) =
220 runGradient(grad_spec, tensors_in, tensor_grads_in);
221
222 // prepare expected structs
223 tensor_list expected_tensors_out, expected_tensor_grads_out;
224 expected_tensors_out.push_back(output);
225 expected_tensor_grads_out.push_back(grad_input);
226 expected_tensor_grads_out.push_back(grad_weight);
227 expected_tensor_grads_out.push_back(grad_bias);
228
229 // Compare results
230 assertAllClose(tensors_out, expected_tensors_out);
231 assertAllClose(tensor_grads_out, expected_tensor_grads_out);
232}
233
234TEST(ATenNativeBatchNormTest, Basic) {
235 // aten::native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor
236 // running_mean, Tensor running_var, bool training, float momentum, float eps)
237 // -> (Tensor, Tensor, Tensor)
238 std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W
239 bool training = true;
240 float momentum = 0.9;
241 float eps = 1e-5;
242
243 // make inputs
244 at::Tensor input = torch::randn(input_size);
245 at::Tensor weight = torch::randn({input_size[1]});
246 at::Tensor bias = torch::randn({input_size[1]});
247 at::Tensor running_mean = torch::randn({input_size[1]});
248 at::Tensor running_var = torch::randn({input_size[1]});
249
250 // running_mean and running_var are changed in-place, so clone and send them
251 at::Tensor running_mean_eager = running_mean.clone();
252 at::Tensor running_var_eager = running_var.clone();
253 at::Tensor running_mean_jit = running_mean.clone();
254 at::Tensor running_var_jit = running_var.clone();
255
256 // run forward eagerly
257 at::Tensor output, savemean, saveinvstd;
258 std::tie(output, savemean, saveinvstd) = at::native_batch_norm(
259 input,
260 weight,
261 bias,
262 running_mean_eager,
263 running_var_eager,
264 training,
265 momentum,
266 eps);
267
268 // make grad_outputs
269 at::Tensor grad_output =
270 torch::randn_like(output, at::MemoryFormat::Preserve);
271 at::Tensor grad_savemean =
272 torch::zeros_like(savemean, at::MemoryFormat::Preserve);
273 at::Tensor grad_saveinvstd =
274 torch::zeros_like(saveinvstd, at::MemoryFormat::Preserve);
275
276 // run backward eagerly
277 at::Tensor grad_input, grad_weight, grad_bias;
278 // aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor
279 // weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor
280 // save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor,
281 // Tensor, Tensor)
282 std::tie(grad_input, grad_weight, grad_bias) = at::native_batch_norm_backward(
283 grad_output,
284 input,
285 weight,
286 running_mean_eager,
287 running_var_eager,
288 savemean,
289 saveinvstd,
290 training,
291 eps,
292 {true, true, true});
293
294 // make JIT graph
295 auto graph = std::make_shared<Graph>();
296 auto training_val = graph->insertConstant(IValue(training));
297 auto momentum_val = graph->insertConstant(IValue(momentum));
298 auto eps_val = graph->insertConstant(IValue(eps));
299
300 auto inputg = graph->addInput("self");
301 auto weightg = graph->addInput("weight");
302 auto biasg = graph->addInput("bias");
303 auto running_meang = graph->addInput("running_mean");
304 auto running_varg = graph->addInput("running_var");
305
306 Value* bn = graph->insert(
307 aten::native_batch_norm,
308 {inputg,
309 weightg,
310 biasg,
311 running_meang,
312 running_varg,
313 training_val,
314 momentum_val,
315 eps_val});
316 auto outputs = bn->node()->outputs();
317 for (auto output : outputs) {
318 graph->registerOutput(output);
319 }
320 LowerAllTuples(graph);
321 graph->lint();
322
323 // differentiate JIT graph
324 EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
325 ConstantPropagation(graph);
326 auto grad_spec = differentiate(graph);
327 LowerGradOf(*grad_spec.df);
328
329 // prepare JIT inputs / gradients
330 tensor_list tensors_in;
331 tensors_in.push_back(input);
332 tensors_in.push_back(weight);
333 tensors_in.push_back(bias);
334 tensors_in.push_back(running_mean_jit);
335 tensors_in.push_back(running_var_jit);
336
337 tensor_list tensor_grads_in;
338 tensor_grads_in.push_back(grad_output);
339 tensor_grads_in.push_back(grad_savemean);
340 tensor_grads_in.push_back(grad_saveinvstd);
341
342 // Get outputs from the interpreter
343 tensor_list tensors_out, tensor_grads_out;
344 std::tie(tensors_out, tensor_grads_out) =
345 runGradient(grad_spec, tensors_in, tensor_grads_in);
346
347 // prepare expected structs
348 tensor_list expected_tensors_out, expected_tensor_grads_out;
349 expected_tensors_out.push_back(output);
350 expected_tensors_out.push_back(savemean);
351 expected_tensors_out.push_back(saveinvstd);
352 expected_tensors_out.push_back(running_mean_eager);
353 expected_tensors_out.push_back(running_var_eager);
354 expected_tensor_grads_out.push_back(grad_input);
355 expected_tensor_grads_out.push_back(grad_weight);
356 expected_tensor_grads_out.push_back(grad_bias);
357
358 tensors_out.push_back(running_mean_jit);
359 tensors_out.push_back(running_var_jit);
360
361 // Compare results
362 assertAllClose(tensors_out, expected_tensors_out);
363 assertAllClose(tensor_grads_out, expected_tensor_grads_out);
364}
365
366TEST(CustomFusionTest, Basic) {
367#if defined(FBCODE_CAFFE2)
368 return;
369#endif
370
371 auto graph_string = R"IR(
372 graph(%0 : Float(2, 3, 4),
373 %1 : Float(2, 3, 4)):
374 %2 : Tensor = aten::mul(%0, %1)
375 %3 : Tensor = aten::mul(%2, %0)
376 return (%3))IR";
377 auto g = std::make_shared<Graph>();
378 torch::jit::parseIR(graph_string, g.get());
379
380 torch::jit::overrideCanFuseOnCPU(true);
381 CustomFuseGraph(
382 g,
383 [](Node* n) { return n->kind() != prim::Param; },
384 Symbol::fromQualString("prim::FusionGroup"));
385 torch::jit::overrideCanFuseOnCPU(false);
386
387 const auto& nodes = g->nodes();
388 auto fusion_group =
389 std::find_if(nodes.begin(), nodes.end(), [](const Node* node) {
390 return node->kind() == Symbol::fromQualString("prim::FusionGroup");
391 });
392 AT_ASSERT(fusion_group != nodes.end());
393
394 auto subgraph = fusion_group->g(attr::Subgraph);
395 auto hits = 0;
396 // two multiplications
397 for (const auto& n : subgraph->nodes()) {
398 (void)n;
399 hits++;
400 }
401 AT_ASSERT(hits == 2);
402}
403
404TEST(CustomFusionTest, NestedBlocks) {
405#if defined(FBCODE_CAFFE2)
406 return;
407#endif
408
409 auto graph_string = R"IR(
410 graph(%0 : Float(2, 3, 4),
411 %1 : Float(2, 3, 4),
412 %2 : Float(2, 3, 4)):
413 %3 : int = prim::Constant[value=1]()
414 %4 : Tensor = prim::If(%2)
415 block0():
416 %5 : Tensor = aten::mul(%0, %2)
417 %6 : Tensor = aten::mul(%5, %1)
418 -> (%6)
419 block1():
420 %7 : Tensor = aten::add(%0, %2, %3)
421 %8 : Tensor = aten::add(%7, %1, %3)
422 -> (%8)
423 %9 : Tensor = aten::add(%4, %2, %3)
424 return (%4))IR";
425 auto g = std::make_shared<Graph>();
426 torch::jit::parseIR(graph_string, g.get());
427
428 CustomFuseGraph(
429 g,
430 [](Node* n) { return n->kind() == aten::mul; },
431 Symbol::fromQualString("prim::FusionGroup"));
432
433 // Could be done in more efficient ways, but this is only a test.
434 std::function<bool(const Block*, Symbol)> dfs = [&](const Block* b,
435 Symbol s) {
436 for (auto node : b->nodes()) {
437 if (node->kind() == s)
438 return true;
439 for (auto nested_b : node->blocks())
440 if (dfs(nested_b, s))
441 return true;
442 }
443 return false;
444 };
445
446 AT_ASSERT(dfs(g->block(), Symbol::fromQualString("prim::FusionGroup")));
447}
448
449static const auto cf_examples = R"JIT(
450 def if_test(a, b):
451 # FIXME: use 0 instead of a.
452 # c = 0
453 c = a
454 if bool(a < b):
455 c = b
456 else:
457 c = a
458 return c
459 def if_one(a, b):
460 c = b
461 if bool(a < b):
462 c = a
463 return c
464 def while_test(a, i):
465 while bool(i < 3):
466 a *= a
467 i += 1
468 return a
469)JIT";
470
471TEST(ControlFlowTest, Basic) {
472 auto cu = compile(cf_examples);
473
474 auto run = [&](const std::string& name, std::vector<IValue> stack) {
475 auto graph = toGraphFunction(cu->get_function(name)).graph();
476 Code code(graph, "");
477 InterpreterState interp(code);
478 interp.run(stack);
479 return stack;
480 };
481
482 auto L = [](int64_t l) { return IValue(scalar_to_tensor(at::Scalar(l))); };
483 auto V = [](IValue t) { return std::move(t).toTensor().item<int64_t>(); };
484 auto run_binary = [&](const std::string& name, int64_t a, int64_t b) {
485 return V(run(name, {L(a), L(b)})[0]);
486 };
487 ASSERT_EQ(2, run_binary("if_test", 1, 2));
488 ASSERT_EQ(3, run_binary("if_test", 3, 2));
489 ASSERT_EQ(2, run_binary("if_one", 2, 3));
490 ASSERT_EQ(2, run_binary("if_one", 3, 2));
491 ASSERT_EQ(256, run_binary("while_test", 2, 0));
492}
493
494#if defined(__has_feature)
495#if __has_feature(address_sanitizer)
496#define HAS_ASANUBSAN 1
497#endif
498#endif
499
500#ifndef HAS_ASANUBSAN
501// This test fails vptr UBSAN checks
502
503TEST(ProtoTest, Basic) {
504 ::ONNX_NAMESPACE::ModelProto proto;
505 proto.set_producer_name("foo");
506}
507#endif
508
509// test a few features that are not directly used in schemas yet
510TEST(SchemaParserTest, NestedArrays) {
511 // nested arrays
512 auto s = parseSchema("at::what(int[][4] foo) -> ()");
513 ASSERT_TRUE(s.arguments().at(0).N() == 4);
514 ASSERT_TRUE(IntType::get()->isSubtypeOf(*s.arguments()
515 .at(0)
516 .type()
517 ->expectRef<ListType>()
518 .getElementType()
519 ->expectRef<ListType>()
520 .getElementType()));
521 auto s2 = parseSchema("at::what(int[][] foo) -> ()");
522 ASSERT_TRUE(IntType::get()->isSubtypeOf(*s2.arguments()
523 .at(0)
524 .type()
525 ->expectRef<ListType>()
526 .getElementType()
527 ->expectRef<ListType>()
528 .getElementType()));
529}
530
531TEST(SchemaParserTest, OutVariant) {
532 auto schema_with_out = parseSchema(
533 "at::foo(Tensor self, *, Tensor(a!) f, Tensor(b!) l) -> (Tensor(a!) f, Tensor(b!) l)");
534 ASSERT_TRUE(schema_with_out.arguments().at(1).is_out());
535 ASSERT_TRUE(schema_with_out.arguments().at(2).is_out());
536
537 auto schema_without_out =
538 parseSchema("at::foo(Tensor self, *, int scalar) -> (int)");
539
540 for (const auto& arg : schema_without_out.arguments()) {
541 ASSERT_TRUE(!arg.is_out());
542 }
543
544 auto schema_with_is_write = parseSchema(
545 "aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> (Tensor(a!))");
546
547 for (const auto& arg : schema_with_is_write.arguments()) {
548 ASSERT_TRUE(!arg.is_out());
549 }
550}
551
552// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
553TEST(SchemaParserTest, NamedReturns) {
554 // named returns
555 parseSchema("at::what(Tensor! i_will_be_written_to) -> ()");
556 auto s3 =
557 parseSchema("at::what() -> (Tensor the_return, Tensor the_return2)");
558 ASSERT_TRUE(s3.returns().at(0).name() == "the_return");
559 ASSERT_TRUE(s3.returns().at(1).name() == "the_return2");
560}
561
562TEST(SchemaParserTest, Futures) {
563 // futures
564 auto s4 = parseSchema("at::what(Future(int) foo) -> ()");
565 ASSERT_TRUE(IntType::get()->isSubtypeOf(
566 *s4.arguments().at(0).type()->expectRef<FutureType>().getElementType()));
567}
568
569TEST(SchemaParserTest, AnnotatedAliasSets) {
570 // test tensor with annotated alias sets
571 parseSchema("at::what(Tensor(a) foo) -> (Tensor(a))");
572}
573
574TEST(SchemaParserTest, TensorListAnnotatedAliasSets) {
575 const auto s = parseSchema(
576 "at::foo(Tensor(a!) self, Tensor(b!)[] out)"
577 " -> ()");
578 const AliasInfo* selfAliasInfo = s.arguments().at(0).alias_info();
579 const AliasInfo* outAliasInfo = s.arguments().at(1).alias_info();
580 ASSERT_TRUE(
581 selfAliasInfo->beforeSets() ==
582 std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
583 ASSERT_TRUE(selfAliasInfo->isWrite());
584
585 ASSERT_TRUE(outAliasInfo->isWrite());
586 ASSERT_TRUE(outAliasInfo->beforeSets().empty());
587 ASSERT_EQ(outAliasInfo->containedTypes().size(), 1);
588
589 auto containedType = outAliasInfo->containedTypes()[0];
590
591 ASSERT_TRUE(containedType.isWrite());
592 ASSERT_TRUE(
593 containedType.beforeSets() ==
594 std::unordered_set<Symbol>{Symbol::fromQualString("alias::b")});
595}
596
597TEST(SchemaParserTest, AnnotatedAliasWithoutBeforeSet) {
598 EXPECT_THAT(
599 []() { parseSchema("at::foo(Tensor(!) self) -> Tensor"); },
600 ::testing::Throws<std::runtime_error>(::testing::Property(
601 &std::runtime_error::what,
602 ::testing::HasSubstr("expected ident but found '!' here"))));
603}
604
605TEST(SchemaParserTest, BeforeAfterSets) {
606 const auto s = parseSchema(
607 "at::what(Tensor(b|c)[](a!) list, Tensor(c) element)"
608 " -> (Tensor(b|c)[](a!))");
609
610 // The list itself is annotated with `a`
611 const AliasInfo* aliasInfo = s.arguments().at(0).alias_info();
612 ASSERT_NE(aliasInfo, nullptr);
613 ASSERT_TRUE(
614 aliasInfo->beforeSets() ==
615 std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
616 ASSERT_TRUE(aliasInfo->isWrite());
617
618 // Check the contained types
619 ASSERT_TRUE(!aliasInfo->containedTypes().empty());
620 const auto& containedAliasInfo = aliasInfo->containedTypes()[0];
621 const auto expected = std::unordered_set<Symbol>{
622 Symbol::fromQualString("alias::b"),
623 Symbol::fromQualString("alias::c"),
624 };
625 ASSERT_TRUE(containedAliasInfo.beforeSets() == expected);
626 ASSERT_TRUE(containedAliasInfo.afterSets() == expected);
627 ASSERT_FALSE(containedAliasInfo.isWrite());
628}
629
630TEST(SchemaParserTest, BeforeAfterSets2) {
631 const auto s = parseSchema(
632 "at::what(Tensor(b -> b|c)[](a!) list, Tensor(c) element)"
633 " -> (Tensor(b|c)[](a!))");
634
635 // The list itself is annotated with `a`
636 const AliasInfo* aliasInfo = s.arguments().at(0).alias_info();
637 ASSERT_NE(aliasInfo, nullptr);
638 ASSERT_EQ(
639 aliasInfo->beforeSets(),
640 std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
641 ASSERT_EQ(
642 aliasInfo->afterSets(),
643 std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
644 ASSERT_TRUE(aliasInfo->isWrite());
645 ASSERT_EQ(aliasInfo->containedTypes().size(), 1);
646
647 // Check the contained types
648 ASSERT_TRUE(!aliasInfo->containedTypes().empty());
649 const auto& containedAliasInfo = aliasInfo->containedTypes()[0];
650 const auto expectedBefore = std::unordered_set<Symbol>{
651 Symbol::fromQualString("alias::b"),
652 };
653 const auto expectedAfter = std::unordered_set<Symbol>{
654 Symbol::fromQualString("alias::b"), Symbol::fromQualString("alias::c")};
655 ASSERT_TRUE(containedAliasInfo.beforeSets() == expectedBefore);
656 ASSERT_TRUE(containedAliasInfo.afterSets() == expectedAfter);
657 ASSERT_FALSE(containedAliasInfo.isWrite());
658}
659
660TEST(TopologicalIndexTest, Basic) {
661 Graph graph;
662 auto node1 = graph.create(prim::AutogradZero);
663 auto node2 = graph.create(prim::AutogradZero);
664 auto node3 = graph.create(prim::AutogradZero);
665 auto node4 = graph.create(prim::AutogradZero);
666
667 graph.appendNode(node4);
668 graph.prependNode(node1);
669 node2->insertAfter(node1);
670 node3->insertBefore(node4);
671
672 // nodes should be in numerical order
673 ASSERT_TRUE(node1->isBefore(node2));
674 ASSERT_TRUE(node1->isBefore(node3));
675 ASSERT_TRUE(node1->isBefore(node4));
676 ASSERT_TRUE(node2->isAfter(node1));
677 ASSERT_TRUE(node2->isBefore(node3));
678 ASSERT_TRUE(node2->isBefore(node4));
679 ASSERT_FALSE(node3->isBefore(node1));
680 ASSERT_FALSE(node3->isBefore(node2));
681 ASSERT_FALSE(node3->isAfter(node4));
682
683 // Built up a block structure
684 // node3
685 // /\ ...
686 // A B block1
687 // \ ...
688 // C block2
689 auto block1 = node3->addBlock();
690 auto A = graph.create(prim::AutogradZero);
691 block1->appendNode(A);
692 auto B = graph.create(prim::AutogradZero);
693 block1->appendNode(B);
694 auto block2 = B->addBlock();
695 auto C = graph.create(prim::AutogradZero);
696 block2->appendNode(C);
697
698 // Check isAfter on different block levels
699 ASSERT_TRUE(node1->isBefore(A));
700 ASSERT_TRUE(A->isBefore(B));
701 ASSERT_TRUE(A->isBefore(C));
702
703 // make sure things don't blow up on deletions
704 node2->destroy();
705 auto node2p = graph.create(prim::AutogradZero);
706 node2p->insertAfter(node1);
707 ASSERT_TRUE(node1->isBefore(node2p));
708 ASSERT_TRUE(node2p->isBefore(node3));
709}
710
711TEST(TopologicalIndexTest, Reindex) {
712 // Induce reindexing to test that path
713 Graph graph;
714 std::map<size_t, Node*> nodes;
715
716 auto anchor = graph.create(prim::AutogradZero);
717 graph.appendNode(anchor);
718 // Inserting to the same place a lot will trigger reindexing
719 for (auto i = 0; i < 100; ++i) {
720 auto n = graph.create(prim::AutogradZero);
721 n->insertAfter(anchor);
722 nodes[i] = n;
723 }
724
725 // Nodes should be in reverse order
726 for (auto i = 0; i < 100; ++i) {
727 for (auto j = i + 1; j < 100; ++j) {
728 ASSERT_TRUE(nodes[i]->isAfter(nodes[j]));
729 }
730 }
731}
732
733at::Tensor invokeTestRecordFunction(at::Tensor& t) {
734 RECORD_FUNCTION("test", std::vector<c10::IValue>({t}));
735
736 auto t2 = t.pow(2);
737 return t2;
738}
739
740static const auto invokeTestRecordFunction_JIT = R"JIT(
741 def foo(self, t):
742 t2 = t.pow(2)
743 return t2
744
745 def forward(self, t):
746 return self.foo(t)
747)JIT";
748
749at::Tensor invokeTestRecordFunctionJIT(at::Tensor& t) {
750 RECORD_FUNCTION("test", std::vector<c10::IValue>({t}));
751
752 auto module = std::make_shared<script::Module>(
753 "RecordFunctionTestModule", std::make_shared<script::CompilationUnit>());
754 module->define(invokeTestRecordFunction_JIT);
755 return module->forward({t}).toTensor();
756}
757
758using TracedTestValues =
759 std::vector<std::tuple<std::string, std::vector<std::vector<int64_t>>>>;
760
761void checkTracedInputs(const TracedTestValues& inputs) {
762 bool found_test = false;
763 bool found_pow = false;
764 bool found_mul = false;
765 for (const auto& input : inputs) {
766 const auto& fn = std::get<0>(input);
767 const auto& sizes = std::get<1>(input);
768
769 if (fn == "test") {
770 found_test = true;
771 TORCH_CHECK(sizes.size() == 1);
772 TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
773 } else if (fn == "aten::pow") {
774 found_pow = true;
775 TORCH_CHECK(sizes.size() == 2);
776 TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
777 TORCH_CHECK(sizes[1].empty());
778 } else if (fn == "aten::mul") {
779 found_mul = true;
780 TORCH_CHECK(sizes.size() > 1);
781 TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
782 }
783 }
784 TORCH_CHECK(found_test);
785 TORCH_CHECK(found_pow);
786 TORCH_CHECK(found_mul);
787}
788
789void checkTracedOutputs(const TracedTestValues& outputs) {
790 bool found_test = false;
791 bool found_pow = false;
792 bool found_mul = false;
793 for (const auto& output : outputs) {
794 const auto& fn = std::get<0>(output);
795 const auto& sizes = std::get<1>(output);
796
797 if (fn == "test") {
798 found_test = true;
799 TORCH_CHECK(sizes.empty());
800 } else if (fn == "aten::pow") {
801 found_pow = true;
802 TORCH_CHECK(sizes.size() == 1);
803 TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
804 } else if (fn == "aten::mul") {
805 found_mul = true;
806 TORCH_CHECK(sizes.size() == 1);
807 TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
808 }
809 }
810 TORCH_CHECK(found_test);
811 TORCH_CHECK(found_pow);
812 TORCH_CHECK(found_mul);
813}
814
815static bool bad_scope = false;
816template <RecordScope scope, size_t* cnt>
817std::unique_ptr<at::ObserverContext> checkScopeCallback(
818 const at::RecordFunction& fn) {
819 if (fn.scope() == scope) {
820 ++(*cnt);
821 } else {
822 bad_scope = true;
823 }
824 return nullptr;
825}
826
827template <RecordScope scope, size_t* cnt>
828void pushScopedCallback() {
829 at::addGlobalCallback(
830 at::RecordFunctionCallback(checkScopeCallback<scope, cnt>)
831 .scopes({scope}));
832}
833
834// These cannot be function-local because that would prohibit them
835// from being used as template arguments prior to C++17.
836static size_t fun_cnt;
837static size_t ts_fun_cnt;
838static size_t user_scope_cnt;
839
840void checkScopeCallbacks() {
841 static bool found_function_scope;
842 static bool found_method_scope;
843 static bool found_user_scope;
844 found_function_scope = false;
845 found_method_scope = false;
846 found_user_scope = false;
847 at::addGlobalCallback(at::RecordFunctionCallback(
848 [](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
849 if (fn.scope() == at::RecordScope::FUNCTION &&
850 std::string(fn.name()) == "test_function") {
851 found_function_scope = true;
852 }
853 if (fn.scope() == at::RecordScope::TORCHSCRIPT_FUNCTION &&
854 std::string(fn.name()) == "test_method") {
855 found_method_scope = true;
856 }
857 if (fn.scope() == at::RecordScope::USER_SCOPE &&
858 std::string(fn.name()) == "test_user_scope") {
859 found_user_scope = true;
860 }
861 return nullptr;
862 }));
863
864 bad_scope = false;
865 fun_cnt = 0;
866 pushScopedCallback<at::RecordScope::FUNCTION, &fun_cnt>();
867 ts_fun_cnt = 0;
868 pushScopedCallback<at::RecordScope::TORCHSCRIPT_FUNCTION, &ts_fun_cnt>();
869 user_scope_cnt = 0;
870 pushScopedCallback<at::RecordScope::USER_SCOPE, &user_scope_cnt>();
871
872 TORCH_CHECK(at::hasCallbacks());
873
874 {
875 RECORD_TORCHSCRIPT_FUNCTION("test_method", {});
876 { RECORD_FUNCTION("test_function", {}); }
877 { RECORD_USER_SCOPE("test_user_scope"); }
878 }
879
880 TORCH_CHECK(!bad_scope);
881 TORCH_CHECK(fun_cnt == 1);
882 TORCH_CHECK(ts_fun_cnt == 1);
883 TORCH_CHECK(user_scope_cnt == 1);
884
885 TORCH_CHECK(found_function_scope);
886 TORCH_CHECK(found_method_scope);
887 TORCH_CHECK(found_user_scope);
888}
889
890static TracedTestValues traced_inputs;
891static TracedTestValues traced_outputs;
892static std::unordered_set<std::string> ts_input_names;
893static std::unordered_set<std::string> ts_output_names;
894
895std::unique_ptr<at::ObserverContext> tracedInputsCallback(
896 const RecordFunction& fn) {
897 if (fn.scope() == RecordScope::FUNCTION) {
898 auto inputs = fn.inputs();
899 std::vector<std::vector<int64_t>> sizes;
900 for (const auto& input : inputs) {
901 if (input.isTensor()) {
902 sizes.push_back(input.toTensor().sizes().vec());
903 } else if (input.isScalar()) {
904 // NOLINTNEXTLINE(modernize-use-emplace)
905 sizes.push_back(std::vector<int64_t>());
906 }
907 }
908 traced_inputs.push_back(std::make_tuple(fn.name(), sizes));
909 } else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
910 ts_input_names.insert(fn.name());
911 }
912 return nullptr;
913}
914
915void tracedOutputsCallback(const RecordFunction& fn, ObserverContext* ctx_ptr) {
916 if (fn.scope() == RecordScope::FUNCTION) {
917 auto outputs = fn.outputs();
918 std::vector<std::vector<int64_t>> sizes;
919 for (const auto& output : outputs) {
920 if (output.isTensor()) {
921 sizes.push_back(output.toTensor().sizes().vec());
922 } else if (output.isScalar()) {
923 sizes.emplace_back();
924 }
925 }
926 traced_outputs.push_back(std::make_tuple(fn.name(), sizes));
927 } else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
928 ts_output_names.insert(fn.name());
929 }
930}
931
932TEST(RecordFunctionTest, TracedTestInputsOutputs) {
933 // disabling the inlining of method calls
934 GraphOptimizerEnabledGuard opt_guard(false);
935
936 // [(fn, [[sizes], [sizes], ...]), ...]
937 addGlobalCallback(
938 RecordFunctionCallback(tracedInputsCallback, tracedOutputsCallback)
939 .needsInputs(true)
940 .needsOutputs(true));
941
942 TracedTestValues eager_inputs, eager_outputs, jit_inputs, jit_outputs;
943 {
944 auto t = torch::randn({1, 2, 3}, at::kCPU);
945 t.set_requires_grad(true);
946 auto t2 = invokeTestRecordFunction(t);
947 t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
948 eager_inputs = traced_inputs;
949 eager_outputs = traced_outputs;
950 traced_inputs.clear();
951 traced_outputs.clear();
952
953 TORCH_CHECK(ts_input_names.empty());
954 TORCH_CHECK(ts_output_names.empty());
955
956 t = torch::randn({1, 2, 3}, at::kCPU);
957 t.set_requires_grad(true);
958 t2 = invokeTestRecordFunctionJIT(t);
959 t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
960 jit_inputs = traced_inputs;
961 jit_outputs = traced_outputs;
962 traced_inputs.clear();
963 traced_outputs.clear();
964 }
965
966 TORCH_CHECK(ts_input_names.find("forward") != ts_input_names.end());
967 TORCH_CHECK(ts_input_names.find("foo") != ts_input_names.end());
968 TORCH_CHECK(ts_output_names.find("forward") != ts_output_names.end());
969 TORCH_CHECK(ts_output_names.find("foo") != ts_output_names.end());
970
971 checkTracedInputs(eager_inputs);
972 checkTracedOutputs(eager_outputs);
973 checkTracedInputs(jit_inputs);
974 checkTracedOutputs(jit_outputs);
975 at::clearCallbacks();
976}
977
978static int sampled_cb_ctr = 0;
979std::unique_ptr<ObserverContext> sampledCallback(const RecordFunction& fn) {
980 if (std::string(fn.name()) == "test") {
981 ++sampled_cb_ctr;
982 }
983 return nullptr;
984}
985
986static int non_sampled_cb_ctr = 0;
987std::unique_ptr<ObserverContext> nonSampledCallback(const RecordFunction& fn) {
988 if (std::string(fn.name()) == "test") {
989 ++non_sampled_cb_ctr;
990 }
991 return nullptr;
992}
993
994TEST(RecordFunctionTest, SampledCallbacks) {
995 // disabling the inlining of method calls
996 GraphOptimizerEnabledGuard opt_guard(false);
997
998 // test sampled callbacks
999 sampled_cb_ctr = 0;
1000 auto setup_sampled_callback = [](double sampling_prob) {
1001 return addGlobalCallback(
1002 RecordFunctionCallback(sampledCallback).samplingProb(sampling_prob));
1003 };
1004
1005 addGlobalCallback(RecordFunctionCallback(nonSampledCallback));
1006
1007 auto handle = setup_sampled_callback(0.5);
1008
1009 auto run_test_function = []() {
1010 auto t = torch::randn({1, 2, 3}, at::kCPU);
1011 for (auto k = 0; k < 1000; k++) {
1012 invokeTestRecordFunction(t);
1013 }
1014 };
1015
1016 run_test_function();
1017 TORCH_CHECK(non_sampled_cb_ctr == 1000);
1018 TORCH_CHECK(sampled_cb_ctr > 0 && sampled_cb_ctr < 1000);
1019
1020 sampled_cb_ctr = 0;
1021 removeCallback(handle);
1022 handle = setup_sampled_callback(0.0);
1023 run_test_function();
1024
1025 TORCH_CHECK(non_sampled_cb_ctr == 2000);
1026 TORCH_CHECK(sampled_cb_ctr == 0);
1027
1028 sampled_cb_ctr = 0;
1029 removeCallback(handle);
1030 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
1031 handle = setup_sampled_callback(1.0);
1032 run_test_function();
1033
1034 TORCH_CHECK(non_sampled_cb_ctr == 3000);
1035 TORCH_CHECK(sampled_cb_ctr == 1000);
1036 clearCallbacks();
1037
1038 // test the scope of the callbacks
1039 checkScopeCallbacks();
1040 clearCallbacks();
1041}
1042
1043TEST(RecordFunctionTest, RecordFunctionGuard) {
1044 // disabling the inlining of method calls
1045 GraphOptimizerEnabledGuard opt_guard(false);
1046
1047 static std::vector<std::string> fn_names;
1048 static std::mutex guard_mtx;
1049
1050 // check record function guard
1051 addGlobalCallback(RecordFunctionCallback(
1052 [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
1053 std::lock_guard<std::mutex> lock(guard_mtx);
1054 // NOLINTNEXTLINE(modernize-use-emplace)
1055 fn_names.push_back(fn.name());
1056 return nullptr;
1057 }));
1058 {
1059 RecordFunctionGuard g1(false);
1060 {
1061 RECORD_USER_SCOPE("A");
1062 {
1063 RecordFunctionGuard g2(true);
1064 RECORD_USER_SCOPE("B");
1065 {
1066 DisableRecordFunctionGuard g3;
1067 RECORD_USER_SCOPE("C");
1068 }
1069 }
1070 { RECORD_USER_SCOPE("D"); }
1071 }
1072 }
1073 TORCH_CHECK(fn_names.size() == 1);
1074 TORCH_CHECK(fn_names[0] == "B");
1075 clearCallbacks();
1076}
1077
1078static std::vector<size_t> ids;
1079
1080template <size_t id>
1081auto add_remove_test_add_cb() {
1082 return addGlobalCallback(RecordFunctionCallback(
1083 [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
1084 ids.push_back(id);
1085 return nullptr;
1086 }));
1087}
1088
1089TEST(RecordFunctionTest, Callbacks) {
1090 // disabling the inlining of method calls
1091 GraphOptimizerEnabledGuard opt_guard(false);
1092
1093 auto h1 = add_remove_test_add_cb<1>();
1094 add_remove_test_add_cb<2>();
1095 auto h3 = add_remove_test_add_cb<3>();
1096
1097 { RECORD_USER_SCOPE("test"); }
1098
1099 TORCH_CHECK(ids.size() == 3);
1100 TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
1101 TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
1102 TORCH_CHECK(std::find(ids.begin(), ids.end(), 3) != ids.end());
1103
1104 ids.clear();
1105 removeCallback(h1);
1106
1107 { RECORD_USER_SCOPE("test"); }
1108
1109 TORCH_CHECK(ids.size() == 2);
1110 TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
1111 TORCH_CHECK(std::find(ids.begin(), ids.end(), 3) != ids.end());
1112
1113 ids.clear();
1114 removeCallback(h3);
1115
1116 { RECORD_USER_SCOPE("test"); }
1117
1118 TORCH_CHECK(ids.size() == 1);
1119 TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
1120
1121 clearCallbacks();
1122
1123 // thread local / global callbacks
1124
1125 ids.clear();
1126 add_remove_test_add_cb<1>();
1127
1128 { RECORD_USER_SCOPE("test"); }
1129
1130 TORCH_CHECK(ids.size() == 1);
1131 TORCH_CHECK(ids[0] == 1);
1132 ids.clear();
1133
1134 auto th = std::thread([]() {
1135 addThreadLocalCallback(RecordFunctionCallback(
1136 [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
1137 ids.push_back(2);
1138 return nullptr;
1139 }));
1140
1141 { RECORD_USER_SCOPE("test_thread"); }
1142 });
1143 th.join();
1144 TORCH_CHECK(ids.size() == 2);
1145 TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
1146 TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
1147 ids.clear();
1148
1149 { RECORD_USER_SCOPE("test"); }
1150
1151 TORCH_CHECK(ids.size() == 1);
1152 TORCH_CHECK(ids[0] == 1);
1153 ids.clear();
1154
1155 clearCallbacks();
1156
1157 // START: thread local / global context check callbacks
1158 struct TestContext : public ObserverContext {
1159 int a{0};
1160 std::string b;
1161 };
1162 ids.clear();
1163 { // START: global test
1164 addGlobalCallback(RecordFunctionCallback(
1165 [](const RecordFunction&
1166 /* unused */) -> std::unique_ptr<at::ObserverContext> {
1167 auto ctx = std::make_unique<TestContext>();
1168 ctx->a = 123;
1169 ctx->b = "test_str";
1170 ids.push_back(1);
1171 return ctx;
1172 },
1173 [](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
1174 auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
1175 TORCH_CHECK(ctx_ptr != nullptr);
1176 TORCH_CHECK(ctx->a == 123);
1177 TORCH_CHECK(ctx->b == "test_str");
1178 }));
1179
1180 { RECORD_USER_SCOPE("test"); }
1181
1182 TORCH_CHECK(ids.size() == 1);
1183 TORCH_CHECK(ids[0] == 1);
1184 ids.clear();
1185 } // END: global test
1186 { // START: thread local test
1187 auto ctx_th = std::thread([]() {
1188 const std::string test_str = "test thread str";
1189 addThreadLocalCallback(RecordFunctionCallback(
1190 [](const RecordFunction&
1191 /* unused */) -> std::unique_ptr<at::ObserverContext> {
1192 auto ctx = std::make_unique<TestContext>();
1193 ctx->a = 234;
1194 ctx->b = "test_thread_str";
1195 ids.push_back(2);
1196 return ctx;
1197 },
1198 [](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
1199 auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
1200 TORCH_CHECK(ctx_ptr != nullptr);
1201 TORCH_CHECK(ctx->a == 234);
1202 TORCH_CHECK(ctx->b == "test_thread_str");
1203 }));
1204
1205 // Will call both global and thread local callbacks.
1206 { RECORD_USER_SCOPE("test_thread"); }
1207 });
1208 ctx_th.join();
1209 TORCH_CHECK(ids.size() == 2);
1210 TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
1211 TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
1212 ids.clear();
1213 } // END: thread local test
1214
1215 clearCallbacks();
1216}
1217
1218TEST(RecordFunctionTest, ShouldRun) {
1219 // disabling the inlining of method calls
1220 GraphOptimizerEnabledGuard opt_guard(false);
1221
1222 static bool ran = false;
1223 auto handle = addGlobalCallback(RecordFunctionCallback(
1224 [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
1225 ran = true;
1226 return nullptr;
1227 }));
1228
1229 { RECORD_USER_SCOPE("test"); }
1230
1231 EXPECT_TRUE(ran) << "first run didn't happen";
1232 ran = false;
1233
1234 disableCallback(handle);
1235
1236 { RECORD_USER_SCOPE("test"); }
1237
1238 EXPECT_FALSE(ran) << "second run happened but shouldn't have";
1239 ran = false;
1240
1241 reenableCallback(handle);
1242
1243 { RECORD_USER_SCOPE("test"); }
1244
1245 EXPECT_TRUE(ran) << "run after re-enable didn't happen";
1246 ran = false;
1247
1248 clearCallbacks();
1249}
1250
1251TEST(RecordFunctionTest, Basic) {
1252 // disabling the inlining of method calls
1253 GraphOptimizerEnabledGuard opt_guard(false);
1254
1255 static std::string recorded_op;
1256 static bool has_ids = false;
1257
1258 // test propagation of TLS callbacks
1259 std::thread t([]() {
1260 RecordFunctionGuard enable_rec_fn;
1261 auto handle = addThreadLocalCallback(RecordFunctionCallback(
1262 [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
1263 recorded_op = fn.name();
1264 return nullptr;
1265 }));
1266 ThreadLocalState state;
1267 std::thread t_child([state]() {
1268 ThreadLocalStateGuard g_tls(state);
1269 RECORD_USER_SCOPE("test_in_thread");
1270 });
1271 t_child.join();
1272 EXPECT_EQ(recorded_op, "test_in_thread");
1273 removeCallback(handle);
1274 });
1275 t.join();
1276 clearCallbacks();
1277
1278 // test set ids
1279 addGlobalCallback(
1280 RecordFunctionCallback(
1281 [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
1282 has_ids = fn.handle() > 0;
1283 return nullptr;
1284 })
1285 .needsIds(true));
1286 { RECORD_USER_SCOPE("test"); }
1287 TORCH_CHECK(has_ids);
1288 clearCallbacks();
1289 has_ids = false;
1290 addGlobalCallback(RecordFunctionCallback(
1291 [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
1292 has_ids = fn.handle() > 0;
1293 return nullptr;
1294 }));
1295 { RECORD_USER_SCOPE("test"); }
1296 TORCH_CHECK(!has_ids);
1297 clearCallbacks();
1298}
1299
1300TEST(RecordFunctionTest, OperatorNameOverload) {
1301 static std::set<std::string> operator_names;
1302 at::addGlobalCallback(at::RecordFunctionCallback(
1303 [](const at::RecordFunction& fn)
1304 -> std::unique_ptr<at::ObserverContext> {
1305 c10::optional<c10::OperatorName> op_name =
1306 fn.operator_name();
1307 if (op_name.has_value()) {
1308 operator_names.insert(c10::toString(*op_name));
1309 } else {
1310 operator_names.insert("No Operator Name");
1311 }
1312 return nullptr;
1313 })
1314 .scopes({at::RecordScope::FUNCTION}));
1315 auto t = torch::randn({1, 2, 3}, at::kCPU);
1316 t.set_requires_grad(false);
1317 auto t2 = t.pow(2);
1318
1319 at::clearCallbacks();
1320 EXPECT_TRUE(operator_names.count("No Operator Name") == 0)
1321 << "Expected that all traced operators had an associated OperatorName object";
1322 EXPECT_TRUE(operator_names.count("aten::randn") == 1)
1323 << "Expected aten::randn to have been called and recorded, but it was not";
1324 EXPECT_TRUE(operator_names.count("aten::pow.Tensor_Scalar") == 1)
1325 << "Expected aten::pow.Tensor_Scalar to have been called and recorded, but it was not";
1326}
1327
1328class TestThreadLocalDebugInfo : public c10::DebugInfoBase {
1329 public:
1330 int getModelId() const {
1331 return model_id_;
1332 }
1333
1334 void setModelId(int model_id) {
1335 model_id_ = model_id;
1336 }
1337
1338 // NOLINTNEXTLINE(modernize-use-override,modernize-use-equals-default)
1339 virtual ~TestThreadLocalDebugInfo() {}
1340
1341 private:
1342 int model_id_ = 0;
1343};
1344
1345void checkDebugInfo(c10::DebugInfoKind kind, int model_id) {
1346 auto* debug_info = c10::ThreadLocalDebugInfo::get(kind);
1347 TORCH_CHECK(debug_info != nullptr);
1348 auto* test_debug_info = dynamic_cast<TestThreadLocalDebugInfo*>(debug_info);
1349 TORCH_CHECK(test_debug_info != nullptr);
1350 TORCH_CHECK(test_debug_info->getModelId() == model_id);
1351}
1352
1353TEST(ThreadLocalDebugInfoTest, Basic) {
1354 static std::atomic<bool> done{false};
1355
1356 TORCH_CHECK(
1357 c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
1358 auto debug_info = std::make_shared<TestThreadLocalDebugInfo>();
1359 debug_info->setModelId(42);
1360 {
1361 c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
1362 checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
1363 }
1364
1365 // check that thread local debug info is propagated through fork calls
1366 TORCH_CHECK(
1367 c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
1368 {
1369 c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
1370 at::launch([]() {
1371 checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
1372 done = true;
1373 });
1374 }
1375 while (!done) {
1376 }
1377
1378 // check that thread local debug info is propagated through backward pass
1379 TORCH_CHECK(
1380 c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
1381 done = false;
1382 auto handle = addGlobalCallback(RecordFunctionCallback(
1383 [](const RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
1384 checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
1385 done = true;
1386 return nullptr;
1387 }));
1388 {
1389 c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
1390 auto t = torch::randn({1, 2, 3}, at::kCPU);
1391 t.set_requires_grad(true);
1392 auto t2 = t.pow(2);
1393 t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
1394 }
1395 removeCallback(handle);
1396 TORCH_CHECK(done);
1397
1398 // check nested debug info
1399 TORCH_CHECK(
1400 c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
1401 {
1402 c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
1403 {
1404 checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
1405 {
1406 auto debug_info = std::make_shared<TestThreadLocalDebugInfo>();
1407 debug_info->setModelId(314);
1408 c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO_2, debug_info);
1409 {
1410 checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
1411 checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
1412 done = false;
1413 at::launch([]() {
1414 checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
1415 checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
1416 done = true;
1417 });
1418 while (!done) {
1419 }
1420 }
1421 }
1422 }
1423 }
1424}
1425
1426TEST(TestSymIntArrayRef, BasicConversion) {
1427 const size_t X = 2, Y = 4, Z = 5;
1428 std::vector<int64_t> tgt_size_v{2, 4, 5};
1429 std::vector<c10::SymInt> tgt_size({SymInt(X), SymInt(Y), SymInt(Z)});
1430 auto a = at::randn({1, 4, 1}, at::kCPU);
1431 auto b = a.expand_symint(tgt_size);
1432 auto c = a.expand(tgt_size_v);
1433 ASSERT_TRUE(torch::allclose(b, c));
1434}
1435
1436TEST(TestSymInt, NarrowCopyWithSymbolicInt) {
1437 static const size_t LENGTH = 5;
1438 auto a = at::randn({10}, at::kCPU);
1439 c10::SymInt si(LENGTH);
1440 auto b = a.narrow_copy_symint(0, 0, si);
1441 auto c = a.narrow(0, 0, LENGTH);
1442 ASSERT_TRUE(torch::allclose(b, c));
1443}
1444
1445TEST(TestSymInt, NarrowCopy) {
1446 static const size_t LENGTH = 5;
1447 auto a = at::randn({10}, at::kCPU);
1448 auto b = a.narrow_copy(0, 0, LENGTH);
1449 auto c = a.narrow(0, 0, LENGTH);
1450 ASSERT_TRUE(torch::allclose(b, c));
1451}
1452
1453TEST(TestSymInt, AddSymbolicInt) {
1454 c10::SymInt a(5);
1455 c10::SymInt b(3);
1456 ASSERT_TRUE((a + b).expect_int() == 8);
1457}
1458
1459TEST(FallbackGraphsTest, Basic) {
1460 auto x = at::randn({1}, at::kCPU);
1461 auto y = at::randn({1}, at::kCPU);
1462 auto stack = createStack({x.clone(), y.clone()});
1463
1464 auto graph_string = R"IR(
1465 graph(%0 : Float(1),
1466 %1 : Float(1)):
1467 %2 : Tensor = aten::mul(%0, %1)
1468 %3 : Tensor = aten::mul(%2, %0)
1469 return (%3))IR";
1470 auto graph = std::make_shared<Graph>();
1471 torch::jit::parseIR(graph_string, graph.get());
1472
1473 {
1474 Code code(graph, "");
1475 InterpreterState interpreter{code};
1476 interpreter.run(stack);
1477 }
1478 at::Tensor et;
1479 pop(stack, et);
1480 float ef = et.item<float>();
1481 {
1482 EnableProfilingGuard epg;
1483 GraphFunction f("fallbackGraphs", graph, nullptr);
1484 for (size_t i = 0; i < getNumProfiledRuns() + 1; i++) {
1485 stack.emplace_back(x.clone());
1486 stack.emplace_back(y.clone());
1487 if (i == getNumProfiledRuns()) {
1488 // we will be modifying a profiled graph
1489 // before ProfilingGraphExecutor
1490 // will optimize it in the next iteration
1491 auto opt_graph = lastExecutedOptimizedGraph();
1492 // this is safe to do since we are done profiling
1493 ProfilingRecord::removeProfileCounter(opt_graph->block());
1494 replaceBlockWithFallbackGraph(opt_graph->block(), opt_graph->inputs());
1495 auto it = opt_graph->block()->nodes().begin();
1496 ASSERT_EQ(it->kind(), prim::FallbackGraph);
1497 auto fallback = *it++;
1498 ASSERT_EQ(it, opt_graph->block()->nodes().end());
1499 ASSERT_TRUE(fallback->hasAttribute(attr::Subgraph));
1500 testing::FileCheck()
1501 .check("Tensor = aten::mul")
1502 ->check("Tensor = aten::mul")
1503 ->run(*fallback->g(attr::Subgraph));
1504 }
1505 f.run(stack);
1506 at::Tensor at;
1507 pop(stack, at);
1508 float af = at.item<float>();
1509 ASSERT_EQ(af, ef);
1510 }
1511
1512 auto opt_graph = lastExecutedOptimizedGraph();
1513 testing::FileCheck()
1514 .check("(Tensor) = prim::CallFunction")
1515 ->run(*opt_graph);
1516 }
1517}
1518
1519// TODO this test wasn't running and is broken.
1520// TEST(AutogradProfilerTest, Basic) {
1521// constexpr int batch_size = 4;
1522// constexpr int input_size = 256;
1523// constexpr int seq_len = 32;
1524
1525// int hidden_size = 2 * input_size;
1526// auto input = torch::randn({seq_len, batch_size, input_size}, at::kCPU);
1527// auto hx = torch::randn({batch_size, hidden_size}, at::kCPU);
1528// auto cx = torch::randn({batch_size, hidden_size}, at::kCPU);
1529// auto w_ih = t_def(torch::randn({4 * hidden_size, input_size}, at::kCPU));
1530// auto w_hh = t_def(torch::randn({4 * hidden_size, hidden_size}, at::kCPU));
1531
1532// std::stringstream ss;
1533// {
1534// RecordProfile guard(ss);
1535// for (size_t i = 0; i < 100; ++i) {
1536// std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
1537// }
1538// }
1539
1540// std::string result = ss.str();
1541// size_t count = 0;
1542// for (size_t pos = 0; (pos = result.find("tanh", pos)) != std::string::npos;
1543// count++, pos++) {
1544// }
1545// ASSERT_EQ((count, 200);
1546// }
1547
1548TEST(NoneSchemaMatchTest, Basic) {
1549 RegisterOperators reg({
1550 Operator(
1551 "prim::test_none() -> int?",
1552 [](Stack& stack) { push(stack, IValue()); },
1553 aliasAnalysisFromSchema()),
1554 Operator(
1555 "prim::is_none(int? a) -> bool",
1556 [](Stack& stack) {
1557 IValue a = pop(stack);
1558 if (a.isNone()) {
1559 push(stack, true);
1560 } else {
1561 push(stack, false);
1562 }
1563 },
1564 aliasAnalysisFromSchema()),
1565 });
1566
1567 // Constant propagation will run test_none and produce a None,
1568 // testing that its type is set appropriately and schema matching doesn't
1569 // fail when running is_none
1570
1571 auto r = std::make_shared<Graph>();
1572 auto& g = *r;
1573 auto opt_int = g.insert(Symbol::fromQualString("prim::test_none"), {});
1574 auto out_bool = g.insert(Symbol::fromQualString("prim::is_none"), {opt_int});
1575 g.registerOutput(out_bool);
1576 ConstantPropagation(r);
1577
1578 auto nodes = r->block()->nodes();
1579 // checking that constant propagation ran wo/failure
1580 AT_ASSERT(std::distance(nodes.begin(), nodes.end()) == 1);
1581}
1582
1583static int testPassValue = 0;
1584void fakePass(std::shared_ptr<Graph>& g) {
1585 testPassValue++;
1586 return;
1587}
1588
1589RegisterPass p(fakePass);
1590
1591TEST(PassManagementTest, Basic) {
1592 std::shared_ptr<Graph> graph = std::make_shared<Graph>();
1593 parseIR(
1594 R"IR(
1595graph(%a):
1596 return (%a))IR",
1597 &*graph);
1598
1599 std::vector<IValue> stack = {IValue(torch::randn({22}, at::kCPU))};
1600 auto run = [&](std::shared_ptr<Graph>& graph, std::vector<IValue> stack) {
1601 GraphExecutor executor(graph, "");
1602 executor.run(stack);
1603 return stack;
1604 };
1605 run(graph, stack);
1606 // we will not run fusion in simple mode
1607 if (!getExecutorMode()) {
1608 AT_ASSERT(testPassValue);
1609 }
1610}
1611
1612static void checkShape(TypePtr typ, std::vector<int64_t> expected) {
1613 auto ptp = typ->expect<TensorType>();
1614 ASSERT_EQ(ptp->sizes().concrete_sizes().value(), expected);
1615}
1616
1617static void checkShape(
1618 Node* n,
1619 std::vector<int64_t> expected,
1620 bool prev = true) {
1621 auto profile = (prev) ? n->inputs().at(0)->node() : n;
1622 checkShape(profile->output()->type(), expected);
1623}
1624
1625void count_(
1626 Block* block,
1627 const std::function<bool(Node* n)>& pred,
1628 size_t& count) {
1629 for (Node* n : block->nodes()) {
1630 if (pred(n)) {
1631 count++;
1632 }
1633
1634 for (Block* ib : n->blocks()) {
1635 count_(ib, pred, count);
1636 }
1637 }
1638}
1639
1640size_t countNodes(
1641 const std::shared_ptr<Graph>& graph,
1642 const std::function<bool(Node* n)>& pred) {
1643 size_t count = 0;
1644 count_(graph->block(), pred, count);
1645 return count;
1646}
1647
1648bool true_pred(Node* n) {
1649 return true;
1650};
1651
1652bool is_loop(Node* n) {
1653 return n->kind() == prim::Loop;
1654};
1655
1656TEST(LoopPeelerTest, NoInductionVariableUse) {
1657 // do not use an induction variable explicitly
1658 static const auto str_func_def = R"JIT(
1659 def test_peel_n_times():
1660 sum = 0
1661 for i in range(10):
1662 sum += 2
1663 return sum
1664 )JIT";
1665
1666 auto cu = compile(str_func_def);
1667 auto& f = toGraphFunction(cu->get_function("test_peel_n_times"));
1668 auto stack = createStack({});
1669 // peeling loop once
1670 {
1671 LoopsPeeler peeler(true_pred, 1);
1672 auto copy = f.graph()->copy();
1673 peeler.run(copy);
1674 int num_loops =
1675 std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
1676 ASSERT_EQ(num_loops, 2);
1677 Code code(copy, "");
1678 InterpreterState interpreter{code};
1679 interpreter.run(stack);
1680 ASSERT_EQ(stack.back().toInt(), 20);
1681 }
1682
1683 // test peeling more than one iteration
1684 {
1685 LoopsPeeler peeler(true_pred, 3);
1686 auto copy = f.graph()->copy();
1687 peeler.run(copy);
1688 int num_loops =
1689 std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
1690 ASSERT_EQ(num_loops, 2);
1691 Code code(copy, "");
1692 InterpreterState interpreter{code};
1693 interpreter.run(stack);
1694 ASSERT_EQ(stack.back().toInt(), 20);
1695 }
1696}
1697
1698TEST(LoopPeelerTest, YesInductionVariableUse) {
1699 // uses the induction variable
1700 static const auto str_func_def = R"JIT(
1701 def test_peel_n_times():
1702 sum = 0
1703 for i in range(10):
1704 sum += i
1705 return sum
1706 )JIT";
1707
1708 auto cu = compile(str_func_def);
1709 auto& f = toGraphFunction(cu->get_function("test_peel_n_times"));
1710 auto stack = createStack({});
1711 // peeling loop once
1712 {
1713 LoopsPeeler peeler(true_pred, 1);
1714 auto copy = f.graph()->copy();
1715 peeler.run(copy);
1716 int num_loops =
1717 std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
1718 ASSERT_EQ(num_loops, 2);
1719 Code code(copy, "");
1720 InterpreterState interpreter{code};
1721 interpreter.run(stack);
1722 ASSERT_EQ(stack.back().toInt(), 45);
1723 }
1724
1725 // test peeling more than one iteration
1726 {
1727 LoopsPeeler peeler(true_pred, 3);
1728 auto copy = f.graph()->copy();
1729 peeler.run(copy);
1730 int num_loops =
1731 std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
1732 ASSERT_EQ(num_loops, 2);
1733 Code code(copy, "");
1734 InterpreterState interpreter{code};
1735 interpreter.run(stack);
1736 ASSERT_EQ(stack.back().toInt(), 45);
1737 }
1738}
1739
1740TEST(LoopPeelerTest, LoopWithTerminationCondition) {
1741 // tests with explicit termination conditions
1742 static const auto str_func_def = R"JIT(
1743 def test_with_cond_times():
1744 sum = 0
1745 i = 0
1746 while (sum < 2):
1747 sum += i
1748 i += 1
1749 return sum
1750 )JIT";
1751
1752 // the peel changes the termination condition to false
1753 // so the original loop doesn't run
1754 auto cu = compile(str_func_def);
1755 auto& f = toGraphFunction(cu->get_function("test_with_cond_times"));
1756 auto stack = createStack({});
1757 // peeling 5 iterations should update the termination
1758 // condition to false
1759 {
1760 LoopsPeeler peeler(true_pred, 5);
1761 auto copy = f.graph()->copy();
1762 peeler.run(copy);
1763 int num_loops =
1764 std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
1765 ASSERT_EQ(num_loops, 2);
1766 Code code(copy, "");
1767 InterpreterState interpreter{code};
1768 interpreter.run(stack);
1769 ASSERT_EQ(stack.back().toInt(), 3);
1770 }
1771
1772 // the termination condition remains true
1773 {
1774 LoopsPeeler peeler(true_pred, 1);
1775 auto copy = f.graph()->copy();
1776 peeler.run(copy);
1777 int num_loops =
1778 std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
1779 ASSERT_EQ(num_loops, 2);
1780 Code code(copy, "");
1781 InterpreterState interpreter{code};
1782 interpreter.run(stack);
1783 ASSERT_EQ(stack.back().toInt(), 3);
1784 }
1785}
1786
1787// tests simple nested loops
1788TEST(LoopPeelerTest, SimpleNestedLoops) {
1789 static const auto str_func_def = R"JIT(
1790 def test_nested_loops():
1791 sum = 0
1792 i = 0
1793 for i in range(10):
1794 for j in range(10):
1795 sum += i + j
1796 return sum
1797 )JIT";
1798
1799 auto cu = compile(str_func_def);
1800 auto& f = toGraphFunction(cu->get_function("test_nested_loops"));
1801 auto stack = createStack({});
1802
1803 {
1804 LoopsPeeler peeler(true_pred, 1);
1805 auto copy = f.graph()->copy();
1806 peeler.run(copy);
1807 ASSERT_EQ(countNodes(copy, is_loop), 5);
1808 Code code(copy, "");
1809 InterpreterState interpreter{code};
1810 interpreter.run(stack);
1811 ASSERT_EQ(stack.back().toInt(), 900);
1812 }
1813
1814 {
1815 LoopsPeeler peeler(true_pred, 5);
1816 auto copy = f.graph()->copy();
1817 peeler.run(copy);
1818 ASSERT_EQ(countNodes(copy, is_loop), 5);
1819 Code code(copy, "");
1820 InterpreterState interpreter{code};
1821 interpreter.run(stack);
1822 ASSERT_EQ(stack.back().toInt(), 900);
1823 }
1824}
1825
1826TEST(LoopPeelerTest, SimpleNestedLoops2) {
1827 static const auto str_func_def = R"JIT(
1828 def test_nested_loops():
1829 sum = 0
1830 i = 0
1831 for i in range(10):
1832 j = 0
1833 while sum < 2:
1834 sum += i + j
1835 j += 1
1836 return sum
1837 )JIT";
1838
1839 auto cu = compile(str_func_def);
1840 auto& f = toGraphFunction(cu->get_function("test_nested_loops"));
1841 auto stack = createStack({});
1842 {
1843 LoopsPeeler peeler(true_pred, 1);
1844 auto copy = f.graph()->copy();
1845 peeler.run(copy);
1846 ASSERT_EQ(countNodes(copy, is_loop), 5);
1847 Code code(copy, "");
1848 InterpreterState interpreter{code};
1849 interpreter.run(stack);
1850 ASSERT_EQ(stack.back().toInt(), 3);
1851 }
1852
1853 {
1854 LoopsPeeler peeler(true_pred, 5);
1855 auto copy = f.graph()->copy();
1856 peeler.run(copy);
1857 ASSERT_EQ(countNodes(copy, is_loop), 5);
1858 Code code(copy, "");
1859 InterpreterState interpreter{code};
1860 interpreter.run(stack);
1861 ASSERT_EQ(stack.back().toInt(), 3);
1862 }
1863}
1864
1865TEST(JitTracing, Basic) {
1866 constexpr int batch_size = 4;
1867 constexpr int input_size = 256;
1868
1869 int hidden_size = 2 * input_size;
1870
1871 auto input = at::randn({batch_size, input_size}, at::kCPU);
1872 auto hx = at::randn({batch_size, hidden_size}, at::kCPU);
1873 auto cx = at::randn({batch_size, hidden_size}, at::kCPU);
1874 auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCPU));
1875 auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCPU));
1876
1877 auto graph = build_lstm();
1878 auto stack = createStack({input, hx, cx, w_ih, w_hh});
1879 auto traced = TraceGraph(graph, stack);
1880
1881 // Check that the inputs of traced graph have the same type as the inputs
1882 // specified here.
1883 ASSERT_EQ(*traced->inputs().at(0)->type(), *TensorType::create(input));
1884 ASSERT_EQ(*traced->inputs().at(1)->type(), *TensorType::create(hx));
1885 ASSERT_EQ(*traced->inputs().at(2)->type(), *TensorType::create(cx));
1886 ASSERT_EQ(*traced->inputs().at(3)->type(), *TensorType::create(w_ih));
1887 ASSERT_EQ(*traced->inputs().at(4)->type(), *TensorType::create(w_hh));
1888
1889 Tensor prof_out;
1890 pop(stack, prof_out);
1891
1892 {
1893 stack = createStack({input, hx, cx, w_ih, w_hh});
1894 Code cd(traced, "traced");
1895 InterpreterState is{cd};
1896 is.run(stack);
1897 Tensor traced_out;
1898 pop(stack, traced_out);
1899 torch::allclose(prof_out, traced_out);
1900 }
1901
1902 {
1903 stack = createStack({input, hx, cx, w_ih, w_hh});
1904 Code cd(graph, "graph");
1905 InterpreterState is{cd};
1906 is.run(stack);
1907 Tensor scripted_out;
1908 pop(stack, scripted_out);
1909 torch::allclose(prof_out, scripted_out);
1910 }
1911}
1912
1913// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
1914TEST(InsertAndEliminateRedundantGuardsTest, Basic) {
1915 static const auto basic_example = R"JIT(
1916 def basic(x, y):
1917 a = x + y
1918 b = x * y
1919 c = x + 1
1920 d = a - c
1921 e = b - c
1922 return d + e
1923 )JIT";
1924
1925 auto cu = compile(basic_example);
1926 auto& fun = toGraphFunction(cu->get_function("basic"));
1927 auto pr = ProfilingRecord::instrumentGraph(fun.graph());
1928 auto x = at::randn({2, 3}, at::kCPU);
1929 auto y = at::randn({2, 3}, at::kCPU);
1930 auto stack = createStack({x, y});
1931 // introduce some profiling information
1932 Code cd(pr->profiled_graph_, "");
1933 InterpreterState is{cd};
1934 is.run(stack);
1935 auto copy = pr->profiled_graph_->copy();
1936 ProfilingRecord::removeProfileCounter(copy->block());
1937 InsertGuards(copy);
1938 auto nodes = copy->block()->nodes();
1939 auto guard = std::find_if(nodes.begin(), nodes.end(), [](Node* n) {
1940 return n->kind() == prim::Guard;
1941 });
1942 ASSERT_NE(guard, nodes.end());
1943 ASSERT_EQ(
1944 guard->input()->type()->expectRef<TensorType>().sizes().size(),
1945 c10::nullopt);
1946 checkShape(*guard, {2, 3}, false);
1947 auto is_guard = [](Node* n) { return n->kind() == prim::Guard; };
1948 int num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
1949 ASSERT_EQ(num_guards, 12);
1950 // now eliminate as many guards as possible
1951 // we should be left with two guards on x and y's defs
1952 EliminateRedundantGuards(copy);
1953 num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
1954 ASSERT_EQ(num_guards, 2);
1955}
1956
1957TEST(InsertBailOutsTest, Basic) {
1958 static const auto basic_example = R"JIT(
1959 def basic_loop(x, y):
1960
1961 a = x + 1
1962 b = y + 2
1963 c = x + y + 3
1964
1965 for i in range(10):
1966 a = a + b
1967 # invariant
1968 d = b * c
1969 #
1970 a = a - d
1971
1972 e = a + 4
1973 return e
1974 )JIT";
1975
1976 auto cu = compile(basic_example);
1977 auto& fun = toGraphFunction(cu->get_function("basic_loop"));
1978 auto pr = ProfilingRecord::instrumentGraph(fun.graph());
1979 auto x = at::randn({2, 3}, at::kCPU);
1980 auto y = at::randn({2, 3}, at::kCPU);
1981 auto stack = createStack({x, y});
1982 // introduce some profiling information
1983 Code cd(pr->profiled_graph_, "");
1984 InterpreterState is{cd};
1985 is.run(stack);
1986 auto copy = pr->profiled_graph_->copy();
1987 ProfilingRecord::removeProfileCounter(copy->block());
1988 InsertGuards(copy);
1989 EliminateRedundantGuards(copy);
1990 auto nodes = copy->block()->nodes();
1991 auto is_guard = [](Node* n) { return n->kind() == prim::Guard; };
1992 auto num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
1993 ASSERT_EQ(num_guards, 3);
1994 InsertBailOuts(copy);
1995 auto is_bailout = [](Node* n) { return n->kind() == prim::BailOut; };
1996 auto num_bailouts = std::count_if(nodes.begin(), nodes.end(), is_bailout);
1997 ASSERT_EQ(num_guards, num_bailouts);
1998 std::vector<Node*> bailouts(num_bailouts);
1999 std::copy_if(nodes.begin(), nodes.end(), bailouts.begin(), is_bailout);
2000
2001 for (auto blo : bailouts) {
2002 ASSERT_EQ(blo->inputs().at(0)->node()->kind(), prim::BailoutTemplate);
2003 }
2004}
2005
2006TEST(ProfilerTest, Basic) {
2007 constexpr int batch_size = 4;
2008 constexpr int input_size = 256;
2009
2010 int hidden_size = 2 * input_size;
2011
2012 auto input = at::randn({batch_size, input_size}, at::kCPU);
2013 auto hx = at::randn({batch_size, hidden_size}, at::kCPU);
2014 auto cx = at::randn({batch_size, hidden_size}, at::kCPU);
2015 auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCPU));
2016 auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCPU));
2017
2018 auto g = build_lstm();
2019 auto stack = createStack({input, hx, cx, w_ih, w_hh});
2020
2021 auto& opt_graph = *g.get();
2022 ArgumentSpecCreator arg_spec_creator(opt_graph);
2023 ArgumentSpec spec =
2024 arg_spec_creator.create(autograd::GradMode::is_enabled(), stack);
2025 arg_spec_creator.specializeTypes(opt_graph, spec);
2026 auto pr = ProfilingRecord::instrumentGraph(g);
2027 Code cd(pr->profiled_graph_, "");
2028 InterpreterState is{cd};
2029 is.run(stack);
2030
2031 // profiled types are stored as attributes and show up in the dump, e.g.
2032 // Tensor = prim::profile[profiled_type=Double(4, 256, strides=[256, 1],
2033 // requires_grad=0, device=cpu)
2034 testing::FileCheck()
2035 .check("Tensor = prim::profile[profiled_type")
2036 ->check_same("256")
2037 ->run(*pr->profiled_graph_);
2038
2039 auto begin = pr->profiled_graph_->block()->nodes().begin();
2040 auto end = pr->profiled_graph_->block()->nodes().end();
2041 auto mm =
2042 std::find_if(begin, end, [](Node* n) { return n->kind() == aten::add; });
2043 ASSERT_NE(mm, end);
2044 std::vector<int64_t> mm_expected{4, 2048};
2045 std::vector<int64_t> eltwise{4, 512};
2046 checkShape(mm->inputs().at(0)->node()->ty(attr::profiled_type), mm_expected);
2047 auto mul_n =
2048 std::find_if(begin, end, [](Node* n) { return n->kind() == aten::mul; });
2049 ASSERT_NE(mul_n, end);
2050 checkShape(mul_n->inputs().at(0)->node()->ty(attr::profiled_type), eltwise);
2051 auto tanh_n =
2052 std::find_if(begin, end, [](Node* n) { return n->kind() == aten::tanh; });
2053 checkShape(tanh_n->inputs().at(0)->node()->ty(attr::profiled_type), eltwise);
2054}
2055
2056TEST(ProfilerTest, OptionalProfiling) {
2057 auto graph = std::make_shared<Graph>();
2058 std::unordered_map<std::string, Value*> vmap;
2059 parseIR(
2060 R"IR(
2061graph(%inp : Tensor,
2062 %weight : Tensor,
2063 %bias : Tensor?):
2064 %1 : Tensor = aten::linear(%inp, %weight, %bias)
2065 return (%1))IR",
2066 &*graph,
2067 vmap);
2068
2069 auto pr = ProfilingRecord::instrumentGraph(graph);
2070 pr->profiling_count_ = 2;
2071
2072 auto input = torch::randn({1, 2});
2073 auto weight = torch::randn({2, 2});
2074 auto bias = torch::randn({1, 2});
2075
2076 auto stack = createStack({input, weight, bias});
2077 Code cd(pr->profiled_graph_, "");
2078 InterpreterState is{cd};
2079 is.run(stack);
2080
2081 testing::FileCheck()
2082 .check_count("Tensor? = prim::profile[profiled_type", 1, true)
2083 ->run(*pr->profiled_graph_);
2084
2085 // make sure we recorded the shape
2086 auto begin = pr->profiled_graph_->block()->nodes().begin();
2087 auto end = pr->profiled_graph_->block()->nodes().end();
2088 auto linear = std::find_if(
2089 begin, end, [](Node* n) { return n->kind() == aten::linear; });
2090 ASSERT_NE(linear, end);
2091 std::vector<int64_t> bias_expected_shape = {1, 2};
2092 auto profiled_bias = linear->namedInput("bias")->node();
2093 checkShape(profiled_bias->ty(attr::profiled_type), bias_expected_shape);
2094 ASSERT_EQ(0, profiled_bias->i(attr::seen_none));
2095
2096 auto none_bias = c10::IValue();
2097
2098 stack.clear();
2099 stack.emplace_back(input);
2100 stack.emplace_back(weight);
2101 stack.emplace_back(none_bias);
2102 is = InterpreterState{cd};
2103 is.run(stack);
2104
2105 // make sure we recorded that "None" was seen.
2106 begin = pr->profiled_graph_->block()->nodes().begin();
2107 end = pr->profiled_graph_->block()->nodes().end();
2108 linear = std::find_if(
2109 begin, end, [](Node* n) { return n->kind() == aten::linear; });
2110 ASSERT_NE(linear, end);
2111 profiled_bias = linear->namedInput("bias")->node();
2112 checkShape(profiled_bias->ty(attr::profiled_type), bias_expected_shape);
2113 ASSERT_EQ(1, profiled_bias->i(attr::seen_none));
2114}
2115
2116TEST(CallStackTest, Basic) {
2117 const auto text = R"(
2118def ham(x):
2119 return x/7
2120
2121def bar(x):
2122 return x*3
2123
2124def baz(x):
2125 return ham(x)*x
2126
2127def foo(x):
2128 return bar(x)*baz(x)*11
2129 )";
2130 auto cu = compile(text);
2131 const auto& foo = toGraphFunction(cu->get_function("foo"));
2132 for (Node* n : foo.optimized_graph()->nodes()) {
2133 if (n->kind() == prim::Constant) {
2134 if (!n->hasAttribute(attr::value) ||
2135 n->kindOf(attr::value) != AttributeKind::i) {
2136 continue;
2137 }
2138 int v = n->i(attr::value);
2139 switch (v) {
2140 case 3: {
2141 // Const 3 comes from function 'bar', which gets inlined to 'foo'.
2142 // The callstack for the corresponding node should contain only the
2143 // function 'bar'.
2144 ASSERT_TRUE(n->callstack());
2145 auto callstack_vector = (*n->callstack())->vec();
2146 ASSERT_EQ(callstack_vector.size(), 1);
2147 ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("bar"));
2148 break;
2149 }
2150 case 7: {
2151 // Const 7 comes from function 'ham', which gets inlined to 'baz',
2152 // which is then inlined to 'foo'. The callstack for the corresponding
2153 // node should contain these two functions.
2154 ASSERT_TRUE(n->callstack());
2155 auto callstack_vector = (*n->callstack())->vec();
2156 ASSERT_EQ(callstack_vector.size(), 2);
2157 ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("baz"));
2158 ASSERT_EQ(std::get<0>(callstack_vector[1]), &cu->get_function("ham"));
2159 break;
2160 }
2161 case 11: {
2162 // Const 11 comes from function 'foo', which is not inlined anywhere
2163 // and thus it should not have a callstack.
2164 ASSERT_FALSE(n->callstack());
2165 break;
2166 }
2167 }
2168 }
2169 }
2170
2171 // Check that inlining doesn't corrupt callstack of the callee's nodes.
2172 const auto& baz = toGraphFunction(cu->get_function("baz"));
2173 for (Node* n : baz.optimized_graph()->nodes()) {
2174 if (n->kind() == prim::Constant) {
2175 if (!n->hasAttribute(attr::value) ||
2176 n->kindOf(attr::value) != AttributeKind::i) {
2177 continue;
2178 }
2179 int v = n->i(attr::value);
2180 ASSERT_TRUE(v == 7);
2181 // Const 7 comes from function 'ham', which gets inlined to 'baz'. 'baz'
2182 // was also inlined into 'foo', but when looking at the graph of 'baz' we
2183 // should only see a callstack of depth 1 (containing only 'ham').
2184 ASSERT_TRUE(n->callstack());
2185 auto callstack_vector = (*n->callstack())->vec();
2186 ASSERT_EQ(callstack_vector.size(), 1);
2187 ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("ham"));
2188 }
2189 }
2190}
2191
2192TEST(CallStackTest, Caching) {
2193 const auto text = R"(
2194
2195def a(x):
2196 print("a1")
2197 print("a2")
2198 return x
2199
2200def b(x):
2201 print("b1")
2202 print("b2")
2203 a(x)
2204 return x
2205
2206def c(x):
2207 print("c1")
2208 print("c2")
2209 b(x)
2210 return x
2211 )";
2212 auto cu = compile(text);
2213 const auto& baz = toGraphFunction(cu->get_function("c"));
2214 std::unordered_map<std::string, InlinedCallStack*> callstack_objects;
2215 for (Node* n : baz.optimized_graph()->nodes()) {
2216 if (n->kind() == prim::Constant) {
2217 if (!n->hasAttribute(attr::value) ||
2218 n->kindOf(attr::value) != AttributeKind::s) {
2219 continue;
2220 }
2221 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
2222 std::string v = n->s(attr::value);
2223 if (n->callstack()) {
2224 callstack_objects[v] = n->callstack()->get();
2225 }
2226 }
2227 }
2228 // We expect to see nodes prim::Constant[value="a1"] and
2229 // prim::Constant[value="a2"] inlined to function 'c'. Their callstacks are
2230 // the same (a->b->c), so we want to make sure we're not creating different
2231 // callstack entries for them.
2232 ASSERT_TRUE(callstack_objects.count("a1") && callstack_objects.count("a2"));
2233 ASSERT_TRUE(callstack_objects.at("a1") == callstack_objects.at("a2"));
2234}
2235
2236TEST(InlinedCallStackTest, BlockAnnotation) {
2237 Module a("A");
2238 a.define(R"(
2239 def forward(self, x, y, z: int):
2240 if (z == 1):
2241 return x + y
2242 else:
2243 return x * y
2244 )");
2245 Module b("B");
2246 b.define(R"(
2247 def forward(self, x):
2248 return x + 2
2249 )");
2250 Module c("C");
2251 c.register_module("A0", a);
2252 c.register_module("B0", b);
2253 c.define(R"(
2254 def forward(self, x, y, z: int):
2255 return self.A0.forward(x, y, z) + self.B0.forward(x)
2256 )");
2257
2258 auto graph =
2259 toGraphFunction(c.get_method("forward").function()).optimized_graph();
2260 std::stringstream add_ss, mul_ss;
2261 for (Node* n : graph->nodes()) {
2262 if (n->kind() == prim::If) {
2263 for (Block* block : n->blocks()) {
2264 for (Node* if_node : block->nodes()) {
2265 if (if_node->kind() == aten::add) {
2266 for (const auto& e : if_node->callstack().value()->vec()) {
2267 add_ss << std::get<1>(e);
2268 }
2269 add_ss << if_node->sourceRange();
2270 }
2271 if (if_node->kind() == aten::mul) {
2272 for (const auto& e : if_node->callstack().value()->vec()) {
2273 mul_ss << std::get<1>(e);
2274 }
2275 mul_ss << if_node->sourceRange();
2276 }
2277 }
2278 }
2279 }
2280 }
2281 ASSERT_NE(add_ss.str().find("line 3"), std::string::npos);
2282 ASSERT_NE(add_ss.str().find("line 4"), std::string::npos);
2283 ASSERT_NE(
2284 add_ss.str().find("return self.A0.forward(x, y, z)"), std::string::npos);
2285 ASSERT_NE(add_ss.str().find("return x + y"), std::string::npos);
2286 ASSERT_NE(mul_ss.str().find("line 3"), std::string::npos);
2287 ASSERT_NE(mul_ss.str().find("line 6"), std::string::npos);
2288 ASSERT_NE(
2289 mul_ss.str().find("return self.A0.forward(x, y, z)"), std::string::npos);
2290 ASSERT_NE(mul_ss.str().find("return x * y"), std::string::npos);
2291}
2292
2293// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
2294TEST(InlinedCallStackTest, SelfCallMethods) {
2295 Module a("A");
2296 a.define(R"(
2297 def my_new_method(self, x):
2298 return x * 3
2299 def forward_impl_(self, x, y):
2300 return self.my_new_method(x) + y
2301 def forward(self, x, y):
2302 y = y + 2
2303 return self.forward_impl_(x, y)
2304 )");
2305 Module b("B");
2306 b.define(R"(
2307 def forward(self, x):
2308 return x + 2
2309 )");
2310 Module c("C");
2311 c.register_module("A0", a);
2312 c.register_module("B0", b);
2313 c.define(R"(
2314 def call_b(self, x):
2315 return self.B0.forward(x)
2316 def forward(self, x, y):
2317 return self.A0.forward(x, y) + self.call_b(x)
2318 )");
2319
2320 auto graph =
2321 toGraphFunction(c.get_method("forward").function()).optimized_graph();
2322 std::unordered_map<std::string, size_t> module_hierarchies;
2323 for (Node* n : graph->nodes()) {
2324 auto hierarchy = torch::jit::utils::getNodesModuleHierarchy(*n);
2325 if (module_hierarchies.count(hierarchy) == 0) {
2326 module_hierarchies[hierarchy] = 0;
2327 }
2328 module_hierarchies[hierarchy] += 1;
2329 }
2330 ASSERT_EQ(module_hierarchies["A0(A)"], 2);
2331 ASSERT_EQ(module_hierarchies["A0(A).SELF(A).SELF(A)"], 2);
2332 ASSERT_EQ(module_hierarchies["A0(A).SELF(A)"], 1);
2333 ASSERT_EQ(module_hierarchies["SELF(C)"], 1);
2334 ASSERT_EQ(module_hierarchies["SELF(C).B0(B)"], 1);
2335}
2336
2337TEST(AutogradSymbolsTest, Basic) {
2338 Symbol sym = Symbol::fromQualString("aten::test_symbol");
2339 Graph graph;
2340 auto node = graph.create(sym);
2341 TORCH_CHECK(canRunWithAutograd(node));
2342
2343 sym = Symbol::fromQualString("prim::test_symbol");
2344 node = graph.create(sym);
2345 TORCH_CHECK(canRunWithAutograd(node));
2346
2347 sym = Symbol::fromQualString("prim::FusionGroup");
2348 node = graph.create(sym);
2349 TORCH_CHECK(!canRunWithAutograd(node));
2350
2351 sym = Symbol::fromQualString("custom::test_symbol");
2352 node = graph.create(sym);
2353 TORCH_CHECK(!canRunWithAutograd(node));
2354}
2355
2356TEST(DefaultArgTypeHintingTest, Basic) {
2357 const auto text_non_hinted = R"(
2358
2359def a(x, y=1):
2360 print("a1")
2361 print("a2")
2362 return x
2363 )";
2364
2365 const auto text_hinted = R"(
2366
2367def a(x, y:int=1):
2368 print("a1")
2369 print("a2")
2370 return x
2371 )";
2372
2373 try {
2374 compile(text_non_hinted);
2375 ASSERT_TRUE(0);
2376 } catch (const std::exception& c) {
2377 }
2378
2379 auto cu = compile(text_hinted);
2380}
2381
2382// Basic set case.
2383TEST(FuturesTest, Basic) {
2384 auto f1 = c10::make_intrusive<Future>(IntType::get());
2385 ASSERT_FALSE(f1->completed());
2386 ASSERT_FALSE(f1->hasValue());
2387 int32_t sat1 = 0;
2388 int32_t sat2 = 0;
2389 f1->addCallback([&](Future& /* unused */) { ++sat1; });
2390 f1->markCompleted(43);
2391 ASSERT_TRUE(f1->completed());
2392 ASSERT_TRUE(f1->hasValue());
2393 ASSERT_FALSE(f1->hasError());
2394 ASSERT_EQ(sat1, 1);
2395 ASSERT_EQ(f1->constValue().toInt(), 43);
2396 ASSERT_EQ(f1->value().toInt(), 43);
2397 f1->addCallback([&](Future& /* unused */) { ++sat2; });
2398 ASSERT_EQ(sat1, 1);
2399 ASSERT_EQ(sat2, 1);
2400}
2401
2402// Basic error cases.
2403TEST(FuturesTest, Error) {
2404 auto f1 = c10::make_intrusive<Future>(IntType::get());
2405 int sat1 = 0;
2406 int sat2 = 0;
2407 f1->addCallback([&](Future& /* unused */) { ++sat1; });
2408 f1->setError(
2409 std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed")));
2410 ASSERT_EQ(sat1, 1);
2411 ASSERT_TRUE(f1->completed());
2412 ASSERT_TRUE(f1->hasError());
2413 ASSERT_FALSE(f1->hasValue());
2414 try {
2415 (void)f1->value();
2416 ASSERT_TRUE(false); // Supposed to throw.
2417 } catch (const std::exception& e) {
2418 ASSERT_TRUE(strcmp(e.what(), "Failed") == 0);
2419 }
2420 f1->addCallback([&](Future& /* unused */) { ++sat2; });
2421 ASSERT_EQ(sat1, 1);
2422 ASSERT_EQ(sat2, 1);
2423 f1->setErrorIfNeeded(
2424 std::make_exception_ptr(c10::ivalue::Future::FutureError("Dup")));
2425 ASSERT_TRUE(strcmp(f1->tryRetrieveErrorMessage().c_str(), "Failed") == 0);
2426 ASSERT_EQ(sat1, 1);
2427 ASSERT_EQ(sat2, 1);
2428 try {
2429 (void)f1->constValue();
2430 ASSERT_TRUE(false); // Supposed to throw.
2431 } catch (const std::exception& e) {
2432 // Original error should be logged.
2433 ASSERT_TRUE(std::string(e.what()).find("Failed") != std::string::npos);
2434 }
2435}
2436
2437// then
2438TEST(FuturesTest, Then) {
2439 auto f1 = c10::make_intrusive<Future>(IntType::get());
2440 auto f2 = f1->then(
2441 [](Future& f1) -> IValue { return f1.constValue().toInt() + 1; },
2442 IntType::get());
2443 auto f3 = f2->then(
2444 [](Future& f2) -> IValue { return f2.constValue().toInt() * 3; },
2445 IntType::get());
2446 bool done = false;
2447 f3->addCallback([&done](Future& f3) {
2448 ASSERT_EQ(f3.constValue().toInt(), (42 + 1) * 3);
2449 done = true;
2450 });
2451 ASSERT_FALSE(done);
2452 f1->markCompleted(42);
2453 ASSERT_TRUE(done);
2454}
2455
2456// collectAll()
2457TEST(FuturesTest, CollectAll) {
2458 auto s1 = c10::make_intrusive<Future>(IntType::get());
2459 auto s2 = c10::make_intrusive<Future>(IntType::get());
2460 auto s3 = c10::make_intrusive<Future>(IntType::get());
2461
2462 // Empty case
2463 c10::List<intrusive_ptr<ivalue::Future>> futures(
2464 FutureType::create(IntType::get()));
2465 auto c1 = collectAll(futures);
2466 ASSERT_TRUE(c1->completed());
2467 ASSERT_EQ(c1->value().toList().size(), 0);
2468 ASSERT_TRUE(
2469 *(c1->value().toList().elementType()) ==
2470 *FutureType::create(IntType::get()));
2471
2472 // 1-element, initially not completed.
2473 futures.push_back(s1);
2474 auto c2 = collectAll(futures);
2475 ASSERT_FALSE(c2->completed());
2476 s1->markCompleted(5);
2477 ASSERT_TRUE(c2->completed());
2478 ASSERT_EQ(c2->value().toList().size(), 1);
2479 ASSERT_TRUE(
2480 *(c2->value().toList().elementType()) ==
2481 *FutureType::create(IntType::get()));
2482 ASSERT_EQ(c2->value().toList().get(0).toFuture()->value().toInt(), 5);
2483
2484 // 1-element, already completed
2485 auto c3 = collectAll(futures);
2486 ASSERT_TRUE(c3->completed());
2487 ASSERT_EQ(c3->value().toList().size(), 1);
2488 ASSERT_EQ(c3->value().toList().get(0).toFuture()->value().toInt(), 5);
2489
2490 // 3 elements.
2491 futures.push_back(s2);
2492 futures.push_back(s3);
2493 auto c4 = collectAll(futures);
2494 ASSERT_FALSE(c4->completed());
2495 s3->markCompleted(7);
2496 ASSERT_FALSE(c4->completed());
2497 s2->markCompleted(6);
2498 ASSERT_TRUE(c4->completed());
2499 ASSERT_EQ(c4->value().toList().size(), 3);
2500 ASSERT_EQ(c4->value().toList().get(0).toFuture()->value().toInt(), 5);
2501 ASSERT_EQ(c4->value().toList().get(1).toFuture()->value().toInt(), 6);
2502 ASSERT_EQ(c4->value().toList().get(2).toFuture()->value().toInt(), 7);
2503 ASSERT_TRUE(
2504 *(c4->value().toList().elementType()) ==
2505 *FutureType::create(IntType::get()));
2506
2507 // Handle exception in the list.
2508 auto s4 = c10::make_intrusive<Future>(IntType::get());
2509 futures.push_back(s4);
2510 auto c5 = collectAll(futures);
2511 ASSERT_FALSE(c5->completed());
2512 s4->setError(
2513 std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed")));
2514 ASSERT_TRUE(c5->completed());
2515 try {
2516 c5->value();
2517 ASSERT_TRUE(false); // supposed to throw
2518 } catch (const std::exception& e) {
2519 ASSERT_EQ(std::string(e.what()), "Failed");
2520 }
2521}
2522
2523// collectAny()
2524TEST(FuturesTest, CollectAny) {
2525 auto s1 = c10::make_intrusive<Future>(IntType::get());
2526
2527 // Empty case
2528 c10::List<intrusive_ptr<ivalue::Future>> futures(
2529 FutureType::create(IntType::get()));
2530 auto c1 = collectAny(futures);
2531 ASSERT_TRUE(c1->completed());
2532
2533 // 1 element, not yet satisfied
2534 futures.push_back(s1);
2535 auto c2 = collectAny(futures);
2536 ASSERT_FALSE(c2->completed());
2537 s1->markCompleted(5);
2538 ASSERT_TRUE(c2->completed());
2539 ASSERT_TRUE(c2->value().isInt());
2540 ASSERT_EQ(c2->value().toInt(), 5);
2541
2542 // 1 element already satisfied.
2543 auto c3 = collectAny(futures);
2544 ASSERT_TRUE(c3->completed());
2545 ASSERT_TRUE(c3->value().isInt());
2546 ASSERT_EQ(c3->value().toInt(), 5);
2547
2548 // 2 elements
2549 futures.clear();
2550 auto s2 = c10::make_intrusive<Future>(IntType::get());
2551 auto s3 = c10::make_intrusive<Future>(IntType::get());
2552 futures.push_back(s2);
2553 futures.push_back(s3);
2554 auto c4 = collectAny(futures);
2555 ASSERT_FALSE(c4->completed());
2556 s3->markCompleted(7);
2557 ASSERT_TRUE(c4->completed());
2558 ASSERT_EQ(c4->value().toInt(), 7);
2559 s2->markCompleted(1);
2560 ASSERT_EQ(c4->value().toInt(), 7);
2561}
2562
2563TEST(TLSFutureCallbacksTest, Basic) {
2564 // cb that verifies the profiler is enabled
2565 auto profilerEnabledCb = [](Future& /* unused */) {
2566 ASSERT_TRUE(torch::autograd::profiler::profilerEnabled());
2567 };
2568 // test running callbacks with propagation of TLS state.
2569 {
2570 // Enable the profiler in this thread
2571 torch::autograd::profiler::enableProfilerLegacy(
2572 torch::autograd::profiler::ProfilerConfig(
2573 torch::autograd::profiler::ProfilerState::CPU, false, false));
2574 auto s1 = c10::make_intrusive<Future>(IntType::get());
2575 s1->addCallback(wrapPropagateTLSState(profilerEnabledCb));
2576 std::thread t([s1 = std::move(s1)]() { s1->markCompleted(); });
2577 // Since we join here, we can ensure that all callbacks corresponding to
2578 // markCompleted() have finished.
2579 t.join();
2580 torch::autograd::profiler::disableProfilerLegacy();
2581 }
2582 // then() with TLS State
2583 {
2584 // Enable the profiler in this thread
2585 torch::autograd::profiler::enableProfilerLegacy(
2586 torch::autograd::profiler::ProfilerConfig(
2587 torch::autograd::profiler::ProfilerState::CPU, false, false));
2588 auto s1 = c10::make_intrusive<Future>(IntType::get());
2589 auto s2 = s1->then(
2590 wrapPropagateTLSState([&profilerEnabledCb](Future& s1) {
2591 profilerEnabledCb(s1);
2592 return at::IValue(1);
2593 }),
2594 IntType::get());
2595 std::thread t([s1 = std::move(s1)]() { s1->markCompleted(); });
2596 t.join();
2597 s2->wait();
2598 torch::autograd::profiler::disableProfilerLegacy();
2599 }
2600}
2601
2602TEST(ProfilerDisableInCallbackTest, Basic) {
2603 // cb that verifies the profiler is enabled
2604 auto profilerEnabledCb = []() {
2605 ASSERT_TRUE(torch::autograd::profiler::profilerEnabled());
2606 };
2607 torch::autograd::profiler::enableProfilerLegacy(
2608 torch::autograd::profiler::ProfilerConfig(
2609 torch::autograd::profiler::ProfilerState::CPU, false, false));
2610 auto s1 = c10::make_intrusive<Future>(IntType::get());
2611 auto verifyProfilerCb =
2612 wrapPropagateTLSState([&profilerEnabledCb](Future& /* unused */) {
2613 // Ensure the profiler is still enabled in this thread.
2614 profilerEnabledCb();
2615 auto t1 = torch::ones({2, 2});
2616 auto t2 = torch::ones({2, 2});
2617 torch::add(t1, t2);
2618 // Don't cleanup TLSState, and just consolidate.
2619 auto opts =
2620 torch::autograd::profiler::ProfilerDisableOptions(false, true);
2621 auto thread_event_lists =
2622 // NOLINTNEXTLINE(performance-move-const-arg)
2623 torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
2624 // Ensure that the events from this thread are still profiled and we
2625 // obtain the expected in events in our consolidated list when calling
2626 // disableProfilerLegacy().
2627 bool found_ones = false;
2628 bool found_add = false;
2629 for (const auto& li : thread_event_lists) {
2630 for (const auto& evt : li) {
2631 if (strcmp(evt.name(), "aten::add") == 0) {
2632 found_add = true;
2633 } else if (strcmp(evt.name(), "aten::ones") == 0) {
2634 found_ones = true;
2635 }
2636 }
2637 if (found_add && found_ones) {
2638 break;
2639 }
2640 }
2641 ASSERT_TRUE(found_ones);
2642 ASSERT_TRUE(found_add);
2643 });
2644
2645 s1->addCallback(verifyProfilerCb);
2646 // Disable the profiler, but do not consolidate results in the main thread.
2647 auto opts = torch::autograd::profiler::ProfilerDisableOptions(true, false);
2648 // NOLINTNEXTLINE(performance-move-const-arg)
2649 torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
2650 std::thread t([s1 = std::move(s1)]() { s1->markCompleted(at::IValue(1)); });
2651 t.join();
2652
2653 // Similar to above test, but verifies correctness in the case where
2654 // continuation runs on the main thread.
2655 torch::autograd::profiler::enableProfilerLegacy(
2656 torch::autograd::profiler::ProfilerConfig(
2657 torch::autograd::profiler::ProfilerState::CPU, false, false));
2658 s1 = c10::make_intrusive<Future>(IntType::get());
2659 s1->addCallback(verifyProfilerCb);
2660 // Runs callback inline
2661 s1->markCompleted(at::IValue(1));
2662 opts = torch::autograd::profiler::ProfilerDisableOptions(true, false);
2663 // NOLINTNEXTLINE(performance-move-const-arg)
2664 torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
2665}
2666
2667TEST(RecordDebugHandles, Basic) {
2668 // Enable the profiler in this thread
2669 const std::set<torch::autograd::profiler::ActivityType> activities(
2670 {torch::autograd::profiler::ActivityType::CPU});
2671 torch::autograd::profiler::prepareProfiler(
2672 torch::autograd::profiler::ProfilerConfig(
2673 torch::autograd::profiler::ProfilerState::KINETO, false, false),
2674 activities);
2675 torch::autograd::profiler::enableProfiler(
2676 torch::autograd::profiler::ProfilerConfig(
2677 torch::autograd::profiler::ProfilerState::KINETO, false, false),
2678 activities);
2679 {
2680 RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS("my_function", 42, {});
2681 float x{5.9999}, y{2.1212};
2682 float z = x / y;
2683 (void)z;
2684 }
2685 {
2686 RECORD_USER_SCOPE_WITH_INPUTS("not_my_function", {});
2687 float x{5.9999}, y{2.1212};
2688 float z = x / y;
2689 (void)z;
2690 }
2691 auto profiler_results_ptr = torch::autograd::profiler::disableProfiler();
2692 const auto& kineto_events = profiler_results_ptr->events();
2693 size_t my_events{0};
2694 for (const auto& e : kineto_events) {
2695 if (e.name() == "my_function") {
2696 ASSERT_EQ(e.debugHandle(), 42);
2697 my_events++;
2698 } else if (e.name() == "not_my_function") {
2699 ASSERT_EQ(e.debugHandle(), -1);
2700 my_events++;
2701 }
2702 }
2703 ASSERT_EQ(my_events, 2);
2704}
2705
2706TEST(RecordDebugHandles, ScopedCallbacks) {
2707 // Enable the profiler in this thread
2708 torch::autograd::profiler::prepareProfiler(
2709 torch::autograd::profiler::ProfilerConfig(
2710 torch::autograd::profiler::ProfilerState::KINETO, false, false),
2711 {torch::autograd::profiler::ActivityType::CPU});
2712 torch::autograd::profiler::enableProfiler(
2713 torch::autograd::profiler::ProfilerConfig(
2714 torch::autograd::profiler::ProfilerState::KINETO, false, false),
2715 {torch::autograd::profiler::ActivityType::CPU});
2716
2717 {
2718 auto a = torch::rand({128, 128});
2719 auto b = torch::rand({128, 128});
2720 auto c = a + b;
2721 }
2722 auto profiler_results_ptr = torch::autograd::profiler::disableProfiler();
2723 ASSERT_TRUE(profiler_results_ptr->events().size() > 0);
2724
2725 // Enable the profiler in this thread
2726 torch::autograd::profiler::prepareProfiler(
2727 torch::autograd::profiler::ProfilerConfig(
2728 torch::autograd::profiler::ProfilerState::KINETO, false, false),
2729 {torch::autograd::profiler::ActivityType::CPU});
2730 torch::autograd::profiler::enableProfiler(
2731 torch::autograd::profiler::ProfilerConfig(
2732 torch::autograd::profiler::ProfilerState::KINETO, false, false),
2733 {torch::autograd::profiler::ActivityType::CPU},
2734 {at::RecordScope::LITE_INTERPRETER});
2735 {
2736 auto a = torch::rand({128, 128});
2737 auto b = torch::rand({128, 128});
2738 auto c = a + b;
2739 }
2740 profiler_results_ptr = torch::autograd::profiler::disableProfiler();
2741 ASSERT_TRUE(profiler_results_ptr->events().size() == 0);
2742
2743 torch::autograd::profiler::prepareProfiler(
2744 torch::autograd::profiler::ProfilerConfig(
2745 torch::autograd::profiler::ProfilerState::KINETO, false, false),
2746 {torch::autograd::profiler::ActivityType::CPU});
2747 torch::autograd::profiler::enableProfiler(
2748 torch::autograd::profiler::ProfilerConfig(
2749 torch::autograd::profiler::ProfilerState::KINETO, false, false),
2750 {torch::autograd::profiler::ActivityType::CPU},
2751 {at::RecordScope::LITE_INTERPRETER});
2752 {
2753 RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS("my_function", 42, {});
2754 auto a = torch::rand({128, 128});
2755 auto b = torch::rand({128, 128});
2756 auto c = a + b;
2757 }
2758 {
2759 RECORD_USER_SCOPE_WITH_INPUTS("not_my_function", {});
2760 auto a = torch::rand({128, 128});
2761 auto b = torch::rand({128, 128});
2762 auto c = a + b;
2763 }
2764 profiler_results_ptr = torch::autograd::profiler::disableProfiler();
2765 const auto& kineto_events = profiler_results_ptr->events();
2766 for (const auto& e : kineto_events) {
2767 if (e.name() == "my_function") {
2768 ASSERT_EQ(e.debugHandle(), 42);
2769 }
2770 }
2771 ASSERT_TRUE(profiler_results_ptr->events().size() == 1);
2772}
2773
2774TEST(IValueKWargsTest, Basic) {
2775 const auto text = R"(
2776 def foo(a : int, b : int, c : int = 4):
2777 return a + 2*b + 3*c
2778 )";
2779 auto cu = compile(text);
2780 auto result = cu->get_function("foo")({1}, {{"b", 3}});
2781 ASSERT_EQ(result.toInt(), 19);
2782}
2783
2784TEST(ComputeFlopsTest, Basic) {
2785 uint64_t flops = 0;
2786
2787 // Test unknown operator
2788 std::unordered_map<std::string, c10::IValue> extra_args;
2789 flops = torch::profiler::impl::computeFlops(
2790 std::string("aten::unknown"), extra_args);
2791 ASSERT_EQ(flops, 0);
2792
2793 // Test aten::conv2d
2794 extra_args.clear();
2795 std::vector<int64_t> input_size = {4, 5, 6, 7};
2796 std::vector<int64_t> weight_size = {3, 5, 2, 1};
2797 std::vector<int64_t> padding = {1, 0};
2798 std::vector<int64_t> stride = {1, 1};
2799 std::vector<int64_t> dilation = {0, 0};
2800 extra_args["input_size"] = at::IValue(at::IntArrayRef(input_size));
2801 extra_args["weight_size"] = at::IValue(at::IntArrayRef(weight_size));
2802 extra_args["groups"] = 1;
2803 extra_args["padding"] = at::IValue(at::IntArrayRef(padding));
2804 extra_args["stride"] = at::IValue(at::IntArrayRef(stride));
2805 extra_args["dilation"] = at::IValue(at::IntArrayRef(dilation));
2806 flops = torch::profiler::impl::computeFlops(
2807 std::string("aten::conv2d"), extra_args);
2808 ASSERT_EQ(flops, 13440);
2809
2810 // Test aten::conv2d fail
2811 input_size = {4, 5, 6, 7};
2812 weight_size = {4, 5, 6};
2813 extra_args["input_size"] = at::IValue(at::IntArrayRef(input_size));
2814 extra_args["weight_size"] = at::IValue(at::IntArrayRef(weight_size));
2815 flops = torch::profiler::impl::computeFlops(
2816 std::string("aten::conv2d"), extra_args);
2817 ASSERT_EQ(flops, 0);
2818
2819 // Test aten::conv2d fail 2
2820 weight_size = {3, 5, 2, 1};
2821 stride = {0, 0};
2822 extra_args["weight_size"] = at::IValue(at::IntArrayRef(input_size));
2823 extra_args["stride"] = at::IValue(at::IntArrayRef(stride));
2824 flops = torch::profiler::impl::computeFlops(
2825 std::string("aten::conv2d"), extra_args);
2826 ASSERT_EQ(flops, 0);
2827
2828 // Test aten::conv2d fail 3
2829 extra_args.clear();
2830 input_size = {4, 5, 6, 7};
2831 extra_args["input_size"] = at::IValue(at::IntArrayRef(input_size));
2832 flops = torch::profiler::impl::computeFlops(
2833 std::string("aten::conv2d"), extra_args);
2834 ASSERT_EQ(flops, 0);
2835
2836 // Test aten::mm
2837 extra_args.clear();
2838 std::vector<int64_t> mat1_sizes = {3, 4, 5, 6};
2839 std::vector<int64_t> mat2_sizes = {6, 5, 4, 3};
2840 extra_args["mat1_size"] = at::IValue(at::IntArrayRef(mat1_sizes));
2841 extra_args["mat2_size"] = at::IValue(at::IntArrayRef(mat2_sizes));
2842 flops =
2843 torch::profiler::impl::computeFlops(std::string("aten::mm"), extra_args);
2844 ASSERT_EQ(flops, 43200);
2845
2846 // Test aten::addmm
2847 flops = torch::profiler::impl::computeFlops(
2848 std::string("aten::addmm"), extra_args);
2849 ASSERT_EQ(flops, 43200);
2850
2851 // Test aten::bmm
2852 extra_args.clear();
2853 mat1_sizes = {7, 5, 6};
2854 mat2_sizes = {7, 6, 3};
2855 extra_args["mat1_size"] = at::IValue(at::IntArrayRef(mat1_sizes));
2856 extra_args["mat2_size"] = at::IValue(at::IntArrayRef(mat2_sizes));
2857 flops =
2858 torch::profiler::impl::computeFlops(std::string("aten::bmm"), extra_args);
2859 ASSERT_EQ(flops, 1260);
2860
2861 // Test aten::baddbmm
2862 flops = torch::profiler::impl::computeFlops(
2863 std::string("aten::baddbmm"), extra_args);
2864 ASSERT_EQ(flops, 1260);
2865
2866 // Test mm out of range
2867 extra_args.clear();
2868 flops =
2869 torch::profiler::impl::computeFlops(std::string("aten::mm"), extra_args);
2870 ASSERT_EQ(flops, 0);
2871
2872 // Test aten::add.Tensor
2873 extra_args.clear();
2874 std::vector<int64_t> mat_sizes = {3, 4, 5, 6};
2875 extra_args["mat_size"] = at::IValue(at::IntArrayRef(mat_sizes));
2876 flops =
2877 torch::profiler::impl::computeFlops(std::string("aten::add"), extra_args);
2878 ASSERT_EQ(flops, 360);
2879
2880 // Test aten::mul.Tensor
2881 extra_args.clear();
2882 mat_sizes = {3, 4, 5, 6};
2883 extra_args["mat_size"] = at::IValue(at::IntArrayRef(mat_sizes));
2884 flops =
2885 torch::profiler::impl::computeFlops(std::string("aten::mul"), extra_args);
2886 ASSERT_EQ(flops, 360);
2887}
2888
2889TEST(TestConstant, TensorGrad) {
2890 auto graph = std::make_shared<Graph>();
2891 IValue ten = torch::randn({3, 5}).requires_grad_(true);
2892 auto con = tryInsertConstant(*graph, ten);
2893 ASSERT_TRUE(con == c10::nullopt);
2894}
2895
2896TEST(TestMutation, Basic) {
2897 auto graph = std::make_shared<Graph>();
2898 std::unordered_map<std::string, Value*> vmap;
2899 parseIR(
2900 R"IR(
2901graph(%x.1 : Tensor):
2902 %2 : int = prim::Constant[value=1]()
2903 %9 : int = prim::Constant[value=4]()
2904 %x.3 : Tensor = aten::add(%x.1, %2, %2)
2905 %7 : Tensor = aten::add_(%x.3, %2, %2)
2906 %y.1 : Tensor = aten::add(%x.3, %9, %2)
2907 return (%y.1))IR",
2908 &*graph,
2909 vmap);
2910 RemoveTensorMutation(graph, [](Node*) { return false; });
2911 testing::FileCheck().check("aten::add_")->run(*graph);
2912 RemoveTensorMutation(graph, [](Node*) { return true; });
2913 testing::FileCheck().check_not("aten::add_")->run(*graph);
2914}
2915
2916TEST(TestInplaceToFunctionalActivation, Basic) {
2917 auto graph = std::make_shared<Graph>();
2918 std::unordered_map<std::string, Value*> vmap;
2919 parseIR(
2920 R"IR(
2921graph(%x.1 : Tensor):
2922 %2 : int = prim::Constant[value=1]()
2923 %x.3 : Tensor = aten::add(%x.1, %2, %2)
2924 %y : Tensor = aten::relu_(%x.3)
2925 return (%y))IR",
2926 &*graph,
2927 vmap);
2928 InplaceToFunctionalActivation(graph);
2929 testing::FileCheck().check("aten::relu")->run(*graph);
2930 testing::FileCheck().check_not("aten::relu_")->run(*graph);
2931}
2932
2933TEST(TestRegisterShapeOp, Basic) {
2934 auto graph = std::make_shared<Graph>();
2935 std::unordered_map<std::string, Value*> vmap;
2936 parseIR(
2937 R"IR(
2938graph():
2939 %2 : int = prim::Constant[value=5]()
2940 %3: int[] = prim::ListConstruct(%2, %2)
2941 return (%3))IR",
2942 &*graph,
2943 vmap);
2944
2945 auto g2 = std::make_shared<Graph>();
2946 parseIR(
2947 R"IR(
2948graph():
2949 %2 : Tensor = prim::MakeTestTensor()
2950 return (%2))IR",
2951 &*g2,
2952 vmap);
2953
2954 const FunctionSchema& schema = g2->nodes().begin()->schema();
2955 torch::jit::RegisterShapeComputeGraphForSchema(schema, graph);
2956 PropagateShapesOnGraph(g2);
2957 testing::FileCheck().check("5, 5")->run(*g2);
2958}
2959
2960TEST(TestFunctionalToInplaceActivation, Basic) {
2961 auto graph = std::make_shared<Graph>();
2962 std::unordered_map<std::string, Value*> vmap;
2963 parseIR(
2964 R"IR(
2965graph(%x.1 : Tensor):
2966 %2 : int = prim::Constant[value=1]()
2967 %x.3 : Tensor = aten::add(%x.1, %2, %2)
2968 %y : Tensor = aten::relu(%x.3)
2969 return (%y))IR",
2970 &*graph,
2971 vmap);
2972 FunctionalToInplaceActivation(graph);
2973 testing::FileCheck().check("aten::relu_")->run(*graph);
2974 testing::FileCheck().check_not("aten::relu(")->run(*graph);
2975}
2976
2977TEST(TestFunctionExecutor, SimpleExecutorTest) {
2978 auto graph = std::make_shared<Graph>();
2979 parseIR(
2980 R"IR(
2981graph(%x.1 : Tensor):
2982 %2 : int = prim::Constant[value=1]()
2983 %x.3 : Tensor = aten::add(%x.1, %2, %2)
2984 %y : Tensor = aten::relu(%x.3)
2985 return (%y))IR",
2986 &*graph);
2987 {
2988 auto func = torch::make_unique<GraphFunction>(
2989 "name", graph, [](GraphFunction&) {}, ExecutorExecutionMode::PROFILING);
2990 auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
2991 Stack stack = {a};
2992 func->run(stack);
2993 auto g = lastExecutedOptimizedGraph();
2994 testing::FileCheck()
2995 .check("prim::profile")
2996 ->check("aten::add")
2997 ->check("aten::relu")
2998 ->run(*g);
2999 }
3000 {
3001 auto func = torch::make_unique<GraphFunction>(
3002 "name", graph, [](GraphFunction&) {}, ExecutorExecutionMode::SIMPLE);
3003 auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
3004 Stack stack = {a};
3005 func->run(stack);
3006 auto g = func->getDebugState().graph;
3007 testing::FileCheck()
3008 .check_not("prim::profile")
3009 ->check("aten::add")
3010 ->check("aten::relu")
3011 ->run(*g);
3012 }
3013}
3014
3015TEST(TestFunctionExecutor, RunDecompositionTest) {
3016 static auto* func = torch::jit::GetDecompositionExecutor(
3017 "aten::var(Tensor self, bool unbiased=True) -> Tensor");
3018 for (bool unbiased : {true, false}) {
3019 auto input = at::rand({4, 4});
3020 Stack stack = {input, unbiased};
3021 func->run(stack);
3022 at::Tensor out = pop(stack).toTensor();
3023 ASSERT_TRUE(at::allclose(out, input.var(unbiased)));
3024 }
3025}
3026
3027TEST(TestShapeGraphLinting, Basic) {
3028 auto schemas = RegisteredShapeComputeSchemas();
3029 for (const auto& schema : schemas) {
3030 // arange does not acually support complex, leave as
3031 // union[int, float] for now
3032 if (schema->name() == "aten::arange") {
3033 continue;
3034 }
3035 auto g = shapeComputeGraphForSchema(*schema);
3036 TORCH_INTERNAL_ASSERT(g);
3037 LintShapeComputeGraph(schema, *g);
3038 }
3039}
3040
3041// TODO: move to test_kernel when global settings are explicit
3042// fusion parameters
3043class Composed : public ::testing::Test {
3044 public:
3045 void SetUp() override {
3046 torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = false;
3047 }
3048};
3049
3050TEST_F(Composed, ComposedOp) {
3051 struct WithCPUFuser {
3052 WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) {
3053 overrideCanFuseOnCPU(val);
3054 }
3055
3056 ~WithCPUFuser() {
3057 overrideCanFuseOnCPU(cpuFuserEnabled);
3058 }
3059
3060 bool cpuFuserEnabled;
3061 };
3062
3063#ifdef TORCH_ENABLE_LLVM
3064 const auto graph_string = R"IR(
3065 graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
3066 %1 : Float(5, 3, strides=[1, 5], device=cpu)):
3067 %2 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %1)
3068 %3 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %2)
3069 %4 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %3)
3070 return (%3, %4))IR";
3071 auto graph = std::make_shared<Graph>();
3072 parseIR(graph_string, &*graph);
3073
3074 // wrong input sizes so we hit the fallback path
3075 auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
3076 auto b = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat))
3077 .transpose(0, 1);
3078 auto ref1 = a * (a * b);
3079 auto ref2 = a * ref1;
3080 WithCPUFuser g(true);
3081 bool fusable_on_device = torch::jit::tensorexpr::getTEMustUseLLVMOnCPU();
3082 torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = false;
3083 FuseTensorExprs(
3084 graph,
3085 /*min_group_size*/ 2,
3086 /*add_composed_op*/ true,
3087 /*fuse_to_dynamic_shapes*/ true);
3088 Code code(graph, "");
3089 InterpreterState interpreter{code};
3090 std::vector<IValue> stack = {a, b};
3091 interpreter.run(stack);
3092 at::Tensor out2 = pop(stack).toTensor();
3093 at::Tensor out1 = pop(stack).toTensor();
3094 ASSERT_TRUE(at::allclose(ref1, out1));
3095 ASSERT_TRUE(at::allclose(ref2, out2));
3096
3097 auto inp_1 = at::ones({4, 4}, TensorOptions(kCPU).dtype(at::kFloat));
3098 auto inp_2 = at::ones({4, 4}, TensorOptions(kCPU).dtype(at::kFloat));
3099 stack = {inp_1, inp_2, a, b};
3100 InterpreterState interpreter2{code};
3101 interpreter2.run(stack);
3102 out2 = pop(stack).toTensor();
3103 out1 = pop(stack).toTensor();
3104 ASSERT_TRUE(at::allclose(ref1, out1));
3105 ASSERT_TRUE(at::allclose(ref2, out2));
3106 // inp_1 is on the bottom of the stack, and corresponds
3107 // to the second output. inp_2 is on the top corresponds to first output
3108 ASSERT_TRUE(at::allclose(inp_1, ref2));
3109 ASSERT_TRUE(at::allclose(inp_2, ref1));
3110 torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = fusable_on_device;
3111#endif
3112}
3113
3114TEST(ConstantPropagation, CustomClassesCanBePropagated) {
3115 const auto src = R"IR(
3116 graph():
3117 %none: NoneType = prim::Constant()
3118 %dim: int = prim::Constant[value=3]()
3119 %shape: int[] = prim::ListConstruct(%dim, %dim)
3120 %weight: Tensor = aten::ones(%shape, %none, %none, %none, %none)
3121 %scale: float = prim::Constant[value=1.]()
3122 %zero_point: int = prim::Constant[value=0]()
3123 %dtype: int = prim::Constant[value=12]()
3124 %weight_q: Tensor = aten::quantize_per_tensor(%weight, %scale, %zero_point, %dtype)
3125 %params: __torch__.torch.classes.quantized.LinearPackedParamsBase = quantized::linear_prepack(%weight_q, %none)
3126 return (%params)
3127 )IR";
3128 auto graph = std::make_shared<Graph>();
3129 std::unordered_map<std::string, Value*> vmap;
3130 parseIR(src, graph.get(), vmap);
3131
3132 ConstantPropagation(graph);
3133
3134 testing::FileCheck().check_not("quantized::linear_prepack")->run(*graph);
3135}
3136
3137} // namespace jit
3138} // namespace torch
3139