1#include <algorithm>
2#include <sstream>
3#include <stdexcept>
4
5#include <gtest/gtest.h>
6
7#include <c10/macros/Macros.h>
8#include <c10/util/irange.h>
9#include "test/cpp/tensorexpr/padded_buffer.h"
10#include "test/cpp/tensorexpr/test_base.h"
11#include "torch/csrc/jit/tensorexpr/ir_printer.h"
12
13namespace torch {
14namespace jit {
15
16using namespace torch::jit::tensorexpr;
17
18TEST(ATen, _cast_Float) {
19 const int kTotalSize = 128;
20 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
21 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
22
23 VarHandle index = VarHandle("index", kInt);
24 ExprHandle load_a = a_buf.load(index);
25 ExprHandle to_float = Cast::make(kFloat, load_a);
26 StmtPtr store_b = b_buf.store({index}, to_float);
27 StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
28
29 PaddedBuffer<int> a_v(kTotalSize);
30 PaddedBuffer<float> b_v(kTotalSize);
31
32 for (const auto i : c10::irange(kTotalSize)) {
33 a_v(i) = i;
34 }
35
36 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
37 ir_eval(a_v, b_v);
38
39 for (const auto i : c10::irange(kTotalSize)) {
40 ASSERT_EQ(a_v(i), i);
41 ASSERT_EQ(b_v(i), static_cast<float>(i));
42 }
43}
44
45TEST(ATen, negInt) {
46 const int kTotalSize = 128;
47 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
48 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
49
50 VarHandle index = VarHandle("index", kInt);
51 ExprHandle load_a = a_buf.load(index);
52 ExprHandle to_float = Sub::make(0, load_a);
53 StmtPtr store_b = b_buf.store({index}, to_float);
54 StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
55
56 PaddedBuffer<int> a_v(kTotalSize);
57 PaddedBuffer<int> b_v(kTotalSize);
58
59 for (const auto i : c10::irange(kTotalSize)) {
60 a_v(i) = i;
61 }
62
63 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
64 ir_eval(a_v, b_v);
65
66 for (const auto i : c10::irange(kTotalSize)) {
67 ASSERT_EQ(a_v(i), i);
68 ASSERT_EQ(b_v(i), -static_cast<float>(i));
69 }
70}
71
72TEST(ATen, negFloat) {
73 const int kTotalSize = 128;
74 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
75 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
76
77 VarHandle index = VarHandle("index", kInt);
78 ExprHandle load_a = a_buf.load(index);
79 ExprHandle to_float = Sub::make(0, load_a);
80 StmtPtr store_b = b_buf.store({index}, to_float);
81 StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
82
83 PaddedBuffer<float> a_v(kTotalSize);
84 PaddedBuffer<float> b_v(kTotalSize);
85
86 for (const auto i : c10::irange(kTotalSize)) {
87 a_v(i) = i;
88 }
89
90 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
91 ir_eval(a_v, b_v);
92
93 for (const auto i : c10::irange(kTotalSize)) {
94 ASSERT_EQ(a_v(i), i);
95 ASSERT_EQ(b_v(i), -i);
96 }
97}
98
99TEST(ATen, addInt) {
100 const int kTotalSize = 128;
101 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
102 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
103 BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt);
104 BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kInt);
105
106 VarHandle index = VarHandle("index", kInt);
107 ExprHandle load_a = a_buf.load(index);
108 ExprHandle load_b = b_buf.load(index);
109 ExprHandle load_c = c_buf.load(index);
110 StmtPtr store_d = d_buf.store({index}, load_a + load_b * load_c);
111 StmtPtr stmt = For::make(index, 0, kTotalSize, store_d);
112
113 PaddedBuffer<int> a_v(kTotalSize);
114 PaddedBuffer<int> b_v(kTotalSize);
115 PaddedBuffer<int> c_v(kTotalSize);
116 PaddedBuffer<int> d_v(kTotalSize);
117
118 for (const auto i : c10::irange(kTotalSize)) {
119 a_v(i) = i;
120 b_v(i) = 2 * i + 1;
121 c_v(i) = 3 * i + 2;
122 }
123
124 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf});
125 ir_eval(a_v, b_v, c_v, d_v);
126
127 for (const auto i : c10::irange(kTotalSize)) {
128 ASSERT_EQ(a_v(i), i);
129 ASSERT_EQ(b_v(i), 2 * i + 1);
130 ASSERT_EQ(c_v(i), 3 * i + 2);
131 ASSERT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i));
132 }
133}
134
135TEST(ATen, addFloat) {
136 const int kTotalSize = 128;
137 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
138 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
139 BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat);
140 BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat);
141
142 VarHandle index = VarHandle("index", kInt);
143 ExprHandle load_a = a_buf.load(index);
144 ExprHandle load_b = b_buf.load(index);
145 ExprHandle load_c = c_buf.load(index);
146 StmtPtr store_d = d_buf.store({index}, load_a + load_b * load_c);
147 StmtPtr stmt = For::make(index, 0, kTotalSize, store_d);
148
149 PaddedBuffer<float> a_v(kTotalSize);
150 PaddedBuffer<float> b_v(kTotalSize);
151 PaddedBuffer<float> c_v(kTotalSize);
152 PaddedBuffer<float> d_v(kTotalSize);
153
154 for (const auto i : c10::irange(kTotalSize)) {
155 a_v(i) = i;
156 b_v(i) = 2 * i + 1;
157 c_v(i) = 3 * i + 2;
158 }
159
160 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf});
161 ir_eval(a_v, b_v, c_v, d_v);
162
163 for (const auto i : c10::irange(kTotalSize)) {
164 ASSERT_EQ(a_v(i), i);
165 ASSERT_EQ(b_v(i), 2 * i + 1);
166 ASSERT_EQ(c_v(i), 3 * i + 2);
167 ASSERT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i));
168 }
169}
170
171TEST(ATen, subInt) {
172 const int kTotalSize = 128;
173 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
174 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
175 BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt);
176 BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kInt);
177
178 VarHandle index = VarHandle("index", kInt);
179 ExprHandle load_a = a_buf.load(index);
180 ExprHandle load_b = b_buf.load(index);
181 ExprHandle load_c = c_buf.load(index);
182 StmtPtr store_d = d_buf.store({index}, load_a - load_b * load_c);
183 StmtPtr stmt = For::make(index, 0, kTotalSize, store_d);
184
185 PaddedBuffer<int> a_v(kTotalSize);
186 PaddedBuffer<int> b_v(kTotalSize);
187 PaddedBuffer<int> c_v(kTotalSize);
188 PaddedBuffer<int> d_v(kTotalSize);
189
190 for (const auto i : c10::irange(kTotalSize)) {
191 a_v(i) = i;
192 b_v(i) = 2 * i + 1;
193 c_v(i) = 3 * i + 2;
194 }
195
196 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf});
197 ir_eval(a_v, b_v, c_v, d_v);
198
199 for (const auto i : c10::irange(kTotalSize)) {
200 ASSERT_EQ(a_v(i), i);
201 ASSERT_EQ(b_v(i), 2 * i + 1);
202 ASSERT_EQ(c_v(i), 3 * i + 2);
203 ASSERT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i));
204 }
205}
206
207TEST(ATen, subFloat) {
208 const int kTotalSize = 128;
209 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
210 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
211 BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat);
212 BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat);
213
214 VarHandle index = VarHandle("index", kInt);
215 ExprHandle load_a = a_buf.load(index);
216 ExprHandle load_b = b_buf.load(index);
217 ExprHandle load_c = c_buf.load(index);
218 StmtPtr store_d = d_buf.store({index}, load_a - load_b * load_c);
219 StmtPtr stmt = For::make(index, 0, kTotalSize, store_d);
220
221 PaddedBuffer<float> a_v(kTotalSize);
222 PaddedBuffer<float> b_v(kTotalSize);
223 PaddedBuffer<float> c_v(kTotalSize);
224 PaddedBuffer<float> d_v(kTotalSize);
225
226 for (const auto i : c10::irange(kTotalSize)) {
227 a_v(i) = i;
228 b_v(i) = 2 * i + 1;
229 c_v(i) = 3 * i + 2;
230 }
231
232 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf});
233 ir_eval(a_v, b_v, c_v, d_v);
234
235 for (const auto i : c10::irange(kTotalSize)) {
236 ASSERT_EQ(a_v(i), i);
237 ASSERT_EQ(b_v(i), 2 * i + 1);
238 ASSERT_EQ(c_v(i), 3 * i + 2);
239 ASSERT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i));
240 }
241}
242
243TEST(ATen, lerp) {
244 const int kTotalSize = 128;
245 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
246 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
247 BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat);
248 BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat);
249
250 VarHandle index = VarHandle("index", kInt);
251 ExprHandle load_a = a_buf.load(index);
252 ExprHandle load_b = b_buf.load(index);
253 ExprHandle load_c = c_buf.load(index);
254 StmtPtr store_d = d_buf.store({index}, load_a + load_c * (load_b - load_a));
255 StmtPtr stmt = For::make(index, 0, kTotalSize, store_d);
256
257 PaddedBuffer<float> a_v(kTotalSize);
258 PaddedBuffer<float> b_v(kTotalSize);
259 PaddedBuffer<float> c_v(kTotalSize);
260 PaddedBuffer<float> d_v(kTotalSize);
261
262 for (const auto i : c10::irange(kTotalSize)) {
263 a_v(i) = i;
264 b_v(i) = 2 * i + 1;
265 c_v(i) = 3 * i + 2;
266 }
267
268 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf});
269 ir_eval(a_v, b_v, c_v, d_v);
270
271 for (const auto i : c10::irange(kTotalSize)) {
272 ASSERT_EQ(a_v(i), i);
273 ASSERT_EQ(b_v(i), 2 * i + 1);
274 ASSERT_EQ(c_v(i), 3 * i + 2);
275 ASSERT_EQ(d_v(i), a_v(i) + c_v(i) * (b_v(i) - a_v(i)));
276 }
277}
278
279TEST(ATen, addcmulInt) {
280 const int kTotalSize = 128;
281 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
282 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
283 BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt);
284 BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kInt);
285 BufHandle e_buf("E", {ExprHandle(kTotalSize)}, kInt);
286
287 VarHandle index = VarHandle("index", kInt);
288 ExprHandle load_a = a_buf.load(index);
289 ExprHandle load_b = b_buf.load(index);
290 ExprHandle load_c = c_buf.load(index);
291 ExprHandle load_d = d_buf.load(index);
292 StmtPtr store_e = e_buf.store({index}, load_a + load_b * load_c * load_d);
293 StmtPtr stmt = For::make(index, 0, kTotalSize, store_e);
294
295 PaddedBuffer<int> a_v(kTotalSize);
296 PaddedBuffer<int> b_v(kTotalSize);
297 PaddedBuffer<int> c_v(kTotalSize);
298 PaddedBuffer<int> d_v(kTotalSize);
299 PaddedBuffer<int> e_v(kTotalSize);
300
301 for (const auto i : c10::irange(kTotalSize)) {
302 a_v(i) = i;
303 b_v(i) = 2 * i + 1;
304 c_v(i) = 3 * i + 2;
305 d_v(i) = 5 * i + 3;
306 }
307
308 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf, e_buf});
309 ir_eval(a_v, b_v, c_v, d_v, e_v);
310
311 for (const auto i : c10::irange(kTotalSize)) {
312 ASSERT_EQ(a_v(i), i);
313 ASSERT_EQ(b_v(i), 2 * i + 1);
314 ASSERT_EQ(c_v(i), 3 * i + 2);
315 ASSERT_EQ(d_v(i), 5 * i + 3);
316 ASSERT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i));
317 }
318}
319
320TEST(ATen, addcmulFloat) {
321 const int kTotalSize = 128;
322 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
323 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
324 BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat);
325 BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat);
326 BufHandle e_buf("E", {ExprHandle(kTotalSize)}, kFloat);
327
328 VarHandle index = VarHandle("index", kInt);
329 ExprHandle load_a = a_buf.load(index);
330 ExprHandle load_b = b_buf.load(index);
331 ExprHandle load_c = c_buf.load(index);
332 ExprHandle load_d = d_buf.load(index);
333 StmtPtr store_e = e_buf.store({index}, load_a + load_b * load_c * load_d);
334 StmtPtr stmt = For::make(index, 0, kTotalSize, store_e);
335
336 PaddedBuffer<float> a_v(kTotalSize);
337 PaddedBuffer<float> b_v(kTotalSize);
338 PaddedBuffer<float> c_v(kTotalSize);
339 PaddedBuffer<float> d_v(kTotalSize);
340 PaddedBuffer<float> e_v(kTotalSize);
341
342 for (const auto i : c10::irange(kTotalSize)) {
343 a_v(i) = i;
344 b_v(i) = 2 * i + 1;
345 c_v(i) = 3 * i + 2;
346 d_v(i) = 5 * i + 3;
347 }
348
349 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf, e_buf});
350 ir_eval(a_v, b_v, c_v, d_v, e_v);
351
352 for (const auto i : c10::irange(kTotalSize)) {
353 ASSERT_EQ(a_v(i), i);
354 ASSERT_EQ(b_v(i), 2 * i + 1);
355 ASSERT_EQ(c_v(i), 3 * i + 2);
356 ASSERT_EQ(d_v(i), 5 * i + 3);
357 ASSERT_FLOAT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i));
358 }
359}
360
361TEST(ATen, mulInt) {
362 const int kTotalSize = 128;
363 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
364 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
365 BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt);
366
367 VarHandle index = VarHandle("index", kInt);
368 ExprHandle load_a = a_buf.load(index);
369 ExprHandle load_b = b_buf.load(index);
370 StmtPtr store_c = c_buf.store({index}, load_a * load_b);
371 StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
372
373 PaddedBuffer<int> a_v(kTotalSize);
374 PaddedBuffer<int> b_v(kTotalSize);
375 PaddedBuffer<int> c_v(kTotalSize);
376
377 for (const auto i : c10::irange(kTotalSize)) {
378 a_v(i) = i;
379 b_v(i) = 2 * i + 1;
380 }
381
382 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf});
383 ir_eval(a_v, b_v, c_v);
384
385 for (const auto i : c10::irange(kTotalSize)) {
386 ASSERT_EQ(a_v(i), i);
387 ASSERT_EQ(b_v(i), 2 * i + 1);
388 ASSERT_EQ(c_v(i), a_v(i) * b_v(i));
389 }
390}
391
392TEST(ATen, mulFloat) {
393 const int kTotalSize = 128;
394 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
395 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
396 BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat);
397
398 VarHandle index = VarHandle("index", kInt);
399 ExprHandle load_a = a_buf.load(index);
400 ExprHandle load_b = b_buf.load(index);
401 StmtPtr store_c = c_buf.store({index}, load_a * load_b);
402 StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
403
404 PaddedBuffer<float> a_v(kTotalSize);
405 PaddedBuffer<float> b_v(kTotalSize);
406 PaddedBuffer<float> c_v(kTotalSize);
407
408 for (const auto i : c10::irange(kTotalSize)) {
409 a_v(i) = i;
410 b_v(i) = 2 * i + 1;
411 }
412
413 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf});
414 ir_eval(a_v, b_v, c_v);
415
416 for (const auto i : c10::irange(kTotalSize)) {
417 ASSERT_EQ(a_v(i), i);
418 ASSERT_EQ(b_v(i), 2 * i + 1);
419 ASSERT_EQ(c_v(i), a_v(i) * b_v(i));
420 }
421}
422
423TEST(ATen, divInt) {
424 const int kTotalSize = 128;
425 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
426 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
427 BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt);
428
429 VarHandle index = VarHandle("index", kInt);
430 ExprHandle load_a = a_buf.load(index);
431 ExprHandle load_b = b_buf.load(index);
432 StmtPtr store_c = c_buf.store({index}, load_a / load_b);
433 StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
434
435 PaddedBuffer<int> a_v(kTotalSize);
436 PaddedBuffer<int> b_v(kTotalSize);
437 PaddedBuffer<int> c_v(kTotalSize);
438
439 for (const auto i : c10::irange(kTotalSize)) {
440 a_v(i) = 2 * i + 1;
441 b_v(i) = i + 1;
442 }
443
444 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf});
445 ir_eval(a_v, b_v, c_v);
446
447 for (const auto i : c10::irange(kTotalSize)) {
448 ASSERT_EQ(a_v(i), 2 * i + 1);
449 ASSERT_EQ(b_v(i), i + 1);
450 ASSERT_EQ(c_v(i), a_v(i) / b_v(i));
451 }
452}
453
454TEST(ATen, divFloat) {
455 const int kTotalSize = 128;
456 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
457 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
458 BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat);
459
460 VarHandle index = VarHandle("index", kInt);
461 ExprHandle load_a = a_buf.load(index);
462 ExprHandle load_b = b_buf.load(index);
463 StmtPtr store_c = c_buf.store({index}, load_a / load_b);
464 StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
465
466 PaddedBuffer<float> a_v(kTotalSize);
467 PaddedBuffer<float> b_v(kTotalSize);
468 PaddedBuffer<float> c_v(kTotalSize);
469
470 for (const auto i : c10::irange(kTotalSize)) {
471 a_v(i) = 2 * i + 1;
472 b_v(i) = i + 1;
473 }
474
475 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf});
476 ir_eval(a_v, b_v, c_v);
477
478 for (const auto i : c10::irange(kTotalSize)) {
479 ASSERT_EQ(a_v(i), 2 * i + 1);
480 ASSERT_EQ(b_v(i), i + 1);
481 ASSERT_EQ(c_v(i), a_v(i) / b_v(i));
482 }
483}
484
485TEST(ATen, maxInt) {
486 const int kTotalSize = 128;
487 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
488 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
489 BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt);
490
491 VarHandle index = VarHandle("index", kInt);
492 ExprHandle load_a = a_buf.load(index);
493 ExprHandle load_b = b_buf.load(index);
494 StmtPtr store_c = c_buf.store({index}, Max::make(load_a, load_b, true));
495 StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
496
497 PaddedBuffer<int> a_v(kTotalSize);
498 PaddedBuffer<int> b_v(kTotalSize);
499 PaddedBuffer<int> c_v(kTotalSize);
500
501 for (const auto i : c10::irange(kTotalSize)) {
502 a_v(i) = i;
503 b_v(i) = 2 * i + 1;
504 }
505
506 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf});
507 ir_eval(a_v, b_v, c_v);
508
509 for (const auto i : c10::irange(kTotalSize)) {
510 ASSERT_EQ(a_v(i), i);
511 ASSERT_EQ(b_v(i), 2 * i + 1);
512 ASSERT_EQ(c_v(i), std::max(a_v(i), b_v(i)));
513 }
514}
515
516TEST(ATen, maxFloat) {
517 const int kTotalSize = 128;
518 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
519 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
520 BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat);
521
522 VarHandle index = VarHandle("index", kInt);
523 ExprHandle load_a = a_buf.load(index);
524 ExprHandle load_b = b_buf.load(index);
525 StmtPtr store_c = c_buf.store({index}, Max::make(load_a, load_b, true));
526 StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
527
528 PaddedBuffer<float> a_v(kTotalSize);
529 PaddedBuffer<float> b_v(kTotalSize);
530 PaddedBuffer<float> c_v(kTotalSize);
531
532 for (const auto i : c10::irange(kTotalSize)) {
533 a_v(i) = i;
534 b_v(i) = 2 * i + 1;
535 }
536
537 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf});
538 ir_eval(a_v, b_v, c_v);
539
540 for (const auto i : c10::irange(kTotalSize)) {
541 ASSERT_EQ(a_v(i), i);
542 ASSERT_EQ(b_v(i), 2 * i + 1);
543 ASSERT_EQ(c_v(i), std::fmax(a_v(i), b_v(i)));
544 }
545}
546
547TEST(ATen, minInt) {
548 const int kTotalSize = 128;
549 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
550 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
551 BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt);
552
553 VarHandle index = VarHandle("index", kInt);
554 ExprHandle load_a = a_buf.load(index);
555 ExprHandle load_b = b_buf.load(index);
556 StmtPtr store_c = c_buf.store({index}, Min::make(load_a, load_b, true));
557 StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
558
559 PaddedBuffer<int> a_v(kTotalSize);
560 PaddedBuffer<int> b_v(kTotalSize);
561 PaddedBuffer<int> c_v(kTotalSize);
562
563 for (const auto i : c10::irange(kTotalSize)) {
564 a_v(i) = i;
565 b_v(i) = 2 * i + 1;
566 }
567
568 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf});
569 ir_eval(a_v, b_v, c_v);
570
571 for (const auto i : c10::irange(kTotalSize)) {
572 ASSERT_EQ(a_v(i), i);
573 ASSERT_EQ(b_v(i), 2 * i + 1);
574 ASSERT_EQ(c_v(i), std::min(a_v(i), b_v(i)));
575 }
576}
577
578TEST(ATen, minFloat) {
579 const int kTotalSize = 128;
580 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
581 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
582 BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat);
583
584 VarHandle index = VarHandle("index", kInt);
585 ExprHandle load_a = a_buf.load(index);
586 ExprHandle load_b = b_buf.load(index);
587 StmtPtr store_c = c_buf.store({index}, Min::make(load_a, load_b, true));
588 StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
589
590 PaddedBuffer<float> a_v(kTotalSize);
591 PaddedBuffer<float> b_v(kTotalSize);
592 PaddedBuffer<float> c_v(kTotalSize);
593
594 for (const auto i : c10::irange(kTotalSize)) {
595 a_v(i) = i;
596 b_v(i) = 2 * i + 1;
597 }
598
599 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf});
600 ir_eval(a_v, b_v, c_v);
601
602 for (const auto i : c10::irange(kTotalSize)) {
603 ASSERT_EQ(a_v(i), i);
604 ASSERT_EQ(b_v(i), 2 * i + 1);
605 ASSERT_EQ(c_v(i), std::fmin(a_v(i), b_v(i)));
606 }
607}
608
609void __ubsan_ignore_float_divide_by_zero__ testATenreciprocal() {
610 const int kTotalSize = 128;
611 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
612 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
613
614 VarHandle index = VarHandle("index", kInt);
615 ExprHandle load_a = a_buf.load(index);
616 StmtPtr store_b = b_buf.store({index}, FloatImm::make(1.0f) / load_a);
617 StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
618
619 PaddedBuffer<float> a_v(kTotalSize);
620 PaddedBuffer<float> b_v(kTotalSize);
621
622 for (const auto i : c10::irange(kTotalSize)) {
623 a_v(i) = i;
624 }
625
626 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
627 ir_eval(a_v, b_v);
628
629 for (const auto i : c10::irange(kTotalSize)) {
630 ASSERT_EQ(a_v(i), i);
631 ASSERT_EQ(b_v(i), 1.0f / i);
632 }
633}
634
635TEST(ATen, reluInt) {
636 const int kTotalSize = 128;
637 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
638 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
639
640 VarHandle index = VarHandle("index", kInt);
641 ExprHandle load_a = a_buf.load(index);
642 StmtPtr store_b = b_buf.store({index}, Max::make(load_a, 0, false));
643 StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
644
645 PaddedBuffer<int> a_v(kTotalSize);
646 PaddedBuffer<int> b_v(kTotalSize);
647
648 for (const auto i : c10::irange(kTotalSize)) {
649 a_v(i) = i - 64;
650 }
651
652 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
653 ir_eval(a_v, b_v);
654
655 for (const auto i : c10::irange(kTotalSize)) {
656 ASSERT_EQ(a_v(i), i - 64);
657 ASSERT_EQ(b_v(i), std::max(a_v(i), 0));
658 }
659}
660
661TEST(ATen, reluFloat) {
662 const int kTotalSize = 128;
663 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
664 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
665
666 VarHandle index = VarHandle("index", kInt);
667 ExprHandle load_a = a_buf.load(index);
668 StmtPtr store_b = b_buf.store(
669 {index}, Max::make(load_a, 0, false) // relu does not propagate nans
670 );
671 StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
672
673 PaddedBuffer<float> a_v(kTotalSize);
674 PaddedBuffer<float> b_v(kTotalSize);
675
676 for (const auto i : c10::irange(kTotalSize)) {
677 a_v(i) = i - 64;
678 }
679
680 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
681 ir_eval(a_v, b_v);
682
683 for (const auto i : c10::irange(kTotalSize)) {
684 ASSERT_EQ(a_v(i), i - 64);
685 ASSERT_EQ(b_v(i), std::fmax(a_v(i), 0));
686 }
687}
688
689TEST(ATen, logFloat) {
690 const int kTotalSize = 128;
691 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
692 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
693
694 VarHandle index = VarHandle("index", kInt);
695 ExprHandle load_a = a_buf.load(index);
696 StmtPtr store_b = b_buf.store({index}, log(load_a));
697 StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
698
699 PaddedBuffer<float> a_v(kTotalSize);
700 PaddedBuffer<float> b_v(kTotalSize);
701
702 for (const auto i : c10::irange(kTotalSize)) {
703 a_v(i) = i + 10;
704 }
705
706 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
707 ir_eval(a_v, b_v);
708
709 for (const auto i : c10::irange(kTotalSize)) {
710 ASSERT_EQ(a_v(i), i + 10);
711 ASSERT_EQ(b_v(i), std::log(a_v(i)));
712 }
713}
714
715TEST(ATen, fastLogFloat) {
716 const int kTotalSize = 128;
717 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
718 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
719
720 VarHandle index = VarHandle("index", kInt);
721 ExprHandle load_a = a_buf.load(index);
722 StmtPtr store_b = b_buf.store({index}, fast_log(load_a));
723 StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
724
725 PaddedBuffer<float> a_v(kTotalSize);
726 PaddedBuffer<float> b_v(kTotalSize);
727
728 for (const auto i : c10::irange(kTotalSize)) {
729 a_v(i) = at::randn({1}).item().to<float>();
730 }
731
732 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
733 ir_eval(a_v, b_v);
734
735 for (const auto i : c10::irange(kTotalSize)) {
736 auto test = b_v(i);
737 auto ref = std::log(a_v(i));
738 if (std::isnan(ref)) {
739 ASSERT_EQ(std::isnan(test), true);
740 } else {
741 ASSERT_FLOAT_EQ(test, ref);
742 }
743 }
744}
745
746TEST(ATen, fastTanhFloat) {
747 const int kTotalSize = 128;
748 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
749 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
750
751 VarHandle index = VarHandle("index", kInt);
752 ExprHandle load_a = a_buf.load(index);
753 StmtPtr store_b = b_buf.store({index}, fast_tanh(load_a));
754 StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
755
756 PaddedBuffer<float> a_v(kTotalSize);
757 PaddedBuffer<float> b_v(kTotalSize);
758
759 for (const auto i : c10::irange(kTotalSize)) {
760 a_v(i) = at::randn({1}).item().to<float>();
761 }
762
763 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
764 ir_eval(a_v, b_v);
765
766 for (const auto i : c10::irange(kTotalSize)) {
767 auto test = b_v(i);
768 auto ref = std::tanh(a_v(i));
769 if (std::isnan(ref)) {
770 ASSERT_EQ(std::isnan(test), true);
771 } else {
772 ASSERT_NEAR(test, ref, 1e-6);
773 }
774 }
775}
776
777TEST(ATen, fastSigmoidFloat) {
778 const int kTotalSize = 128;
779 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
780 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
781
782 VarHandle index = VarHandle("index", kInt);
783 ExprHandle load_a = a_buf.load(index);
784 StmtPtr store_b = b_buf.store({index}, fast_sigmoid(load_a));
785 StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
786
787 PaddedBuffer<float> a_v(kTotalSize);
788 PaddedBuffer<float> b_v(kTotalSize);
789
790 for (const auto i : c10::irange(kTotalSize)) {
791 a_v(i) = at::randn({1}).item().to<float>();
792 }
793
794 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
795 ir_eval(a_v, b_v);
796
797 for (const auto i : c10::irange(kTotalSize)) {
798 auto test = b_v(i);
799 at::Tensor t = at::ones({1}) * a_v(i);
800 float ref = at::sigmoid(t).item().to<float>();
801 if (std::isnan(ref)) {
802 ASSERT_EQ(std::isnan(test), true);
803 } else {
804 ASSERT_NEAR(test, ref, 1e-6);
805 }
806 }
807}
808
809TEST(ATen, log10Float) {
810 const int kTotalSize = 128;
811 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
812 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
813
814 VarHandle index = VarHandle("index", kInt);
815 ExprHandle load_a = a_buf.load(index);
816 StmtPtr store_b = b_buf.store({index}, log10(load_a));
817 StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
818
819 PaddedBuffer<float> a_v(kTotalSize);
820 PaddedBuffer<float> b_v(kTotalSize);
821
822 for (const auto i : c10::irange(kTotalSize)) {
823 a_v(i) = i + 10;
824 }
825
826 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
827 ir_eval(a_v, b_v);
828
829 for (const auto i : c10::irange(kTotalSize)) {
830 ASSERT_EQ(a_v(i), i + 10);
831 ASSERT_EQ(b_v(i), std::log10(a_v(i)));
832 }
833}
834
835TEST(ATen, log2Float) {
836 const int kTotalSize = 128;
837 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
838 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
839
840 VarHandle index = VarHandle("index", kInt);
841 ExprHandle load_a = a_buf.load(index);
842 StmtPtr store_b = b_buf.store({index}, log2(load_a));
843 StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
844
845 PaddedBuffer<float> a_v(kTotalSize);
846 PaddedBuffer<float> b_v(kTotalSize);
847
848 for (const auto i : c10::irange(kTotalSize)) {
849 a_v(i) = i + 10;
850 }
851
852 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
853 ir_eval(a_v, b_v);
854
855 for (const auto i : c10::irange(kTotalSize)) {
856 ASSERT_EQ(a_v(i), i + 10);
857 ASSERT_EQ(b_v(i), std::log2(a_v(i)));
858 }
859}
860
861TEST(ATen, expFloat) {
862 const int kTotalSize = 128;
863 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
864 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
865
866 VarHandle index = VarHandle("index", kInt);
867 ExprHandle load_a = a_buf.load(index);
868 StmtPtr store_b = b_buf.store({index}, exp(load_a));
869 StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
870
871 PaddedBuffer<float> a_v(kTotalSize);
872 PaddedBuffer<float> b_v(kTotalSize);
873
874 for (const auto i : c10::irange(kTotalSize)) {
875 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
876 a_v(i) = i / 10.0f;
877 }
878
879 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
880 ir_eval(a_v, b_v);
881
882 for (const auto i : c10::irange(kTotalSize)) {
883 ASSERT_EQ(a_v(i), i / 10.0f);
884 ASSERT_EQ(b_v(i), std::exp(a_v(i)));
885 }
886}
887
888TEST(ATen, erfFloat) {
889 const int kTotalSize = 128;
890 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
891 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
892
893 VarHandle index = VarHandle("index", kInt);
894 ExprHandle load_a = a_buf.load(index);
895 StmtPtr store_b = b_buf.store({index}, erf(load_a));
896 StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
897
898 PaddedBuffer<float> a_v(kTotalSize);
899 PaddedBuffer<float> b_v(kTotalSize);
900
901 for (const auto i : c10::irange(kTotalSize)) {
902 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
903 a_v(i) = i / 10.0f;
904 }
905
906 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
907 ir_eval(a_v, b_v);
908
909 for (const auto i : c10::irange(kTotalSize)) {
910 ASSERT_EQ(a_v(i), i / 10.0f);
911 ASSERT_EQ(b_v(i), std::erf(a_v(i)));
912 }
913}
914
915TEST(ATen, cosFloat) {
916 const int kTotalSize = 128;
917 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
918 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
919
920 VarHandle index = VarHandle("index", kInt);
921 ExprHandle load_a = a_buf.load(index);
922 StmtPtr store_b = b_buf.store({index}, cos(load_a));
923 StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
924
925 PaddedBuffer<float> a_v(kTotalSize);
926 PaddedBuffer<float> b_v(kTotalSize);
927
928 for (const auto i : c10::irange(kTotalSize)) {
929 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
930 a_v(i) = i / 10.0f;
931 }
932
933 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
934 ir_eval(a_v, b_v);
935
936 for (const auto i : c10::irange(kTotalSize)) {
937 ASSERT_EQ(a_v(i), i / 10.0f);
938 ASSERT_EQ(b_v(i), std::cos(a_v(i)));
939 }
940}
941
942TEST(ATen, eqInt) {
943 constexpr int N = 128;
944 BufHandle a("A", {N}, kInt);
945 BufHandle b("B", {N}, kInt);
946 BufHandle c("C", {N}, kInt);
947 std::vector<int> a_buffer(N, 1);
948 std::vector<int> b_buffer(N, 1);
949 std::vector<int> c_buffer(N, 0);
950
951 VarHandle i("i", kInt);
952 auto memcpy_expr = For::make(
953 i,
954 0,
955 N,
956 c.store(
957 {i},
958 CompareSelect::make(
959 a.load(i), b.load(i), CompareSelectOperation::kEQ)));
960
961 SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c});
962 ir_eval(a_buffer, b_buffer, c_buffer);
963
964 assertAllEqual(c_buffer, 1);
965}
966
967TEST(ATen, geInt) {
968 constexpr int N = 128;
969 BufHandle a("A", {N}, kInt);
970 BufHandle b("B", {N}, kInt);
971 BufHandle c("C", {N}, kInt);
972 std::vector<int> a_buffer(N, 5);
973 std::vector<int> b_buffer(N, 5);
974 std::vector<int> c_buffer(N, 0);
975
976 VarHandle i("i", kInt);
977 auto memcpy_expr = For::make(
978 i,
979 0,
980 N,
981 c.store(
982 {i},
983 CompareSelect::make(
984 a.load(i), b.load(i), CompareSelectOperation::kGE)));
985
986 SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c});
987 ir_eval(a_buffer, b_buffer, c_buffer);
988
989 assertAllEqual(c_buffer, 1);
990}
991
992TEST(ATen, gtInt) {
993 constexpr int N = 128;
994 BufHandle a("A", {N}, kInt);
995 BufHandle b("B", {N}, kInt);
996 BufHandle c("C", {N}, kInt);
997 std::vector<int> a_buffer(N, 6);
998 std::vector<int> b_buffer(N, 3);
999 std::vector<int> c_buffer(N, 0);
1000
1001 VarHandle i("i", kInt);
1002 auto memcpy_expr = For::make(
1003 i,
1004 0,
1005 N,
1006 c.store(
1007 {i},
1008 CompareSelect::make(
1009 a.load(i), b.load(i), CompareSelectOperation::kGT)));
1010
1011 SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c});
1012 ir_eval(a_buffer, b_buffer, c_buffer);
1013
1014 assertAllEqual(c_buffer, 1);
1015}
1016
1017TEST(ATen, leInt) {
1018 constexpr int N = 128;
1019 BufHandle a("A", {N}, kInt);
1020 BufHandle b("B", {N}, kInt);
1021 BufHandle c("C", {N}, kInt);
1022 std::vector<int> a_buffer(N, 5);
1023 std::vector<int> b_buffer(N, 5);
1024 std::vector<int> c_buffer(N, 0);
1025
1026 VarHandle i("i", kInt);
1027 auto memcpy_expr = For::make(
1028 i,
1029 0,
1030 N,
1031 c.store(
1032 {i},
1033 CompareSelect::make(
1034 a.load(i), b.load(i), CompareSelectOperation::kLE)));
1035
1036 SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c});
1037 ir_eval(a_buffer, b_buffer, c_buffer);
1038
1039 assertAllEqual(c_buffer, 1);
1040}
1041
1042TEST(ATen, ltInt) {
1043 constexpr int N = 128;
1044 BufHandle a("A", {N}, kInt);
1045 BufHandle b("B", {N}, kInt);
1046 BufHandle c("C", {N}, kInt);
1047 std::vector<int> a_buffer(N, 5);
1048 std::vector<int> b_buffer(N, 5);
1049 std::vector<int> c_buffer(N, 1);
1050
1051 VarHandle i("i", kInt);
1052 auto memcpy_expr = For::make(
1053 i,
1054 0,
1055 N,
1056 c.store(
1057 {i},
1058 CompareSelect::make(
1059 a.load(i), b.load(i), CompareSelectOperation::kLT)));
1060
1061 SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c});
1062 ir_eval(a_buffer, b_buffer, c_buffer);
1063
1064 assertAllEqual(c_buffer, 0);
1065}
1066
1067} // namespace jit
1068} // namespace torch
1069