1 | #include <gtest/gtest.h> |
2 | |
3 | #include "test/cpp/tensorexpr/test_base.h" |
4 | |
5 | #include <c10/util/irange.h> |
6 | #include <torch/csrc/jit/tensorexpr/cpp_codegen.h> |
7 | #include <torch/csrc/jit/tensorexpr/fwd_decls.h> |
8 | #include <torch/csrc/jit/tensorexpr/stmt.h> |
9 | #include <torch/csrc/jit/tensorexpr/tensor.h> |
10 | #include <torch/csrc/jit/testing/file_check.h> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | |
15 | using namespace torch::jit::tensorexpr; |
16 | |
17 | #define STR_CHECK(node, expected) \ |
18 | std::stringstream ss; \ |
19 | CppPrinter printer(&ss); \ |
20 | printer.visit(node); \ |
21 | ASSERT_EQ(ss.str(), expected) |
22 | |
23 | #define FILE_CHECK(node, pattern) \ |
24 | std::stringstream ss; \ |
25 | CppPrinter printer(&ss); \ |
26 | printer.visit(node); \ |
27 | torch::jit::testing::FileCheck().run(pattern, ss.str()) |
28 | |
29 | TEST(CppPrinter, IntImm) { |
30 | auto i = alloc<IntImm>(10); |
31 | STR_CHECK(i, "10" ); |
32 | } |
33 | |
34 | TEST(CppPrinter, FloatImm) { |
35 | auto f = alloc<FloatImm>(10); |
36 | STR_CHECK(f, "10.f" ); |
37 | } |
38 | |
39 | TEST(CppPrinter, FloatImm1) { |
40 | auto f = alloc<FloatImm>(10); |
41 | STR_CHECK(f, "10.f" ); |
42 | } |
43 | |
44 | TEST(CppPrinter, DoubleImm) { |
45 | auto d = alloc<DoubleImm>(10); |
46 | STR_CHECK(d, "10.0" ); |
47 | } |
48 | |
49 | TEST(CppPrinter, DoubleImm1) { |
50 | auto d = alloc<DoubleImm>(10.1); |
51 | STR_CHECK(d, "10.1" ); |
52 | } |
53 | |
54 | TEST(CppPrinter, HalfImm) { |
55 | auto h = alloc<HalfImm>(10); |
56 | STR_CHECK(h, "10" ); |
57 | } |
58 | |
59 | TEST(CppPrinter, Add) { |
60 | auto add = alloc<Add>(alloc<IntImm>(1), alloc<IntImm>(2)); |
61 | STR_CHECK(add, "1 + 2" ); |
62 | } |
63 | |
64 | TEST(CppPrinter, AddExpr1) { |
65 | auto add = alloc<Add>( |
66 | alloc<Add>(alloc<IntImm>(0), alloc<IntImm>(1)), |
67 | alloc<Sub>(alloc<IntImm>(2), alloc<IntImm>(3))); |
68 | STR_CHECK(add, "(0 + 1) + (2 - 3)" ); |
69 | } |
70 | |
71 | TEST(CppPrinter, AddExpr2) { |
72 | auto add = alloc<Add>( |
73 | alloc<Mul>(alloc<IntImm>(0), alloc<IntImm>(1)), |
74 | alloc<Sub>(alloc<IntImm>(2), alloc<IntImm>(3))); |
75 | STR_CHECK(add, "0 * 1 + (2 - 3)" ); |
76 | } |
77 | |
78 | TEST(CppPrinter, AddExpr3) { |
79 | auto add = alloc<Add>( |
80 | alloc<Add>(alloc<IntImm>(0), alloc<IntImm>(1)), |
81 | alloc<Div>(alloc<IntImm>(2), alloc<IntImm>(3))); |
82 | STR_CHECK(add, "(0 + 1) + 2 / 3" ); |
83 | } |
84 | |
85 | TEST(CppPrinter, Mod) { |
86 | auto mod = alloc<Mod>(alloc<IntImm>(1), alloc<IntImm>(2)); |
87 | STR_CHECK(mod, "1 % 2" ); |
88 | } |
89 | |
90 | TEST(CppPrinter, ModFloat) { |
91 | auto mod = alloc<Mod>(alloc<FloatImm>(1), alloc<FloatImm>(2)); |
92 | STR_CHECK(mod, "std::fmod(1.f, 2.f)" ); |
93 | } |
94 | |
95 | TEST(CppPrinter, Max) { |
96 | auto max = alloc<Max>(alloc<IntImm>(1), alloc<IntImm>(2), false); |
97 | STR_CHECK(max, "std::max(1, 2)" ); |
98 | } |
99 | |
100 | TEST(CppPrinter, MaxFloat) { |
101 | auto max = alloc<Max>(alloc<FloatImm>(1), alloc<FloatImm>(2), false); |
102 | STR_CHECK(max, "std::max(1.f, 2.f)" ); |
103 | } |
104 | |
105 | TEST(CppPrinter, MaxHalf) { |
106 | auto max = alloc<Max>(alloc<HalfImm>(1), alloc<HalfImm>(2), false); |
107 | STR_CHECK(max, "(1 < 2) ? 2 : 1" ); |
108 | } |
109 | |
110 | TEST(CppPrinter, And) { |
111 | auto v = alloc<And>(alloc<IntImm>(1), alloc<IntImm>(2)); |
112 | STR_CHECK(v, "1 & 2" ); |
113 | } |
114 | |
115 | TEST(CppPrinter, CompareSelect) { |
116 | auto cs = alloc<CompareSelect>( |
117 | alloc<IntImm>(1), |
118 | alloc<IntImm>(2), |
119 | alloc<FloatImm>(1), |
120 | alloc<FloatImm>(2), |
121 | CompareSelectOperation::kLE); |
122 | STR_CHECK(cs, "((1 <= 2) ? 1.f : 2.f)" ); |
123 | } |
124 | |
125 | TEST(CppPrinter, IfThenElse) { |
126 | auto cond = alloc<Add>(alloc<IntImm>(1), alloc<IntImm>(2)); |
127 | auto true_value = alloc<Sub>(alloc<IntImm>(0), alloc<IntImm>(1)); |
128 | auto false_value = alloc<Mul>(alloc<IntImm>(2), alloc<IntImm>(3)); |
129 | auto v = alloc<IfThenElse>(cond, true_value, false_value); |
130 | STR_CHECK(v, "((1 + 2) ? 0 - 1 : 2 * 3)" ); |
131 | } |
132 | |
133 | TEST(CppPrinter, AllocateFree) { |
134 | BufHandle buf("x" , {2, 3}, kInt); |
135 | AllocatePtr alloc = Allocate::make(buf); |
136 | FreePtr free = Free::make(buf); |
137 | BlockPtr block = Block::make({alloc, free}); |
138 | |
139 | const std::string pattern = R"( |
140 | # CHECK: { |
141 | # CHECK: int* x = static_cast<int*>(malloc(24)); |
142 | # CHECK: free(x); |
143 | # CHECK: } |
144 | )" ; |
145 | FILE_CHECK(block, pattern); |
146 | } |
147 | |
148 | TEST(CppPrinter, LoadStore) { |
149 | BufHandle a("A" , {2, 3}, kInt); |
150 | BufHandle b("B" , {3, 4}, kInt); |
151 | auto store = b.store({2, 2}, a.load(1, 1)); |
152 | STR_CHECK( |
153 | store, "B[(0 + 2 * (1 * 4)) + 2 * 1] = A[(0 + 1 * (1 * 3)) + 1 * 1];\n" ); |
154 | } |
155 | |
156 | TEST(CppPrinter, Var) { |
157 | auto var = alloc<Var>("x" , kInt); |
158 | STR_CHECK(var, "x" ); |
159 | } |
160 | |
161 | TEST(CppPrinter, Cast) { |
162 | auto cast = alloc<Cast>(kFloat, alloc<IntImm>(1)); |
163 | STR_CHECK(cast, "static_cast<float>(1)" ); |
164 | } |
165 | |
166 | TEST(CppPrinter, BitCast) { |
167 | auto cast = alloc<BitCast>(kInt, alloc<FloatImm>(20)); |
168 | STR_CHECK(cast, "std::bitcast<float, int>(20.f)" ); |
169 | } |
170 | |
171 | TEST(CppPrinter, Let) { |
172 | auto var = alloc<Var>("x" , kFloat); |
173 | auto val = alloc<FloatImm>(2); |
174 | auto let = alloc<Let>(var, val); |
175 | STR_CHECK(let, "float x = 2.f;\n" ); |
176 | } |
177 | |
178 | TEST(CppPrinter, For) { |
179 | constexpr int N = 1024; |
180 | BufHandle a("A" , {N}, kInt); |
181 | BufHandle b("B" , {N}, kInt); |
182 | BufHandle c("C" , {N}, kInt); |
183 | VarHandle i("i" , kInt); |
184 | auto f = For::make(i, 0, N, c.store({i}, Add::make(a.load(i), b.load(i)))); |
185 | const std::string pattern = R"( |
186 | # CHECK: for (int i = 0; i < 1024; i++) { |
187 | # CHECK: C[i] = (A[i]) + (B[i]); |
188 | # CHECK: } |
189 | )" ; |
190 | FILE_CHECK(f, pattern); |
191 | } |
192 | |
193 | TEST(CppPrinter, Cond) { |
194 | BufHandle x("X" , {1}, kInt); |
195 | auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); |
196 | auto cond = |
197 | Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); |
198 | const std::string pattern = R"( |
199 | # CHECK: if (((X[0] < 10) ? 1 : 0)) { |
200 | # CHECK: X[0] = (X[0]) + 1; |
201 | # CHECK: } else { |
202 | # CHECK: X[0] = (X[0]) - 1; |
203 | # CHECK: } |
204 | )" ; |
205 | FILE_CHECK(cond, pattern); |
206 | } |
207 | |
208 | TEST(CppPrinter, Intrinsics) { |
209 | const std::unordered_set<IntrinsicsOp, std::hash<int>> unsupported_ops{ |
210 | kRand, kSigmoid}; |
211 | for (const auto i : c10::irange(static_cast<uint32_t>(kMaxIntrinsicsOp))) { |
212 | IntrinsicsOp op = static_cast<IntrinsicsOp>(i); |
213 | if (unsupported_ops.count(op)) { |
214 | continue; |
215 | } |
216 | |
217 | if (Intrinsics::OpArgCount(op) == 1) { |
218 | auto v = alloc<Intrinsics>(op, alloc<FloatImm>(2.0f)); |
219 | STR_CHECK(v, "std::" + v->func_name() + "(2.f)" ); |
220 | } else { |
221 | auto v = |
222 | alloc<Intrinsics>(op, alloc<FloatImm>(1.0f), alloc<FloatImm>(2.0f)); |
223 | STR_CHECK(v, "std::" + v->func_name() + "(1.f, 2.f)" ); |
224 | } |
225 | } |
226 | } |
227 | |
228 | TEST(CppPrinter, ExternalCall) { |
229 | std::vector<ExprPtr> dims{alloc<IntImm>(2), alloc<IntImm>(2)}; |
230 | auto output = alloc<Buf>("out" , dims, kFloat); |
231 | auto buf_arg1 = alloc<Buf>("a" , dims, kFloat); |
232 | auto buf_arg2 = alloc<Buf>("b" , dims, kFloat); |
233 | auto scalar_arg = alloc<Add>(alloc<IntImm>(1), alloc<IntImm>(2)); |
234 | std::vector<BufPtr> buf_args{buf_arg1, buf_arg2}; |
235 | std::vector<ExprPtr> scalar_args{scalar_arg}; |
236 | auto call = |
237 | alloc<ExternalCall>(output, "nnc_aten_matmul" , buf_args, scalar_args); |
238 | const std::string pattern = R"( |
239 | # CHECK: { |
240 | # CHECK: void* buf_ptrs[]{out, a, b}; |
241 | # CHECK: int64_t buf_ranks[]{2, 2, 2}; |
242 | # CHECK: int64_t buf_dims[]{2, 2, 2, 2, 2, 2}; |
243 | # CHECK: int8_t buf_dtypes[]{6, 6, 6}; |
244 | # CHECK: int64_t extra_args[]{1 + 2}; |
245 | # CHECK: nnc_aten_matmul( |
246 | # CHECK: 3, |
247 | # CHECK: buf_ptrs, |
248 | # CHECK: buf_ranks, |
249 | # CHECK: buf_dims, |
250 | # CHECK: buf_dtypes, |
251 | # CHECK: 1, |
252 | # CHECK: extra_args); |
253 | # CHECK: } |
254 | )" ; |
255 | FILE_CHECK(call, pattern); |
256 | } |
257 | |
258 | } // namespace jit |
259 | } // namespace torch |
260 | |