1 | #include "taichi/ir/ir_builder.h" |
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/program/program.h" |
4 | |
5 | void autograd() { |
6 | /* |
7 | import taichi as ti, numpy as np |
8 | ti.init() |
9 | |
10 | n = 10 |
11 | a = ti.field(ti.f32, n, needs_grad=True) |
12 | b = ti.field(ti.f32, n, needs_grad=True) |
13 | c = ti.field(ti.f32, n, needs_grad=True) |
14 | energy = ti.field(ti.f32, [], needs_grad=True) |
15 | |
16 | @ti.kernel |
17 | def init(): |
18 | for i in range(n): |
19 | a[i] = i |
20 | b[i] = i + 1 |
21 | |
22 | @ti.kernel |
23 | def cal(): |
24 | for i in a: |
25 | c[i] += a[i] + (b[i] + b[i]) |
26 | |
27 | @ti.kernel |
28 | def support(): # this function will not appear in CHI Builder code |
29 | for i in a: |
30 | energy += c[i] |
31 | |
32 | init() |
33 | with ti.ad.Tape(energy): |
34 | cal() |
35 | support() |
36 | |
37 | print(a.grad) |
38 | print(b.grad) |
39 | print(c.to_numpy()) |
40 | */ |
41 | using namespace taichi; |
42 | using namespace lang; |
43 | |
44 | auto program = Program(Arch::x64); |
45 | |
46 | int n = 10; |
47 | program.materialize_runtime(); |
48 | auto *root = new SNode(0, SNodeType::root); |
49 | auto get_snode_grad = [&]() -> SNode * { |
50 | class GradInfoPrimal final : public SNode::GradInfoProvider { |
51 | public: |
52 | SNode *snode; |
53 | GradInfoPrimal(SNode *_snode) : snode(_snode) { |
54 | } |
55 | bool is_primal() const override { |
56 | return true; |
57 | } |
58 | SNodeGradType get_snode_grad_type() const override { |
59 | return SNodeGradType::kPrimal; |
60 | } |
61 | SNode *adjoint_snode() const override { |
62 | return snode; |
63 | } |
64 | SNode *dual_snode() const override { |
65 | return snode; |
66 | } |
67 | SNode *adjoint_checkbit_snode() const override { |
68 | return nullptr; |
69 | } |
70 | }; |
71 | class GradInfoAdjoint final : public SNode::GradInfoProvider { |
72 | public: |
73 | GradInfoAdjoint() { |
74 | } |
75 | bool is_primal() const override { |
76 | return false; |
77 | } |
78 | SNodeGradType get_snode_grad_type() const override { |
79 | return SNodeGradType::kAdjoint; |
80 | } |
81 | SNode *adjoint_snode() const override { |
82 | return nullptr; |
83 | } |
84 | SNode *dual_snode() const override { |
85 | return nullptr; |
86 | } |
87 | SNode *adjoint_checkbit_snode() const override { |
88 | return nullptr; |
89 | } |
90 | }; |
91 | |
92 | auto *snode = |
93 | &root->dense(Axis(0), n, false, "" ).insert_children(SNodeType::place); |
94 | snode->dt = PrimitiveType::f32; |
95 | snode->grad_info = std::make_unique<GradInfoPrimal>( |
96 | &root->dense(Axis(0), n, false, "" ).insert_children(SNodeType::place)); |
97 | snode->get_adjoint()->dt = PrimitiveType::f32; |
98 | snode->get_adjoint()->grad_info = std::make_unique<GradInfoAdjoint>(); |
99 | return snode; |
100 | }; |
101 | auto *a = get_snode_grad(), *b = get_snode_grad(), *c = get_snode_grad(); |
102 | program.add_snode_tree(std::unique_ptr<SNode>(root), /*compile_only=*/false); |
103 | |
104 | std::unique_ptr<Kernel> kernel_init, kernel_forward, kernel_backward, |
105 | kernel_ext; |
106 | |
107 | { |
108 | IRBuilder builder; |
109 | auto *zero = builder.get_int32(0); |
110 | auto *one = builder.get_int32(1); |
111 | auto *n_stmt = builder.get_int32(n); |
112 | auto *loop = builder.create_range_for(zero, n_stmt, 0, 4); |
113 | { |
114 | auto _ = builder.get_loop_guard(loop); |
115 | auto *i = builder.get_loop_index(loop); |
116 | builder.create_global_store(builder.create_global_ptr(a, {i}), i); |
117 | builder.create_global_store(builder.create_global_ptr(b, {i}), |
118 | builder.create_add(i, one)); |
119 | builder.create_global_store(builder.create_global_ptr(c, {i}), zero); |
120 | |
121 | builder.create_global_store( |
122 | builder.create_global_ptr(a->get_adjoint(), {i}), zero); |
123 | builder.create_global_store( |
124 | builder.create_global_ptr(b->get_adjoint(), {i}), zero); |
125 | builder.create_global_store( |
126 | builder.create_global_ptr(c->get_adjoint(), {i}), one); |
127 | } |
128 | |
129 | kernel_init = |
130 | std::make_unique<Kernel>(program, builder.extract_ir(), "init" ); |
131 | } |
132 | |
133 | auto get_kernel_cal = [&](AutodiffMode autodiff_mode) -> Kernel * { |
134 | IRBuilder builder; |
135 | auto *loop = builder.create_struct_for(a, 0, 4); |
136 | { |
137 | auto _ = builder.get_loop_guard(loop); |
138 | auto *i = builder.get_loop_index(loop); |
139 | auto *a_i = builder.create_global_load(builder.create_global_ptr(a, {i})); |
140 | auto *b_i = builder.create_global_load(builder.create_global_ptr(b, {i})); |
141 | auto *val = builder.create_add(a_i, builder.create_mul(b_i, i)); |
142 | auto *c_i = builder.create_global_ptr(c, {i}); |
143 | builder.insert( |
144 | std::make_unique<AtomicOpStmt>(AtomicOpType::add, c_i, val)); |
145 | } |
146 | |
147 | return new Kernel(program, builder.extract_ir(), "cal" , autodiff_mode); |
148 | }; |
149 | kernel_forward = std::unique_ptr<Kernel>(get_kernel_cal(AutodiffMode::kNone)); |
150 | kernel_backward = |
151 | std::unique_ptr<Kernel>(get_kernel_cal(AutodiffMode::kReverse)); |
152 | |
153 | { |
154 | IRBuilder builder; |
155 | auto *loop = builder.create_struct_for(a, 0, 4); |
156 | { |
157 | auto _ = builder.get_loop_guard(loop); |
158 | auto *i = builder.get_loop_index(loop); |
159 | |
160 | auto *ext_a = builder.create_external_ptr( |
161 | builder.create_arg_load(0, PrimitiveType::f32, true), {i}); |
162 | auto *a_grad_i = builder.create_global_load( |
163 | builder.create_global_ptr(a->get_adjoint(), {i})); |
164 | builder.create_global_store(ext_a, a_grad_i); |
165 | |
166 | auto *ext_b = builder.create_external_ptr( |
167 | builder.create_arg_load(1, PrimitiveType::f32, true), {i}); |
168 | auto *b_grad_i = builder.create_global_load( |
169 | builder.create_global_ptr(b->get_adjoint(), {i})); |
170 | builder.create_global_store(ext_b, b_grad_i); |
171 | |
172 | auto *ext_c = builder.create_external_ptr( |
173 | builder.create_arg_load(2, PrimitiveType::f32, true), {i}); |
174 | auto *c_i = builder.create_global_load(builder.create_global_ptr(c, {i})); |
175 | builder.create_global_store(ext_c, c_i); |
176 | } |
177 | |
178 | kernel_ext = std::make_unique<Kernel>(program, builder.extract_ir(), "ext" ); |
179 | kernel_ext->insert_arr_arg(get_data_type<int>(), /*total_dim=*/1, {n}); |
180 | kernel_ext->insert_arr_arg(get_data_type<int>(), /*total_dim=*/1, {n}); |
181 | kernel_ext->insert_arr_arg(get_data_type<int>(), /*total_dim=*/1, {n}); |
182 | } |
183 | |
184 | auto ctx_init = kernel_init->make_launch_context(); |
185 | auto ctx_forward = kernel_forward->make_launch_context(); |
186 | auto ctx_backward = kernel_backward->make_launch_context(); |
187 | auto ctx_ext = kernel_ext->make_launch_context(); |
188 | std::vector<float> ext_a(n), ext_b(n), ext_c(n); |
189 | ctx_ext.set_arg_external_array_with_shape(0, taichi::uint64(ext_a.data()), n, |
190 | {n}); |
191 | ctx_ext.set_arg_external_array_with_shape(1, taichi::uint64(ext_b.data()), n, |
192 | {n}); |
193 | ctx_ext.set_arg_external_array_with_shape(2, taichi::uint64(ext_c.data()), n, |
194 | {n}); |
195 | |
196 | (*kernel_init)(ctx_init); |
197 | (*kernel_forward)(ctx_forward); |
198 | (*kernel_backward)(ctx_backward); |
199 | (*kernel_ext)(ctx_ext); |
200 | for (int i = 0; i < n; i++) |
201 | std::cout << ext_a[i] << " " ; |
202 | std::cout << std::endl; |
203 | for (int i = 0; i < n; i++) |
204 | std::cout << ext_b[i] << " " ; |
205 | std::cout << std::endl; |
206 | for (int i = 0; i < n; i++) |
207 | std::cout << ext_c[i] << " " ; |
208 | std::cout << std::endl; |
209 | } |
210 | |