1#include <gtest/gtest.h>
2
3#include "test/cpp/tensorexpr/test_base.h"
4
5#include <c10/util/irange.h>
6#include <torch/csrc/jit/tensorexpr/cpp_codegen.h>
7#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
8#include <torch/csrc/jit/tensorexpr/stmt.h>
9#include <torch/csrc/jit/tensorexpr/tensor.h>
10#include <torch/csrc/jit/testing/file_check.h>
11
12namespace torch {
13namespace jit {
14
15using namespace torch::jit::tensorexpr;
16
17#define STR_CHECK(node, expected) \
18 std::stringstream ss; \
19 CppPrinter printer(&ss); \
20 printer.visit(node); \
21 ASSERT_EQ(ss.str(), expected)
22
23#define FILE_CHECK(node, pattern) \
24 std::stringstream ss; \
25 CppPrinter printer(&ss); \
26 printer.visit(node); \
27 torch::jit::testing::FileCheck().run(pattern, ss.str())
28
29TEST(CppPrinter, IntImm) {
30 auto i = alloc<IntImm>(10);
31 STR_CHECK(i, "10");
32}
33
34TEST(CppPrinter, FloatImm) {
35 auto f = alloc<FloatImm>(10);
36 STR_CHECK(f, "10.f");
37}
38
39TEST(CppPrinter, FloatImm1) {
40 auto f = alloc<FloatImm>(10);
41 STR_CHECK(f, "10.f");
42}
43
44TEST(CppPrinter, DoubleImm) {
45 auto d = alloc<DoubleImm>(10);
46 STR_CHECK(d, "10.0");
47}
48
49TEST(CppPrinter, DoubleImm1) {
50 auto d = alloc<DoubleImm>(10.1);
51 STR_CHECK(d, "10.1");
52}
53
54TEST(CppPrinter, HalfImm) {
55 auto h = alloc<HalfImm>(10);
56 STR_CHECK(h, "10");
57}
58
59TEST(CppPrinter, Add) {
60 auto add = alloc<Add>(alloc<IntImm>(1), alloc<IntImm>(2));
61 STR_CHECK(add, "1 + 2");
62}
63
64TEST(CppPrinter, AddExpr1) {
65 auto add = alloc<Add>(
66 alloc<Add>(alloc<IntImm>(0), alloc<IntImm>(1)),
67 alloc<Sub>(alloc<IntImm>(2), alloc<IntImm>(3)));
68 STR_CHECK(add, "(0 + 1) + (2 - 3)");
69}
70
71TEST(CppPrinter, AddExpr2) {
72 auto add = alloc<Add>(
73 alloc<Mul>(alloc<IntImm>(0), alloc<IntImm>(1)),
74 alloc<Sub>(alloc<IntImm>(2), alloc<IntImm>(3)));
75 STR_CHECK(add, "0 * 1 + (2 - 3)");
76}
77
78TEST(CppPrinter, AddExpr3) {
79 auto add = alloc<Add>(
80 alloc<Add>(alloc<IntImm>(0), alloc<IntImm>(1)),
81 alloc<Div>(alloc<IntImm>(2), alloc<IntImm>(3)));
82 STR_CHECK(add, "(0 + 1) + 2 / 3");
83}
84
85TEST(CppPrinter, Mod) {
86 auto mod = alloc<Mod>(alloc<IntImm>(1), alloc<IntImm>(2));
87 STR_CHECK(mod, "1 % 2");
88}
89
90TEST(CppPrinter, ModFloat) {
91 auto mod = alloc<Mod>(alloc<FloatImm>(1), alloc<FloatImm>(2));
92 STR_CHECK(mod, "std::fmod(1.f, 2.f)");
93}
94
95TEST(CppPrinter, Max) {
96 auto max = alloc<Max>(alloc<IntImm>(1), alloc<IntImm>(2), false);
97 STR_CHECK(max, "std::max(1, 2)");
98}
99
100TEST(CppPrinter, MaxFloat) {
101 auto max = alloc<Max>(alloc<FloatImm>(1), alloc<FloatImm>(2), false);
102 STR_CHECK(max, "std::max(1.f, 2.f)");
103}
104
105TEST(CppPrinter, MaxHalf) {
106 auto max = alloc<Max>(alloc<HalfImm>(1), alloc<HalfImm>(2), false);
107 STR_CHECK(max, "(1 < 2) ? 2 : 1");
108}
109
110TEST(CppPrinter, And) {
111 auto v = alloc<And>(alloc<IntImm>(1), alloc<IntImm>(2));
112 STR_CHECK(v, "1 & 2");
113}
114
115TEST(CppPrinter, CompareSelect) {
116 auto cs = alloc<CompareSelect>(
117 alloc<IntImm>(1),
118 alloc<IntImm>(2),
119 alloc<FloatImm>(1),
120 alloc<FloatImm>(2),
121 CompareSelectOperation::kLE);
122 STR_CHECK(cs, "((1 <= 2) ? 1.f : 2.f)");
123}
124
125TEST(CppPrinter, IfThenElse) {
126 auto cond = alloc<Add>(alloc<IntImm>(1), alloc<IntImm>(2));
127 auto true_value = alloc<Sub>(alloc<IntImm>(0), alloc<IntImm>(1));
128 auto false_value = alloc<Mul>(alloc<IntImm>(2), alloc<IntImm>(3));
129 auto v = alloc<IfThenElse>(cond, true_value, false_value);
130 STR_CHECK(v, "((1 + 2) ? 0 - 1 : 2 * 3)");
131}
132
133TEST(CppPrinter, AllocateFree) {
134 BufHandle buf("x", {2, 3}, kInt);
135 AllocatePtr alloc = Allocate::make(buf);
136 FreePtr free = Free::make(buf);
137 BlockPtr block = Block::make({alloc, free});
138
139 const std::string pattern = R"(
140 # CHECK: {
141 # CHECK: int* x = static_cast<int*>(malloc(24));
142 # CHECK: free(x);
143 # CHECK: }
144 )";
145 FILE_CHECK(block, pattern);
146}
147
148TEST(CppPrinter, LoadStore) {
149 BufHandle a("A", {2, 3}, kInt);
150 BufHandle b("B", {3, 4}, kInt);
151 auto store = b.store({2, 2}, a.load(1, 1));
152 STR_CHECK(
153 store, "B[(0 + 2 * (1 * 4)) + 2 * 1] = A[(0 + 1 * (1 * 3)) + 1 * 1];\n");
154}
155
156TEST(CppPrinter, Var) {
157 auto var = alloc<Var>("x", kInt);
158 STR_CHECK(var, "x");
159}
160
161TEST(CppPrinter, Cast) {
162 auto cast = alloc<Cast>(kFloat, alloc<IntImm>(1));
163 STR_CHECK(cast, "static_cast<float>(1)");
164}
165
166TEST(CppPrinter, BitCast) {
167 auto cast = alloc<BitCast>(kInt, alloc<FloatImm>(20));
168 STR_CHECK(cast, "std::bitcast<float, int>(20.f)");
169}
170
171TEST(CppPrinter, Let) {
172 auto var = alloc<Var>("x", kFloat);
173 auto val = alloc<FloatImm>(2);
174 auto let = alloc<Let>(var, val);
175 STR_CHECK(let, "float x = 2.f;\n");
176}
177
178TEST(CppPrinter, For) {
179 constexpr int N = 1024;
180 BufHandle a("A", {N}, kInt);
181 BufHandle b("B", {N}, kInt);
182 BufHandle c("C", {N}, kInt);
183 VarHandle i("i", kInt);
184 auto f = For::make(i, 0, N, c.store({i}, Add::make(a.load(i), b.load(i))));
185 const std::string pattern = R"(
186 # CHECK: for (int i = 0; i < 1024; i++) {
187 # CHECK: C[i] = (A[i]) + (B[i]);
188 # CHECK: }
189 )";
190 FILE_CHECK(f, pattern);
191}
192
193TEST(CppPrinter, Cond) {
194 BufHandle x("X", {1}, kInt);
195 auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT);
196 auto cond =
197 Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1));
198 const std::string pattern = R"(
199 # CHECK: if (((X[0] < 10) ? 1 : 0)) {
200 # CHECK: X[0] = (X[0]) + 1;
201 # CHECK: } else {
202 # CHECK: X[0] = (X[0]) - 1;
203 # CHECK: }
204 )";
205 FILE_CHECK(cond, pattern);
206}
207
208TEST(CppPrinter, Intrinsics) {
209 const std::unordered_set<IntrinsicsOp, std::hash<int>> unsupported_ops{
210 kRand, kSigmoid};
211 for (const auto i : c10::irange(static_cast<uint32_t>(kMaxIntrinsicsOp))) {
212 IntrinsicsOp op = static_cast<IntrinsicsOp>(i);
213 if (unsupported_ops.count(op)) {
214 continue;
215 }
216
217 if (Intrinsics::OpArgCount(op) == 1) {
218 auto v = alloc<Intrinsics>(op, alloc<FloatImm>(2.0f));
219 STR_CHECK(v, "std::" + v->func_name() + "(2.f)");
220 } else {
221 auto v =
222 alloc<Intrinsics>(op, alloc<FloatImm>(1.0f), alloc<FloatImm>(2.0f));
223 STR_CHECK(v, "std::" + v->func_name() + "(1.f, 2.f)");
224 }
225 }
226}
227
228TEST(CppPrinter, ExternalCall) {
229 std::vector<ExprPtr> dims{alloc<IntImm>(2), alloc<IntImm>(2)};
230 auto output = alloc<Buf>("out", dims, kFloat);
231 auto buf_arg1 = alloc<Buf>("a", dims, kFloat);
232 auto buf_arg2 = alloc<Buf>("b", dims, kFloat);
233 auto scalar_arg = alloc<Add>(alloc<IntImm>(1), alloc<IntImm>(2));
234 std::vector<BufPtr> buf_args{buf_arg1, buf_arg2};
235 std::vector<ExprPtr> scalar_args{scalar_arg};
236 auto call =
237 alloc<ExternalCall>(output, "nnc_aten_matmul", buf_args, scalar_args);
238 const std::string pattern = R"(
239 # CHECK: {
240 # CHECK: void* buf_ptrs[]{out, a, b};
241 # CHECK: int64_t buf_ranks[]{2, 2, 2};
242 # CHECK: int64_t buf_dims[]{2, 2, 2, 2, 2, 2};
243 # CHECK: int8_t buf_dtypes[]{6, 6, 6};
244 # CHECK: int64_t extra_args[]{1 + 2};
245 # CHECK: nnc_aten_matmul(
246 # CHECK: 3,
247 # CHECK: buf_ptrs,
248 # CHECK: buf_ranks,
249 # CHECK: buf_dims,
250 # CHECK: buf_dtypes,
251 # CHECK: 1,
252 # CHECK: extra_args);
253 # CHECK: }
254 )";
255 FILE_CHECK(call, pattern);
256}
257
258} // namespace jit
259} // namespace torch
260