1#include "taichi/ir/ir_builder.h"
2#include "taichi/ir/statements.h"
3#include "taichi/program/program.h"
4
5void run_snode() {
6 /*
7 import taichi as ti, numpy as np
8 ti.init()
9 #ti.init(print_ir = True)
10
11 n = 10
12 place = ti.field(dtype = ti.i32)
13 ti.root.pointer(ti.i, n).place(place)
14
15 @ti.kernel
16 def init():
17 for index in range(n):
18 place[index] = index
19
20 @ti.kernel
21 def ret() -> ti.i32:
22 sum = 0
23 for index in place:
24 sum = sum + place[index]
25 return sum
26
27 @ti.kernel
28 def ext(ext_arr: ti.ext_arr()):
29 for index in place:
30 ext_arr[index] = place[index]
31
32 init()
33 print(ret())
34 ext_arr = np.zeros(n, np.int32)
35 ext(ext_arr)
36 #ext_arr = place.to_numpy()
37 print(ext_arr)
38 */
39
40 using namespace taichi;
41 using namespace lang;
42 auto program = Program(Arch::x64);
43 /*CompileConfig config_print_ir;
44 config_print_ir.print_ir = true;
45 prog_.config = config_print_ir;*/ // print_ir = True
46
47 int n = 10;
48 program.materialize_runtime();
49 auto *root = new SNode(0, SNodeType::root);
50 auto *pointer = &root->pointer(Axis(0), n, false, "");
51 auto *place = &pointer->insert_children(SNodeType::place);
52 place->dt = PrimitiveType::i32;
53 program.add_snode_tree(std::unique_ptr<SNode>(root), /*compile_only=*/false);
54
55 std::unique_ptr<Kernel> kernel_init, kernel_ret, kernel_ext;
56
57 {
58 /*
59 @ti.kernel
60 def init():
61 for index in range(n):
62 place[index] = index
63 */
64 IRBuilder builder;
65 auto *zero = builder.get_int32(0);
66 auto *n_stmt = builder.get_int32(n);
67 auto *loop = builder.create_range_for(zero, n_stmt, 0, 4);
68 {
69 auto _ = builder.get_loop_guard(loop);
70 auto *index = builder.get_loop_index(loop);
71 auto *ptr = builder.create_global_ptr(place, {index});
72 builder.create_global_store(ptr, index);
73 }
74
75 kernel_init =
76 std::make_unique<Kernel>(program, builder.extract_ir(), "init");
77 }
78
79 {
80 /*
81 @ti.kernel
82 def ret():
83 sum = 0
84 for index in place:
85 sum = sum + place[index];
86 return sum
87 */
88 IRBuilder builder;
89 auto *sum = builder.create_local_var(PrimitiveType::i32);
90 auto *loop = builder.create_struct_for(pointer, 0, 4);
91 {
92 auto _ = builder.get_loop_guard(loop);
93 auto *index = builder.get_loop_index(loop);
94 auto *sum_old = builder.create_local_load(sum);
95 auto *place_index =
96 builder.create_global_load(builder.create_global_ptr(place, {index}));
97 builder.create_local_store(sum, builder.create_add(sum_old, place_index));
98 }
99 builder.create_return(builder.create_local_load(sum));
100
101 kernel_ret = std::make_unique<Kernel>(program, builder.extract_ir(), "ret");
102 }
103
104 {
105 /*
106 @ti.kernel
107 def ext(ext: ti.ext_arr()):
108 for index in place:
109 ext[index] = place[index];
110 # ext = place.to_numpy()
111 */
112 IRBuilder builder;
113 auto *loop = builder.create_struct_for(pointer, 0, 4);
114 {
115 auto _ = builder.get_loop_guard(loop);
116 auto *index = builder.get_loop_index(loop);
117 auto *ext = builder.create_external_ptr(
118 builder.create_arg_load(0, PrimitiveType::i32, true), {index});
119 auto *place_index =
120 builder.create_global_load(builder.create_global_ptr(place, {index}));
121 builder.create_global_store(ext, place_index);
122 }
123
124 kernel_ext = std::make_unique<Kernel>(program, builder.extract_ir(), "ext");
125 kernel_ext->insert_arr_arg(get_data_type<int>(), /*total_dim=*/1, {n});
126 }
127
128 auto ctx_init = kernel_init->make_launch_context();
129 auto ctx_ret = kernel_ret->make_launch_context();
130 auto ctx_ext = kernel_ext->make_launch_context();
131 std::vector<int> ext_arr(n);
132 ctx_ext.set_arg_external_array_with_shape(0, taichi::uint64(ext_arr.data()),
133 n, {n});
134
135 (*kernel_init)(ctx_init);
136 (*kernel_ret)(ctx_ret);
137 std::cout << program.fetch_result<int>(0) << std::endl;
138 (*kernel_ext)(ctx_ext);
139 for (int i = 0; i < n; i++)
140 std::cout << ext_arr[i] << " ";
141 std::cout << std::endl;
142}
143