1#pragma once
2
3#include "taichi/ir/snode.h"
4
5namespace taichi::lang {
6
7class Stmt;
8
9enum AccessFlag : unsigned int {
10 read = 1 << 1,
11 write = 1 << 2,
12 accumulate = 1 << 3
13};
14
15inline AccessFlag operator|(AccessFlag a, AccessFlag b) {
16 return static_cast<AccessFlag>(static_cast<unsigned>(a) |
17 static_cast<unsigned>(b));
18}
19
20inline AccessFlag operator&(AccessFlag a, AccessFlag b) {
21 return static_cast<AccessFlag>(static_cast<unsigned>(a) &
22 static_cast<unsigned>(b));
23}
24
25inline AccessFlag operator|=(AccessFlag &a, AccessFlag &b) {
26 a = a | b;
27 return a;
28}
29
30class ScratchPad {
31 public:
32 // The lowest and highest index in each dimension.
33 struct BoundRange {
34 int low{0};
35 int high{0};
36
37 int range() const {
38 return high - low;
39 }
40
41 TI_IO_DEF(low, high);
42 };
43
44 SNode *snode{nullptr};
45 using AccessFlag = taichi::lang::AccessFlag;
46
47 std::vector<int> coefficients;
48 std::vector<BoundRange> bounds;
49 // pad_size[i] := bounds[i].high - bounds[i].low
50 // TODO: This can be replaced by a function call to bounds[i].range()
51 std::vector<int> pad_size;
52 // block_size[i] := (1 << snode.extractor[i].num_bits)
53 std::vector<int> block_size;
54 bool finalized;
55 int dim;
56 bool empty;
57
58 AccessFlag total_flags;
59 std::vector<AccessFlag> flags;
60 std::vector<std::pair<std::vector<int>, AccessFlag>> accesses;
61
62 ScratchPad() = default;
63
64 explicit ScratchPad(SNode *snode) : snode(snode) {
65 TI_ASSERT(snode != nullptr);
66 dim = snode->num_active_indices;
67 coefficients.resize(dim);
68 bounds.resize(dim);
69 pad_size.resize(dim);
70
71 finalized = false;
72
73 total_flags = AccessFlag(0);
74 BoundRange init_bound;
75 init_bound.low = std::numeric_limits<int>::max();
76 init_bound.high = std::numeric_limits<int>::min();
77 std::fill(bounds.begin(), bounds.end(), init_bound);
78 empty = false;
79 }
80
81 void access(const std::vector<int> &coeffs,
82 const std::vector<int> &indices,
83 AccessFlag flags) {
84 TI_ASSERT(!finalized);
85 empty = true;
86 TI_ASSERT((int)indices.size() == dim);
87 for (int i = 0; i < dim; i++) {
88 coefficients[i] = coeffs[i];
89 bounds[i].low = std::min(bounds[i].low, indices[i]);
90 bounds[i].high = std::max(bounds[i].high, indices[i] + 1);
91 pad_size[i] = bounds[i].range();
92 }
93 accesses.emplace_back(indices, flags);
94 }
95
96 void finalize() {
97 int size = 1;
98 for (int i = 0; i < dim; i++) {
99 size *= pad_size[i];
100 }
101 flags.resize(size);
102
103 block_size.resize(dim);
104 for (int i = 0; i < dim; i++) {
105 block_size[i] =
106 snode->parent->extractors[snode->physical_index_position[i]].shape;
107 TI_ASSERT(bounds[i].low != std::numeric_limits<int>::max());
108 TI_ASSERT(bounds[i].high != std::numeric_limits<int>::min());
109 }
110
111 finalized = true;
112 flags = std::vector<AccessFlag>(pad_size_linear(), AccessFlag(0));
113
114 for (auto &acc : accesses) {
115 total_flags |= acc.second;
116 flags[linearized_index(acc.first)] |= acc.second;
117 }
118 }
119
120 void codegen_cpu() {
121 }
122
123 std::string name() {
124 return snode->node_type_name + "_scratch_pad";
125 }
126
127 bool is_pure() const {
128 return bit::is_power_of_two((unsigned)total_flags);
129 }
130
131 int pad_size_linear() {
132 TI_ASSERT(finalized);
133 int s = 1;
134 for (int i = 0; i < dim; i++) {
135 s *= pad_size[i];
136 }
137 return s;
138 }
139
140 int block_size_linear() {
141 TI_ASSERT(finalized);
142 int s = 1;
143 for (int i = 0; i < dim; i++) {
144 s *= block_size[i];
145 }
146 return s;
147 }
148
149 int linearized_index(const std::vector<int> &indices) {
150 int ret = 0;
151 TI_ASSERT(finalized);
152 for (int i = 0; i < dim; i++) {
153 ret *= (bounds[i].high - bounds[i].low);
154 ret += indices[i] - bounds[i].low;
155 }
156 return ret;
157 }
158
159 std::string extract_offset(std::string var, int d) const {
160 auto div = 1;
161 for (int i = d + 1; i < dim; i++) {
162 div *= pad_size[i];
163 }
164 return fmt::format("({} / {} % {} + {})", var, div, pad_size[d],
165 bounds[d].low);
166 }
167
168 /*
169 std::string array_dimensions_str() const {
170 std::string ret = "";
171 for (int i = 0; i < dim; i++) {
172 ret += fmt::format("[{}]", bounds[1][i] - bounds[0][i]);
173 }
174 return ret;
175 }
176 */
177
178 std::string global_to_linearized_local(const std::vector<Stmt *> &loop_vars,
179 const std::vector<Stmt *> &indices);
180};
181
182inline int div_floor(int a, int b) {
183 return a >= 0 ? a / b : (a - b + 1) / b;
184}
185
186class ScratchPads {
187 public:
188 std::map<SNode *, ScratchPad> pads;
189
190 using AccessFlag = ScratchPad::AccessFlag;
191
192 void insert(SNode *snode) {
193 if (pads.find(snode) == pads.end()) {
194 pads.emplace(std::piecewise_construct, std::forward_as_tuple(snode),
195 std::forward_as_tuple(snode));
196 } else {
197 TI_ERROR("ScratchPad for {} already exists.", snode->node_type_name);
198 }
199 }
200
201 void access(SNode *snode,
202 const std::vector<int> &coeffs,
203 const std::vector<int> &indices,
204 AccessFlag flags) {
205 TI_ASSERT(snode != nullptr);
206 if (pads.find(snode) == pads.end())
207 return;
208 pads.find(snode)->second.access(coeffs, indices, flags);
209 /*
210 if (snode->parent->type != SNodeType::root) {
211 auto parent_indices = indices;
212 for (int i = 0; i < snode->parent->num_active_indices; i++) {
213 int block_dim =
214 snode->parent->extractors[snode->parent->physical_index_position[i]]
215 .dimension;
216 parent_indices[i] = div_floor(parent_indices[i], block_dim);
217 }
218 access(snode->parent, parent_indices, flags);
219 }
220 */
221 }
222
223 void finalize() {
224 for (auto &pad : pads) {
225 pad.second.finalize();
226 }
227 }
228
229 void generate_address_code(SNode *snode, const std::vector<int> &indices) {
230 if (pads.find(snode) != pads.end()) {
231 auto &pad = pads[snode];
232 int offset = 0;
233 for (int i = 0; i < pad.dim; i++) {
234 offset = offset + (indices[i] - pad.bounds[i].low);
235 if (i > 0)
236 offset = offset * pad.pad_size[i - 1];
237 }
238 } else if (pads.find(snode->parent) != pads.end()) {
239 } else {
240 TI_NOT_IMPLEMENTED
241 }
242 }
243
244 void print() {
245 for (auto &it : pads) {
246 TI_P(it.first->node_type_name);
247 TI_P(it.second.bounds);
248 }
249 }
250
251 bool has(SNode *snode) {
252 return pads.find(snode) != pads.end();
253 }
254
255 ScratchPad &get(SNode *snode) {
256 TI_ASSERT(pads.find(snode) != pads.end());
257 return pads[snode];
258 }
259};
260
261} // namespace taichi::lang
262