1#include "taichi/program/snode_expr_utils.h"
2#include "taichi/ir/snode.h"
3#include "taichi/ir/frontend_ir.h"
4
5namespace taichi::lang {
6
7namespace {
8
9class GradInfoImpl final : public SNode::GradInfoProvider {
10 public:
11 explicit GradInfoImpl(FieldExpression *field) : field_(field) {
12 }
13
14 bool is_primal() const override {
15 return field_->snode_grad_type == SNodeGradType::kPrimal;
16 }
17
18 SNodeGradType get_snode_grad_type() const override {
19 return field_->snode_grad_type;
20 }
21
22 SNode *adjoint_snode() const override {
23 auto &adj = field_->adjoint;
24 if (adj.expr == nullptr) {
25 return nullptr;
26 }
27 return adj.snode();
28 }
29
30 SNode *dual_snode() const override {
31 auto &dual = field_->dual;
32 if (dual.expr == nullptr) {
33 return nullptr;
34 }
35 return dual.snode();
36 }
37
38 SNode *adjoint_checkbit_snode() const override {
39 auto &adjoint_checkbit = field_->adjoint_checkbit;
40 if (adjoint_checkbit.expr == nullptr) {
41 return nullptr;
42 }
43 return adjoint_checkbit.snode();
44 }
45
46 private:
47 FieldExpression *field_;
48};
49
50} // namespace
51
52void place_child(Expr *expr_arg,
53 const std::vector<int> &offset,
54 int id_in_bit_struct,
55 SNode *parent,
56 SNodeFieldMap *snode_to_exprs) {
57 if (parent->type == SNodeType::root) {
58 // never directly place to root
59 auto &ds = parent->dense(std::vector<Axis>(), {}, "");
60 place_child(expr_arg, offset, id_in_bit_struct, &ds, snode_to_exprs);
61 } else {
62 TI_ASSERT(expr_arg->is<FieldExpression>());
63 auto field = expr_arg->cast<FieldExpression>();
64 TI_ERROR_IF(field->snode != nullptr, "This variable has been placed.");
65 auto &child = parent->insert_children(SNodeType::place);
66 field->set_snode(&child);
67 if (field->name == "") {
68 child.name = field->ident.raw_name();
69 } else {
70 child.name = field->name;
71 }
72 if (field->has_ambient) {
73 field->snode->has_ambient = true;
74 field->snode->ambient_val = field->ambient_value;
75 }
76 field->snode->grad_info = std::make_unique<GradInfoImpl>(field.get());
77 (*snode_to_exprs)[field->snode] = field;
78 child.dt = field->dt;
79 child.id_in_bit_struct = id_in_bit_struct;
80 if (!offset.empty())
81 child.set_index_offsets(offset);
82 }
83}
84
85void make_lazy_place(SNode *snode,
86 SNodeFieldMap *snode_to_fields,
87 const std::function<void(std::unique_ptr<SNode> &,
88 std::vector<Expr> &)> &collect) {
89 if (snode->type == SNodeType::place)
90 return;
91 for (auto &c : snode->ch) {
92 make_lazy_place(c.get(), snode_to_fields, collect);
93 }
94 std::vector<Expr> new_places;
95 for (auto &c : snode->ch) {
96 collect(c, new_places);
97 }
98 for (auto p : new_places) {
99 place_child(&p, /*offset=*/{}, -1, snode, snode_to_fields);
100 }
101}
102
103} // namespace taichi::lang
104