1#pragma once
2
3#include <memory>
4#include <vector>
5
6#include <test/cpp/tensorexpr/test_base.h>
7#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
8#include <torch/csrc/jit/testing/file_check.h>
9
10namespace torch {
11namespace jit {
12using namespace torch::jit::tensorexpr;
13
14#define IS_NODE(T, node) \
15 { \
16 auto node_ = to<T>(node); \
17 ASSERT_NE(nullptr, node_); \
18 }
19
20#define IS_NODE_WITH_NAME(T, node, name) \
21 auto name = to<T>(node); \
22 ASSERT_NE(nullptr, name);
23
24#define IS_NODE_WITH_NAME_AND_CAST(T, node, name, Type) \
25 NodePtr<T> name = nullptr; \
26 { \
27 auto node_ = to<Cast>(node); \
28 ASSERT_NE(nullptr, node_); \
29 ASSERT_EQ(node_->dtype().scalar_type(), ScalarType::Type); \
30 name = to<T>(node_->src_value()); \
31 } \
32 ASSERT_NE(nullptr, name);
33
34#define IS_IMM_WITH_VAL(T, node, val) \
35 { \
36 auto node_ = to<T##Imm>(node); \
37 ASSERT_NE(nullptr, node_); \
38 ASSERT_EQ(node_->value(), val); \
39 }
40
41#define IS_VAR_WITH_NAME(node, name) \
42 { \
43 auto node_ = to<Var>(node); \
44 ASSERT_NE(nullptr, node_); \
45 ASSERT_EQ(node_->name_hint(), name); \
46 }
47
48#define IS_BINOP_W_VARS(T, node, name, v1, v2) \
49 NodePtr<T> name = nullptr; \
50 { \
51 name = to<T>(node); \
52 ASSERT_NE(nullptr, name); \
53 IS_VAR_WITH_NAME(name->lhs(), v1); \
54 IS_VAR_WITH_NAME(name->rhs(), v2); \
55 }
56
57#define IS_BINOP_W_CONST(T, node, name, v, c) \
58 NodePtr<T> name = nullptr; \
59 { \
60 name = to<T>(node); \
61 ASSERT_NE(nullptr, name); \
62 IS_VAR_WITH_NAME(name->lhs(), v); \
63 IS_IMM_WITH_VAL(Int, name->rhs(), c); \
64 }
65
66#define IS_RAND(node) \
67 { \
68 auto node_ = to<Intrinsics>(node); \
69 ASSERT_NE(nullptr, node_); \
70 ASSERT_EQ(node_->op_type(), kRand); \
71 }
72
73void checkIR(StmtPtr s, const std::string& pattern);
74void checkExprIR(ExprPtr e, const std::string& pattern);
75void checkExprIR(const ExprHandle& e, const std::string& pattern);
76
77} // namespace jit
78} // namespace torch
79