1#include <gtest/gtest.h>
2
3#include <torch/csrc/autograd/generated/variable_factories.h>
4#include <torch/csrc/jit/frontend/ir_emitter.h>
5#include <torch/csrc/jit/ir/ir.h>
6#include <torch/csrc/jit/ir/irparser.h>
7#include <torch/csrc/jit/jit_log.h>
8#include <torch/csrc/jit/passes/pass_manager.h>
9#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
10
11// Test that tensor type specializations are availabie in
12// the custom passes
13
14namespace torch {
15namespace jit {
16
17namespace {
18
19bool hasTensorTypeSpecializations(torch::jit::Block* block) {
20 for (Value* v : block->inputs()) {
21 if (hasTensorTypeSpecialization(v))
22 return true;
23 }
24 for (Node* n : block->nodes()) {
25 for (torch::jit::Block* b : n->blocks()) {
26 if (hasTensorTypeSpecializations(b))
27 return true;
28 }
29 for (Value* v : n->outputs()) {
30 if (hasTensorTypeSpecialization(v))
31 return true;
32 }
33 }
34 return false;
35}
36
37static bool hasSpecializations = false;
38void detectTTSpecializationPass(std::shared_ptr<Graph>& graph) {
39 GRAPH_DUMP("In detectTTSpecialization Custom Post Pass: ", graph);
40 hasSpecializations = hasTensorTypeSpecializations(graph->block());
41}
42
43} // namespace
44
45TEST(SpecializationsInCustomPasses, Basic) {
46 RegisterPass p(detectTTSpecializationPass);
47 hasSpecializations = false;
48 std::shared_ptr<Graph> graph = std::make_shared<Graph>();
49 parseIR(
50 R"IR(
51graph(%a.1 : Tensor,
52 %b.1 : Tensor):
53 %c.1 : Tensor = aten::mul(%a.1, %b.1) # misc/test_specializations.py:5:8
54 %d.1 : Tensor = aten::mul(%c.1, %b.1) # misc/test_specializations.py:6:8
55 return (%d.1)
56 )IR",
57 &*graph);
58
59 IValue ival = IValue(torch::randn({22}, at::kCPU));
60 std::vector<IValue> stack = {ival, ival};
61 auto run = [&](std::shared_ptr<Graph>& graph, std::vector<IValue> stack) {
62 GraphExecutor executor(graph, "");
63 executor.run(stack);
64 return stack;
65 };
66 run(graph, stack);
67
68 // Priofiling mode will not be run with simple executor
69 if (!getExecutorMode()) {
70 EXPECT_TRUE(hasSpecializations);
71 }
72}
73
74} // namespace jit
75} // namespace torch
76