1 | #include "taichi/ir/ir_builder.h" |
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/program/program.h" |
4 | |
5 | void run_snode() { |
6 | /* |
7 | import taichi as ti, numpy as np |
8 | ti.init() |
9 | #ti.init(print_ir = True) |
10 | |
11 | n = 10 |
12 | place = ti.field(dtype = ti.i32) |
13 | ti.root.pointer(ti.i, n).place(place) |
14 | |
15 | @ti.kernel |
16 | def init(): |
17 | for index in range(n): |
18 | place[index] = index |
19 | |
20 | @ti.kernel |
21 | def ret() -> ti.i32: |
22 | sum = 0 |
23 | for index in place: |
24 | sum = sum + place[index] |
25 | return sum |
26 | |
27 | @ti.kernel |
28 | def ext(ext_arr: ti.ext_arr()): |
29 | for index in place: |
30 | ext_arr[index] = place[index] |
31 | |
32 | init() |
33 | print(ret()) |
34 | ext_arr = np.zeros(n, np.int32) |
35 | ext(ext_arr) |
36 | #ext_arr = place.to_numpy() |
37 | print(ext_arr) |
38 | */ |
39 | |
40 | using namespace taichi; |
41 | using namespace lang; |
42 | auto program = Program(Arch::x64); |
43 | /*CompileConfig config_print_ir; |
44 | config_print_ir.print_ir = true; |
45 | prog_.config = config_print_ir;*/ // print_ir = True |
46 | |
47 | int n = 10; |
48 | program.materialize_runtime(); |
49 | auto *root = new SNode(0, SNodeType::root); |
50 | auto *pointer = &root->pointer(Axis(0), n, false, "" ); |
51 | auto *place = &pointer->insert_children(SNodeType::place); |
52 | place->dt = PrimitiveType::i32; |
53 | program.add_snode_tree(std::unique_ptr<SNode>(root), /*compile_only=*/false); |
54 | |
55 | std::unique_ptr<Kernel> kernel_init, kernel_ret, kernel_ext; |
56 | |
57 | { |
58 | /* |
59 | @ti.kernel |
60 | def init(): |
61 | for index in range(n): |
62 | place[index] = index |
63 | */ |
64 | IRBuilder builder; |
65 | auto *zero = builder.get_int32(0); |
66 | auto *n_stmt = builder.get_int32(n); |
67 | auto *loop = builder.create_range_for(zero, n_stmt, 0, 4); |
68 | { |
69 | auto _ = builder.get_loop_guard(loop); |
70 | auto *index = builder.get_loop_index(loop); |
71 | auto *ptr = builder.create_global_ptr(place, {index}); |
72 | builder.create_global_store(ptr, index); |
73 | } |
74 | |
75 | kernel_init = |
76 | std::make_unique<Kernel>(program, builder.extract_ir(), "init" ); |
77 | } |
78 | |
79 | { |
80 | /* |
81 | @ti.kernel |
82 | def ret(): |
83 | sum = 0 |
84 | for index in place: |
85 | sum = sum + place[index]; |
86 | return sum |
87 | */ |
88 | IRBuilder builder; |
89 | auto *sum = builder.create_local_var(PrimitiveType::i32); |
90 | auto *loop = builder.create_struct_for(pointer, 0, 4); |
91 | { |
92 | auto _ = builder.get_loop_guard(loop); |
93 | auto *index = builder.get_loop_index(loop); |
94 | auto *sum_old = builder.create_local_load(sum); |
95 | auto *place_index = |
96 | builder.create_global_load(builder.create_global_ptr(place, {index})); |
97 | builder.create_local_store(sum, builder.create_add(sum_old, place_index)); |
98 | } |
99 | builder.create_return(builder.create_local_load(sum)); |
100 | |
101 | kernel_ret = std::make_unique<Kernel>(program, builder.extract_ir(), "ret" ); |
102 | } |
103 | |
104 | { |
105 | /* |
106 | @ti.kernel |
107 | def ext(ext: ti.ext_arr()): |
108 | for index in place: |
109 | ext[index] = place[index]; |
110 | # ext = place.to_numpy() |
111 | */ |
112 | IRBuilder builder; |
113 | auto *loop = builder.create_struct_for(pointer, 0, 4); |
114 | { |
115 | auto _ = builder.get_loop_guard(loop); |
116 | auto *index = builder.get_loop_index(loop); |
117 | auto *ext = builder.create_external_ptr( |
118 | builder.create_arg_load(0, PrimitiveType::i32, true), {index}); |
119 | auto *place_index = |
120 | builder.create_global_load(builder.create_global_ptr(place, {index})); |
121 | builder.create_global_store(ext, place_index); |
122 | } |
123 | |
124 | kernel_ext = std::make_unique<Kernel>(program, builder.extract_ir(), "ext" ); |
125 | kernel_ext->insert_arr_arg(get_data_type<int>(), /*total_dim=*/1, {n}); |
126 | } |
127 | |
128 | auto ctx_init = kernel_init->make_launch_context(); |
129 | auto ctx_ret = kernel_ret->make_launch_context(); |
130 | auto ctx_ext = kernel_ext->make_launch_context(); |
131 | std::vector<int> ext_arr(n); |
132 | ctx_ext.set_arg_external_array_with_shape(0, taichi::uint64(ext_arr.data()), |
133 | n, {n}); |
134 | |
135 | (*kernel_init)(ctx_init); |
136 | (*kernel_ret)(ctx_ret); |
137 | std::cout << program.fetch_result<int>(0) << std::endl; |
138 | (*kernel_ext)(ctx_ext); |
139 | for (int i = 0; i < n; i++) |
140 | std::cout << ext_arr[i] << " " ; |
141 | std::cout << std::endl; |
142 | } |
143 | |