1 | #include <gtest/gtest.h> |
2 | |
3 | #include <torch/csrc/autograd/generated/variable_factories.h> |
4 | #include <torch/csrc/jit/frontend/ir_emitter.h> |
5 | #include <torch/csrc/jit/ir/ir.h> |
6 | #include <torch/csrc/jit/ir/irparser.h> |
7 | #include <torch/csrc/jit/jit_log.h> |
8 | #include <torch/csrc/jit/passes/pass_manager.h> |
9 | #include <torch/csrc/jit/passes/tensorexpr_fuser.h> |
10 | |
11 | // Test that tensor type specializations are availabie in |
12 | // the custom passes |
13 | |
14 | namespace torch { |
15 | namespace jit { |
16 | |
17 | namespace { |
18 | |
19 | bool hasTensorTypeSpecializations(torch::jit::Block* block) { |
20 | for (Value* v : block->inputs()) { |
21 | if (hasTensorTypeSpecialization(v)) |
22 | return true; |
23 | } |
24 | for (Node* n : block->nodes()) { |
25 | for (torch::jit::Block* b : n->blocks()) { |
26 | if (hasTensorTypeSpecializations(b)) |
27 | return true; |
28 | } |
29 | for (Value* v : n->outputs()) { |
30 | if (hasTensorTypeSpecialization(v)) |
31 | return true; |
32 | } |
33 | } |
34 | return false; |
35 | } |
36 | |
37 | static bool hasSpecializations = false; |
38 | void detectTTSpecializationPass(std::shared_ptr<Graph>& graph) { |
39 | GRAPH_DUMP("In detectTTSpecialization Custom Post Pass: " , graph); |
40 | hasSpecializations = hasTensorTypeSpecializations(graph->block()); |
41 | } |
42 | |
43 | } // namespace |
44 | |
45 | TEST(SpecializationsInCustomPasses, Basic) { |
46 | RegisterPass p(detectTTSpecializationPass); |
47 | hasSpecializations = false; |
48 | std::shared_ptr<Graph> graph = std::make_shared<Graph>(); |
49 | parseIR( |
50 | R"IR( |
51 | graph(%a.1 : Tensor, |
52 | %b.1 : Tensor): |
53 | %c.1 : Tensor = aten::mul(%a.1, %b.1) # misc/test_specializations.py:5:8 |
54 | %d.1 : Tensor = aten::mul(%c.1, %b.1) # misc/test_specializations.py:6:8 |
55 | return (%d.1) |
56 | )IR" , |
57 | &*graph); |
58 | |
59 | IValue ival = IValue(torch::randn({22}, at::kCPU)); |
60 | std::vector<IValue> stack = {ival, ival}; |
61 | auto run = [&](std::shared_ptr<Graph>& graph, std::vector<IValue> stack) { |
62 | GraphExecutor executor(graph, "" ); |
63 | executor.run(stack); |
64 | return stack; |
65 | }; |
66 | run(graph, stack); |
67 | |
68 | // Priofiling mode will not be run with simple executor |
69 | if (!getExecutorMode()) { |
70 | EXPECT_TRUE(hasSpecializations); |
71 | } |
72 | } |
73 | |
74 | } // namespace jit |
75 | } // namespace torch |
76 | |