1 | #include "taichi/program/snode_expr_utils.h" |
2 | #include "taichi/ir/snode.h" |
3 | #include "taichi/ir/frontend_ir.h" |
4 | |
5 | namespace taichi::lang { |
6 | |
7 | namespace { |
8 | |
9 | class GradInfoImpl final : public SNode::GradInfoProvider { |
10 | public: |
11 | explicit GradInfoImpl(FieldExpression *field) : field_(field) { |
12 | } |
13 | |
14 | bool is_primal() const override { |
15 | return field_->snode_grad_type == SNodeGradType::kPrimal; |
16 | } |
17 | |
18 | SNodeGradType get_snode_grad_type() const override { |
19 | return field_->snode_grad_type; |
20 | } |
21 | |
22 | SNode *adjoint_snode() const override { |
23 | auto &adj = field_->adjoint; |
24 | if (adj.expr == nullptr) { |
25 | return nullptr; |
26 | } |
27 | return adj.snode(); |
28 | } |
29 | |
30 | SNode *dual_snode() const override { |
31 | auto &dual = field_->dual; |
32 | if (dual.expr == nullptr) { |
33 | return nullptr; |
34 | } |
35 | return dual.snode(); |
36 | } |
37 | |
38 | SNode *adjoint_checkbit_snode() const override { |
39 | auto &adjoint_checkbit = field_->adjoint_checkbit; |
40 | if (adjoint_checkbit.expr == nullptr) { |
41 | return nullptr; |
42 | } |
43 | return adjoint_checkbit.snode(); |
44 | } |
45 | |
46 | private: |
47 | FieldExpression *field_; |
48 | }; |
49 | |
50 | } // namespace |
51 | |
52 | void place_child(Expr *expr_arg, |
53 | const std::vector<int> &offset, |
54 | int id_in_bit_struct, |
55 | SNode *parent, |
56 | SNodeFieldMap *snode_to_exprs) { |
57 | if (parent->type == SNodeType::root) { |
58 | // never directly place to root |
59 | auto &ds = parent->dense(std::vector<Axis>(), {}, "" ); |
60 | place_child(expr_arg, offset, id_in_bit_struct, &ds, snode_to_exprs); |
61 | } else { |
62 | TI_ASSERT(expr_arg->is<FieldExpression>()); |
63 | auto field = expr_arg->cast<FieldExpression>(); |
64 | TI_ERROR_IF(field->snode != nullptr, "This variable has been placed." ); |
65 | auto &child = parent->insert_children(SNodeType::place); |
66 | field->set_snode(&child); |
67 | if (field->name == "" ) { |
68 | child.name = field->ident.raw_name(); |
69 | } else { |
70 | child.name = field->name; |
71 | } |
72 | if (field->has_ambient) { |
73 | field->snode->has_ambient = true; |
74 | field->snode->ambient_val = field->ambient_value; |
75 | } |
76 | field->snode->grad_info = std::make_unique<GradInfoImpl>(field.get()); |
77 | (*snode_to_exprs)[field->snode] = field; |
78 | child.dt = field->dt; |
79 | child.id_in_bit_struct = id_in_bit_struct; |
80 | if (!offset.empty()) |
81 | child.set_index_offsets(offset); |
82 | } |
83 | } |
84 | |
85 | void make_lazy_place(SNode *snode, |
86 | SNodeFieldMap *snode_to_fields, |
87 | const std::function<void(std::unique_ptr<SNode> &, |
88 | std::vector<Expr> &)> &collect) { |
89 | if (snode->type == SNodeType::place) |
90 | return; |
91 | for (auto &c : snode->ch) { |
92 | make_lazy_place(c.get(), snode_to_fields, collect); |
93 | } |
94 | std::vector<Expr> new_places; |
95 | for (auto &c : snode->ch) { |
96 | collect(c, new_places); |
97 | } |
98 | for (auto p : new_places) { |
99 | place_child(&p, /*offset=*/{}, -1, snode, snode_to_fields); |
100 | } |
101 | } |
102 | |
103 | } // namespace taichi::lang |
104 | |