1 | #pragma once |
2 | |
3 | #include <memory> |
4 | #include <vector> |
5 | |
6 | #include <test/cpp/tensorexpr/test_base.h> |
7 | #include <torch/csrc/jit/tensorexpr/fwd_decls.h> |
8 | #include <torch/csrc/jit/testing/file_check.h> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | using namespace torch::jit::tensorexpr; |
13 | |
14 | #define IS_NODE(T, node) \ |
15 | { \ |
16 | auto node_ = to<T>(node); \ |
17 | ASSERT_NE(nullptr, node_); \ |
18 | } |
19 | |
20 | #define IS_NODE_WITH_NAME(T, node, name) \ |
21 | auto name = to<T>(node); \ |
22 | ASSERT_NE(nullptr, name); |
23 | |
24 | #define IS_NODE_WITH_NAME_AND_CAST(T, node, name, Type) \ |
25 | NodePtr<T> name = nullptr; \ |
26 | { \ |
27 | auto node_ = to<Cast>(node); \ |
28 | ASSERT_NE(nullptr, node_); \ |
29 | ASSERT_EQ(node_->dtype().scalar_type(), ScalarType::Type); \ |
30 | name = to<T>(node_->src_value()); \ |
31 | } \ |
32 | ASSERT_NE(nullptr, name); |
33 | |
34 | #define IS_IMM_WITH_VAL(T, node, val) \ |
35 | { \ |
36 | auto node_ = to<T##Imm>(node); \ |
37 | ASSERT_NE(nullptr, node_); \ |
38 | ASSERT_EQ(node_->value(), val); \ |
39 | } |
40 | |
41 | #define IS_VAR_WITH_NAME(node, name) \ |
42 | { \ |
43 | auto node_ = to<Var>(node); \ |
44 | ASSERT_NE(nullptr, node_); \ |
45 | ASSERT_EQ(node_->name_hint(), name); \ |
46 | } |
47 | |
48 | #define IS_BINOP_W_VARS(T, node, name, v1, v2) \ |
49 | NodePtr<T> name = nullptr; \ |
50 | { \ |
51 | name = to<T>(node); \ |
52 | ASSERT_NE(nullptr, name); \ |
53 | IS_VAR_WITH_NAME(name->lhs(), v1); \ |
54 | IS_VAR_WITH_NAME(name->rhs(), v2); \ |
55 | } |
56 | |
57 | #define IS_BINOP_W_CONST(T, node, name, v, c) \ |
58 | NodePtr<T> name = nullptr; \ |
59 | { \ |
60 | name = to<T>(node); \ |
61 | ASSERT_NE(nullptr, name); \ |
62 | IS_VAR_WITH_NAME(name->lhs(), v); \ |
63 | IS_IMM_WITH_VAL(Int, name->rhs(), c); \ |
64 | } |
65 | |
66 | #define IS_RAND(node) \ |
67 | { \ |
68 | auto node_ = to<Intrinsics>(node); \ |
69 | ASSERT_NE(nullptr, node_); \ |
70 | ASSERT_EQ(node_->op_type(), kRand); \ |
71 | } |
72 | |
73 | void checkIR(StmtPtr s, const std::string& pattern); |
74 | void checkExprIR(ExprPtr e, const std::string& pattern); |
75 | void checkExprIR(const ExprHandle& e, const std::string& pattern); |
76 | |
77 | } // namespace jit |
78 | } // namespace torch |
79 | |