1 | #include <gtest/gtest.h> |
2 | |
3 | #include <stdexcept> |
4 | #include "test/cpp/tensorexpr/test_base.h" |
5 | |
6 | #include <torch/csrc/jit/tensorexpr/expr.h> |
7 | #include <torch/csrc/jit/tensorexpr/ir.h> |
8 | #include <torch/csrc/jit/tensorexpr/ir_verifier.h> |
9 | #include <torch/csrc/jit/tensorexpr/loopnest.h> |
10 | #include <torch/csrc/jit/tensorexpr/tensor.h> |
11 | #include <torch/csrc/jit/testing/file_check.h> |
12 | |
13 | #include <sstream> |
14 | namespace torch { |
15 | namespace jit { |
16 | |
17 | using namespace torch::jit::tensorexpr; |
18 | |
19 | TEST(IRVerifier, BitwiseOps) { |
20 | VarPtr X = alloc<Var>("x" , kInt); |
21 | VarPtr Y = alloc<Var>("y" , kFloat); |
22 | { |
23 | auto a = alloc<And>(X, Y); |
24 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) |
25 | EXPECT_ANY_THROW(verify(a)); |
26 | } |
27 | { |
28 | auto a = alloc<Or>(X, Y); |
29 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) |
30 | EXPECT_ANY_THROW(verify(a)); |
31 | } |
32 | { |
33 | auto a = alloc<Xor>(X, Y); |
34 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) |
35 | EXPECT_ANY_THROW(verify(a)); |
36 | } |
37 | { |
38 | auto a = alloc<Lshift>(X, Y); |
39 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) |
40 | EXPECT_ANY_THROW(verify(a)); |
41 | } |
42 | { |
43 | auto a = alloc<Rshift>(X, Y); |
44 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) |
45 | EXPECT_ANY_THROW(verify(a)); |
46 | } |
47 | } |
48 | |
49 | TEST(IRVerifier, CompareSelect) { |
50 | ExprPtr X = alloc<IntImm>(1); |
51 | ExprPtr Y = alloc<FloatImm>(3.14f); |
52 | { |
53 | auto a = alloc<CompareSelect>(X, X, X, Y, kEQ); |
54 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) |
55 | EXPECT_ANY_THROW(verify(a)); |
56 | } |
57 | { |
58 | auto a = alloc<CompareSelect>(X, Y, X, X, kEQ); |
59 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) |
60 | EXPECT_ANY_THROW(verify(a)); |
61 | } |
62 | } |
63 | |
64 | TEST(IRVerifier, Ramp) { |
65 | VarPtr I = alloc<Var>("i" , kInt); |
66 | VarPtr J = alloc<Var>("j" , kFloat); |
67 | { |
68 | auto a = alloc<Ramp>(I, J, 4); |
69 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) |
70 | EXPECT_ANY_THROW(verify(a)); |
71 | } |
72 | } |
73 | |
74 | TEST(IRVerifier, Load) { |
75 | VarPtr I = alloc<Var>("i" , kInt); |
76 | VarPtr J = alloc<Var>("j" , kLong); |
77 | VarPtr K = alloc<Var>("k" , kFloat); |
78 | BufPtr B = alloc<Buf>( |
79 | "b" , |
80 | std::vector<ExprPtr>({alloc<IntImm>(10), alloc<IntImm>(20)}), |
81 | kFloat); |
82 | { |
83 | // Indices with different int dtypes (kInt, kLong) are ok |
84 | auto a = alloc<Load>(B, std::vector<ExprPtr>({I, J})); |
85 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) |
86 | EXPECT_NO_THROW(verify(a)); |
87 | } |
88 | { |
89 | // Float index |
90 | auto a = alloc<Load>(B, std::vector<ExprPtr>({K, K})); |
91 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) |
92 | EXPECT_ANY_THROW(verify(a)); |
93 | } |
94 | { |
95 | // Multilanes are only allowed in flattened indices |
96 | auto multilane_index = alloc<Ramp>(I, alloc<IntImm>(1), 4); |
97 | auto a = alloc<Load>(B, std::vector<ExprPtr>({I, multilane_index})); |
98 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) |
99 | EXPECT_ANY_THROW(verify(a)); |
100 | } |
101 | } |
102 | |
103 | TEST(IRVerifier, IfThenElse) { |
104 | VarPtr I = alloc<Var>("i" , kInt); |
105 | VarPtr J = alloc<Var>("j" , kLong); |
106 | VarPtr K = alloc<Var>("k" , kFloat); |
107 | { |
108 | // Condition must be integral |
109 | auto a = alloc<IfThenElse>(K, I, I); |
110 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) |
111 | EXPECT_ANY_THROW(verify(a)); |
112 | } |
113 | { |
114 | // Dtypes of true and false exprs must match |
115 | auto a = alloc<IfThenElse>(I, I, J); |
116 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) |
117 | EXPECT_ANY_THROW(verify(a)); |
118 | } |
119 | { |
120 | // Can't have multiple lanes in condition expr |
121 | auto a = alloc<IfThenElse>(alloc<Broadcast>(I, 4), I, I); |
122 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) |
123 | EXPECT_ANY_THROW(verify(a)); |
124 | } |
125 | } |
126 | |
127 | TEST(IRVerifier, For) { |
128 | VarPtr I = alloc<Var>("i" , kInt); |
129 | VarPtr J = alloc<Var>("j" , kInt); |
130 | StmtPtr body = alloc<Block>(std::vector<StmtPtr>({})); |
131 | { |
132 | // Can't have nullptr as a Var |
133 | auto a = alloc<For>(nullptr, I, J, body); |
134 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
135 | EXPECT_ANY_THROW(verify(a)); |
136 | } |
137 | } |
138 | |
139 | TEST(IRVerifier, Block) { |
140 | VarPtr I = alloc<Var>("i" , kInt); |
141 | BufPtr B = alloc<Buf>("B" , std::vector<ExprPtr>({alloc<IntImm>(10)}), kInt); |
142 | { |
143 | StmtPtr store = alloc<Store>(B, std::vector<ExprPtr>({I}), I); |
144 | // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
145 | StmtPtr block1 = alloc<Block>(std::vector<StmtPtr>({store})); |
146 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
147 | StmtPtr block2 = alloc<Block>(std::vector<StmtPtr>({store})); |
148 | // Stmt can't have multiple parrents, thus inserting it into several blocks |
149 | // is illegal |
150 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) |
151 | EXPECT_ANY_THROW(verify(block2)); |
152 | } |
153 | } |
154 | |
155 | TEST(IRVerifier, Store) { |
156 | VarPtr I = alloc<Var>("i" , kInt); |
157 | VarPtr J = alloc<Var>("j" , kLong); |
158 | VarPtr K = alloc<Var>("k" , kFloat); |
159 | BufPtr B = alloc<Buf>( |
160 | "b" , |
161 | std::vector<ExprPtr>({alloc<IntImm>(10), alloc<IntImm>(20)}), |
162 | kFloat); |
163 | { |
164 | // Indices with different int dtypes (kInt, kLong) are ok |
165 | auto a = alloc<Store>(B, std::vector<ExprPtr>({I, J}), K); |
166 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) |
167 | EXPECT_NO_THROW(verify(a)); |
168 | } |
169 | { |
170 | // Float index |
171 | auto a = alloc<Store>(B, std::vector<ExprPtr>({K, K}), K); |
172 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) |
173 | EXPECT_ANY_THROW(verify(a)); |
174 | } |
175 | { |
176 | // Multilanes are only allowed in flattened indices |
177 | auto multilane_index = alloc<Ramp>(I, alloc<IntImm>(1), 4); |
178 | auto a = alloc<Store>(B, std::vector<ExprPtr>({I, multilane_index}), K); |
179 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) |
180 | EXPECT_ANY_THROW(verify(a)); |
181 | } |
182 | { |
183 | // Value and buf dtypes mismatch |
184 | auto a = alloc<Store>(B, std::vector<ExprPtr>({I}), I); |
185 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) |
186 | EXPECT_ANY_THROW(verify(a)); |
187 | } |
188 | } |
189 | |
190 | } // namespace jit |
191 | } // namespace torch |
192 | |