1#include <gtest/gtest.h>
2
3#include <stdexcept>
4#include "test/cpp/tensorexpr/test_base.h"
5
6#include <torch/csrc/jit/tensorexpr/expr.h>
7#include <torch/csrc/jit/tensorexpr/ir.h>
8#include <torch/csrc/jit/tensorexpr/ir_printer.h>
9#include <torch/csrc/jit/tensorexpr/loopnest.h>
10#include <torch/csrc/jit/tensorexpr/tensor.h>
11#include <torch/csrc/jit/testing/file_check.h>
12
13#include <sstream>
14namespace torch {
15namespace jit {
16
17using namespace torch::jit::tensorexpr;
18
19TEST(IRPrinter, BasicValueTest) {
20 ExprHandle a = IntImm::make(2), b = IntImm::make(3);
21 ExprHandle c = Add::make(a, b);
22
23 std::stringstream ss;
24 ss << c;
25 ASSERT_EQ(ss.str(), "2 + 3");
26}
27
28TEST(IRPrinter, BasicValueTest02) {
29 ExprHandle a(2.0f);
30 ExprHandle b(3.0f);
31 ExprHandle c(4.0f);
32 ExprHandle d(5.0f);
33 ExprHandle f = (a + b) - (c + d);
34
35 std::stringstream ss;
36 ss << f;
37 ASSERT_EQ(ss.str(), "(2.f + 3.f) - (4.f + 5.f)");
38}
39
40TEST(IRPrinter, CastTest) {
41 VarHandle x("x", kHalf);
42 VarHandle y("y", kFloat);
43 ExprHandle body = ExprHandle(2.f) +
44 (Cast::make(kFloat, x) * ExprHandle(3.f) + ExprHandle(4.f) * y);
45
46 std::stringstream ss;
47 ss << body;
48 ASSERT_EQ(ss.str(), "2.f + (float(x) * 3.f + 4.f * y)");
49}
50
51TEST(IRPrinter, FunctionName) {
52 int M = 4;
53 int N = 20;
54
55 Tensor producer = Compute(
56 "producer", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
57 return m * n;
58 });
59
60 Tensor chunk_0 = Compute(
61 "chunk_0", {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) {
62 return producer.load(m, n);
63 });
64
65 Tensor chunk_1 = Compute(
66 "chunk_1", {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) {
67 return producer.load(m, n + ExprHandle(N / 2));
68 });
69
70 Tensor consumer = Compute(
71 "consumer", {M, N / 2}, [&](const ExprHandle& i, const ExprHandle& j) {
72 return i * chunk_1.load(i, j);
73 });
74
75 LoopNest l({chunk_0, chunk_1, consumer});
76 auto body = LoopNest::sanitizeNames(l.root_stmt());
77
78 std::stringstream ss;
79 ss << *body;
80
81 const std::string& verification_pattern =
82 R"IR(
83 # CHECK: for (int i_2
84 # CHECK: for (int j_2
85 # CHECK: consumer[i_2, j_2] = i_2 * (chunk_1[i_2, j_2])IR";
86
87 torch::jit::testing::FileCheck().run(verification_pattern, ss.str());
88}
89} // namespace jit
90} // namespace torch
91