1 | #include <gtest/gtest.h> |
2 | |
3 | #include <stdexcept> |
4 | #include "test/cpp/tensorexpr/test_base.h" |
5 | |
6 | #include <torch/csrc/jit/tensorexpr/expr.h> |
7 | #include <torch/csrc/jit/tensorexpr/ir.h> |
8 | #include <torch/csrc/jit/tensorexpr/ir_printer.h> |
9 | #include <torch/csrc/jit/tensorexpr/loopnest.h> |
10 | #include <torch/csrc/jit/tensorexpr/tensor.h> |
11 | #include <torch/csrc/jit/testing/file_check.h> |
12 | |
13 | #include <sstream> |
14 | namespace torch { |
15 | namespace jit { |
16 | |
17 | using namespace torch::jit::tensorexpr; |
18 | |
19 | TEST(IRPrinter, BasicValueTest) { |
20 | ExprHandle a = IntImm::make(2), b = IntImm::make(3); |
21 | ExprHandle c = Add::make(a, b); |
22 | |
23 | std::stringstream ss; |
24 | ss << c; |
25 | ASSERT_EQ(ss.str(), "2 + 3" ); |
26 | } |
27 | |
28 | TEST(IRPrinter, BasicValueTest02) { |
29 | ExprHandle a(2.0f); |
30 | ExprHandle b(3.0f); |
31 | ExprHandle c(4.0f); |
32 | ExprHandle d(5.0f); |
33 | ExprHandle f = (a + b) - (c + d); |
34 | |
35 | std::stringstream ss; |
36 | ss << f; |
37 | ASSERT_EQ(ss.str(), "(2.f + 3.f) - (4.f + 5.f)" ); |
38 | } |
39 | |
40 | TEST(IRPrinter, CastTest) { |
41 | VarHandle x("x" , kHalf); |
42 | VarHandle y("y" , kFloat); |
43 | ExprHandle body = ExprHandle(2.f) + |
44 | (Cast::make(kFloat, x) * ExprHandle(3.f) + ExprHandle(4.f) * y); |
45 | |
46 | std::stringstream ss; |
47 | ss << body; |
48 | ASSERT_EQ(ss.str(), "2.f + (float(x) * 3.f + 4.f * y)" ); |
49 | } |
50 | |
51 | TEST(IRPrinter, FunctionName) { |
52 | int M = 4; |
53 | int N = 20; |
54 | |
55 | Tensor producer = Compute( |
56 | "producer" , {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
57 | return m * n; |
58 | }); |
59 | |
60 | Tensor chunk_0 = Compute( |
61 | "chunk_0" , {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) { |
62 | return producer.load(m, n); |
63 | }); |
64 | |
65 | Tensor chunk_1 = Compute( |
66 | "chunk_1" , {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) { |
67 | return producer.load(m, n + ExprHandle(N / 2)); |
68 | }); |
69 | |
70 | Tensor consumer = Compute( |
71 | "consumer" , {M, N / 2}, [&](const ExprHandle& i, const ExprHandle& j) { |
72 | return i * chunk_1.load(i, j); |
73 | }); |
74 | |
75 | LoopNest l({chunk_0, chunk_1, consumer}); |
76 | auto body = LoopNest::sanitizeNames(l.root_stmt()); |
77 | |
78 | std::stringstream ss; |
79 | ss << *body; |
80 | |
81 | const std::string& verification_pattern = |
82 | R"IR( |
83 | # CHECK: for (int i_2 |
84 | # CHECK: for (int j_2 |
85 | # CHECK: consumer[i_2, j_2] = i_2 * (chunk_1[i_2, j_2])IR" ; |
86 | |
87 | torch::jit::testing::FileCheck().run(verification_pattern, ss.str()); |
88 | } |
89 | } // namespace jit |
90 | } // namespace torch |
91 | |