1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tir/contrib/ethosu/passes.cc |
22 | * |
23 | * \brief Passes used in TIR lowering for the microNPU compiler. |
24 | */ |
25 | #include <tvm/tir/builtin.h> |
26 | #include <tvm/tir/function.h> |
27 | #include <tvm/tir/op.h> |
28 | #include <tvm/tir/stmt_functor.h> |
29 | #include <tvm/tir/transform.h> |
30 | |
31 | #include <algorithm> |
32 | #include <unordered_map> |
33 | #include <unordered_set> |
34 | |
35 | namespace tvm { |
36 | |
37 | /*! |
38 | * \brief The maximum number of movements allowed for a copy in the CopyComputeReordering pass. |
39 | */ |
40 | constexpr const char* kCopyComputeReorderingMaxCopyMovements = |
41 | "tir.contrib.ethos-u.copy_compute_reordering_max_copy_movements" ; |
42 | TVM_REGISTER_PASS_CONFIG_OPTION(kCopyComputeReorderingMaxCopyMovements, Integer); |
43 | |
44 | /*! |
45 | * \brief Whether to reorder copies and computes based on cycle count. |
46 | */ |
47 | constexpr const char* kCopyComputeReorderingReorderByCycles = |
48 | "tir.contrib.ethos-u.copy_compute_reordering_reorder_by_cycles" ; |
49 | TVM_REGISTER_PASS_CONFIG_OPTION(kCopyComputeReorderingReorderByCycles, Bool); |
50 | |
51 | namespace tir { |
52 | namespace contrib { |
53 | namespace ethosu { |
54 | |
55 | namespace { |
56 | |
57 | /*! Returns the arguments of the given statement */ |
58 | Array<PrimExpr> GetStmtArgs(const Stmt& stmt) { |
59 | auto attr{stmt.as<AttrStmtNode>()}; |
60 | Stmt eval_stmt{attr ? attr->body : stmt}; |
61 | auto eval{eval_stmt.as<EvaluateNode>()}; |
62 | ICHECK(eval) << "Expected statement to be an evaluate node, but was " << eval_stmt->GetTypeKey(); |
63 | auto call{eval->value.as<CallNode>()}; |
64 | ICHECK(call) << "Expected expression to be a call node, but was " << eval->value->GetTypeKey(); |
65 | return call->args; |
66 | } |
67 | |
68 | enum class StmtType { global_copy, local_copy, compute }; |
69 | |
70 | /*! Returns the type of the given statement */ |
71 | StmtType GetStmtType(const Stmt& stmt) { |
72 | Array<PrimExpr> args{GetStmtArgs(stmt)}; |
73 | if (args[0].as<StringImmNode>()->value == "ethosu_copy" ) { |
74 | if (args[3].as<BufferLoadNode>()->buffer.scope() == "global" ) { |
75 | return StmtType::global_copy; |
76 | } else { |
77 | return StmtType::local_copy; |
78 | } |
79 | } |
80 | return StmtType::compute; |
81 | } |
82 | /*! Returns the buffer read my the given copy statement */ |
83 | Buffer GetCopyReadBuffer(const Stmt& stmt) { |
84 | Array<PrimExpr> args{GetStmtArgs(stmt)}; |
85 | return args[1].as<BufferLoadNode>()->buffer; |
86 | } |
87 | |
88 | /*! Returns the buffer written my the given copy statement */ |
89 | Buffer GetCopyWriteBuffer(const Stmt& stmt) { |
90 | Array<PrimExpr> args{GetStmtArgs(stmt)}; |
91 | return args[3].as<BufferLoadNode>()->buffer; |
92 | } |
93 | |
94 | /*! Returns the length of the given copy statement */ |
95 | int64_t GetCopyLength(const Stmt& stmt) { |
96 | Array<PrimExpr> args{GetStmtArgs(stmt)}; |
97 | return args[2].as<IntImmNode>()->value; |
98 | } |
99 | |
100 | /*! Returns the cycles of the given statement */ |
101 | int64_t GetStmtCycles(const Stmt& stmt) { |
102 | auto attr{stmt.as<AttrStmtNode>()}; |
103 | if (attr && attr->attr_key == "pragma_compute_cycles_hint" ) { |
104 | int64_t cycles{Downcast<Integer>(attr->value)->value}; |
105 | return cycles; |
106 | } |
107 | return 0; |
108 | } |
109 | } // namespace |
110 | |
111 | /*! |
112 | * \brief This mutator moves allocates to the top of the body of the main |
113 | * function. |
114 | * |
115 | * Note: This pass can currently only be run in conjunction with the |
116 | * LowerToTIR() pass as it expects a single primitive function called |
117 | * "main" that is being offloaded to the NPU. |
118 | * |
119 | * For example, |
120 | * Before: |
121 | * allocate { |
122 | * extern_call(...) |
123 | * allocate { |
124 | * extern_call(...) |
125 | * } |
126 | * } |
127 | * |
128 | * After: |
129 | * allocate { |
130 | * allocate { |
131 | * extern_call(...) |
132 | * extern_call(...) |
133 | * } |
134 | * } |
135 | */ |
136 | class HoistAllocatesMutator : public StmtExprMutator { |
137 | public: |
138 | HoistAllocatesMutator() {} |
139 | |
140 | PrimFunc operator()(PrimFunc main_func) { |
141 | Stmt new_main_func_body = SeqStmt::Flatten(this->VisitStmt(main_func->body)); |
142 | |
143 | // Write all allocates that were removed in reverse order |
144 | for (auto it = allocates_.rbegin(); it != allocates_.rend(); it++) { |
145 | Allocate current_alloc = *it; |
146 | if (it != allocates_.rbegin()) { |
147 | new_main_func_body = SeqStmt({new_main_func_body}); |
148 | } |
149 | new_main_func_body = |
150 | Allocate(current_alloc->buffer_var, current_alloc->dtype, current_alloc->extents, |
151 | current_alloc->condition, new_main_func_body, current_alloc->annotations, |
152 | current_alloc->span); |
153 | } |
154 | |
155 | PrimFunc new_main_func = PrimFunc(main_func->params, new_main_func_body, main_func->ret_type, |
156 | main_func->buffer_map, main_func->attrs); |
157 | return new_main_func; |
158 | } |
159 | |
160 | private: |
161 | Stmt VisitStmt_(const AllocateNode* op) override { |
162 | allocates_.push_back(GetRef<Allocate>(op)); |
163 | return VisitStmt(op->body); |
164 | } |
165 | |
166 | /*! A stack to store allocates as they are visited. */ |
167 | std::vector<Allocate> allocates_; |
168 | }; |
169 | |
170 | /*! |
171 | * \brief A pass to hoist allocate nodes to the top of the body of the main function. |
172 | * |
173 | * \return tvm::transform::Pass |
174 | */ |
175 | tvm::transform::Pass HoistAllocates() { |
176 | auto pass_func = [=](PrimFunc f, IRModule mod, tvm::transform::PassContext ctx) { |
177 | ICHECK(mod->GetGlobalVars().size() == 1 && mod->ContainGlobalVar("main" )) |
178 | << "Expected a single primitive function called 'main'. Please run the HoistAllocates pass " |
179 | "in conjunction with the LowerToTIR() pass." ; |
180 | return HoistAllocatesMutator()(f); |
181 | }; |
182 | return tvm::tir::transform::CreatePrimFuncPass(pass_func, 0, "tir.contrib.ethos-u.HoistAllocates" , |
183 | {}); |
184 | } |
185 | |
186 | TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.HoistAllocates" ).set_body_typed(HoistAllocates); |
187 | |
188 | /*! |
189 | * \brief Reorders copy and compute nodes in such a way that independent DMA copies |
190 | * and computes happen in parallel. |
191 | * Copies to buffers with local scope are not reordered since they copy LUT |
192 | * into the SHRAM and that already happens in parallel with copying weights into |
193 | * the weights encoder. |
194 | */ |
195 | class CopyComputeReorderingMutator : public StmtExprMutator { |
196 | public: |
197 | explicit CopyComputeReorderingMutator(int max_copy_movements, bool reorder_by_cycles) |
198 | : _max_copy_movements{max_copy_movements}, _reorder_by_cycles{reorder_by_cycles} {} |
199 | |
200 | PrimFunc operator()(PrimFunc main_func) { |
201 | if (_max_copy_movements > 0) { |
202 | auto prim_func_node{main_func.CopyOnWrite()}; |
203 | prim_func_node->body = this->VisitStmt(main_func->body); |
204 | return GetRef<PrimFunc>(prim_func_node); |
205 | } |
206 | return main_func; |
207 | } |
208 | |
209 | private: |
210 | // A structure to hold a compute op with the corresponding weights/bias copy and LUT copy |
211 | struct OpWithCopies { |
212 | Stmt compute_op{}; |
213 | Stmt global_copy{}; |
214 | Stmt local_copy{}; |
215 | }; |
216 | |
217 | Stmt VisitStmt_(const SeqStmtNode* op) override { |
218 | if (op->size() <= 1) { |
219 | return StmtExprMutator::VisitStmt_(op); |
220 | } |
221 | |
222 | auto seq_stmt{GetRef<SeqStmt>(op)}; |
223 | std::vector<Stmt> new_seq(seq_stmt->size()); |
224 | std::copy(seq_stmt->seq.begin(), seq_stmt->seq.end(), new_seq.begin()); |
225 | |
226 | // Reorder the copies and computes based on the cycle count |
227 | if (_reorder_by_cycles) { |
228 | // We can't hide the first copy, so ignore it for the purpose of hiding copies |
229 | Stmt first_copy{}; |
230 | if (stmt_is_global_copy(new_seq[0]) || |
231 | (stmt_is_local_copy(new_seq[0]) && stmt_is_global_copy(new_seq[1]))) { |
232 | auto copy_position = stmt_is_global_copy(new_seq[0]) ? 0 : 1; |
233 | first_copy = new_seq[copy_position]; |
234 | new_seq.erase(new_seq.begin() + copy_position); |
235 | } |
236 | |
237 | // Build up a list of cells with the compute op and the copy ops that directly preceed it |
238 | std::vector<OpWithCopies> ops{}; |
239 | for (size_t idx = 0; idx < new_seq.size(); ++idx) { |
240 | if (stmt_is_compute_op(new_seq[idx])) { |
241 | OpWithCopies new_op; |
242 | new_op.compute_op = new_seq[idx]; |
243 | if (idx > 0) { |
244 | auto prev_op = new_seq[idx - 1]; |
245 | if (!stmt_is_compute_op(prev_op)) { |
246 | if (stmt_is_local_copy(prev_op)) { |
247 | new_op.local_copy = prev_op; |
248 | } else { |
249 | new_op.global_copy = prev_op; |
250 | } |
251 | if (idx > 1) { |
252 | auto prev_prev_op = new_seq[idx - 2]; |
253 | if (!stmt_is_compute_op(prev_prev_op)) { |
254 | if (stmt_is_local_copy(prev_prev_op)) { |
255 | new_op.local_copy = prev_prev_op; |
256 | } else { |
257 | new_op.global_copy = prev_prev_op; |
258 | } |
259 | } |
260 | } |
261 | } |
262 | } |
263 | ops.push_back(new_op); |
264 | } |
265 | } |
266 | |
267 | // Move the global copies up by one. If in general the computes take longer than the copies, |
268 | // that should be good enough |
269 | for (size_t idx = 1; idx < ops.size(); ++idx) { |
270 | if (ops[idx].global_copy.as<AttrStmtNode>()) { |
271 | ops[idx - 1].global_copy = ops[idx].global_copy; |
272 | ops[idx].global_copy = {}; |
273 | } |
274 | } |
275 | |
276 | // If there are long copies, try to hide them further |
277 | for (size_t idx = ops.size() - 1; idx > 0; --idx) { |
278 | if (ops[idx].global_copy.as<AttrStmtNode>()) { |
279 | // Check whether the copy is hidden |
280 | int64_t copy_cycles{GetStmtCycles(ops[idx].global_copy)}; |
281 | int64_t compute_cycles{GetStmtCycles(ops[idx].compute_op)}; |
282 | bool is_hidden = compute_cycles >= copy_cycles; |
283 | |
284 | // If the previous compute op is not already hiding another copy, move the copy back, so |
285 | // that it would be hidden by multiple computes |
286 | while (!is_hidden && !ops[idx - 1].global_copy.as<AttrStmtNode>() && (idx > 0)) { |
287 | int64_t new_compute_cycles{GetStmtCycles(ops[idx - 1].compute_op)}; |
288 | ops[idx - 1].global_copy = ops[idx].global_copy; |
289 | ops[idx].global_copy = {}; |
290 | compute_cycles += new_compute_cycles; |
291 | is_hidden = compute_cycles >= copy_cycles; |
292 | --idx; |
293 | } |
294 | } |
295 | } |
296 | |
297 | // Reconstruct the op sequence from the vector of OpWithCopies |
298 | new_seq.clear(); |
299 | if (first_copy.as<AttrStmtNode>()) { |
300 | new_seq.push_back(first_copy); |
301 | } |
302 | for (auto& op : ops) { |
303 | if (op.global_copy.as<AttrStmtNode>()) { |
304 | new_seq.push_back(op.global_copy); |
305 | } |
306 | if (op.local_copy.as<EvaluateNode>()) { |
307 | new_seq.push_back(op.local_copy); |
308 | } |
309 | if (op.compute_op.as<AttrStmtNode>()) { |
310 | new_seq.push_back(op.compute_op); |
311 | } |
312 | } |
313 | } else { |
314 | // Each copy statement to a buffer with global scope is moved up |
315 | // at most `_max_copy_movements` times. |
316 | for (size_t index = 0; index < new_seq.size(); ++index) { |
317 | if (GetStmtType(new_seq[index]) == StmtType::global_copy) { |
318 | int lower = std::max(0, static_cast<int>(index) - _max_copy_movements); |
319 | for (int i = index; i > lower && (GetStmtType(new_seq[i - 1]) == StmtType::compute); |
320 | --i) { |
321 | std::swap(new_seq[i - 1], new_seq[i]); |
322 | } |
323 | } |
324 | } |
325 | } |
326 | |
327 | auto seq_stmt_node{CopyOnWrite(op)}; |
328 | seq_stmt_node->seq = std::move(new_seq); |
329 | return Stmt{seq_stmt_node}; |
330 | } |
331 | |
332 | bool stmt_is_global_copy(const Stmt& stmt) { return GetStmtType(stmt) == StmtType::global_copy; } |
333 | |
334 | bool stmt_is_local_copy(const Stmt& stmt) { return GetStmtType(stmt) == StmtType::local_copy; } |
335 | |
336 | bool stmt_is_compute_op(const Stmt& stmt) { return GetStmtType(stmt) == StmtType::compute; } |
337 | |
338 | /*! The maximum number of movements allowed for a copy. */ |
339 | int _max_copy_movements; |
340 | /*! Whether we use the cycle hint to determine the reordering. */ |
341 | bool _reorder_by_cycles; |
342 | }; |
343 | |
344 | /*! |
345 | * \brief A pass to reorder copy and compute nodes in such a way that independent DMA copies |
346 | * and computes happen in parallel. If reorder_by_cycles is set, we will ignore the |
347 | * max_copy_movements value. |
348 | * |
349 | * \param max_copy_movements: The maximum number of movements allowed for a copy. |
350 | * If None, the pass context option tir.contrib.ethos-u.copy_compute_reordering_max_copy_movements |
351 | * is used if provided, otherwise the default value will be 1. |
352 | * |
353 | * \param reorder_by_cycles: Whether to reorder copies and computes by cycles. |
354 | * If None, the pass context option tir.contrib.ethos-u.copy_compute_reordering_reorder_by_cycles |
355 | * is used if provided, otherwise the default value will be False. If the value is True, |
356 | * max_copy_movements will be ignored. |
357 | * \return tvm::transform::Pass |
358 | */ |
359 | tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements, |
360 | Optional<Bool> reorder_by_cycles) { |
361 | auto pass_func = [=](PrimFunc f, IRModule mod, tvm::transform::PassContext ctx) { |
362 | ICHECK(mod->GetGlobalVars().size() == 1 && mod->ContainGlobalVar("main" )) |
363 | << "Expected a single primitive function called 'main'. Please run the " |
364 | "CopyComputeReordering " |
365 | "pass in conjunction with the LowerToTIR() pass." ; |
366 | |
367 | auto copy_movements = max_copy_movements.value_or( |
368 | ctx->GetConfig(kCopyComputeReorderingMaxCopyMovements, Integer(1)).value()); |
369 | auto reorder = reorder_by_cycles.value_or( |
370 | ctx->GetConfig(kCopyComputeReorderingReorderByCycles, Bool(false)).value()); |
371 | return CopyComputeReorderingMutator(copy_movements.IntValue(), reorder)(f); |
372 | }; |
373 | return tvm::tir::transform::CreatePrimFuncPass(pass_func, 0, |
374 | "tir.contrib.ethos-u.CopyComputeReordering" , {}); |
375 | } |
376 | |
377 | TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering" ) |
378 | .set_body_typed(CopyComputeReordering); |
379 | |
380 | /*! |
381 | * \brief This mutator removes all allocates. |
382 | */ |
383 | class RemoveAllocatesMutator : public StmtExprMutator { |
384 | public: |
385 | PrimFunc operator()(PrimFunc main_func) { |
386 | auto prim_func_node{main_func.CopyOnWrite()}; |
387 | prim_func_node->body = this->VisitStmt(main_func->body); |
388 | return GetRef<PrimFunc>(prim_func_node); |
389 | } |
390 | |
391 | private: |
392 | Stmt VisitStmt_(const AllocateNode* op) override { return VisitStmt(op->body); } |
393 | }; |
394 | |
395 | /*! |
396 | * \brief This extractor collects information used by the MergeConstantsMutator |
397 | */ |
398 | class : public StmtExprVisitor { |
399 | public: |
400 | class { |
401 | public: |
402 | /*! A stack to store allocates as they are visited. */ |
403 | std::vector<Allocate> {}; |
404 | |
405 | /*! A list that contains in the i-th position the write buffer of the i-th statement |
406 | * if that statement is a copy to a buffer with global scope */ |
407 | std::vector<Optional<Buffer>> {}; |
408 | |
409 | /*! Maps a copy's write buffer to an index representing the |
410 | * new buffer and an offset in that buffer */ |
411 | std::unordered_map<const BufferNode*, std::pair<int /* new buffer index */, int /* offset */>> |
412 | {}; |
413 | |
414 | /*! Maps an index representing a new buffer to the length of that buffer */ |
415 | std::unordered_map<int /* new buffer index */, int /* length */> {}; |
416 | |
417 | /*! Maps an index representing a new buffer to the cycless needed to copy that buffer */ |
418 | std::unordered_map<int /* new buffer index */, int64_t> {}; |
419 | }; |
420 | |
421 | Info (PrimFunc main_func) { |
422 | this->VisitStmt(main_func->body); |
423 | return std::move(_info); |
424 | } |
425 | |
426 | private: |
427 | /*! The information collected by this extractor */ |
428 | Info {}; |
429 | |
430 | void (const AllocateNode* op) override { |
431 | _info.allocates.push_back(GetRef<Allocate>(op)); |
432 | VisitStmt(op->body); |
433 | } |
434 | |
435 | void (const SeqStmtNode* op) override { |
436 | if (op->size() <= 1) { |
437 | StmtExprVisitor::VisitStmt_(op); |
438 | return; |
439 | } |
440 | |
441 | auto seq_stmt{GetRef<SeqStmt>(op)}; |
442 | for (size_t i = 0; i < seq_stmt.size(); ++i) { |
443 | Stmt stmt{seq_stmt[i]}; |
444 | switch (GetStmtType(stmt)) { |
445 | case StmtType::global_copy: { |
446 | Buffer write_buffer{GetCopyWriteBuffer(stmt)}; |
447 | _info.copy_write_buffers.push_back(write_buffer); |
448 | _info.old_to_new_write_buffer[write_buffer.as<BufferNode>()] = std::make_pair(-1, -1); |
449 | break; |
450 | } |
451 | case StmtType::local_copy: { |
452 | _info.copy_write_buffers.push_back(Optional<Buffer>{}); |
453 | break; |
454 | } |
455 | case StmtType::compute: { |
456 | _info.copy_write_buffers.push_back(Optional<Buffer>{}); |
457 | std::vector<Buffer> buffers{GetCopiedBuffersUsedByStmt(stmt)}; |
458 | if (buffers.empty()) { |
459 | continue; |
460 | } |
461 | _info.new_buffers_length[i] = 0; |
462 | for (Buffer buffer : buffers) { |
463 | for (size_t j{i - 1}; j >= 0; --j) { |
464 | if (_info.copy_write_buffers[j] == buffer) { |
465 | _info.old_to_new_write_buffer[buffer.as<BufferNode>()] = |
466 | std::make_pair(i, _info.new_buffers_length[i]); |
467 | _info.new_buffers_length[i] += GetCopyLength(seq_stmt[j]); |
468 | _info.cycless[i] += GetStmtCycles(seq_stmt[j]); |
469 | break; |
470 | } |
471 | } |
472 | } |
473 | break; |
474 | } |
475 | } |
476 | } |
477 | } |
478 | |
479 | /*! Get all buffers written by copies and used by a given statement */ |
480 | std::vector<Buffer> (const Stmt& stmt) { |
481 | std::vector<Buffer> buffers{}; |
482 | for (PrimExpr arg : GetStmtArgs(stmt)) { |
483 | if (auto buffer_load = arg.as<BufferLoadNode>()) { |
484 | Buffer buffer{buffer_load->buffer}; |
485 | // Check if the buffer has already been added |
486 | if (std::find(buffers.begin(), buffers.end(), buffer) == buffers.end()) { |
487 | // Check if the buffer is copied |
488 | if (_info.old_to_new_write_buffer.count(buffer.as<BufferNode>())) { |
489 | buffers.push_back(buffer); |
490 | } |
491 | } |
492 | } |
493 | } |
494 | return buffers; |
495 | } |
496 | }; |
497 | |
498 | /*! |
499 | * \brief This mutator looks for the constants used by each compute operator |
500 | * and merges them into a single buffer. |
501 | * Constants written to a buffer with local scope are not merged. |
502 | */ |
503 | class MergeConstantsMutator : public StmtExprMutator { |
504 | public: |
505 | explicit (MergeConstantsInfoExtractor::Info info) : _info{std::move(info)} {} |
506 | |
507 | PrimFunc operator()(PrimFunc main_func, const Map<IntImm, runtime::NDArray>& const_dict) { |
508 | // Rewrite |
509 | Stmt new_body = RewritePrimFuncBody(main_func->body); |
510 | std::unordered_set<const VarNode*> params_to_delete{}; |
511 | Map<Var, Buffer> new_buffer_map{MakeNewBufferMap(main_func->buffer_map, ¶ms_to_delete)}; |
512 | Array<Var> new_params{MakeNewParams(main_func->params, params_to_delete)}; |
513 | |
514 | // Make the new const dict |
515 | Array<Array<IntImm>> args_to_merge{GetArgsToMerge(main_func->buffer_map, main_func->params)}; |
516 | Map<IntImm, Array<IntImm>> buffers_to_merge{ |
517 | GetArgsToMergeWithoutArgsNotInConstDict(args_to_merge, const_dict)}; |
518 | Map<IntImm, runtime::NDArray> new_const_dict{MakeNewConstDict(buffers_to_merge, const_dict)}; |
519 | |
520 | // Make the new prim func |
521 | auto prim_func_node{main_func.CopyOnWrite()}; |
522 | prim_func_node->body = std::move(new_body); |
523 | prim_func_node->buffer_map = std::move(new_buffer_map); |
524 | prim_func_node->params = std::move(new_params); |
525 | PrimFunc f{GetRef<PrimFunc>(prim_func_node)}; |
526 | |
527 | // Add the new const dict as an attribute |
528 | f = WithAttr(std::move(f), "ethos-u.const_dict" , new_const_dict); |
529 | |
530 | return f; |
531 | } |
532 | |
533 | private: |
534 | /*! The information collected by the MergeConstantsInfoExtractor */ |
535 | MergeConstantsInfoExtractor::Info _info; |
536 | |
537 | /*! Maps an index representing a new buffer to the new buffer */ |
538 | std::unordered_map<int /* new buffer index */, Buffer> new_buffers{}; |
539 | |
540 | /*! Maps a copy's read buffer to the new copy's read buffer */ |
541 | std::unordered_map<const BufferNode*, Buffer> old_to_new_read_buffers{}; |
542 | |
543 | /*! Maps an index representing a new buffer to the list of buffers to be merged in the new buffer |
544 | */ |
545 | std::unordered_map<int /* new buffer index */, std::vector<Buffer>> buffers_to_merge{}; |
546 | |
547 | /*! A set of buffers to delete */ |
548 | std::unordered_set<const BufferNode*> buffers_to_delete{}; |
549 | |
550 | Stmt RewritePrimFuncBody(Stmt body) { |
551 | std::unordered_map<const VarNode*, Allocate> var_to_allocate{}; |
552 | |
553 | // Rewrite old allocates |
554 | std::unordered_set<const VarNode*> buffer_vars{GetVarsForWrittenCopyBuffers()}; |
555 | for (auto it{_info.allocates.rbegin()}; it != _info.allocates.rend(); ++it) { |
556 | Allocate alloc{*it}; |
557 | var_to_allocate[alloc->buffer_var.get()] = alloc; |
558 | if (buffer_vars.count(alloc->buffer_var.as<VarNode>()) == 0) { |
559 | body = Allocate(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->condition, body, |
560 | alloc->annotations, alloc->span); |
561 | } |
562 | } |
563 | |
564 | // Rewrite new allocates |
565 | for (auto it{_info.copy_write_buffers.rbegin()}; it != _info.copy_write_buffers.rend(); ++it) { |
566 | if (Optional<Buffer> buffer_opt = *it) { |
567 | Buffer old_write_buffer{buffer_opt.value()}; |
568 | int new_buffer_index{ |
569 | _info.old_to_new_write_buffer[old_write_buffer.as<BufferNode>()].first}; |
570 | |
571 | // Check if the allocate has already been created |
572 | if (new_buffers.count(new_buffer_index) == 0) { |
573 | BufferNode* new_buffer{old_write_buffer.CopyOnWrite()}; |
574 | new_buffer->shape = {_info.new_buffers_length[new_buffer_index]}; |
575 | |
576 | new_buffers[new_buffer_index] = GetRef<Buffer>(new_buffer); |
577 | |
578 | Allocate old_allocate{var_to_allocate[old_write_buffer->data.get()]}; |
579 | body = Allocate(new_buffer->data, new_buffer->dtype, new_buffer->shape, tir::const_true(), |
580 | body, old_allocate->annotations, old_allocate->span); |
581 | } |
582 | } |
583 | } |
584 | |
585 | // Rewrite operators |
586 | return this->VisitStmt(body); |
587 | } |
588 | |
589 | Stmt VisitStmt_(const AllocateNode* op) override { |
590 | auto allocate{CopyOnWrite(op)}; |
591 | allocate->body = this->VisitStmt(op->body); |
592 | return Stmt(allocate); |
593 | } |
594 | |
595 | Stmt VisitStmt_(const SeqStmtNode* op) override { |
596 | if (op->size() <= 1) { |
597 | return StmtExprMutator::VisitStmt_(op); |
598 | } |
599 | |
600 | Array<Stmt> new_seq{}; |
601 | SeqStmt seq_stmt{GetRef<SeqStmt>(op)}; |
602 | for (size_t i{0}; i < seq_stmt.size(); ++i) { |
603 | Stmt stmt{seq_stmt[i]}; |
604 | |
605 | switch (GetStmtType(stmt)) { |
606 | case StmtType::global_copy: { |
607 | Buffer old_write_buffer{_info.copy_write_buffers[i].value()}; |
608 | std::pair<int, int> pair{ |
609 | _info.old_to_new_write_buffer[old_write_buffer.as<BufferNode>()]}; |
610 | int new_buffer_index{pair.first}; |
611 | int new_buffer_offset{pair.second}; |
612 | UpdateBuffersToMergeAndDelete(stmt, new_buffer_index, new_buffer_offset); |
613 | |
614 | if (!IsCopyToBeDeleted(new_buffer_offset)) { |
615 | Optional<PrimExpr> cycless{GetMergedCycles(new_buffer_index)}; |
616 | new_seq.push_back(MakeNewStmt( |
617 | stmt, MakeNewCopyArgs(stmt, old_write_buffer, new_buffer_index), cycless)); |
618 | } |
619 | break; |
620 | } |
621 | case StmtType::local_copy: { |
622 | new_seq.push_back(stmt); |
623 | break; |
624 | } |
625 | case StmtType::compute: { |
626 | new_seq.push_back(MakeNewStmt(stmt, MakeNewComputeArgs(stmt))); |
627 | break; |
628 | } |
629 | } |
630 | } |
631 | return SeqStmt(new_seq, op->span); |
632 | } |
633 | |
634 | /*! Returns the variables of the buffers written by copies */ |
635 | std::unordered_set<const VarNode*> GetVarsForWrittenCopyBuffers() { |
636 | std::unordered_set<const VarNode*> buffer_vars{}; |
637 | std::transform(_info.old_to_new_write_buffer.begin(), _info.old_to_new_write_buffer.end(), |
638 | std::inserter(buffer_vars, buffer_vars.begin()), |
639 | [](std::pair<const BufferNode*, std::pair<int, int>> pair) -> const VarNode* { |
640 | return pair.first->data.as<VarNode>(); |
641 | }); |
642 | return buffer_vars; |
643 | } |
644 | |
645 | /*! Returns the cycles of the new buffer at the given index */ |
646 | Optional<PrimExpr> GetMergedCycles(int new_buffer_index) { |
647 | auto it = _info.cycless.find(new_buffer_index); |
648 | if (it != _info.cycless.end()) { |
649 | return Integer(it->second); |
650 | } |
651 | return Optional<PrimExpr>{}; |
652 | } |
653 | |
654 | /*! Returns true if a copy must be deleted, false otherwise */ |
655 | bool IsCopyToBeDeleted(int new_buffer_offset) { return new_buffer_offset > 0; } |
656 | |
657 | Array<PrimExpr> MakeNewCopyArgs(const Stmt& stmt, const Buffer& old_write_buffer, |
658 | int new_buffer_index) { |
659 | Array<PrimExpr> args{GetStmtArgs(stmt)}; |
660 | int new_length{_info.new_buffers_length[new_buffer_index]}; |
661 | |
662 | Array<PrimExpr> new_args{}; |
663 | for (size_t i = 0; i < args.size(); ++i) { |
664 | switch (i) { |
665 | case 1: /* read_address */ { |
666 | auto buffer_load = args[1].as<BufferLoadNode>(); |
667 | Buffer buffer{buffer_load->buffer}; |
668 | Buffer new_buffer{buffer->data, |
669 | buffer->dtype, |
670 | {new_length}, |
671 | buffer->strides, |
672 | buffer->elem_offset, |
673 | buffer->name, |
674 | buffer->data_alignment, |
675 | buffer->offset_factor, |
676 | buffer->buffer_type, |
677 | buffer->axis_separators, |
678 | buffer->span}; |
679 | old_to_new_read_buffers[buffer.as<BufferNode>()] = new_buffer; |
680 | new_args.push_back(BufferLoad(new_buffer, buffer_load->indices, buffer_load->span)); |
681 | break; |
682 | } |
683 | case 2: /* length */ { |
684 | new_args.push_back(new_length); |
685 | break; |
686 | } |
687 | case 3: /* write_address */ { |
688 | new_args.push_back(MakeNewBufferLoad(old_write_buffer, 0, true).value()); |
689 | break; |
690 | } |
691 | default: |
692 | new_args.push_back(args[i]); |
693 | break; |
694 | } |
695 | } |
696 | return new_args; |
697 | } |
698 | |
699 | Array<PrimExpr> MakeNewComputeArgs(const Stmt& stmt) { |
700 | Array<PrimExpr> args{GetStmtArgs(stmt)}; |
701 | Array<PrimExpr> new_args{}; |
702 | for (size_t i = 0; i < args.size(); ++i) { |
703 | if (auto buffer_load = args[i].as<BufferLoadNode>()) { |
704 | BufferLoad new_buffer_load{ |
705 | MakeNewBufferLoad(buffer_load->buffer, buffer_load->indices[0], false) |
706 | .value_or(GetRef<BufferLoad>(buffer_load))}; |
707 | new_args.push_back(new_buffer_load); |
708 | } else { |
709 | new_args.push_back(args[i]); |
710 | } |
711 | } |
712 | return new_args; |
713 | } |
714 | |
715 | Stmt MakeNewStmt(const Stmt& stmt, const Array<PrimExpr>& new_args, |
716 | Optional<PrimExpr> cycless = Optional<PrimExpr>{}) { |
717 | auto attr{stmt.as<AttrStmtNode>()}; |
718 | Stmt eval_stmt{attr ? attr->body : stmt}; |
719 | auto eval{eval_stmt.as<EvaluateNode>()}; |
720 | ICHECK(eval) << "Expected statement to be an evaluate node, but was " |
721 | << eval_stmt->GetTypeKey(); |
722 | auto call{eval->value.as<CallNode>()}; |
723 | ICHECK(call) << "Expected expression to be a call node, but was " << eval->value->GetTypeKey(); |
724 | |
725 | Call new_call{call->dtype, call->op, new_args, call->span}; |
726 | Evaluate new_eval{new_call, eval->span}; |
727 | |
728 | if (attr) { |
729 | ICHECK(attr->attr_key == "pragma_compute_cycles_hint" ); |
730 | PrimExpr value = cycless.value_or(attr->value); |
731 | return AttrStmt{attr->node, attr->attr_key, value, new_eval, attr->span}; |
732 | } else { |
733 | return std::move(new_eval); |
734 | } |
735 | } |
736 | |
737 | Optional<BufferLoad> MakeNewBufferLoad(const Buffer& write_buffer, const PrimExpr& old_index, |
738 | bool only_old_index) { |
739 | auto it = _info.old_to_new_write_buffer.find(write_buffer.as<BufferNode>()); |
740 | if (it != _info.old_to_new_write_buffer.end()) { |
741 | std::pair<int, int> pair{it->second}; |
742 | int new_buffer_index{pair.first}; |
743 | PrimExpr new_index{only_old_index ? old_index : (pair.second + old_index)}; |
744 | return BufferLoad{new_buffers[new_buffer_index], {new_index}}; |
745 | } |
746 | return Optional<BufferLoad>{}; |
747 | } |
748 | |
749 | Map<tir::Var, Buffer> MakeNewBufferMap(const Map<tir::Var, Buffer>& buffer_map, |
750 | std::unordered_set<const VarNode*>* params_to_delete) { |
751 | Map<tir::Var, Buffer> new_buffer_map{}; |
752 | for (std::pair<Var, Buffer> pair : buffer_map) { |
753 | Var var{pair.first}; |
754 | Buffer buffer{pair.second}; |
755 | |
756 | if (buffers_to_delete.count(buffer.as<BufferNode>()) == 1) { |
757 | params_to_delete->insert(var.as<VarNode>()); |
758 | } else if (old_to_new_read_buffers.count(buffer.as<BufferNode>()) == 1) { |
759 | new_buffer_map.Set(var, old_to_new_read_buffers[buffer.as<BufferNode>()]); |
760 | } else { |
761 | new_buffer_map.Set(var, buffer); |
762 | } |
763 | } |
764 | return new_buffer_map; |
765 | } |
766 | |
767 | Array<tir::Var> MakeNewParams(const Array<tir::Var>& params, |
768 | const std::unordered_set<const VarNode*>& params_to_delete) { |
769 | std::vector<Var> new_params{}; |
770 | for (Var var : params) { |
771 | if (params_to_delete.count(var.as<VarNode>()) == 0) { |
772 | new_params.push_back(var); |
773 | } |
774 | } |
775 | return new_params; |
776 | } |
777 | |
778 | void UpdateBuffersToMergeAndDelete(const Stmt& stmt, int new_buffer_index, |
779 | int new_buffer_offset) { |
780 | Array<PrimExpr> args{GetStmtArgs(stmt)}; |
781 | Buffer read_buffer{GetCopyReadBuffer(stmt)}; |
782 | |
783 | if (buffers_to_merge.count(new_buffer_index) == 0) { |
784 | buffers_to_merge[new_buffer_index] = std::vector<Buffer>{read_buffer}; |
785 | } else { |
786 | buffers_to_merge[new_buffer_index].push_back(read_buffer); |
787 | } |
788 | |
789 | if (new_buffer_offset > 0) { |
790 | buffers_to_delete.insert(read_buffer.as<BufferNode>()); |
791 | } |
792 | } |
793 | |
794 | /*! Returns an array whose elements are the indices of the function arguments to be merged. |
795 | * Example: if a function has three arguments and the second and the third ones must |
796 | * be merged then the array is: [[0], [1, 2], [3]] */ |
797 | Array<Array<IntImm>> GetArgsToMerge(const Map<Var, Buffer>& buffer_map, |
798 | const Array<Var>& params) { |
799 | std::unordered_map<const BufferNode*, Var> buffer_to_var{}; |
800 | for (std::pair<Var, Buffer> var_buffer : buffer_map) { |
801 | buffer_to_var[var_buffer.second.as<BufferNode>()] = var_buffer.first; |
802 | } |
803 | |
804 | std::unordered_map<const VarNode*, int> var_to_index{}; |
805 | for (int i = 0; i < static_cast<int>(params.size()); ++i) { |
806 | var_to_index[params[i].as<VarNode>()] = i; |
807 | } |
808 | |
809 | std::vector<Array<IntImm>> vector{}; |
810 | for (std::pair<int, std::vector<Buffer>> index_vector : buffers_to_merge) { |
811 | std::vector<IntImm> indices{}; |
812 | for (Buffer buffer : index_vector.second) { |
813 | const VarNode* var{buffer_to_var[buffer.as<BufferNode>()].as<VarNode>()}; |
814 | IntImm index{DataType::Int(64), var_to_index[var]}; |
815 | var_to_index.erase(var); |
816 | auto it = std::find_if(indices.begin(), indices.end(), |
817 | [&](IntImm value) { return value->value == index->value; }); |
818 | if (it == indices.end()) { |
819 | indices.push_back(index); |
820 | } |
821 | } |
822 | vector.push_back(Array<IntImm>{indices}); |
823 | } |
824 | |
825 | for (std::pair<const VarNode*, int> var_index : var_to_index) { |
826 | vector.push_back(Array<IntImm>{IntImm(DataType::Int(64), var_index.second)}); |
827 | } |
828 | std::sort(vector.begin(), vector.end(), |
829 | [](Array<IntImm> a, Array<IntImm> b) { return a[0]->value < b[0]->value; }); |
830 | return vector; |
831 | } |
832 | |
833 | Map<IntImm, Array<IntImm>> GetArgsToMergeWithoutArgsNotInConstDict( |
834 | const Array<Array<IntImm>>& args_to_merge, const Map<IntImm, runtime::NDArray>& const_dict) { |
835 | Map<IntImm, Array<IntImm>> new_args_to_merge{}; |
836 | bool first_arg_found = false; |
837 | int64_t new_arg_key = 0; // the updated key of the merged const_dict |
838 | for (Array<IntImm> args : args_to_merge) { |
839 | IntImm key{args[0]}; |
840 | auto it = std::find_if(const_dict.begin(), const_dict.end(), |
841 | [&](std::pair<tvm::IntImm, runtime::NDArray> pair) { |
842 | return pair.first->value == key->value; |
843 | }); |
844 | if (it != const_dict.end()) { |
845 | if (first_arg_found == false) { |
846 | first_arg_found = true; |
847 | new_arg_key = key->value; |
848 | } |
849 | new_args_to_merge.Set(IntImm(DataType::Int(64), new_arg_key), args); |
850 | } |
851 | if (first_arg_found) { |
852 | new_arg_key++; |
853 | } |
854 | } |
855 | return new_args_to_merge; |
856 | } |
857 | |
858 | Map<IntImm, runtime::NDArray> MakeNewConstDict(const Map<IntImm, Array<IntImm>>& args_to_merge, |
859 | Map<IntImm, runtime::NDArray> const_dict) { |
860 | Map<IntImm, runtime::NDArray> new_const_dict{}; |
861 | if (args_to_merge.size() == 0) { |
862 | return new_const_dict; |
863 | } |
864 | |
865 | for (auto const& elem : args_to_merge) { |
866 | IntImm key = elem.first; |
867 | Array<IntImm> args = elem.second; |
868 | int64_t size = 0; |
869 | for (IntImm arg : args) { |
870 | auto it = std::find_if(const_dict.begin(), const_dict.end(), |
871 | [&](auto pair) { return pair.first->value == arg->value; }); |
872 | runtime::NDArray arg_constant{(*it).second}; |
873 | size += runtime::GetDataSize(*arg_constant.operator->()); |
874 | } |
875 | |
876 | runtime::NDArray constant = runtime::NDArray::Empty({size}, DataType::UInt(8), {kDLCPU, 0}); |
877 | |
878 | size_t offset = 0; |
879 | for (IntImm arg : args) { |
880 | auto it = std::find_if(const_dict.begin(), const_dict.end(), |
881 | [&](auto pair) { return pair.first->value == arg->value; }); |
882 | runtime::NDArray arg_constant{(*it).second}; |
883 | size_t nbytes = runtime::GetDataSize(*arg_constant.operator->()); |
884 | arg_constant.CopyToBytes(static_cast<uint8_t*>(constant->data) + offset, nbytes); |
885 | offset += nbytes; |
886 | } |
887 | new_const_dict.Set(key, constant); |
888 | } |
889 | return new_const_dict; |
890 | } |
891 | }; |
892 | |
893 | /*! |
894 | * \brief This pass looks for the constants used by each compute operator |
895 | * and merges them into a single buffer. |
896 | * Constants written to a buffer with local scope are not merged. |
897 | * \return tvm::transform::Pass |
898 | */ |
899 | tvm::transform::Pass MergeConstants() { |
900 | auto pass_func = [=](PrimFunc f, IRModule mod, tvm::transform::PassContext ctx) { |
901 | ICHECK(mod->GetGlobalVars().size() == 1 && mod->ContainGlobalVar("main" )) |
902 | << "Expected a single primitive function called 'main'. Please run the " |
903 | "MergeConstants pass in conjunction with the LowerToTIR() pass." ; |
904 | Optional<Map<IntImm, runtime::NDArray>> const_dict{ |
905 | f->attrs.GetAttr("ethos-u.const_dict" , Optional<Map<IntImm, runtime::NDArray>>{})}; |
906 | ICHECK(const_dict) << "Expected a ethos-u.const_dict attribute" ; |
907 | |
908 | MergeConstantsInfoExtractor::Info info{MergeConstantsInfoExtractor()(f)}; |
909 | f = RemoveAllocatesMutator()(f); |
910 | return MergeConstantsMutator(info)(f, const_dict.value()); |
911 | }; |
912 | return tvm::tir::transform::CreatePrimFuncPass(pass_func, 0, "tir.contrib.ethos-u.MergeConstants" , |
913 | {}); |
914 | } |
915 | |
916 | TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.MergeConstants" ).set_body_typed(MergeConstants); |
917 | |
918 | /*! |
919 | * \brief This pass removes the ethos-u.const_dict attribute |
920 | * \return tvm::transform::Pass |
921 | */ |
922 | class RemoveConstDictAttributeMutator : public StmtExprMutator { |
923 | public: |
924 | RemoveConstDictAttributeMutator() {} |
925 | |
926 | PrimFunc operator()(PrimFunc main_func) { |
927 | return WithoutAttr(std::move(main_func), "ethos-u.const_dict" ); |
928 | } |
929 | }; |
930 | |
931 | tvm::transform::Pass RemoveConstDictAttribute() { |
932 | auto pass_func = [=](PrimFunc f, IRModule mod, tvm::transform::PassContext ctx) { |
933 | return RemoveConstDictAttributeMutator()(f); |
934 | }; |
935 | return tvm::tir::transform::CreatePrimFuncPass( |
936 | pass_func, 0, "tir.contrib.ethos-u.RemoveConstDictAttribute" , {}); |
937 | } |
938 | |
939 | TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.RemoveConstDictAttribute" ) |
940 | .set_body_typed(RemoveConstDictAttribute); |
941 | |
942 | } // namespace ethosu |
943 | } // namespace contrib |
944 | } // namespace tir |
945 | } // namespace tvm |
946 | |