1#include "taichi/ir/ir_builder.h"
2#include "taichi/ir/statements.h"
3#include "taichi/program/program.h"
4
5void aot_save(taichi::Arch arch) {
6 using namespace taichi;
7 using namespace lang;
8 auto program = Program(arch);
9
10 program.this_thread_config().advanced_optimization = false;
11
12 int n = 10;
13
14 // program.materialize_runtime();
15 auto *root = new SNode(0, SNodeType::root);
16 auto *pointer = &root->dense(Axis(0), n, false, "");
17 auto *place = &pointer->insert_children(SNodeType::place);
18 place->dt = PrimitiveType::i32;
19 program.add_snode_tree(std::unique_ptr<SNode>(root), /*compile_only=*/true);
20
21 auto aot_builder = program.make_aot_module_builder(arch, {});
22
23 std::unique_ptr<Kernel> kernel_init, kernel_ret;
24
25 {
26 /*
27 @ti.kernel
28 def init():
29 for index in range(n):
30 place[index] = index
31 */
32 IRBuilder builder;
33 auto *zero = builder.get_int32(0);
34 auto *n_stmt = builder.get_int32(n);
35 auto *loop = builder.create_range_for(zero, n_stmt, 0, 4);
36 {
37 auto _ = builder.get_loop_guard(loop);
38 auto *index = builder.get_loop_index(loop);
39 auto *ptr = builder.create_global_ptr(place, {index});
40 builder.create_global_store(ptr, index);
41 }
42
43 kernel_init =
44 std::make_unique<Kernel>(program, builder.extract_ir(), "init");
45 }
46
47 {
48 /*
49 @ti.kernel
50 def ret():
51 sum = 0
52 for index in place:
53 sum = sum + place[index];
54 return sum
55 */
56 IRBuilder builder;
57 auto *sum = builder.create_local_var(PrimitiveType::i32);
58 auto *loop = builder.create_struct_for(pointer, 0, 4);
59 {
60 auto _ = builder.get_loop_guard(loop);
61 auto *index = builder.get_loop_index(loop);
62 auto *sum_old = builder.create_local_load(sum);
63 auto *place_index =
64 builder.create_global_load(builder.create_global_ptr(place, {index}));
65 builder.create_local_store(sum, builder.create_add(sum_old, place_index));
66 }
67 builder.create_return(builder.create_local_load(sum));
68
69 kernel_ret = std::make_unique<Kernel>(program, builder.extract_ir(), "ret");
70 kernel_ret->insert_ret(PrimitiveType::i32);
71 }
72
73 aot_builder->add_field("place", place, true, place->dt, {n}, 1, 1);
74 aot_builder->add("init", kernel_init.get());
75 aot_builder->add("ret", kernel_ret.get());
76 aot_builder->dump(".", taichi::arch_name(arch) + "_aot.tcb");
77 std::cout << "done" << std::endl;
78}
79