1 | #include "taichi/ir/ir_builder.h" |
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/program/program.h" |
4 | |
5 | void aot_save(taichi::Arch arch) { |
6 | using namespace taichi; |
7 | using namespace lang; |
8 | auto program = Program(arch); |
9 | |
10 | program.this_thread_config().advanced_optimization = false; |
11 | |
12 | int n = 10; |
13 | |
14 | // program.materialize_runtime(); |
15 | auto *root = new SNode(0, SNodeType::root); |
16 | auto *pointer = &root->dense(Axis(0), n, false, "" ); |
17 | auto *place = &pointer->insert_children(SNodeType::place); |
18 | place->dt = PrimitiveType::i32; |
19 | program.add_snode_tree(std::unique_ptr<SNode>(root), /*compile_only=*/true); |
20 | |
21 | auto aot_builder = program.make_aot_module_builder(arch, {}); |
22 | |
23 | std::unique_ptr<Kernel> kernel_init, kernel_ret; |
24 | |
25 | { |
26 | /* |
27 | @ti.kernel |
28 | def init(): |
29 | for index in range(n): |
30 | place[index] = index |
31 | */ |
32 | IRBuilder builder; |
33 | auto *zero = builder.get_int32(0); |
34 | auto *n_stmt = builder.get_int32(n); |
35 | auto *loop = builder.create_range_for(zero, n_stmt, 0, 4); |
36 | { |
37 | auto _ = builder.get_loop_guard(loop); |
38 | auto *index = builder.get_loop_index(loop); |
39 | auto *ptr = builder.create_global_ptr(place, {index}); |
40 | builder.create_global_store(ptr, index); |
41 | } |
42 | |
43 | kernel_init = |
44 | std::make_unique<Kernel>(program, builder.extract_ir(), "init" ); |
45 | } |
46 | |
47 | { |
48 | /* |
49 | @ti.kernel |
50 | def ret(): |
51 | sum = 0 |
52 | for index in place: |
53 | sum = sum + place[index]; |
54 | return sum |
55 | */ |
56 | IRBuilder builder; |
57 | auto *sum = builder.create_local_var(PrimitiveType::i32); |
58 | auto *loop = builder.create_struct_for(pointer, 0, 4); |
59 | { |
60 | auto _ = builder.get_loop_guard(loop); |
61 | auto *index = builder.get_loop_index(loop); |
62 | auto *sum_old = builder.create_local_load(sum); |
63 | auto *place_index = |
64 | builder.create_global_load(builder.create_global_ptr(place, {index})); |
65 | builder.create_local_store(sum, builder.create_add(sum_old, place_index)); |
66 | } |
67 | builder.create_return(builder.create_local_load(sum)); |
68 | |
69 | kernel_ret = std::make_unique<Kernel>(program, builder.extract_ir(), "ret" ); |
70 | kernel_ret->insert_ret(PrimitiveType::i32); |
71 | } |
72 | |
73 | aot_builder->add_field("place" , place, true, place->dt, {n}, 1, 1); |
74 | aot_builder->add("init" , kernel_init.get()); |
75 | aot_builder->add("ret" , kernel_ret.get()); |
76 | aot_builder->dump("." , taichi::arch_name(arch) + "_aot.tcb" ); |
77 | std::cout << "done" << std::endl; |
78 | } |
79 | |