1 | #include "taichi/program/snode_rw_accessors_bank.h" |
2 | |
3 | #include "taichi/program/program.h" |
4 | |
5 | namespace taichi::lang { |
6 | |
7 | namespace { |
8 | void set_kernel_args(const std::vector<int> &I, |
9 | int num_active_indices, |
10 | Kernel::LaunchContextBuilder *launch_ctx) { |
11 | for (int i = 0; i < num_active_indices; i++) { |
12 | launch_ctx->set_arg_int(i, I[i]); |
13 | } |
14 | } |
15 | } // namespace |
16 | |
17 | SNodeRwAccessorsBank::Accessors SNodeRwAccessorsBank::get(SNode *snode) { |
18 | auto &kernels = snode_to_kernels_[snode]; |
19 | if (kernels.reader == nullptr) { |
20 | kernels.reader = &(program_->get_snode_reader(snode)); |
21 | } |
22 | if (kernels.writer == nullptr) { |
23 | kernels.writer = &(program_->get_snode_writer(snode)); |
24 | } |
25 | return Accessors(snode, kernels, program_); |
26 | } |
27 | |
28 | SNodeRwAccessorsBank::Accessors::Accessors(const SNode *snode, |
29 | const RwKernels &kernels, |
30 | Program *prog) |
31 | : snode_(snode), |
32 | prog_(prog), |
33 | reader_(kernels.reader), |
34 | writer_(kernels.writer) { |
35 | TI_ASSERT(reader_ != nullptr); |
36 | TI_ASSERT(writer_ != nullptr); |
37 | } |
38 | void SNodeRwAccessorsBank::Accessors::write_float(const std::vector<int> &I, |
39 | float64 val) { |
40 | auto launch_ctx = writer_->make_launch_context(); |
41 | set_kernel_args(I, snode_->num_active_indices, &launch_ctx); |
42 | launch_ctx.set_arg_float(snode_->num_active_indices, val); |
43 | prog_->synchronize(); |
44 | (*writer_)(prog_->compile_config(), launch_ctx); |
45 | } |
46 | |
47 | float64 SNodeRwAccessorsBank::Accessors::read_float(const std::vector<int> &I) { |
48 | prog_->synchronize(); |
49 | auto launch_ctx = reader_->make_launch_context(); |
50 | set_kernel_args(I, snode_->num_active_indices, &launch_ctx); |
51 | (*reader_)(prog_->compile_config(), launch_ctx); |
52 | prog_->synchronize(); |
53 | if (arch_uses_llvm(prog_->compile_config().arch)) { |
54 | return reader_->get_struct_ret_float({0}); |
55 | } |
56 | auto ret = reader_->get_ret_float(0); |
57 | return ret; |
58 | } |
59 | |
60 | // for int32 and int64 |
61 | void SNodeRwAccessorsBank::Accessors::write_int(const std::vector<int> &I, |
62 | int64 val) { |
63 | auto launch_ctx = writer_->make_launch_context(); |
64 | set_kernel_args(I, snode_->num_active_indices, &launch_ctx); |
65 | launch_ctx.set_arg_int(snode_->num_active_indices, val); |
66 | prog_->synchronize(); |
67 | (*writer_)(prog_->compile_config(), launch_ctx); |
68 | } |
69 | |
70 | // for int32 and int64 |
71 | void SNodeRwAccessorsBank::Accessors::write_uint(const std::vector<int> &I, |
72 | uint64 val) { |
73 | auto launch_ctx = writer_->make_launch_context(); |
74 | set_kernel_args(I, snode_->num_active_indices, &launch_ctx); |
75 | launch_ctx.set_arg_uint(snode_->num_active_indices, val); |
76 | prog_->synchronize(); |
77 | (*writer_)(prog_->compile_config(), launch_ctx); |
78 | } |
79 | |
80 | int64 SNodeRwAccessorsBank::Accessors::read_int(const std::vector<int> &I) { |
81 | prog_->synchronize(); |
82 | auto launch_ctx = reader_->make_launch_context(); |
83 | set_kernel_args(I, snode_->num_active_indices, &launch_ctx); |
84 | (*reader_)(prog_->compile_config(), launch_ctx); |
85 | prog_->synchronize(); |
86 | if (arch_uses_llvm(prog_->compile_config().arch)) { |
87 | return reader_->get_struct_ret_int({0}); |
88 | } |
89 | auto ret = reader_->get_ret_int(0); |
90 | return ret; |
91 | } |
92 | |
93 | uint64 SNodeRwAccessorsBank::Accessors::read_uint(const std::vector<int> &I) { |
94 | prog_->synchronize(); |
95 | auto launch_ctx = reader_->make_launch_context(); |
96 | set_kernel_args(I, snode_->num_active_indices, &launch_ctx); |
97 | (*reader_)(prog_->compile_config(), launch_ctx); |
98 | prog_->synchronize(); |
99 | if (arch_uses_llvm(prog_->compile_config().arch)) { |
100 | return reader_->get_struct_ret_uint({0}); |
101 | } |
102 | auto ret = reader_->get_ret_uint(0); |
103 | return ret; |
104 | } |
105 | |
106 | } // namespace taichi::lang |
107 | |