1#include "taichi/codegen/spirv/kernel_utils.h"
2
3#include <unordered_map>
4
5#include "taichi/program/kernel.h"
6#define TI_RUNTIME_HOST
7#include "taichi/program/context.h"
8#undef TI_RUNTIME_HOST
9
10namespace taichi::lang {
11namespace spirv {
12
13// static
14std::string TaskAttributes::buffers_name(BufferInfo b) {
15 if (b.type == BufferType::Args) {
16 return "Args";
17 }
18 if (b.type == BufferType::Rets) {
19 return "Rets";
20 }
21 if (b.type == BufferType::GlobalTmps) {
22 return "GlobalTmps";
23 }
24 if (b.type == BufferType::Root) {
25 return std::string("Root: ") + std::to_string(b.root_id);
26 }
27 TI_ERROR("unrecognized buffer type");
28}
29
30std::string TaskAttributes::debug_string() const {
31 std::string result;
32 result += fmt::format(
33 "<TaskAttributes name={} advisory_total_num_threads={} "
34 "task_type={} buffers=[ ",
35 name, advisory_total_num_threads, offloaded_task_type_name(task_type));
36 for (auto b : buffer_binds) {
37 result += buffers_name(b.buffer) + " ";
38 }
39 result += "]"; // closes |buffers|
40 // TODO(k-ye): show range_for
41 result += ">";
42 return result;
43}
44
45std::string TaskAttributes::BufferBind::debug_string() const {
46 return fmt::format("<type={} binding={}>",
47 TaskAttributes::buffers_name(buffer), binding);
48}
49
50KernelContextAttributes::KernelContextAttributes(
51 const Kernel &kernel,
52 const DeviceCapabilityConfig *caps)
53 : args_bytes_(0),
54 rets_bytes_(0),
55 extra_args_bytes_(RuntimeContext::extra_args_size) {
56 arr_access.resize(kernel.parameter_list.size(), irpass::ExternalPtrAccess(0));
57 arg_attribs_vec_.reserve(kernel.parameter_list.size());
58 // TODO: We should be able to limit Kernel args and rets to be primitive types
59 // as well but let's leave that as a followup up PR.
60 for (const auto &ka : kernel.parameter_list) {
61 ArgAttributes aa;
62 aa.dtype = ka.get_element_type()->as<PrimitiveType>()->type;
63 const size_t dt_bytes = ka.get_element_size();
64 aa.is_array = ka.is_array;
65 if (aa.is_array) {
66 aa.field_dim = ka.total_dim - ka.get_element_shape().size();
67 aa.element_shape = ka.get_element_shape();
68 }
69 aa.stride = dt_bytes;
70 aa.index = arg_attribs_vec_.size();
71 arg_attribs_vec_.push_back(aa);
72 }
73 for (const auto &kr : kernel.rets) {
74 RetAttributes ra;
75 size_t dt_bytes{0};
76 if (auto tensor_type = kr.dt->cast<TensorType>()) {
77 auto tensor_dtype = tensor_type->get_element_type();
78 TI_ASSERT(tensor_dtype->is<PrimitiveType>());
79 ra.dtype = tensor_dtype->cast<PrimitiveType>()->type;
80 dt_bytes = data_type_size(tensor_dtype);
81 ra.is_array = true;
82 ra.stride = tensor_type->get_num_elements() * dt_bytes;
83 } else {
84 TI_ASSERT(kr.dt->is<PrimitiveType>());
85 ra.dtype = kr.dt->cast<PrimitiveType>()->type;
86 dt_bytes = data_type_size(kr.dt);
87 ra.is_array = false;
88 ra.stride = dt_bytes;
89 }
90 ra.index = ret_attribs_vec_.size();
91 ret_attribs_vec_.push_back(ra);
92 }
93
94 auto arange_args = [](auto *vec, size_t offset, bool is_ret,
95 bool has_buffer_ptr) -> size_t {
96 size_t bytes = offset;
97 for (int i = 0; i < vec->size(); ++i) {
98 auto &attribs = (*vec)[i];
99 const size_t dt_bytes =
100 (attribs.is_array && !is_ret && has_buffer_ptr)
101 ? sizeof(uint64_t)
102 : data_type_size(PrimitiveType::get(attribs.dtype));
103 // Align bytes to the nearest multiple of dt_bytes
104 bytes = (bytes + dt_bytes - 1) / dt_bytes * dt_bytes;
105 attribs.offset_in_mem = bytes;
106 bytes += is_ret ? attribs.stride : dt_bytes;
107 TI_TRACE(
108 " at={} {} offset_in_mem={} stride={}",
109 (*vec)[i].is_array ? (is_ret ? "array" : "vector ptr") : "scalar", i,
110 attribs.offset_in_mem, attribs.stride);
111 }
112 return bytes - offset;
113 };
114
115 TI_TRACE("args:");
116 args_bytes_ = arange_args(
117 &arg_attribs_vec_, 0, false,
118 caps->get(DeviceCapability::spirv_has_physical_storage_buffer));
119 // Align to extra args
120 args_bytes_ = (args_bytes_ + 4 - 1) / 4 * 4;
121
122 TI_TRACE("rets:");
123 rets_bytes_ = arange_args(&ret_attribs_vec_, 0, true, false);
124
125 TI_TRACE("sizes: args={} rets={}", args_bytes(), rets_bytes());
126 TI_ASSERT(has_rets() == (rets_bytes_ > 0));
127}
128
129} // namespace spirv
130} // namespace taichi::lang
131