1#include "taichi/ir/expr.h"
2
3#include "taichi/ir/frontend_ir.h"
4#include "taichi/ir/ir.h"
5#include "taichi/program/program.h"
6
7namespace taichi::lang {
8
9void Expr::set_tb(const std::string &tb) {
10 expr->tb = tb;
11}
12
13DataType Expr::get_ret_type() const {
14 return expr->ret_type;
15}
16
17void Expr::type_check(const CompileConfig *config) {
18 expr->type_check(config);
19}
20
21Expr cast(const Expr &input, DataType dt) {
22 return Expr::make<UnaryOpExpression>(UnaryOpType::cast_value, input, dt);
23}
24
25Expr bit_cast(const Expr &input, DataType dt) {
26 return Expr::make<UnaryOpExpression>(UnaryOpType::cast_bits, input, dt);
27}
28
29Expr &Expr::operator=(const Expr &o) {
30 set(o);
31 return *this;
32}
33
34SNode *Expr::snode() const {
35 TI_ASSERT_INFO(is<FieldExpression>(),
36 "Cannot get snode of non-field expressions.");
37 return cast<FieldExpression>()->snode;
38}
39
40void Expr::set_adjoint(const Expr &o) {
41 this->cast<FieldExpression>()->adjoint.set(o);
42}
43
44void Expr::set_dual(const Expr &o) {
45 this->cast<FieldExpression>()->dual.set(o);
46}
47
48void Expr::set_adjoint_checkbit(const Expr &o) {
49 this->cast<FieldExpression>()->adjoint_checkbit.set(o);
50}
51
52Expr::Expr(int16 x) : Expr() {
53 expr = std::make_shared<ConstExpression>(PrimitiveType::i16, x);
54}
55
56Expr::Expr(int32 x) : Expr() {
57 expr = std::make_shared<ConstExpression>(PrimitiveType::i32, x);
58}
59
60Expr::Expr(int64 x) : Expr() {
61 expr = std::make_shared<ConstExpression>(PrimitiveType::i64, x);
62}
63
64Expr::Expr(float32 x) : Expr() {
65 expr = std::make_shared<ConstExpression>(PrimitiveType::f32, x);
66}
67
68Expr::Expr(float64 x) : Expr() {
69 expr = std::make_shared<ConstExpression>(PrimitiveType::f64, x);
70}
71
72Expr::Expr(const Identifier &id) : Expr() {
73 expr = std::make_shared<IdExpression>(id);
74}
75
76Expr expr_rand(DataType dt) {
77 return Expr::make<RandExpression>(dt);
78}
79
80Expr assume_range(const Expr &expr, const Expr &base, int low, int high) {
81 return Expr::make<RangeAssumptionExpression>(expr, base, low, high);
82}
83
84Expr loop_unique(const Expr &input, const std::vector<SNode *> &covers) {
85 return Expr::make<LoopUniqueExpression>(input, covers);
86}
87
88Expr expr_field(Expr id_expr, DataType dt) {
89 TI_ASSERT(id_expr.is<IdExpression>());
90 auto ret = Expr(
91 std::make_shared<FieldExpression>(dt, id_expr.cast<IdExpression>()->id));
92 return ret;
93}
94
95Expr expr_matrix_field(const std::vector<Expr> &fields,
96 const std::vector<int> &element_shape) {
97 return Expr::make<MatrixFieldExpression>(fields, element_shape);
98}
99
100} // namespace taichi::lang
101