1#include "taichi/program/snode_rw_accessors_bank.h"
2
3#include "taichi/program/program.h"
4
5namespace taichi::lang {
6
7namespace {
8void set_kernel_args(const std::vector<int> &I,
9 int num_active_indices,
10 Kernel::LaunchContextBuilder *launch_ctx) {
11 for (int i = 0; i < num_active_indices; i++) {
12 launch_ctx->set_arg_int(i, I[i]);
13 }
14}
15} // namespace
16
17SNodeRwAccessorsBank::Accessors SNodeRwAccessorsBank::get(SNode *snode) {
18 auto &kernels = snode_to_kernels_[snode];
19 if (kernels.reader == nullptr) {
20 kernels.reader = &(program_->get_snode_reader(snode));
21 }
22 if (kernels.writer == nullptr) {
23 kernels.writer = &(program_->get_snode_writer(snode));
24 }
25 return Accessors(snode, kernels, program_);
26}
27
28SNodeRwAccessorsBank::Accessors::Accessors(const SNode *snode,
29 const RwKernels &kernels,
30 Program *prog)
31 : snode_(snode),
32 prog_(prog),
33 reader_(kernels.reader),
34 writer_(kernels.writer) {
35 TI_ASSERT(reader_ != nullptr);
36 TI_ASSERT(writer_ != nullptr);
37}
38void SNodeRwAccessorsBank::Accessors::write_float(const std::vector<int> &I,
39 float64 val) {
40 auto launch_ctx = writer_->make_launch_context();
41 set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
42 launch_ctx.set_arg_float(snode_->num_active_indices, val);
43 prog_->synchronize();
44 (*writer_)(prog_->compile_config(), launch_ctx);
45}
46
47float64 SNodeRwAccessorsBank::Accessors::read_float(const std::vector<int> &I) {
48 prog_->synchronize();
49 auto launch_ctx = reader_->make_launch_context();
50 set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
51 (*reader_)(prog_->compile_config(), launch_ctx);
52 prog_->synchronize();
53 if (arch_uses_llvm(prog_->compile_config().arch)) {
54 return reader_->get_struct_ret_float({0});
55 }
56 auto ret = reader_->get_ret_float(0);
57 return ret;
58}
59
60// for int32 and int64
61void SNodeRwAccessorsBank::Accessors::write_int(const std::vector<int> &I,
62 int64 val) {
63 auto launch_ctx = writer_->make_launch_context();
64 set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
65 launch_ctx.set_arg_int(snode_->num_active_indices, val);
66 prog_->synchronize();
67 (*writer_)(prog_->compile_config(), launch_ctx);
68}
69
70// for int32 and int64
71void SNodeRwAccessorsBank::Accessors::write_uint(const std::vector<int> &I,
72 uint64 val) {
73 auto launch_ctx = writer_->make_launch_context();
74 set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
75 launch_ctx.set_arg_uint(snode_->num_active_indices, val);
76 prog_->synchronize();
77 (*writer_)(prog_->compile_config(), launch_ctx);
78}
79
80int64 SNodeRwAccessorsBank::Accessors::read_int(const std::vector<int> &I) {
81 prog_->synchronize();
82 auto launch_ctx = reader_->make_launch_context();
83 set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
84 (*reader_)(prog_->compile_config(), launch_ctx);
85 prog_->synchronize();
86 if (arch_uses_llvm(prog_->compile_config().arch)) {
87 return reader_->get_struct_ret_int({0});
88 }
89 auto ret = reader_->get_ret_int(0);
90 return ret;
91}
92
93uint64 SNodeRwAccessorsBank::Accessors::read_uint(const std::vector<int> &I) {
94 prog_->synchronize();
95 auto launch_ctx = reader_->make_launch_context();
96 set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
97 (*reader_)(prog_->compile_config(), launch_ctx);
98 prog_->synchronize();
99 if (arch_uses_llvm(prog_->compile_config().arch)) {
100 return reader_->get_struct_ret_uint({0});
101 }
102 auto ret = reader_->get_ret_uint(0);
103 return ret;
104}
105
106} // namespace taichi::lang
107