1#include "taichi/ir/analysis.h"
2#include "taichi/ir/ir.h"
3#include "taichi/ir/statements.h"
4#include "taichi/ir/transforms.h"
5#include "taichi/ir/visitors.h"
6#include "taichi/transforms/utils.h"
7
8#include <typeinfo>
9#include <algorithm>
10
11namespace taichi::lang {
12
13class IndependentBlockMetaData {
14 public:
15 bool is_ib = true;
16 bool is_smallest_ib = true;
17};
18
19class NonLinearOps {
20 public:
21 inline static const std::set<TernaryOpType> ternary_collections{
22 TernaryOpType::select};
23 inline static const std::set<UnaryOpType> unary_collections{
24 UnaryOpType::abs, UnaryOpType::sin, UnaryOpType::cos,
25 UnaryOpType::tanh, UnaryOpType::asin, UnaryOpType::acos,
26 UnaryOpType::exp, UnaryOpType::log, UnaryOpType::sqrt};
27 inline static const std::set<BinaryOpType> binary_collections{
28 BinaryOpType::mul, BinaryOpType::div, BinaryOpType::atan2,
29 BinaryOpType::pow};
30};
31
32class IndependentBlocksJudger : public BasicStmtVisitor {
33 public:
34 using BasicStmtVisitor::visit;
35
36 void visit(LocalLoadStmt *stmt) override {
37 TI_ASSERT(stmt->src->is<AllocaStmt>() || stmt->src->is<MatrixPtrStmt>());
38 touched_allocas_.insert(stmt->src);
39 }
40
41 void visit(LocalStoreStmt *stmt) override {
42 TI_ASSERT(stmt->dest->is<AllocaStmt>() || stmt->dest->is<MatrixPtrStmt>());
43 touched_allocas_.insert(stmt->dest);
44 }
45
46 void visit(AtomicOpStmt *stmt) override {
47 // We don't need to check the global atomics inside the range for-loops
48 // because
49 // 1. If the range for-loop is innermost, they will be captured by
50 // MakeAdjoint anyway
51 // 2. If the range for-loop is not innermost, they will be processed by
52 // another IndependentBlocksJudger
53 if (is_inside_loop_)
54 return;
55 TI_ASSERT(stmt->dest->is<GlobalPtrStmt>());
56 if (stmt->dest->as<GlobalPtrStmt>()->snode->has_adjoint()) {
57 qualified_glb_operations_ = true;
58 }
59 }
60
61 void visit(GlobalLoadStmt *stmt) override {
62 // We don't need to check the global load inside the range for-loops
63 // because
64 // 1. If the range for-loop is innermost, they will be captured by
65 // MakeAdjoint anyway
66 // 2. If the range for-loop is not innermost, they will be processed by
67 // another IndependentBlocksJudger
68 if (is_inside_loop_)
69 return;
70 // TODO: handle external ptr stmt after autodiff supporting ndarray
71 if (stmt->src->is<GlobalPtrStmt>() &&
72 stmt->src->as<GlobalPtrStmt>()->snode->has_adjoint()) {
73 qualified_glb_operations_ = true;
74 }
75 }
76
77 void visit(RangeForStmt *stmt) override {
78 inner_most_loop_ = false;
79 is_inside_loop_ = true;
80 stmt->body->accept(this);
81 is_inside_loop_ = false;
82 }
83
84 static void run(IRNode *root, IndependentBlockMetaData &ib_meta_data) {
85 IndependentBlocksJudger Judger;
86 Block *block = root->as<Block>();
87 root->accept(&Judger);
88 std::set<Block *> outside_blocks;
89 // Collect all parent blocks (i.e. outside blocks) of the current block for
90 // local load/store stmt checks
91 for (auto b = block->parent_block(); b; b = b->parent_block()) {
92 if (b)
93 outside_blocks.insert(b);
94 }
95 for (const auto &alloca : Judger.touched_allocas_) {
96 // Test if the alloca belongs to the current block
97 if (outside_blocks.find(alloca->parent) != outside_blocks.end()) {
98 // This block is not an IB since it loads/modifies outside variables
99 ib_meta_data.is_ib = false;
100 }
101 }
102
103 // To judge whether a block is an IB
104 // - No local load/store to allocas *outside* itself has been strictly
105 // enforced
106
107 // To judge whether a block is a smallest IB
108 // - If the #1 is satisfied, either an inner most loop or a block without
109 // global atomics / global load is an IB
110 ib_meta_data.is_smallest_ib =
111 ib_meta_data.is_ib &&
112 (Judger.qualified_glb_operations_ || Judger.inner_most_loop_);
113 }
114
115 private:
116 std::set<Stmt *> touched_allocas_;
117 bool qualified_glb_operations_ = false;
118 bool inner_most_loop_ = true;
119 bool is_inside_loop_ = false;
120};
121
122// Remove the duplicated IBs, remove blocks who are others' children because
123// each block should only be processed once
124class DuplicateIndependentBlocksCleaner : public BasicStmtVisitor {
125 public:
126 using BasicStmtVisitor::visit;
127
128 void check_children_ib(Block *target_block) {
129 // Remove the block if it is the child of the block being visiting
130 if (independent_blocks_cleaned_.find(target_block) !=
131 independent_blocks_cleaned_.end()) {
132 independent_blocks_cleaned_.erase(target_block);
133 }
134 }
135
136 void visit(StructForStmt *stmt) override {
137 check_children_ib(stmt->body.get());
138 stmt->body->accept(this);
139 }
140 void visit(RangeForStmt *stmt) override {
141 check_children_ib(stmt->body.get());
142 stmt->body->accept(this);
143 }
144
145 static std::set<Block *> run(
146 const std::vector<std::pair<int, Block *>> &raw_IBs) {
147 DuplicateIndependentBlocksCleaner cleaner;
148 // Remove duplicate IBs
149 for (auto const &item : raw_IBs) {
150 cleaner.independent_blocks_cleaned_.insert(item.second);
151 }
152 // No clean is needed if only one IB exists
153 if (cleaner.independent_blocks_cleaned_.size() > 1) {
154 // Check from the block with smallest depth, ensure no duplicate visit
155 // happens
156 for (const auto &block : cleaner.independent_blocks_cleaned_) {
157 block->accept(&cleaner);
158 }
159 }
160 return cleaner.independent_blocks_cleaned_;
161 }
162
163 private:
164 std::set<Block *> independent_blocks_cleaned_;
165};
166
167// Do automatic differentiation pass in the reverse order (reverse-mode AD)
168
169// Independent Block (IB): blocks (i.e. loop bodies) whose iterations are
170// independent of previous iterations and outer scopes. IBs are where the
171// MakeAdjoint pass happens. IBs may contain if's and for-loops.
172
173// IBs are not always the inner-most for loop body. If the inner-most for-loop
174// has iterations that carry iteration-dependent variables, it's not an IB.
175
176// Clearly the outermost level is always an IB, but we want IBs to be as small
177// as possible. Outside IBs, we just need to reverse the for-loop orders.
178
179// Figure out the IBs.
180class IdentifyIndependentBlocks : public BasicStmtVisitor {
181 public:
182 using BasicStmtVisitor::visit;
183
184 void visit(WhileStmt *stmt) override {
185 TI_ERROR("WhileStmt is not supported in AutoDiff.");
186 }
187
188 void visit(ContinueStmt *stmt) override {
189 TI_ERROR("ContinueStmt is not supported in AutoDiff.");
190 }
191
192 void visit(WhileControlStmt *stmt) override {
193 TI_ERROR("WhileControlStmt (break) is not supported in AutoDiff.");
194 }
195
196 void visit_loop_body(Block *block) {
197 auto ib_meta_data = IndependentBlockMetaData();
198 // An IB has no local load/store to allocas *outside* itself
199 // Note:
200 // - Local atomics should have been demoted before this pass.
201 // - It is OK for an IB to have more than two for loops.
202 // - No global load/atomics operations to the global variables which
203 // require gradient
204 if (block->statements.empty()) {
205 // A empty block shoud be a smallest IB
206 ib_meta_data.is_ib = true;
207 ib_meta_data.is_smallest_ib = true;
208 } else {
209 IndependentBlocksJudger::run(block, ib_meta_data);
210 }
211
212 if (ib_meta_data.is_smallest_ib) {
213 independent_blocks_.push_back({depth_, block});
214 } else if (ib_meta_data.is_ib) {
215 current_ib_ = block;
216 block->accept(this);
217 } else {
218 if (depth_ <= 1) {
219 TI_ASSERT(depth_ == 1);
220 // The top level block is already not an IB, store it
221 independent_blocks_.push_back({depth_ - 1, block});
222 } else {
223 independent_blocks_.push_back({depth_ - 1, block->parent_block()});
224 }
225 }
226 }
227
228 void visit(StructForStmt *stmt) override {
229 TI_ASSERT(depth_ == 0);
230 depth_++;
231 current_ib_ = stmt->body.get();
232 visit_loop_body(stmt->body.get());
233 depth_--;
234 }
235
236 void visit(RangeForStmt *stmt) override {
237 if (depth_ == 0) {
238 current_ib_ = stmt->body.get();
239 }
240 depth_++;
241 visit_loop_body(stmt->body.get());
242 depth_--;
243 }
244
245 static std::set<Block *> run(IRNode *root) {
246 IdentifyIndependentBlocks pass;
247 Block *block = root->as<Block>();
248 bool has_for = false;
249 for (auto &s : block->statements) {
250 if (s->is<StructForStmt>() || s->is<RangeForStmt>()) {
251 has_for = true;
252 }
253 }
254 if (!has_for) {
255 // The whole block is an IB
256 pass.independent_blocks_.push_back({0, block});
257 } else {
258 root->accept(&pass);
259 }
260 // Sort the IBs by their depth from shallow to deep
261 std::sort(pass.independent_blocks_.begin(), pass.independent_blocks_.end(),
262 [](const std::pair<int, Block *> &a,
263 const std::pair<int, Block *> &b) -> bool {
264 return a.first < b.first;
265 });
266
267 TI_ASSERT(!pass.independent_blocks_.empty());
268 return DuplicateIndependentBlocksCleaner::run(pass.independent_blocks_);
269 }
270
271 private:
272 std::vector<std::pair<int, Block *>> independent_blocks_;
273 int depth_{0};
274 Block *current_ib_{nullptr};
275};
276
277// Note that SSA does not mean the instruction will be executed at most once.
278// For instructions that may be executed multiple times, we treat them as a
279// mutable local variables.
280class PromoteSSA2LocalVar : public BasicStmtVisitor {
281 using BasicStmtVisitor::visit;
282
283 explicit PromoteSSA2LocalVar(Block *block) {
284 alloca_block_ = block;
285 invoke_default_visitor = true;
286 execute_once_ = true;
287 }
288
289 void visit(Stmt *stmt) override {
290 if (execute_once_)
291 return;
292 if (!(stmt->is<UnaryOpStmt>() || stmt->is<BinaryOpStmt>() ||
293 stmt->is<TernaryOpStmt>() || stmt->is<GlobalLoadStmt>() ||
294 stmt->is<AllocaStmt>())) {
295 // TODO: this list may be incomplete
296 return;
297 }
298
299 if (stmt->is<AllocaStmt>()) {
300 // Create a new alloc at the top of an ib to replace the old alloca
301 auto alloc = Stmt::make<AllocaStmt>(stmt->ret_type);
302 auto alloc_ptr = alloc.get();
303 TI_ASSERT(alloca_block_);
304 alloca_block_->insert(std::move(alloc), 0);
305 // Replace all the usages of the old stmt with that of the new one
306 irpass::replace_all_usages_with(stmt->parent, stmt, alloc_ptr);
307
308 // Replace the old alloca with a local store
309 // and it will be replaced by a AdStackPushStmt in the following
310 // ReplaceLocalVarWithStacks pass
311 auto dtype = stmt->ret_type;
312 auto zero =
313 stmt->insert_after_me(Stmt::make<ConstStmt>(TypedConstant(dtype, 0)));
314 zero->insert_after_me(Stmt::make<LocalStoreStmt>(alloc_ptr, zero));
315 // Remove the old stmt
316 stmt->parent->erase(stmt);
317 } else {
318 // Create a alloc
319 auto alloc = Stmt::make<AllocaStmt>(stmt->ret_type);
320 auto alloc_ptr = alloc.get();
321 TI_ASSERT(alloca_block_);
322 alloca_block_->insert(std::move(alloc), 0);
323 auto load = stmt->insert_after_me(Stmt::make<LocalLoadStmt>(alloc_ptr));
324 irpass::replace_all_usages_with(stmt->parent, stmt, load);
325 // Create the load first so that the operand of the store won't get
326 // replaced
327 stmt->insert_after_me(Stmt::make<LocalStoreStmt>(alloc_ptr, stmt));
328 }
329 }
330
331 void visit(RangeForStmt *stmt) override {
332 auto old_execute_once = execute_once_;
333 execute_once_ = false; // loop body may be executed many times
334 stmt->body->accept(this);
335 execute_once_ = old_execute_once;
336 }
337
338 private:
339 Block *alloca_block_{nullptr};
340 bool execute_once_;
341
342 public:
343 static void run(Block *block) {
344 PromoteSSA2LocalVar pass(block);
345 block->accept(&pass);
346 }
347};
348
349class AdStackAllocaJudger : public BasicStmtVisitor {
350 public:
351 using BasicStmtVisitor::visit;
352 // Find the usage of the stmt recursively along the LocalLoadStmt
353 void visit(LocalLoadStmt *stmt) override {
354 if (stmt->src == target_alloca_) {
355 local_loaded_ = true;
356 target_alloca_ = stmt;
357 }
358 }
359
360 // Check if there is a LocalLoadStmt - LocalStoreStmt cycle for an alloca
361 // Check if the alloca is load only
362 void visit(LocalStoreStmt *stmt) override {
363 if (stmt->dest == target_alloca_backup_)
364 load_only_ = false;
365 if (local_loaded_ && stmt->dest == target_alloca_backup_) {
366 is_stack_needed_ = true;
367 }
368 }
369
370 // Check if the alloca is load only
371 void visit(AtomicOpStmt *stmt) override {
372 if (stmt->dest == target_alloca_backup_)
373 load_only_ = false;
374 }
375
376 // The stack is needed if the alloc serves as the index of any global
377 // variables
378 void visit(GlobalPtrStmt *stmt) override {
379 if (is_stack_needed_)
380 return;
381 for (const auto &index : stmt->indices) {
382 if (index == target_alloca_)
383 is_stack_needed_ = true;
384 }
385 }
386
387 // Check whether the target stmt is used by the UnaryOpStmts who requires the
388 // ad stack
389 void visit(UnaryOpStmt *stmt) override {
390 if (is_stack_needed_)
391 return;
392 if (NonLinearOps::unary_collections.find(stmt->op_type) !=
393 NonLinearOps::unary_collections.end()) {
394 if (stmt->operand == target_alloca_)
395 is_stack_needed_ = true;
396 }
397 }
398
399 // Check whether the target stmt is used by the BinaryOpStmts who requires the
400 // ad stack
401 void visit(BinaryOpStmt *stmt) override {
402 if (is_stack_needed_)
403 return;
404 if (NonLinearOps::binary_collections.find(stmt->op_type) !=
405 NonLinearOps::binary_collections.end()) {
406 if (stmt->lhs == target_alloca_ || stmt->rhs == target_alloca_)
407 is_stack_needed_ = true;
408 }
409 }
410
411 // Check whether the target stmt is used by the TernaryOpStmts who requires
412 // the ad stack
413 void visit(TernaryOpStmt *stmt) override {
414 if (is_stack_needed_)
415 return;
416 if (NonLinearOps::ternary_collections.find(stmt->op_type) !=
417 NonLinearOps::ternary_collections.end()) {
418 if (stmt->op1 == target_alloca_ || stmt->op2 == target_alloca_ ||
419 stmt->op3 == target_alloca_)
420 is_stack_needed_ = true;
421 }
422 }
423
424 // Check whether the target serves as the condition of a if stmt
425 void visit(IfStmt *stmt) override {
426 if (is_stack_needed_)
427 return;
428
429 if (stmt->cond == target_alloca_) {
430 is_stack_needed_ = true;
431 return;
432 }
433
434 if (stmt->true_statements)
435 stmt->true_statements->accept(this);
436 if (stmt->false_statements)
437 stmt->false_statements->accept(this);
438 }
439
440 static bool run(AllocaStmt *target_alloca) {
441 AdStackAllocaJudger judger;
442 judger.target_alloca_ = target_alloca;
443 judger.target_alloca_backup_ = target_alloca;
444 target_alloca->parent->accept(&judger);
445 return (!judger.load_only_) && judger.is_stack_needed_;
446 }
447
448 private:
449 Stmt *target_alloca_;
450 Stmt *target_alloca_backup_;
451 bool is_stack_needed_ = false;
452 bool local_loaded_ = false;
453 bool load_only_ = true;
454};
455
456class ReplaceLocalVarWithStacks : public BasicStmtVisitor {
457 public:
458 using BasicStmtVisitor::visit;
459 int ad_stack_size;
460 explicit ReplaceLocalVarWithStacks(int ad_stack_size)
461 : ad_stack_size(ad_stack_size) {
462 }
463
464 void visit(AllocaStmt *alloc) override {
465 bool is_stack_needed = AdStackAllocaJudger::run(alloc);
466 if (is_stack_needed) {
467 auto dtype = alloc->ret_type;
468 auto stack_alloca = Stmt::make<AdStackAllocaStmt>(dtype, ad_stack_size);
469 auto stack_alloca_ptr = stack_alloca.get();
470
471 alloc->replace_with(VecStatement(std::move(stack_alloca)));
472
473 // Note that unlike AllocaStmt, AdStackAllocaStmt does NOT have an 0 as
474 // initial value. Therefore here we push an initial 0 value.
475 auto zero = stack_alloca_ptr->insert_after_me(
476 Stmt::make<ConstStmt>(TypedConstant(dtype, 0)));
477 zero->insert_after_me(
478 Stmt::make<AdStackPushStmt>(stack_alloca_ptr, zero));
479 }
480 }
481
482 void visit(LocalLoadStmt *stmt) override {
483 if (stmt->src->is<AdStackAllocaStmt>())
484 stmt->replace_with(Stmt::make<AdStackLoadTopStmt>(stmt->src));
485 }
486
487 void visit(LocalStoreStmt *stmt) override {
488 if (stmt->dest->is<AdStackAllocaStmt>())
489 stmt->replace_with(Stmt::make<AdStackPushStmt>(stmt->dest, stmt->val));
490 }
491};
492
493class ReverseOuterLoops : public BasicStmtVisitor {
494 using BasicStmtVisitor::visit;
495
496 private:
497 explicit ReverseOuterLoops(const std::set<Block *> &IB)
498 : loop_depth_(0), ib_(IB) {
499 }
500
501 bool is_ib(Block *block) const {
502 return std::find(ib_.begin(), ib_.end(), block) != ib_.end();
503 }
504
505 void visit(StructForStmt *stmt) override {
506 loop_depth_ += 1;
507 if (!is_ib(stmt->body.get()))
508 stmt->body->accept(this);
509 loop_depth_ -= 1;
510 }
511
512 void visit(RangeForStmt *stmt) override {
513 if (loop_depth_ >= 1) {
514 stmt->reversed = !stmt->reversed;
515 }
516 loop_depth_ += 1;
517 if (!is_ib(stmt->body.get()))
518 stmt->body->accept(this);
519 loop_depth_ -= 1;
520 }
521
522 int loop_depth_;
523 std::set<Block *> ib_;
524
525 public:
526 static void run(IRNode *root, const std::set<Block *> &IB) {
527 ReverseOuterLoops pass(IB);
528 root->accept(&pass);
529 }
530};
531
532// Base class for both reverse (make adjoint) and forward (make dual) mode
533class ADTransform : public IRVisitor {
534 protected:
535 Stmt *constant(float32 x) {
536 return insert<ConstStmt>(TypedConstant(x));
537 }
538
539 // utils
540 Stmt *sgn(Stmt *inp) {
541 return insert<UnaryOpStmt>(UnaryOpType::sgn, load(inp));
542 }
543
544 // utils
545 Stmt *negate(Stmt *inp) {
546 return insert<UnaryOpStmt>(UnaryOpType::neg, load(inp));
547 }
548
549 Stmt *sqrt(Stmt *inp) {
550 return insert<UnaryOpStmt>(UnaryOpType::sqrt, load(inp));
551 }
552
553 Stmt *rsqrt(Stmt *inp) {
554 return insert<UnaryOpStmt>(UnaryOpType::rsqrt, load(inp));
555 }
556
557 Stmt *mul(Stmt *op1, Stmt *op2) {
558 return insert<BinaryOpStmt>(BinaryOpType::mul, load(op1), load(op2));
559 }
560
561 Stmt *sqr(Stmt *op1) {
562 return mul(op1, op1);
563 }
564
565 Stmt *add(Stmt *op1, Stmt *op2) {
566 return insert<BinaryOpStmt>(BinaryOpType::add, load(op1), load(op2));
567 }
568
569 Stmt *cmp_lt(Stmt *op1, Stmt *op2) {
570 return insert<BinaryOpStmt>(BinaryOpType::cmp_lt, load(op1), load(op2));
571 }
572
573 Stmt *sub(Stmt *op1, Stmt *op2) {
574 return insert<BinaryOpStmt>(BinaryOpType::sub, load(op1), load(op2));
575 }
576
577 Stmt *div(Stmt *op1, Stmt *op2) {
578 return insert<BinaryOpStmt>(BinaryOpType::div, load(op1), load(op2));
579 }
580
581 Stmt *sel(Stmt *op1, Stmt *op2, Stmt *op3) {
582 return insert<TernaryOpStmt>(TernaryOpType::select, load(op1), load(op2),
583 load(op3));
584 }
585
586 Stmt *cos(Stmt *op1) {
587 return insert<UnaryOpStmt>(UnaryOpType::cos, load(op1));
588 }
589
590 Stmt *sin(Stmt *op1) {
591 return insert<UnaryOpStmt>(UnaryOpType::sin, load(op1));
592 }
593
594 Stmt *log(Stmt *op1) {
595 return insert<UnaryOpStmt>(UnaryOpType::log, load(op1));
596 }
597
598 Stmt *pow(Stmt *op1, Stmt *op2) {
599 return insert<BinaryOpStmt>(BinaryOpType::pow, load(op1), load(op2));
600 }
601
602 public:
603 virtual Stmt *insert_grad_stmt(std::unique_ptr<Stmt> &&stmt) = 0;
604
605 template <typename T, typename... Args>
606 Stmt *insert(Args &&...args) {
607 return insert_grad_stmt(Stmt::make<T>(args...));
608 }
609
610 void visit(AllocaStmt *alloca) override {
611 // do nothing.
612 }
613
614 void visit(AdStackAllocaStmt *alloca) override {
615 // do nothing.
616 }
617
618 void visit(ArgLoadStmt *stmt) override {
619 // do nothing.
620 }
621
622 void visit(LoopIndexStmt *stmt) override {
623 // do nothing.
624 }
625
626 void visit(MatrixPtrStmt *stmt) override {
627 // do nothing.
628 }
629
630 void visit(PrintStmt *print_stmt) override {
631 // do nothing
632 }
633
634 void visit(ConstStmt *const_stmt) override {
635 // do nothing
636 }
637
638 void visit(WhileControlStmt *stmt) override {
639 TI_NOT_IMPLEMENTED
640 }
641
642 void visit(ContinueStmt *stmt) override {
643 TI_NOT_IMPLEMENTED;
644 }
645
646 void visit(WhileStmt *stmt) override {
647 TI_NOT_IMPLEMENTED
648 }
649
650 void visit(GlobalPtrStmt *stmt) override {
651 // do nothing
652 }
653
654 Stmt *load(Stmt *alloc) {
655 TI_ASSERT(alloc != nullptr);
656 if (alloc->is<AllocaStmt>()) {
657 return insert<LocalLoadStmt>(alloc);
658 } else {
659 // non alloca
660 return alloc;
661 }
662 }
663
664 bool gradients_stopped(GlobalLoadStmt *stmt, SNode *snode) {
665 for (auto block = stmt->parent; block; block = block->parent_block()) {
666 for (auto s : block->stop_gradients) {
667 if (s == snode) {
668 return true;
669 }
670 }
671 }
672 return false;
673 }
674
675 void visit(AssertStmt *stmt) override {
676 // do nothing
677 }
678
679 void visit(RangeAssumptionStmt *stmt) override {
680 // do nothing
681 }
682
683 void visit(LinearizeStmt *stmt) override {
684 // do nothing
685 }
686
687 void visit(IntegerOffsetStmt *stmt) override {
688 // do nothing
689 }
690
691 void visit(RandStmt *stmt) override {
692 TI_ERROR("RandStmt not supported in AutoDiff for now.");
693 }
694};
695
696// Generate the adjoint version of an independent block
697class MakeAdjoint : public ADTransform {
698 public:
699 using ADTransform::visit;
700 Block *current_block;
701 Block *alloca_block;
702 // Backup the forward pass (the forward pass might be modified during the
703 // MakeAdjoint) for search whether a GlobalLoadStmt is inside a for-loop when
704 // allocating adjoint (see the function `adjoint`) Should be stored
705 // 1. Before entering a for-loop body
706 // 2. Before entering a if-stmt
707 // Should be restored after processing every statement in the two cases above
708 Block *forward_backup;
709 std::map<Stmt *, Stmt *> adjoint_stmt;
710
711 explicit MakeAdjoint(Block *block) {
712 current_block = nullptr;
713 alloca_block = block;
714 forward_backup = block;
715 }
716
717 static void run(Block *block) {
718 auto p = MakeAdjoint(block);
719 block->accept(&p);
720 }
721
722 // TODO: current block might not be the right block to insert adjoint
723 // instructions!
724 void visit(Block *block) override {
725 std::vector<Stmt *> statements;
726 // always make a copy since the list can be modified.
727 for (auto &stmt : block->statements) {
728 statements.push_back(stmt.get());
729 }
730 std::reverse(statements.begin(), statements.end()); // reverse-mode AD...
731 for (auto stmt : statements) {
732 current_block = block;
733 stmt->accept(this);
734 }
735 }
736
737 Stmt *insert_grad_stmt(std::unique_ptr<Stmt> &&stmt) override {
738 auto ptr = stmt.get();
739 current_block->insert(std::move(stmt), -1);
740 return ptr;
741 }
742
743 // Accumulate [value] to the adjoint of [primal]
744 void accumulate(Stmt *primal, Stmt *value) {
745 auto alloca_ = adjoint(primal);
746 if (!alloca_ || alloca_->is<ConstStmt>())
747 return; // primal may be int variable
748 if (alloca_->is<AdStackAllocaStmt>()) {
749 auto alloca = alloca_->cast<AdStackAllocaStmt>();
750 if (is_real(alloca->ret_type)) {
751 insert<AdStackAccAdjointStmt>(alloca, load(value));
752 }
753 } else {
754 TI_ASSERT(alloca_->is<AllocaStmt>());
755 auto alloca = alloca_->as<AllocaStmt>();
756 auto local_load = insert<LocalLoadStmt>(alloca);
757 insert<LocalStoreStmt>(alloca, add(local_load, value));
758 }
759 }
760
761 Stmt *adjoint(Stmt *stmt) {
762 if (!is_real(stmt->ret_type) || stmt->is<ConstStmt>()) {
763 return constant(0);
764 }
765 if (adjoint_stmt.find(stmt) == adjoint_stmt.end()) {
766 // normal SSA cases
767
768 // create the alloca
769 // auto alloca =
770 // Stmt::make<AllocaStmt>(get_current_program().config.gradient_dt);
771 // maybe it's better to use the statement data type than the default type
772 auto alloca = Stmt::make<AllocaStmt>(stmt->ret_type);
773 adjoint_stmt[stmt] = alloca.get();
774
775 // We need to insert the alloca in the block of GlobalLoadStmt when the
776 // GlobalLoadStmt is not inside a range-for
777 // Code sample:
778 // a and b require grad
779 // Case 1 (GlobalLoadStmt is outside the for-loop, compute 5 times and
780 // accumulate once, alloca history value is needed):
781 // for i in range(5):
782 // p = a[i]
783 // q = b[i]
784 // for _ in range(5)
785 // q += p
786
787 // Case 2 (GlobalLoadStmt is inside the for-loop, compute once and
788 // accumulate immediately, alloca history value can be discarded):
789 // for i in range(5):
790 // q = b[i]
791 // for _ in range(5)
792 // q += a[i]
793 if (stmt->is<GlobalLoadStmt>() &&
794 (stmt->parent->parent_stmt != nullptr) &&
795 stmt->parent->parent_stmt->is<RangeForStmt>()) {
796 // Check whether this GlobalLoadStmt is in the body of a for-loop by
797 // searching in the backup forward pass If not (Case 1), the alloca
798 // should not be clear every iteration, therefore, we need to insert the
799 // alloca in the block of the GlobalLoadStmt i.e., where GlobalLoadStmt
800 // is defined
801 if (forward_backup->locate(stmt->as<GlobalLoadStmt>()) == -1) {
802 stmt->as<GlobalLoadStmt>()->parent->insert(std::move(alloca), 0);
803 } else {
804 alloca_block->insert(std::move(alloca), 0);
805 }
806 } else {
807 alloca_block->insert(std::move(alloca), 0);
808 }
809 }
810 return adjoint_stmt[stmt];
811 }
812
813 void visit(UnaryOpStmt *stmt) override {
814 if (stmt->op_type == UnaryOpType::floor ||
815 stmt->op_type == UnaryOpType::ceil) {
816 // do nothing
817 } else if (stmt->op_type == UnaryOpType::neg) {
818 accumulate(stmt->operand, negate(adjoint(stmt)));
819 } else if (stmt->op_type == UnaryOpType::abs) {
820 accumulate(stmt->operand, mul(adjoint(stmt), sgn(stmt->operand)));
821 } else if (stmt->op_type == UnaryOpType::sin) {
822 accumulate(stmt->operand, mul(adjoint(stmt), cos(stmt->operand)));
823 } else if (stmt->op_type == UnaryOpType::cos) {
824 accumulate(stmt->operand, negate(mul(adjoint(stmt), sin(stmt->operand))));
825 } else if (stmt->op_type == UnaryOpType::tan) {
826 // The derivative of `tan` is `1 / cos^2`, which has many singular points
827 // causing NaNs. Though the NaNs are expected, it is error prone and hard
828 // to debug. Therefore we currently don't support computing derivative for
829 // `tan`.
830 TI_NOT_IMPLEMENTED;
831 } else if (stmt->op_type == UnaryOpType::tanh) {
832 accumulate(stmt->operand,
833 mul(adjoint(stmt), sub(constant(1), sqr(stmt))));
834 } else if (stmt->op_type == UnaryOpType::asin) {
835 accumulate(
836 stmt->operand,
837 mul(adjoint(stmt),
838 div(constant(1), sqrt(sub(constant(1), sqr(stmt->operand))))));
839 } else if (stmt->op_type == UnaryOpType::acos) {
840 accumulate(stmt->operand,
841 mul(adjoint(stmt),
842 negate(div(constant(1),
843 sqrt(sub(constant(1), sqr(stmt->operand)))))));
844 } else if (stmt->op_type == UnaryOpType::exp) {
845 accumulate(stmt->operand, mul(adjoint(stmt), stmt));
846 } else if (stmt->op_type == UnaryOpType::log) {
847 accumulate(stmt->operand, div(adjoint(stmt), stmt->operand));
848 } else if (stmt->op_type == UnaryOpType::sqrt) {
849 accumulate(stmt->operand,
850 mul(adjoint(stmt), div(constant(0.5f), sqrt(stmt->operand))));
851 } else if (stmt->op_type == UnaryOpType::rsqrt) {
852 accumulate(
853 stmt->operand,
854 mul(adjoint(stmt),
855 mul(constant(-0.5f), pow(rsqrt(stmt->operand), constant(3)))));
856 } else if (stmt->op_type == UnaryOpType::cast_value) {
857 if (is_real(stmt->cast_type) && is_real(stmt->operand->ret_type)) {
858 accumulate(stmt->operand, adjoint(stmt));
859 }
860 } else if (stmt->op_type == UnaryOpType::logic_not) {
861 // do nothing
862 } else {
863 TI_P(unary_op_type_name(stmt->op_type));
864 TI_NOT_IMPLEMENTED;
865 }
866 }
867
868 void visit(BinaryOpStmt *bin) override {
869 if (bin->op_type == BinaryOpType::add) {
870 accumulate(bin->lhs, adjoint(bin));
871 accumulate(bin->rhs, adjoint(bin));
872 } else if (bin->op_type == BinaryOpType::sub) {
873 accumulate(bin->lhs, adjoint(bin));
874 accumulate(bin->rhs, negate(adjoint(bin)));
875 } else if (bin->op_type == BinaryOpType::mul) {
876 // d (x * y) = y * dx + x * dy
877 accumulate(bin->lhs, mul(adjoint(bin), bin->rhs));
878 accumulate(bin->rhs, mul(adjoint(bin), bin->lhs));
879 } else if (bin->op_type == BinaryOpType::mod) {
880 // Do nothing
881 } else if (bin->op_type == BinaryOpType::div) {
882 accumulate(bin->lhs, div(adjoint(bin), bin->rhs));
883 accumulate(bin->rhs, negate(div(mul(adjoint(bin), bin->lhs),
884 mul(bin->rhs, bin->rhs))));
885 } else if (bin->op_type == BinaryOpType::atan2) {
886 auto numerator = add(sqr(bin->lhs), sqr(bin->rhs));
887 accumulate(bin->lhs, div(mul(adjoint(bin), bin->rhs), numerator));
888 accumulate(bin->rhs, negate(div(mul(adjoint(bin), bin->lhs), numerator)));
889 } else if (bin->op_type == BinaryOpType::pow) {
890 // d (x ^ y) = x ^ (y-1) * (y * dx + log(x) * x * dy)
891 auto common_coeff =
892 pow(bin->lhs, sub(bin->rhs, constant(1))); // x ^ (y-1)
893 accumulate(bin->lhs, mul(adjoint(bin), mul(bin->rhs, common_coeff)));
894 accumulate(bin->rhs, mul(adjoint(bin), mul(log(bin->lhs),
895 mul(bin->lhs, common_coeff))));
896 } else if (bin->op_type == BinaryOpType::min ||
897 bin->op_type == BinaryOpType::max) {
898 auto cmp = bin->op_type == BinaryOpType::min ? cmp_lt(bin->lhs, bin->rhs)
899 : cmp_lt(bin->rhs, bin->lhs);
900 auto zero = insert<ConstStmt>(TypedConstant(bin->ret_type));
901 accumulate(bin->lhs, sel(cmp, adjoint(bin), zero));
902 accumulate(bin->rhs, sel(cmp, zero, adjoint(bin)));
903 } else if (bin->op_type == BinaryOpType::floordiv) {
904 // do nothing
905 } else if (is_comparison(bin->op_type) || is_bit_op(bin->op_type)) {
906 // do nothing
907 } else {
908 TI_WARN("gradient of binary op {}\n{}", binary_op_type_name(bin->op_type),
909 bin->tb);
910 TI_NOT_IMPLEMENTED;
911 }
912 }
913
914 void visit(TernaryOpStmt *stmt) override {
915 TI_ASSERT(stmt->op_type == TernaryOpType::select);
916 auto zero = insert<ConstStmt>(TypedConstant(stmt->ret_type));
917 accumulate(stmt->op2,
918 insert<TernaryOpStmt>(TernaryOpType::select, stmt->op1,
919 load(adjoint(stmt)), zero));
920 accumulate(stmt->op3,
921 insert<TernaryOpStmt>(TernaryOpType::select, stmt->op1, zero,
922 load(adjoint(stmt))));
923 }
924
925 void visit(IfStmt *if_stmt) override {
926 auto new_if = Stmt::make_typed<IfStmt>(if_stmt->cond);
927 if (if_stmt->true_statements) {
928 new_if->set_true_statements(std::make_unique<Block>());
929 auto old_current_block = current_block;
930 // Backup forward pass
931 forward_backup = if_stmt->true_statements.get();
932
933 current_block = new_if->true_statements.get();
934 for (int i = if_stmt->true_statements->statements.size() - 1; i >= 0;
935 i--) {
936 if_stmt->true_statements->statements[i]->accept(this);
937 // Restore forward pass
938 forward_backup = if_stmt->true_statements.get();
939 }
940
941 current_block = old_current_block;
942 }
943 if (if_stmt->false_statements) {
944 new_if->set_false_statements(std::make_unique<Block>());
945 auto old_current_block = current_block;
946
947 // Backup forward pass
948 forward_backup = if_stmt->false_statements.get();
949
950 current_block = new_if->false_statements.get();
951 for (int i = if_stmt->false_statements->statements.size() - 1; i >= 0;
952 i--) {
953 if_stmt->false_statements->statements[i]->accept(this);
954 // Restore forward pass
955 forward_backup = if_stmt->false_statements.get();
956 }
957 current_block = old_current_block;
958 }
959 insert_grad_stmt(std::move(new_if));
960 }
961
962 void visit(RangeForStmt *for_stmt) override {
963 auto new_for = for_stmt->clone();
964 auto new_for_ptr = new_for->as<RangeForStmt>();
965 new_for_ptr->reversed = !new_for_ptr->reversed;
966 insert_grad_stmt(std::move(new_for));
967 const int len = new_for_ptr->body->size();
968
969 for (int i = 0; i < len; i++) {
970 new_for_ptr->body->erase(0);
971 }
972
973 std::vector<Stmt *> statements;
974 // always make a copy since the list can be modified.
975 for (auto &stmt : for_stmt->body->statements) {
976 statements.push_back(stmt.get());
977 }
978 std::reverse(statements.begin(), statements.end()); // reverse-mode AD...
979 auto old_alloca_block = alloca_block;
980 auto old_forward_backup =
981 forward_backup; // store the block which is not inside the current IB,
982 // such as outer most loop
983 // Backup the forward pass
984 forward_backup = for_stmt->body.get();
985 for (auto stmt : statements) {
986 alloca_block = new_for_ptr->body.get();
987 current_block = new_for_ptr->body.get();
988 stmt->accept(this);
989 // Restore the forward pass
990 forward_backup = for_stmt->body.get();
991 }
992 forward_backup = old_forward_backup;
993 alloca_block = old_alloca_block;
994 }
995
996 void visit(StructForStmt *for_stmt) override {
997 alloca_block = for_stmt->body.get();
998 for_stmt->body->accept(this);
999 }
1000
1001 // Equivalent to AdStackLoadTopStmt when no stack is needed
1002 void visit(LocalLoadStmt *stmt) override {
1003 // TI_ASSERT(!needs_grad(stmt->ret_type));
1004 if (is_real(stmt->ret_type))
1005 accumulate(stmt->src, load(adjoint(stmt)));
1006 }
1007
1008 // Equivalent to AdStackPushStmt when no stack is needed
1009 void visit(LocalStoreStmt *stmt) override {
1010 accumulate(stmt->val, load(adjoint(stmt->dest)));
1011
1012 // Clear the adjoint of the dest after local store,
1013 // Because LocalStoreStmt overwrites the dest,
1014 // 1. If the alloca is inside a loop, the adjoint of this alloca of this
1015 // iteration should be cleared after this iteration has been done
1016 // 2. If the alloca serves as the dest of multiple LocalStoreStmt, only the
1017 // last LocalStoreStmt should be taken account of
1018 if (is_real(stmt->dest->ret_type)) {
1019 auto dtype = stmt->dest->ret_type;
1020 auto zero = insert<ConstStmt>(TypedConstant(dtype, 0));
1021 insert<LocalStoreStmt>(adjoint(stmt->dest), zero);
1022 }
1023 }
1024
1025 void visit(AdStackLoadTopStmt *stmt) override {
1026 if (is_real(stmt->ret_type))
1027 insert<AdStackAccAdjointStmt>(stmt->stack, load(adjoint(stmt)));
1028 }
1029
1030 void visit(AdStackPushStmt *stmt) override {
1031 accumulate(stmt->v, insert<AdStackLoadTopAdjStmt>(stmt->stack));
1032 insert<AdStackPopStmt>(stmt->stack);
1033 }
1034
1035 void visit(GlobalLoadStmt *stmt) override {
1036 // issue global store to adjoint
1037 if (stmt->src->is<ExternalPtrStmt>()) {
1038 TI_ERROR(
1039 "Importing data from external array (such as numpy array) not "
1040 "supported in AutoDiff for now")
1041 }
1042
1043 GlobalPtrStmt *src = nullptr;
1044 bool is_ptr_offset = false;
1045 if (stmt->src->is<MatrixPtrStmt>()) {
1046 is_ptr_offset = true;
1047 src = stmt->src->as<MatrixPtrStmt>()->origin->as<GlobalPtrStmt>();
1048 } else {
1049 src = stmt->src->as<GlobalPtrStmt>();
1050 }
1051
1052 auto snode = src->snode;
1053 if (!snode->has_adjoint()) {
1054 // No adjoint SNode. Do nothing
1055 return;
1056 }
1057 if (gradients_stopped(stmt, snode)) {
1058 // gradients stopped, do nothing.
1059 return;
1060 }
1061 TI_ASSERT(snode->get_adjoint() != nullptr);
1062 snode = snode->get_adjoint();
1063 auto adj_ptr = insert<GlobalPtrStmt>(snode, src->indices);
1064 if (is_ptr_offset) {
1065 adj_ptr = insert<MatrixPtrStmt>(adj_ptr,
1066 stmt->src->as<MatrixPtrStmt>()->offset);
1067 }
1068 insert<AtomicOpStmt>(AtomicOpType::add, adj_ptr, load(adjoint(stmt)));
1069 }
1070
1071 void visit(GlobalStoreStmt *stmt) override {
1072 // erase and replace with global load adjoint
1073 if (stmt->dest->is<ExternalPtrStmt>()) {
1074 TI_ERROR(
1075 "Exporting data to external array (such as numpy array) not "
1076 "supported in AutoDiff for now")
1077 }
1078
1079 GlobalPtrStmt *dest = nullptr;
1080 bool is_ptr_offset = false;
1081 if (stmt->dest->is<MatrixPtrStmt>()) {
1082 is_ptr_offset = true;
1083 dest = stmt->dest->as<MatrixPtrStmt>()->origin->as<GlobalPtrStmt>();
1084 } else {
1085 dest = stmt->dest->as<GlobalPtrStmt>();
1086 }
1087
1088 auto snode = dest->snode;
1089 if (!snode->has_adjoint()) {
1090 // no gradient (likely integer types)
1091 return;
1092 }
1093 TI_ASSERT(snode->get_adjoint() != nullptr);
1094 snode = snode->get_adjoint();
1095 auto adjoint_ptr = insert<GlobalPtrStmt>(snode, dest->indices);
1096 if (is_ptr_offset) {
1097 adjoint_ptr = insert<MatrixPtrStmt>(
1098 adjoint_ptr, stmt->dest->as<MatrixPtrStmt>()->offset);
1099 }
1100 accumulate(stmt->val, insert<GlobalLoadStmt>(adjoint_ptr));
1101
1102 // Clear the gradient after accumulation finished.
1103 auto zero = insert<ConstStmt>(
1104 TypedConstant(adjoint_ptr->ret_type.ptr_removed(), 0));
1105 insert<GlobalStoreStmt>(adjoint_ptr, zero);
1106
1107 stmt->parent->erase(stmt);
1108 }
1109
1110 void visit(AtomicOpStmt *stmt) override {
1111 // erase and replace with global load adjoint
1112 GlobalPtrStmt *dest = nullptr;
1113 bool is_ptr_offset = false;
1114 if (stmt->dest->is<MatrixPtrStmt>()) {
1115 is_ptr_offset = true;
1116 dest = stmt->dest->as<MatrixPtrStmt>()->origin->as<GlobalPtrStmt>();
1117 } else {
1118 dest = stmt->dest->as<GlobalPtrStmt>();
1119 }
1120
1121 auto snode = dest->snode;
1122 if (!snode->has_adjoint()) {
1123 // no gradient (likely integer types)
1124 return;
1125 }
1126
1127 TI_ASSERT(snode->get_adjoint() != nullptr);
1128 snode = snode->get_adjoint();
1129 auto adjoint_ptr = insert<GlobalPtrStmt>(snode, dest->indices);
1130 if (is_ptr_offset) {
1131 adjoint_ptr = insert<MatrixPtrStmt>(
1132 adjoint_ptr, stmt->dest->as<MatrixPtrStmt>()->offset);
1133 }
1134 accumulate(stmt->val, insert<GlobalLoadStmt>(adjoint_ptr));
1135 stmt->parent->erase(stmt);
1136 }
1137};
1138
1139// Forward mode autodiff
1140class MakeDual : public ADTransform {
1141 public:
1142 using ADTransform::visit;
1143 Stmt *current_stmt;
1144 Block *current_block;
1145 Block *alloca_block;
1146 std::map<Stmt *, Stmt *> dual_stmt;
1147
1148 explicit MakeDual(Block *block) {
1149 current_stmt = nullptr;
1150 alloca_block = block;
1151 current_block = block;
1152 }
1153
1154 static void run(Block *block) {
1155 auto p = MakeDual(block);
1156 block->accept(&p);
1157 }
1158
1159 Stmt *insert_grad_stmt(std::unique_ptr<Stmt> &&stmt) override {
1160 auto ptr = stmt.get();
1161 current_stmt = current_stmt->insert_after_me(std::move(stmt));
1162 return ptr;
1163 }
1164
1165 void visit(Block *block) override {
1166 std::vector<Stmt *> statements;
1167 // always make a copy since the list can be modified.
1168 for (auto &stmt : block->statements) {
1169 statements.push_back(stmt.get());
1170 }
1171 for (auto stmt : statements) {
1172 current_stmt = stmt;
1173 stmt->accept(this);
1174 }
1175 }
1176
1177 // Accumulate [value] to the dual of [primal]
1178 void accumulate(Stmt *primal, Stmt *value) {
1179 auto alloca_ = dual(primal);
1180 if (!alloca_ || alloca_->is<ConstStmt>())
1181 return; // primal may be int variable
1182
1183 TI_ASSERT(alloca_->is<AllocaStmt>());
1184 auto alloca = alloca_->as<AllocaStmt>();
1185 auto local_load = insert<LocalLoadStmt>(alloca);
1186 insert<LocalStoreStmt>(alloca, add(local_load, value));
1187 }
1188
1189 Stmt *dual(Stmt *stmt) {
1190 if (!is_real(stmt->ret_type) || stmt->is<ConstStmt>()) {
1191 return constant(0);
1192 }
1193 if (dual_stmt.find(stmt) == dual_stmt.end()) {
1194 // normal SSA cases
1195
1196 // create the alloca
1197 // auto alloca =
1198 // Stmt::make<AllocaStmt>(get_current_program().config.gradient_dt);
1199 // maybe it's better to use the statement data type than the default type
1200 auto alloca = Stmt::make<AllocaStmt>(stmt->ret_type);
1201 dual_stmt[stmt] = alloca.get();
1202
1203 // TODO: check whether there are any edge cases for the alloca_block
1204 alloca_block->insert(std::move(alloca), 0);
1205 }
1206 return dual_stmt[stmt];
1207 }
1208
1209 void visit(UnaryOpStmt *stmt) override {
1210 if (stmt->op_type == UnaryOpType::neg) {
1211 accumulate(stmt, negate(dual(stmt->operand)));
1212 } else if (stmt->op_type == UnaryOpType::abs) {
1213 accumulate(stmt, mul(sgn(stmt->operand), dual(stmt->operand)));
1214 } else if (stmt->op_type == UnaryOpType::sin) {
1215 accumulate(stmt, mul(cos(stmt->operand), dual(stmt->operand)));
1216 } else if (stmt->op_type == UnaryOpType::cos) {
1217 accumulate(stmt, negate(mul(sin(stmt->operand), dual(stmt->operand))));
1218 } else if (stmt->op_type == UnaryOpType::tan) {
1219 // The derivative of `tan` is `1 / cos^2`, which has many singular points
1220 // causing NaNs. Though the NaNs are expected, it is error prone and hard
1221 // to debug. Therefore we currently don't support computing derivative for
1222 // `tan`.
1223 TI_NOT_IMPLEMENTED;
1224 } else if (stmt->op_type == UnaryOpType::tanh) {
1225 accumulate(stmt, mul(sub(constant(1), sqr(stmt)), dual(stmt->operand)));
1226 } else if (stmt->op_type == UnaryOpType::asin) {
1227 accumulate(stmt, mul(div(constant(1),
1228 sqrt(sub(constant(1), sqr(stmt->operand)))),
1229 dual(stmt->operand)));
1230 } else if (stmt->op_type == UnaryOpType::acos) {
1231 accumulate(stmt,
1232 mul(negate(div(constant(1),
1233 sqrt(sub(constant(1), sqr(stmt->operand))))),
1234 dual(stmt->operand)));
1235 } else if (stmt->op_type == UnaryOpType::exp) {
1236 accumulate(stmt, mul(stmt, dual(stmt->operand)));
1237 } else if (stmt->op_type == UnaryOpType::log) {
1238 accumulate(stmt, div(dual(stmt->operand), stmt->operand));
1239 } else if (stmt->op_type == UnaryOpType::sqrt) {
1240 accumulate(stmt, mul(div(constant(0.5f), sqrt(stmt->operand)),
1241 dual(stmt->operand)));
1242 } else if (stmt->op_type == UnaryOpType::rsqrt) {
1243 accumulate(stmt, mul(mul(constant(-0.5f),
1244 pow(rsqrt(stmt->operand), constant(3))),
1245 dual(stmt->operand)));
1246 } else if (stmt->op_type == UnaryOpType::cast_value) {
1247 if (is_real(stmt->cast_type) && is_real(stmt->operand->ret_type)) {
1248 accumulate(stmt, dual(stmt->operand));
1249 }
1250 } else if (stmt->op_type == UnaryOpType::logic_not) {
1251 // do nothing
1252 } else {
1253 TI_P(unary_op_type_name(stmt->op_type));
1254 TI_NOT_IMPLEMENTED
1255 }
1256 }
1257
1258 void visit(BinaryOpStmt *bin) override {
1259 if (bin->op_type == BinaryOpType::add) {
1260 accumulate(bin, dual(bin->lhs));
1261 accumulate(bin, dual(bin->rhs));
1262 } else if (bin->op_type == BinaryOpType::sub) {
1263 accumulate(bin, dual(bin->lhs));
1264 accumulate(bin, negate(dual(bin->rhs)));
1265 } else if (bin->op_type == BinaryOpType::mul) {
1266 // d (x * y) = y * dx + x * dy
1267 accumulate(bin, mul(bin->lhs, dual(bin->rhs)));
1268 accumulate(bin, mul(bin->rhs, dual(bin->lhs)));
1269 } else if (bin->op_type == BinaryOpType::mod) {
1270 // Do nothing
1271 } else if (bin->op_type == BinaryOpType::div) {
1272 accumulate(bin, div(dual(bin->lhs), bin->rhs));
1273 accumulate(bin, negate(div(mul(dual(bin->rhs), bin->lhs),
1274 mul(bin->rhs, bin->rhs))));
1275 } else if (bin->op_type == BinaryOpType::atan2) {
1276 auto numerator = add(sqr(bin->lhs), sqr(bin->rhs));
1277 accumulate(bin, div(mul(bin->rhs, dual(bin->lhs)), numerator));
1278 accumulate(bin, negate(div(mul(bin->lhs, dual(bin->rhs)), numerator)));
1279 } else if (bin->op_type == BinaryOpType::pow) {
1280 // d (x ^ y) = x ^ (y-1) * (y * dx + log(x) * x * dy)
1281 auto common_coeff =
1282 pow(bin->lhs, sub(bin->rhs, constant(1))); // x ^ (y-1)
1283 accumulate(bin, mul(dual(bin->lhs), mul(bin->rhs, common_coeff)));
1284 accumulate(bin, mul(dual(bin->rhs),
1285 mul(log(bin->lhs), mul(bin->lhs, common_coeff))));
1286 } else if (bin->op_type == BinaryOpType::min ||
1287 bin->op_type == BinaryOpType::max) {
1288 auto cmp = bin->op_type == BinaryOpType::min ? cmp_lt(bin->lhs, bin->rhs)
1289 : cmp_lt(bin->rhs, bin->lhs);
1290 auto zero = insert<ConstStmt>(TypedConstant(bin->ret_type));
1291 accumulate(bin, sel(cmp, dual(bin->lhs), zero));
1292 accumulate(bin, sel(cmp, zero, dual(bin->rhs)));
1293 } else if (bin->op_type == BinaryOpType::floordiv) {
1294 // do nothing
1295 } else if (is_comparison(bin->op_type) || is_bit_op(bin->op_type)) {
1296 // do nothing
1297 } else {
1298 TI_WARN("gradient of binary op {}\n{}", binary_op_type_name(bin->op_type),
1299 bin->tb);
1300 TI_NOT_IMPLEMENTED
1301 }
1302 }
1303
1304 void visit(TernaryOpStmt *stmt) override {
1305 TI_ASSERT(stmt->op_type == TernaryOpType::select);
1306 auto zero = insert<ConstStmt>(TypedConstant(stmt->ret_type));
1307 accumulate(stmt, insert<TernaryOpStmt>(TernaryOpType::select, stmt->op1,
1308 load(dual(stmt->op2)), zero));
1309 accumulate(stmt, insert<TernaryOpStmt>(TernaryOpType::select, stmt->op1,
1310 zero, load(dual(stmt->op3))));
1311 }
1312
1313 void visit(IfStmt *if_stmt) override {
1314 if (if_stmt->true_statements) {
1315 std::vector<Stmt *> true_statements;
1316 for (auto &stmt : if_stmt->true_statements->statements) {
1317 true_statements.push_back(stmt.get());
1318 }
1319
1320 for (auto stmt : true_statements) {
1321 current_stmt = stmt;
1322 stmt->accept(this);
1323 }
1324 }
1325 if (if_stmt->false_statements) {
1326 std::vector<Stmt *> false_statements;
1327 for (auto &stmt : if_stmt->false_statements->statements) {
1328 false_statements.push_back(stmt.get());
1329 }
1330
1331 for (auto stmt : false_statements) {
1332 current_stmt = stmt;
1333 stmt->accept(this);
1334 }
1335 }
1336 }
1337
1338 void visit(RangeForStmt *for_stmt) override {
1339 std::vector<Stmt *> statements;
1340 // always make a copy since the list can be modified.
1341 for (auto &stmt : for_stmt->body->statements) {
1342 statements.push_back(stmt.get());
1343 }
1344 auto previous_alloca_block = alloca_block;
1345 alloca_block = for_stmt->body.get();
1346 for (auto stmt : statements) {
1347 current_stmt = stmt;
1348 stmt->accept(this);
1349 }
1350 alloca_block = previous_alloca_block;
1351 }
1352
1353 void visit(StructForStmt *for_stmt) override {
1354 alloca_block = for_stmt->body.get();
1355 for_stmt->body->accept(this);
1356 }
1357
1358 void visit(LocalLoadStmt *stmt) override {
1359 // TI_ASSERT(!needs_grad(stmt->ret_type));
1360 accumulate(stmt, dual(stmt->src));
1361 }
1362
1363 void visit(LocalStoreStmt *stmt) override {
1364 // Clear the dual of the dest before local store,
1365 // Because LocalStoreStmt overwrites the dest,
1366 // If the alloca serves as the dest of multiple LocalStoreStmt, only the
1367 // last LocalStoreStmt should be taken account of, i.e, its history should
1368 // be cleared
1369 if (is_real(stmt->dest->ret_type)) {
1370 auto dtype = stmt->dest->ret_type;
1371 auto zero = insert<ConstStmt>(TypedConstant(dtype, 0));
1372 insert<LocalStoreStmt>(dual(stmt->dest), zero);
1373 }
1374
1375 accumulate(stmt->dest, dual(stmt->val));
1376 }
1377
1378 void visit(GlobalLoadStmt *stmt) override {
1379 // issue global store to dual
1380 GlobalPtrStmt *src = nullptr;
1381 bool is_ptr_offset = false;
1382 if (stmt->src->is<MatrixPtrStmt>()) {
1383 is_ptr_offset = true;
1384 src = stmt->src->as<MatrixPtrStmt>()->origin->as<GlobalPtrStmt>();
1385 } else {
1386 src = stmt->src->as<GlobalPtrStmt>();
1387 }
1388 auto snode = src->snode;
1389 if (!snode->has_dual()) {
1390 // No dual SNode. Do nothing
1391 return;
1392 }
1393 if (gradients_stopped(stmt, snode)) {
1394 // gradients stopped, do nothing.
1395 return;
1396 }
1397 TI_ASSERT(snode->get_dual() != nullptr);
1398 snode = snode->get_dual();
1399 auto dual_ptr = insert<GlobalPtrStmt>(snode, src->indices);
1400 if (is_ptr_offset) {
1401 dual_ptr = insert<MatrixPtrStmt>(dual_ptr,
1402 stmt->src->as<MatrixPtrStmt>()->offset);
1403 }
1404 accumulate(stmt, insert<GlobalLoadStmt>(dual_ptr));
1405 }
1406
1407 void visit(GlobalStoreStmt *stmt) override {
1408 GlobalPtrStmt *dest = nullptr;
1409 bool is_ptr_offset = false;
1410 if (stmt->dest->is<MatrixPtrStmt>()) {
1411 is_ptr_offset = true;
1412 dest = stmt->dest->as<MatrixPtrStmt>()->origin->as<GlobalPtrStmt>();
1413 } else {
1414 dest = stmt->dest->as<GlobalPtrStmt>();
1415 }
1416 auto snode = dest->snode;
1417 if (!snode->has_dual()) {
1418 // no gradient (likely integer types)
1419 return;
1420 }
1421 TI_ASSERT(snode->get_dual() != nullptr);
1422 snode = snode->get_dual();
1423 auto dual_ptr = insert<GlobalPtrStmt>(snode, dest->indices);
1424 if (is_ptr_offset) {
1425 dual_ptr = insert<MatrixPtrStmt>(dual_ptr,
1426 stmt->dest->as<MatrixPtrStmt>()->offset);
1427 }
1428 insert<AtomicOpStmt>(AtomicOpType::add, dual_ptr, load(dual(stmt->val)));
1429 }
1430
1431 void visit(AtomicOpStmt *stmt) override {
1432 GlobalPtrStmt *dest = nullptr;
1433 bool is_ptr_offset = false;
1434 if (stmt->dest->is<MatrixPtrStmt>()) {
1435 is_ptr_offset = true;
1436 dest = stmt->dest->as<MatrixPtrStmt>()->origin->as<GlobalPtrStmt>();
1437 } else {
1438 dest = stmt->dest->as<GlobalPtrStmt>();
1439 }
1440 auto snode = dest->snode;
1441 if (!snode->has_dual()) {
1442 // no gradient (likely integer types)
1443 return;
1444 }
1445 TI_ASSERT(snode->get_dual() != nullptr);
1446 snode = snode->get_dual();
1447 auto dual_ptr = insert<GlobalPtrStmt>(snode, dest->indices);
1448 if (is_ptr_offset) {
1449 dual_ptr = insert<MatrixPtrStmt>(dual_ptr,
1450 stmt->dest->as<MatrixPtrStmt>()->offset);
1451 }
1452 insert<AtomicOpStmt>(AtomicOpType::add, dual_ptr, load(dual(stmt->val)));
1453 }
1454};
1455
1456class BackupSSA : public BasicStmtVisitor {
1457 public:
1458 using BasicStmtVisitor::visit;
1459
1460 Block *independent_block;
1461 std::map<Stmt *, Stmt *> backup_alloca;
1462
1463 explicit BackupSSA(Block *independent_block)
1464 : independent_block(independent_block) {
1465 allow_undefined_visitor = true;
1466 invoke_default_visitor = true;
1467 }
1468
1469 Stmt *load(Stmt *stmt) {
1470 if (backup_alloca.find(stmt) == backup_alloca.end()) {
1471 auto alloca = Stmt::make<AllocaStmt>(stmt->ret_type);
1472 auto alloca_ptr = alloca.get();
1473 independent_block->insert(std::move(alloca), 0);
1474 auto local_store = Stmt::make<LocalStoreStmt>(alloca_ptr, stmt);
1475 stmt->insert_after_me(std::move(local_store));
1476 backup_alloca[stmt] = alloca_ptr;
1477 }
1478 return backup_alloca[stmt];
1479 }
1480
1481 void generic_visit(Stmt *stmt) {
1482 std::vector<Block *> leaf_to_root;
1483 auto t = stmt->parent;
1484 while (t != nullptr) {
1485 leaf_to_root.push_back(t);
1486 t = t->parent_block();
1487 }
1488 int num_operands = stmt->get_operands().size();
1489 for (int i = 0; i < num_operands; i++) {
1490 auto op = stmt->operand(i);
1491 if (op == nullptr) {
1492 continue;
1493 }
1494 if (std::find(leaf_to_root.begin(), leaf_to_root.end(), op->parent) ==
1495 leaf_to_root.end() &&
1496 !op->is<AllocaStmt>()) {
1497 if (op->is<AdStackLoadTopStmt>()) {
1498 // Just create another AdStackLoadTopStmt
1499 stmt->set_operand(i, stmt->insert_before_me(op->clone()));
1500 } else if (op->is<AdStackAllocaStmt>()) {
1501 // Backup AdStackAllocaStmt because it should not be local stored and
1502 // local loaded
1503 auto stack_alloca = op->as<AdStackAllocaStmt>();
1504 if (backup_alloca.find(op) == backup_alloca.end()) {
1505 auto backup_stack_alloca = Stmt::make<AdStackAllocaStmt>(
1506 stack_alloca->dt, stack_alloca->max_size);
1507 auto backup_stack_alloca_ptr = backup_stack_alloca.get();
1508 independent_block->insert(std::move(backup_stack_alloca), 0);
1509 backup_alloca[op] = backup_stack_alloca_ptr;
1510 // Replace usages of all blocks i.e., the entry point for the
1511 // replace is the top level block
1512 irpass::replace_all_usages_with(leaf_to_root.back(), op,
1513 backup_stack_alloca_ptr);
1514 // Erase the outdated AdStackAllocaStmt
1515 op->parent->erase(op);
1516 }
1517 } else {
1518 auto alloca = load(op);
1519 stmt->set_operand(
1520 i, stmt->insert_before_me(Stmt::make<LocalLoadStmt>(alloca)));
1521 }
1522 }
1523 }
1524 }
1525
1526 void visit(Stmt *stmt) override {
1527 generic_visit(stmt);
1528 }
1529
1530 void visit(IfStmt *stmt) override {
1531 generic_visit(stmt);
1532 BasicStmtVisitor::visit(stmt);
1533 }
1534
1535 // TODO: test operands for statements
1536 void visit(RangeForStmt *stmt) override {
1537 stmt->body->accept(this);
1538 }
1539
1540 void visit(StructForStmt *stmt) override {
1541 stmt->body->accept(this);
1542 }
1543
1544 void visit(WhileStmt *stmt) override {
1545 TI_ERROR("WhileStmt not supported in AutoDiff for now.");
1546 }
1547
1548 void visit(Block *block) override {
1549 std::vector<Stmt *> statements;
1550 // always make a copy since the list can be modified.
1551 for (auto &stmt : block->statements) {
1552 statements.push_back(stmt.get());
1553 }
1554 for (auto stmt : statements) {
1555 TI_ASSERT(!stmt->erased);
1556 stmt->accept(this);
1557 }
1558 }
1559
1560 public:
1561 static void run(Block *block) {
1562 BackupSSA pass(block);
1563 block->accept(&pass);
1564 }
1565};
1566
1567namespace irpass {
1568
1569void auto_diff(IRNode *root,
1570 const CompileConfig &config,
1571 AutodiffMode autodiff_mode,
1572 bool use_stack) {
1573 TI_AUTO_PROF;
1574 if (autodiff_mode == AutodiffMode::kReverse) {
1575 if (use_stack) {
1576 auto IB = IdentifyIndependentBlocks::run(root);
1577 ReverseOuterLoops::run(root, IB);
1578
1579 for (auto ib : IB) {
1580 PromoteSSA2LocalVar::run(ib);
1581 ReplaceLocalVarWithStacks replace(config.ad_stack_size);
1582 ib->accept(&replace);
1583 type_check(root, config);
1584 MakeAdjoint::run(ib);
1585 type_check(root, config);
1586 BackupSSA::run(ib);
1587 irpass::analysis::verify(root);
1588 }
1589 } else {
1590 auto IB = IdentifyIndependentBlocks::run(root);
1591 ReverseOuterLoops::run(root, IB);
1592 type_check(root, config);
1593 for (auto ib : IB) {
1594 MakeAdjoint::run(ib);
1595 }
1596 }
1597 } else if (autodiff_mode == AutodiffMode::kForward) {
1598 // Forward mode autodiff
1599 Block *block = root->as<Block>();
1600 MakeDual::run(block);
1601 }
1602 type_check(root, config);
1603 irpass::analysis::verify(root);
1604}
1605
1606class GloablDataAccessRuleChecker : public BasicStmtVisitor {
1607 public:
1608 using BasicStmtVisitor::visit;
1609
1610 void visit(GlobalLoadStmt *stmt) override {
1611 GlobalPtrStmt *src = stmt->src->as<GlobalPtrStmt>();
1612 auto snode = src->snode;
1613 if (!snode->has_adjoint_checkbit()) {
1614 return;
1615 }
1616 TI_ASSERT(snode->get_adjoint_checkbit() != nullptr);
1617 snode = snode->get_adjoint_checkbit();
1618 auto global_ptr =
1619 stmt->insert_after_me(Stmt::make<GlobalPtrStmt>(snode, src->indices));
1620 auto dtype = global_ptr->ret_type;
1621 auto one = global_ptr->insert_after_me(
1622 Stmt::make<ConstStmt>(TypedConstant(dtype, 1)));
1623 one->insert_after_me(Stmt::make<GlobalStoreStmt>(global_ptr, one));
1624 }
1625
1626 void visit_gloabl_store_stmt_and_atomic_add(Stmt *stmt, GlobalPtrStmt *dest) {
1627 auto snode = dest->snode;
1628 if (!snode->has_adjoint_checkbit()) {
1629 return;
1630 }
1631 TI_ASSERT(snode->get_adjoint_checkbit() != nullptr);
1632 snode = snode->get_adjoint_checkbit();
1633 auto global_ptr =
1634 stmt->insert_before_me(Stmt::make<GlobalPtrStmt>(snode, dest->indices));
1635 auto global_load =
1636 stmt->insert_before_me(Stmt::make<GlobalLoadStmt>(global_ptr));
1637 auto dtype = global_ptr->ret_type;
1638 auto zero =
1639 stmt->insert_before_me(Stmt::make<ConstStmt>(TypedConstant(dtype, 0)));
1640 auto check_equal = stmt->insert_before_me(
1641 Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_eq, global_load, zero));
1642 std::string msg = fmt::format(
1643 "(kernel={}) Breaks the global data access rule. Snode {} is "
1644 "overwritten unexpectedly.",
1645 kernel_name_, dest->snode->get_node_type_name());
1646 msg += "\n" + stmt->tb;
1647
1648 stmt->insert_before_me(
1649 Stmt::make<AssertStmt>(check_equal, msg, std::vector<Stmt *>()));
1650 }
1651
1652 void visit(GlobalStoreStmt *stmt) override {
1653 GlobalPtrStmt *dest = stmt->dest->as<GlobalPtrStmt>();
1654 visit_gloabl_store_stmt_and_atomic_add(stmt, dest);
1655 }
1656
1657 void visit(AtomicOpStmt *stmt) override {
1658 GlobalPtrStmt *dest = stmt->dest->as<GlobalPtrStmt>();
1659 visit_gloabl_store_stmt_and_atomic_add(stmt, dest);
1660 }
1661
1662 static void run(IRNode *root, const std::string &kernel_name) {
1663 GloablDataAccessRuleChecker checker;
1664 checker.kernel_name_ = kernel_name;
1665 root->accept(&checker);
1666 }
1667
1668 private:
1669 std::string kernel_name_;
1670};
1671
1672void differentiation_validation_check(IRNode *root,
1673 const CompileConfig &config,
1674 const std::string &kernel_name) {
1675 return irpass::GloablDataAccessRuleChecker::run(root, kernel_name);
1676}
1677
1678} // namespace irpass
1679
1680} // namespace taichi::lang
1681