1#include "taichi/ir/ir_builder.h"
2#include "taichi/ir/statements.h"
3#include "taichi/program/program.h"
4
5void autograd() {
6 /*
7 import taichi as ti, numpy as np
8 ti.init()
9
10 n = 10
11 a = ti.field(ti.f32, n, needs_grad=True)
12 b = ti.field(ti.f32, n, needs_grad=True)
13 c = ti.field(ti.f32, n, needs_grad=True)
14 energy = ti.field(ti.f32, [], needs_grad=True)
15
16 @ti.kernel
17 def init():
18 for i in range(n):
19 a[i] = i
20 b[i] = i + 1
21
22 @ti.kernel
23 def cal():
24 for i in a:
25 c[i] += a[i] + (b[i] + b[i])
26
27 @ti.kernel
28 def support(): # this function will not appear in CHI Builder code
29 for i in a:
30 energy += c[i]
31
32 init()
33 with ti.ad.Tape(energy):
34 cal()
35 support()
36
37 print(a.grad)
38 print(b.grad)
39 print(c.to_numpy())
40 */
41 using namespace taichi;
42 using namespace lang;
43
44 auto program = Program(Arch::x64);
45
46 int n = 10;
47 program.materialize_runtime();
48 auto *root = new SNode(0, SNodeType::root);
49 auto get_snode_grad = [&]() -> SNode * {
50 class GradInfoPrimal final : public SNode::GradInfoProvider {
51 public:
52 SNode *snode;
53 GradInfoPrimal(SNode *_snode) : snode(_snode) {
54 }
55 bool is_primal() const override {
56 return true;
57 }
58 SNodeGradType get_snode_grad_type() const override {
59 return SNodeGradType::kPrimal;
60 }
61 SNode *adjoint_snode() const override {
62 return snode;
63 }
64 SNode *dual_snode() const override {
65 return snode;
66 }
67 SNode *adjoint_checkbit_snode() const override {
68 return nullptr;
69 }
70 };
71 class GradInfoAdjoint final : public SNode::GradInfoProvider {
72 public:
73 GradInfoAdjoint() {
74 }
75 bool is_primal() const override {
76 return false;
77 }
78 SNodeGradType get_snode_grad_type() const override {
79 return SNodeGradType::kAdjoint;
80 }
81 SNode *adjoint_snode() const override {
82 return nullptr;
83 }
84 SNode *dual_snode() const override {
85 return nullptr;
86 }
87 SNode *adjoint_checkbit_snode() const override {
88 return nullptr;
89 }
90 };
91
92 auto *snode =
93 &root->dense(Axis(0), n, false, "").insert_children(SNodeType::place);
94 snode->dt = PrimitiveType::f32;
95 snode->grad_info = std::make_unique<GradInfoPrimal>(
96 &root->dense(Axis(0), n, false, "").insert_children(SNodeType::place));
97 snode->get_adjoint()->dt = PrimitiveType::f32;
98 snode->get_adjoint()->grad_info = std::make_unique<GradInfoAdjoint>();
99 return snode;
100 };
101 auto *a = get_snode_grad(), *b = get_snode_grad(), *c = get_snode_grad();
102 program.add_snode_tree(std::unique_ptr<SNode>(root), /*compile_only=*/false);
103
104 std::unique_ptr<Kernel> kernel_init, kernel_forward, kernel_backward,
105 kernel_ext;
106
107 {
108 IRBuilder builder;
109 auto *zero = builder.get_int32(0);
110 auto *one = builder.get_int32(1);
111 auto *n_stmt = builder.get_int32(n);
112 auto *loop = builder.create_range_for(zero, n_stmt, 0, 4);
113 {
114 auto _ = builder.get_loop_guard(loop);
115 auto *i = builder.get_loop_index(loop);
116 builder.create_global_store(builder.create_global_ptr(a, {i}), i);
117 builder.create_global_store(builder.create_global_ptr(b, {i}),
118 builder.create_add(i, one));
119 builder.create_global_store(builder.create_global_ptr(c, {i}), zero);
120
121 builder.create_global_store(
122 builder.create_global_ptr(a->get_adjoint(), {i}), zero);
123 builder.create_global_store(
124 builder.create_global_ptr(b->get_adjoint(), {i}), zero);
125 builder.create_global_store(
126 builder.create_global_ptr(c->get_adjoint(), {i}), one);
127 }
128
129 kernel_init =
130 std::make_unique<Kernel>(program, builder.extract_ir(), "init");
131 }
132
133 auto get_kernel_cal = [&](AutodiffMode autodiff_mode) -> Kernel * {
134 IRBuilder builder;
135 auto *loop = builder.create_struct_for(a, 0, 4);
136 {
137 auto _ = builder.get_loop_guard(loop);
138 auto *i = builder.get_loop_index(loop);
139 auto *a_i = builder.create_global_load(builder.create_global_ptr(a, {i}));
140 auto *b_i = builder.create_global_load(builder.create_global_ptr(b, {i}));
141 auto *val = builder.create_add(a_i, builder.create_mul(b_i, i));
142 auto *c_i = builder.create_global_ptr(c, {i});
143 builder.insert(
144 std::make_unique<AtomicOpStmt>(AtomicOpType::add, c_i, val));
145 }
146
147 return new Kernel(program, builder.extract_ir(), "cal", autodiff_mode);
148 };
149 kernel_forward = std::unique_ptr<Kernel>(get_kernel_cal(AutodiffMode::kNone));
150 kernel_backward =
151 std::unique_ptr<Kernel>(get_kernel_cal(AutodiffMode::kReverse));
152
153 {
154 IRBuilder builder;
155 auto *loop = builder.create_struct_for(a, 0, 4);
156 {
157 auto _ = builder.get_loop_guard(loop);
158 auto *i = builder.get_loop_index(loop);
159
160 auto *ext_a = builder.create_external_ptr(
161 builder.create_arg_load(0, PrimitiveType::f32, true), {i});
162 auto *a_grad_i = builder.create_global_load(
163 builder.create_global_ptr(a->get_adjoint(), {i}));
164 builder.create_global_store(ext_a, a_grad_i);
165
166 auto *ext_b = builder.create_external_ptr(
167 builder.create_arg_load(1, PrimitiveType::f32, true), {i});
168 auto *b_grad_i = builder.create_global_load(
169 builder.create_global_ptr(b->get_adjoint(), {i}));
170 builder.create_global_store(ext_b, b_grad_i);
171
172 auto *ext_c = builder.create_external_ptr(
173 builder.create_arg_load(2, PrimitiveType::f32, true), {i});
174 auto *c_i = builder.create_global_load(builder.create_global_ptr(c, {i}));
175 builder.create_global_store(ext_c, c_i);
176 }
177
178 kernel_ext = std::make_unique<Kernel>(program, builder.extract_ir(), "ext");
179 kernel_ext->insert_arr_arg(get_data_type<int>(), /*total_dim=*/1, {n});
180 kernel_ext->insert_arr_arg(get_data_type<int>(), /*total_dim=*/1, {n});
181 kernel_ext->insert_arr_arg(get_data_type<int>(), /*total_dim=*/1, {n});
182 }
183
184 auto ctx_init = kernel_init->make_launch_context();
185 auto ctx_forward = kernel_forward->make_launch_context();
186 auto ctx_backward = kernel_backward->make_launch_context();
187 auto ctx_ext = kernel_ext->make_launch_context();
188 std::vector<float> ext_a(n), ext_b(n), ext_c(n);
189 ctx_ext.set_arg_external_array_with_shape(0, taichi::uint64(ext_a.data()), n,
190 {n});
191 ctx_ext.set_arg_external_array_with_shape(1, taichi::uint64(ext_b.data()), n,
192 {n});
193 ctx_ext.set_arg_external_array_with_shape(2, taichi::uint64(ext_c.data()), n,
194 {n});
195
196 (*kernel_init)(ctx_init);
197 (*kernel_forward)(ctx_forward);
198 (*kernel_backward)(ctx_backward);
199 (*kernel_ext)(ctx_ext);
200 for (int i = 0; i < n; i++)
201 std::cout << ext_a[i] << " ";
202 std::cout << std::endl;
203 for (int i = 0; i < n; i++)
204 std::cout << ext_b[i] << " ";
205 std::cout << std::endl;
206 for (int i = 0; i < n; i++)
207 std::cout << ext_c[i] << " ";
208 std::cout << std::endl;
209}
210