1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tir/contrib/ethosu/passes.cc
22 *
23 * \brief Passes used in TIR lowering for the microNPU compiler.
24 */
25#include <tvm/tir/builtin.h>
26#include <tvm/tir/function.h>
27#include <tvm/tir/op.h>
28#include <tvm/tir/stmt_functor.h>
29#include <tvm/tir/transform.h>
30
31#include <algorithm>
32#include <unordered_map>
33#include <unordered_set>
34
35namespace tvm {
36
37/*!
38 * \brief The maximum number of movements allowed for a copy in the CopyComputeReordering pass.
39 */
40constexpr const char* kCopyComputeReorderingMaxCopyMovements =
41 "tir.contrib.ethos-u.copy_compute_reordering_max_copy_movements";
42TVM_REGISTER_PASS_CONFIG_OPTION(kCopyComputeReorderingMaxCopyMovements, Integer);
43
44/*!
45 * \brief Whether to reorder copies and computes based on cycle count.
46 */
47constexpr const char* kCopyComputeReorderingReorderByCycles =
48 "tir.contrib.ethos-u.copy_compute_reordering_reorder_by_cycles";
49TVM_REGISTER_PASS_CONFIG_OPTION(kCopyComputeReorderingReorderByCycles, Bool);
50
51namespace tir {
52namespace contrib {
53namespace ethosu {
54
55namespace {
56
57/*! Returns the arguments of the given statement */
58Array<PrimExpr> GetStmtArgs(const Stmt& stmt) {
59 auto attr{stmt.as<AttrStmtNode>()};
60 Stmt eval_stmt{attr ? attr->body : stmt};
61 auto eval{eval_stmt.as<EvaluateNode>()};
62 ICHECK(eval) << "Expected statement to be an evaluate node, but was " << eval_stmt->GetTypeKey();
63 auto call{eval->value.as<CallNode>()};
64 ICHECK(call) << "Expected expression to be a call node, but was " << eval->value->GetTypeKey();
65 return call->args;
66}
67
68enum class StmtType { global_copy, local_copy, compute };
69
70/*! Returns the type of the given statement */
71StmtType GetStmtType(const Stmt& stmt) {
72 Array<PrimExpr> args{GetStmtArgs(stmt)};
73 if (args[0].as<StringImmNode>()->value == "ethosu_copy") {
74 if (args[3].as<BufferLoadNode>()->buffer.scope() == "global") {
75 return StmtType::global_copy;
76 } else {
77 return StmtType::local_copy;
78 }
79 }
80 return StmtType::compute;
81}
82/*! Returns the buffer read my the given copy statement */
83Buffer GetCopyReadBuffer(const Stmt& stmt) {
84 Array<PrimExpr> args{GetStmtArgs(stmt)};
85 return args[1].as<BufferLoadNode>()->buffer;
86}
87
88/*! Returns the buffer written my the given copy statement */
89Buffer GetCopyWriteBuffer(const Stmt& stmt) {
90 Array<PrimExpr> args{GetStmtArgs(stmt)};
91 return args[3].as<BufferLoadNode>()->buffer;
92}
93
94/*! Returns the length of the given copy statement */
95int64_t GetCopyLength(const Stmt& stmt) {
96 Array<PrimExpr> args{GetStmtArgs(stmt)};
97 return args[2].as<IntImmNode>()->value;
98}
99
100/*! Returns the cycles of the given statement */
101int64_t GetStmtCycles(const Stmt& stmt) {
102 auto attr{stmt.as<AttrStmtNode>()};
103 if (attr && attr->attr_key == "pragma_compute_cycles_hint") {
104 int64_t cycles{Downcast<Integer>(attr->value)->value};
105 return cycles;
106 }
107 return 0;
108}
109} // namespace
110
111/*!
112 * \brief This mutator moves allocates to the top of the body of the main
113 * function.
114 *
115 * Note: This pass can currently only be run in conjunction with the
116 * LowerToTIR() pass as it expects a single primitive function called
117 * "main" that is being offloaded to the NPU.
118 *
119 * For example,
120 * Before:
121 * allocate {
122 * extern_call(...)
123 * allocate {
124 * extern_call(...)
125 * }
126 * }
127 *
128 * After:
129 * allocate {
130 * allocate {
131 * extern_call(...)
132 * extern_call(...)
133 * }
134 * }
135 */
136class HoistAllocatesMutator : public StmtExprMutator {
137 public:
138 HoistAllocatesMutator() {}
139
140 PrimFunc operator()(PrimFunc main_func) {
141 Stmt new_main_func_body = SeqStmt::Flatten(this->VisitStmt(main_func->body));
142
143 // Write all allocates that were removed in reverse order
144 for (auto it = allocates_.rbegin(); it != allocates_.rend(); it++) {
145 Allocate current_alloc = *it;
146 if (it != allocates_.rbegin()) {
147 new_main_func_body = SeqStmt({new_main_func_body});
148 }
149 new_main_func_body =
150 Allocate(current_alloc->buffer_var, current_alloc->dtype, current_alloc->extents,
151 current_alloc->condition, new_main_func_body, current_alloc->annotations,
152 current_alloc->span);
153 }
154
155 PrimFunc new_main_func = PrimFunc(main_func->params, new_main_func_body, main_func->ret_type,
156 main_func->buffer_map, main_func->attrs);
157 return new_main_func;
158 }
159
160 private:
161 Stmt VisitStmt_(const AllocateNode* op) override {
162 allocates_.push_back(GetRef<Allocate>(op));
163 return VisitStmt(op->body);
164 }
165
166 /*! A stack to store allocates as they are visited. */
167 std::vector<Allocate> allocates_;
168};
169
170/*!
171 * \brief A pass to hoist allocate nodes to the top of the body of the main function.
172 *
173 * \return tvm::transform::Pass
174 */
175tvm::transform::Pass HoistAllocates() {
176 auto pass_func = [=](PrimFunc f, IRModule mod, tvm::transform::PassContext ctx) {
177 ICHECK(mod->GetGlobalVars().size() == 1 && mod->ContainGlobalVar("main"))
178 << "Expected a single primitive function called 'main'. Please run the HoistAllocates pass "
179 "in conjunction with the LowerToTIR() pass.";
180 return HoistAllocatesMutator()(f);
181 };
182 return tvm::tir::transform::CreatePrimFuncPass(pass_func, 0, "tir.contrib.ethos-u.HoistAllocates",
183 {});
184}
185
186TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.HoistAllocates").set_body_typed(HoistAllocates);
187
188/*!
189 * \brief Reorders copy and compute nodes in such a way that independent DMA copies
190 * and computes happen in parallel.
191 * Copies to buffers with local scope are not reordered since they copy LUT
192 * into the SHRAM and that already happens in parallel with copying weights into
193 * the weights encoder.
194 */
195class CopyComputeReorderingMutator : public StmtExprMutator {
196 public:
197 explicit CopyComputeReorderingMutator(int max_copy_movements, bool reorder_by_cycles)
198 : _max_copy_movements{max_copy_movements}, _reorder_by_cycles{reorder_by_cycles} {}
199
200 PrimFunc operator()(PrimFunc main_func) {
201 if (_max_copy_movements > 0) {
202 auto prim_func_node{main_func.CopyOnWrite()};
203 prim_func_node->body = this->VisitStmt(main_func->body);
204 return GetRef<PrimFunc>(prim_func_node);
205 }
206 return main_func;
207 }
208
209 private:
210 // A structure to hold a compute op with the corresponding weights/bias copy and LUT copy
211 struct OpWithCopies {
212 Stmt compute_op{};
213 Stmt global_copy{};
214 Stmt local_copy{};
215 };
216
217 Stmt VisitStmt_(const SeqStmtNode* op) override {
218 if (op->size() <= 1) {
219 return StmtExprMutator::VisitStmt_(op);
220 }
221
222 auto seq_stmt{GetRef<SeqStmt>(op)};
223 std::vector<Stmt> new_seq(seq_stmt->size());
224 std::copy(seq_stmt->seq.begin(), seq_stmt->seq.end(), new_seq.begin());
225
226 // Reorder the copies and computes based on the cycle count
227 if (_reorder_by_cycles) {
228 // We can't hide the first copy, so ignore it for the purpose of hiding copies
229 Stmt first_copy{};
230 if (stmt_is_global_copy(new_seq[0]) ||
231 (stmt_is_local_copy(new_seq[0]) && stmt_is_global_copy(new_seq[1]))) {
232 auto copy_position = stmt_is_global_copy(new_seq[0]) ? 0 : 1;
233 first_copy = new_seq[copy_position];
234 new_seq.erase(new_seq.begin() + copy_position);
235 }
236
237 // Build up a list of cells with the compute op and the copy ops that directly preceed it
238 std::vector<OpWithCopies> ops{};
239 for (size_t idx = 0; idx < new_seq.size(); ++idx) {
240 if (stmt_is_compute_op(new_seq[idx])) {
241 OpWithCopies new_op;
242 new_op.compute_op = new_seq[idx];
243 if (idx > 0) {
244 auto prev_op = new_seq[idx - 1];
245 if (!stmt_is_compute_op(prev_op)) {
246 if (stmt_is_local_copy(prev_op)) {
247 new_op.local_copy = prev_op;
248 } else {
249 new_op.global_copy = prev_op;
250 }
251 if (idx > 1) {
252 auto prev_prev_op = new_seq[idx - 2];
253 if (!stmt_is_compute_op(prev_prev_op)) {
254 if (stmt_is_local_copy(prev_prev_op)) {
255 new_op.local_copy = prev_prev_op;
256 } else {
257 new_op.global_copy = prev_prev_op;
258 }
259 }
260 }
261 }
262 }
263 ops.push_back(new_op);
264 }
265 }
266
267 // Move the global copies up by one. If in general the computes take longer than the copies,
268 // that should be good enough
269 for (size_t idx = 1; idx < ops.size(); ++idx) {
270 if (ops[idx].global_copy.as<AttrStmtNode>()) {
271 ops[idx - 1].global_copy = ops[idx].global_copy;
272 ops[idx].global_copy = {};
273 }
274 }
275
276 // If there are long copies, try to hide them further
277 for (size_t idx = ops.size() - 1; idx > 0; --idx) {
278 if (ops[idx].global_copy.as<AttrStmtNode>()) {
279 // Check whether the copy is hidden
280 int64_t copy_cycles{GetStmtCycles(ops[idx].global_copy)};
281 int64_t compute_cycles{GetStmtCycles(ops[idx].compute_op)};
282 bool is_hidden = compute_cycles >= copy_cycles;
283
284 // If the previous compute op is not already hiding another copy, move the copy back, so
285 // that it would be hidden by multiple computes
286 while (!is_hidden && !ops[idx - 1].global_copy.as<AttrStmtNode>() && (idx > 0)) {
287 int64_t new_compute_cycles{GetStmtCycles(ops[idx - 1].compute_op)};
288 ops[idx - 1].global_copy = ops[idx].global_copy;
289 ops[idx].global_copy = {};
290 compute_cycles += new_compute_cycles;
291 is_hidden = compute_cycles >= copy_cycles;
292 --idx;
293 }
294 }
295 }
296
297 // Reconstruct the op sequence from the vector of OpWithCopies
298 new_seq.clear();
299 if (first_copy.as<AttrStmtNode>()) {
300 new_seq.push_back(first_copy);
301 }
302 for (auto& op : ops) {
303 if (op.global_copy.as<AttrStmtNode>()) {
304 new_seq.push_back(op.global_copy);
305 }
306 if (op.local_copy.as<EvaluateNode>()) {
307 new_seq.push_back(op.local_copy);
308 }
309 if (op.compute_op.as<AttrStmtNode>()) {
310 new_seq.push_back(op.compute_op);
311 }
312 }
313 } else {
314 // Each copy statement to a buffer with global scope is moved up
315 // at most `_max_copy_movements` times.
316 for (size_t index = 0; index < new_seq.size(); ++index) {
317 if (GetStmtType(new_seq[index]) == StmtType::global_copy) {
318 int lower = std::max(0, static_cast<int>(index) - _max_copy_movements);
319 for (int i = index; i > lower && (GetStmtType(new_seq[i - 1]) == StmtType::compute);
320 --i) {
321 std::swap(new_seq[i - 1], new_seq[i]);
322 }
323 }
324 }
325 }
326
327 auto seq_stmt_node{CopyOnWrite(op)};
328 seq_stmt_node->seq = std::move(new_seq);
329 return Stmt{seq_stmt_node};
330 }
331
332 bool stmt_is_global_copy(const Stmt& stmt) { return GetStmtType(stmt) == StmtType::global_copy; }
333
334 bool stmt_is_local_copy(const Stmt& stmt) { return GetStmtType(stmt) == StmtType::local_copy; }
335
336 bool stmt_is_compute_op(const Stmt& stmt) { return GetStmtType(stmt) == StmtType::compute; }
337
338 /*! The maximum number of movements allowed for a copy. */
339 int _max_copy_movements;
340 /*! Whether we use the cycle hint to determine the reordering. */
341 bool _reorder_by_cycles;
342};
343
344/*!
345 * \brief A pass to reorder copy and compute nodes in such a way that independent DMA copies
346 * and computes happen in parallel. If reorder_by_cycles is set, we will ignore the
347 * max_copy_movements value.
348 *
349 * \param max_copy_movements: The maximum number of movements allowed for a copy.
350 * If None, the pass context option tir.contrib.ethos-u.copy_compute_reordering_max_copy_movements
351 * is used if provided, otherwise the default value will be 1.
352 *
353 * \param reorder_by_cycles: Whether to reorder copies and computes by cycles.
354 * If None, the pass context option tir.contrib.ethos-u.copy_compute_reordering_reorder_by_cycles
355 * is used if provided, otherwise the default value will be False. If the value is True,
356 * max_copy_movements will be ignored.
357 * \return tvm::transform::Pass
358 */
359tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements,
360 Optional<Bool> reorder_by_cycles) {
361 auto pass_func = [=](PrimFunc f, IRModule mod, tvm::transform::PassContext ctx) {
362 ICHECK(mod->GetGlobalVars().size() == 1 && mod->ContainGlobalVar("main"))
363 << "Expected a single primitive function called 'main'. Please run the "
364 "CopyComputeReordering "
365 "pass in conjunction with the LowerToTIR() pass.";
366
367 auto copy_movements = max_copy_movements.value_or(
368 ctx->GetConfig(kCopyComputeReorderingMaxCopyMovements, Integer(1)).value());
369 auto reorder = reorder_by_cycles.value_or(
370 ctx->GetConfig(kCopyComputeReorderingReorderByCycles, Bool(false)).value());
371 return CopyComputeReorderingMutator(copy_movements.IntValue(), reorder)(f);
372 };
373 return tvm::tir::transform::CreatePrimFuncPass(pass_func, 0,
374 "tir.contrib.ethos-u.CopyComputeReordering", {});
375}
376
377TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering")
378 .set_body_typed(CopyComputeReordering);
379
380/*!
381 * \brief This mutator removes all allocates.
382 */
383class RemoveAllocatesMutator : public StmtExprMutator {
384 public:
385 PrimFunc operator()(PrimFunc main_func) {
386 auto prim_func_node{main_func.CopyOnWrite()};
387 prim_func_node->body = this->VisitStmt(main_func->body);
388 return GetRef<PrimFunc>(prim_func_node);
389 }
390
391 private:
392 Stmt VisitStmt_(const AllocateNode* op) override { return VisitStmt(op->body); }
393};
394
395/*!
396 * \brief This extractor collects information used by the MergeConstantsMutator
397 */
398class MergeConstantsInfoExtractor : public StmtExprVisitor {
399 public:
400 class Info {
401 public:
402 /*! A stack to store allocates as they are visited. */
403 std::vector<Allocate> allocates{};
404
405 /*! A list that contains in the i-th position the write buffer of the i-th statement
406 * if that statement is a copy to a buffer with global scope */
407 std::vector<Optional<Buffer>> copy_write_buffers{};
408
409 /*! Maps a copy's write buffer to an index representing the
410 * new buffer and an offset in that buffer */
411 std::unordered_map<const BufferNode*, std::pair<int /* new buffer index */, int /* offset */>>
412 old_to_new_write_buffer{};
413
414 /*! Maps an index representing a new buffer to the length of that buffer */
415 std::unordered_map<int /* new buffer index */, int /* length */> new_buffers_length{};
416
417 /*! Maps an index representing a new buffer to the cycless needed to copy that buffer */
418 std::unordered_map<int /* new buffer index */, int64_t> cycless{};
419 };
420
421 Info operator()(PrimFunc main_func) {
422 this->VisitStmt(main_func->body);
423 return std::move(_info);
424 }
425
426 private:
427 /*! The information collected by this extractor */
428 Info _info{};
429
430 void VisitStmt_(const AllocateNode* op) override {
431 _info.allocates.push_back(GetRef<Allocate>(op));
432 VisitStmt(op->body);
433 }
434
435 void VisitStmt_(const SeqStmtNode* op) override {
436 if (op->size() <= 1) {
437 StmtExprVisitor::VisitStmt_(op);
438 return;
439 }
440
441 auto seq_stmt{GetRef<SeqStmt>(op)};
442 for (size_t i = 0; i < seq_stmt.size(); ++i) {
443 Stmt stmt{seq_stmt[i]};
444 switch (GetStmtType(stmt)) {
445 case StmtType::global_copy: {
446 Buffer write_buffer{GetCopyWriteBuffer(stmt)};
447 _info.copy_write_buffers.push_back(write_buffer);
448 _info.old_to_new_write_buffer[write_buffer.as<BufferNode>()] = std::make_pair(-1, -1);
449 break;
450 }
451 case StmtType::local_copy: {
452 _info.copy_write_buffers.push_back(Optional<Buffer>{});
453 break;
454 }
455 case StmtType::compute: {
456 _info.copy_write_buffers.push_back(Optional<Buffer>{});
457 std::vector<Buffer> buffers{GetCopiedBuffersUsedByStmt(stmt)};
458 if (buffers.empty()) {
459 continue;
460 }
461 _info.new_buffers_length[i] = 0;
462 for (Buffer buffer : buffers) {
463 for (size_t j{i - 1}; j >= 0; --j) {
464 if (_info.copy_write_buffers[j] == buffer) {
465 _info.old_to_new_write_buffer[buffer.as<BufferNode>()] =
466 std::make_pair(i, _info.new_buffers_length[i]);
467 _info.new_buffers_length[i] += GetCopyLength(seq_stmt[j]);
468 _info.cycless[i] += GetStmtCycles(seq_stmt[j]);
469 break;
470 }
471 }
472 }
473 break;
474 }
475 }
476 }
477 }
478
479 /*! Get all buffers written by copies and used by a given statement */
480 std::vector<Buffer> GetCopiedBuffersUsedByStmt(const Stmt& stmt) {
481 std::vector<Buffer> buffers{};
482 for (PrimExpr arg : GetStmtArgs(stmt)) {
483 if (auto buffer_load = arg.as<BufferLoadNode>()) {
484 Buffer buffer{buffer_load->buffer};
485 // Check if the buffer has already been added
486 if (std::find(buffers.begin(), buffers.end(), buffer) == buffers.end()) {
487 // Check if the buffer is copied
488 if (_info.old_to_new_write_buffer.count(buffer.as<BufferNode>())) {
489 buffers.push_back(buffer);
490 }
491 }
492 }
493 }
494 return buffers;
495 }
496};
497
498/*!
499 * \brief This mutator looks for the constants used by each compute operator
500 * and merges them into a single buffer.
501 * Constants written to a buffer with local scope are not merged.
502 */
503class MergeConstantsMutator : public StmtExprMutator {
504 public:
505 explicit MergeConstantsMutator(MergeConstantsInfoExtractor::Info info) : _info{std::move(info)} {}
506
507 PrimFunc operator()(PrimFunc main_func, const Map<IntImm, runtime::NDArray>& const_dict) {
508 // Rewrite
509 Stmt new_body = RewritePrimFuncBody(main_func->body);
510 std::unordered_set<const VarNode*> params_to_delete{};
511 Map<Var, Buffer> new_buffer_map{MakeNewBufferMap(main_func->buffer_map, &params_to_delete)};
512 Array<Var> new_params{MakeNewParams(main_func->params, params_to_delete)};
513
514 // Make the new const dict
515 Array<Array<IntImm>> args_to_merge{GetArgsToMerge(main_func->buffer_map, main_func->params)};
516 Map<IntImm, Array<IntImm>> buffers_to_merge{
517 GetArgsToMergeWithoutArgsNotInConstDict(args_to_merge, const_dict)};
518 Map<IntImm, runtime::NDArray> new_const_dict{MakeNewConstDict(buffers_to_merge, const_dict)};
519
520 // Make the new prim func
521 auto prim_func_node{main_func.CopyOnWrite()};
522 prim_func_node->body = std::move(new_body);
523 prim_func_node->buffer_map = std::move(new_buffer_map);
524 prim_func_node->params = std::move(new_params);
525 PrimFunc f{GetRef<PrimFunc>(prim_func_node)};
526
527 // Add the new const dict as an attribute
528 f = WithAttr(std::move(f), "ethos-u.const_dict", new_const_dict);
529
530 return f;
531 }
532
533 private:
534 /*! The information collected by the MergeConstantsInfoExtractor */
535 MergeConstantsInfoExtractor::Info _info;
536
537 /*! Maps an index representing a new buffer to the new buffer */
538 std::unordered_map<int /* new buffer index */, Buffer> new_buffers{};
539
540 /*! Maps a copy's read buffer to the new copy's read buffer */
541 std::unordered_map<const BufferNode*, Buffer> old_to_new_read_buffers{};
542
543 /*! Maps an index representing a new buffer to the list of buffers to be merged in the new buffer
544 */
545 std::unordered_map<int /* new buffer index */, std::vector<Buffer>> buffers_to_merge{};
546
547 /*! A set of buffers to delete */
548 std::unordered_set<const BufferNode*> buffers_to_delete{};
549
550 Stmt RewritePrimFuncBody(Stmt body) {
551 std::unordered_map<const VarNode*, Allocate> var_to_allocate{};
552
553 // Rewrite old allocates
554 std::unordered_set<const VarNode*> buffer_vars{GetVarsForWrittenCopyBuffers()};
555 for (auto it{_info.allocates.rbegin()}; it != _info.allocates.rend(); ++it) {
556 Allocate alloc{*it};
557 var_to_allocate[alloc->buffer_var.get()] = alloc;
558 if (buffer_vars.count(alloc->buffer_var.as<VarNode>()) == 0) {
559 body = Allocate(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->condition, body,
560 alloc->annotations, alloc->span);
561 }
562 }
563
564 // Rewrite new allocates
565 for (auto it{_info.copy_write_buffers.rbegin()}; it != _info.copy_write_buffers.rend(); ++it) {
566 if (Optional<Buffer> buffer_opt = *it) {
567 Buffer old_write_buffer{buffer_opt.value()};
568 int new_buffer_index{
569 _info.old_to_new_write_buffer[old_write_buffer.as<BufferNode>()].first};
570
571 // Check if the allocate has already been created
572 if (new_buffers.count(new_buffer_index) == 0) {
573 BufferNode* new_buffer{old_write_buffer.CopyOnWrite()};
574 new_buffer->shape = {_info.new_buffers_length[new_buffer_index]};
575
576 new_buffers[new_buffer_index] = GetRef<Buffer>(new_buffer);
577
578 Allocate old_allocate{var_to_allocate[old_write_buffer->data.get()]};
579 body = Allocate(new_buffer->data, new_buffer->dtype, new_buffer->shape, tir::const_true(),
580 body, old_allocate->annotations, old_allocate->span);
581 }
582 }
583 }
584
585 // Rewrite operators
586 return this->VisitStmt(body);
587 }
588
589 Stmt VisitStmt_(const AllocateNode* op) override {
590 auto allocate{CopyOnWrite(op)};
591 allocate->body = this->VisitStmt(op->body);
592 return Stmt(allocate);
593 }
594
595 Stmt VisitStmt_(const SeqStmtNode* op) override {
596 if (op->size() <= 1) {
597 return StmtExprMutator::VisitStmt_(op);
598 }
599
600 Array<Stmt> new_seq{};
601 SeqStmt seq_stmt{GetRef<SeqStmt>(op)};
602 for (size_t i{0}; i < seq_stmt.size(); ++i) {
603 Stmt stmt{seq_stmt[i]};
604
605 switch (GetStmtType(stmt)) {
606 case StmtType::global_copy: {
607 Buffer old_write_buffer{_info.copy_write_buffers[i].value()};
608 std::pair<int, int> pair{
609 _info.old_to_new_write_buffer[old_write_buffer.as<BufferNode>()]};
610 int new_buffer_index{pair.first};
611 int new_buffer_offset{pair.second};
612 UpdateBuffersToMergeAndDelete(stmt, new_buffer_index, new_buffer_offset);
613
614 if (!IsCopyToBeDeleted(new_buffer_offset)) {
615 Optional<PrimExpr> cycless{GetMergedCycles(new_buffer_index)};
616 new_seq.push_back(MakeNewStmt(
617 stmt, MakeNewCopyArgs(stmt, old_write_buffer, new_buffer_index), cycless));
618 }
619 break;
620 }
621 case StmtType::local_copy: {
622 new_seq.push_back(stmt);
623 break;
624 }
625 case StmtType::compute: {
626 new_seq.push_back(MakeNewStmt(stmt, MakeNewComputeArgs(stmt)));
627 break;
628 }
629 }
630 }
631 return SeqStmt(new_seq, op->span);
632 }
633
634 /*! Returns the variables of the buffers written by copies */
635 std::unordered_set<const VarNode*> GetVarsForWrittenCopyBuffers() {
636 std::unordered_set<const VarNode*> buffer_vars{};
637 std::transform(_info.old_to_new_write_buffer.begin(), _info.old_to_new_write_buffer.end(),
638 std::inserter(buffer_vars, buffer_vars.begin()),
639 [](std::pair<const BufferNode*, std::pair<int, int>> pair) -> const VarNode* {
640 return pair.first->data.as<VarNode>();
641 });
642 return buffer_vars;
643 }
644
645 /*! Returns the cycles of the new buffer at the given index */
646 Optional<PrimExpr> GetMergedCycles(int new_buffer_index) {
647 auto it = _info.cycless.find(new_buffer_index);
648 if (it != _info.cycless.end()) {
649 return Integer(it->second);
650 }
651 return Optional<PrimExpr>{};
652 }
653
654 /*! Returns true if a copy must be deleted, false otherwise */
655 bool IsCopyToBeDeleted(int new_buffer_offset) { return new_buffer_offset > 0; }
656
657 Array<PrimExpr> MakeNewCopyArgs(const Stmt& stmt, const Buffer& old_write_buffer,
658 int new_buffer_index) {
659 Array<PrimExpr> args{GetStmtArgs(stmt)};
660 int new_length{_info.new_buffers_length[new_buffer_index]};
661
662 Array<PrimExpr> new_args{};
663 for (size_t i = 0; i < args.size(); ++i) {
664 switch (i) {
665 case 1: /* read_address */ {
666 auto buffer_load = args[1].as<BufferLoadNode>();
667 Buffer buffer{buffer_load->buffer};
668 Buffer new_buffer{buffer->data,
669 buffer->dtype,
670 {new_length},
671 buffer->strides,
672 buffer->elem_offset,
673 buffer->name,
674 buffer->data_alignment,
675 buffer->offset_factor,
676 buffer->buffer_type,
677 buffer->axis_separators,
678 buffer->span};
679 old_to_new_read_buffers[buffer.as<BufferNode>()] = new_buffer;
680 new_args.push_back(BufferLoad(new_buffer, buffer_load->indices, buffer_load->span));
681 break;
682 }
683 case 2: /* length */ {
684 new_args.push_back(new_length);
685 break;
686 }
687 case 3: /* write_address */ {
688 new_args.push_back(MakeNewBufferLoad(old_write_buffer, 0, true).value());
689 break;
690 }
691 default:
692 new_args.push_back(args[i]);
693 break;
694 }
695 }
696 return new_args;
697 }
698
699 Array<PrimExpr> MakeNewComputeArgs(const Stmt& stmt) {
700 Array<PrimExpr> args{GetStmtArgs(stmt)};
701 Array<PrimExpr> new_args{};
702 for (size_t i = 0; i < args.size(); ++i) {
703 if (auto buffer_load = args[i].as<BufferLoadNode>()) {
704 BufferLoad new_buffer_load{
705 MakeNewBufferLoad(buffer_load->buffer, buffer_load->indices[0], false)
706 .value_or(GetRef<BufferLoad>(buffer_load))};
707 new_args.push_back(new_buffer_load);
708 } else {
709 new_args.push_back(args[i]);
710 }
711 }
712 return new_args;
713 }
714
715 Stmt MakeNewStmt(const Stmt& stmt, const Array<PrimExpr>& new_args,
716 Optional<PrimExpr> cycless = Optional<PrimExpr>{}) {
717 auto attr{stmt.as<AttrStmtNode>()};
718 Stmt eval_stmt{attr ? attr->body : stmt};
719 auto eval{eval_stmt.as<EvaluateNode>()};
720 ICHECK(eval) << "Expected statement to be an evaluate node, but was "
721 << eval_stmt->GetTypeKey();
722 auto call{eval->value.as<CallNode>()};
723 ICHECK(call) << "Expected expression to be a call node, but was " << eval->value->GetTypeKey();
724
725 Call new_call{call->dtype, call->op, new_args, call->span};
726 Evaluate new_eval{new_call, eval->span};
727
728 if (attr) {
729 ICHECK(attr->attr_key == "pragma_compute_cycles_hint");
730 PrimExpr value = cycless.value_or(attr->value);
731 return AttrStmt{attr->node, attr->attr_key, value, new_eval, attr->span};
732 } else {
733 return std::move(new_eval);
734 }
735 }
736
737 Optional<BufferLoad> MakeNewBufferLoad(const Buffer& write_buffer, const PrimExpr& old_index,
738 bool only_old_index) {
739 auto it = _info.old_to_new_write_buffer.find(write_buffer.as<BufferNode>());
740 if (it != _info.old_to_new_write_buffer.end()) {
741 std::pair<int, int> pair{it->second};
742 int new_buffer_index{pair.first};
743 PrimExpr new_index{only_old_index ? old_index : (pair.second + old_index)};
744 return BufferLoad{new_buffers[new_buffer_index], {new_index}};
745 }
746 return Optional<BufferLoad>{};
747 }
748
749 Map<tir::Var, Buffer> MakeNewBufferMap(const Map<tir::Var, Buffer>& buffer_map,
750 std::unordered_set<const VarNode*>* params_to_delete) {
751 Map<tir::Var, Buffer> new_buffer_map{};
752 for (std::pair<Var, Buffer> pair : buffer_map) {
753 Var var{pair.first};
754 Buffer buffer{pair.second};
755
756 if (buffers_to_delete.count(buffer.as<BufferNode>()) == 1) {
757 params_to_delete->insert(var.as<VarNode>());
758 } else if (old_to_new_read_buffers.count(buffer.as<BufferNode>()) == 1) {
759 new_buffer_map.Set(var, old_to_new_read_buffers[buffer.as<BufferNode>()]);
760 } else {
761 new_buffer_map.Set(var, buffer);
762 }
763 }
764 return new_buffer_map;
765 }
766
767 Array<tir::Var> MakeNewParams(const Array<tir::Var>& params,
768 const std::unordered_set<const VarNode*>& params_to_delete) {
769 std::vector<Var> new_params{};
770 for (Var var : params) {
771 if (params_to_delete.count(var.as<VarNode>()) == 0) {
772 new_params.push_back(var);
773 }
774 }
775 return new_params;
776 }
777
778 void UpdateBuffersToMergeAndDelete(const Stmt& stmt, int new_buffer_index,
779 int new_buffer_offset) {
780 Array<PrimExpr> args{GetStmtArgs(stmt)};
781 Buffer read_buffer{GetCopyReadBuffer(stmt)};
782
783 if (buffers_to_merge.count(new_buffer_index) == 0) {
784 buffers_to_merge[new_buffer_index] = std::vector<Buffer>{read_buffer};
785 } else {
786 buffers_to_merge[new_buffer_index].push_back(read_buffer);
787 }
788
789 if (new_buffer_offset > 0) {
790 buffers_to_delete.insert(read_buffer.as<BufferNode>());
791 }
792 }
793
794 /*! Returns an array whose elements are the indices of the function arguments to be merged.
795 * Example: if a function has three arguments and the second and the third ones must
796 * be merged then the array is: [[0], [1, 2], [3]] */
797 Array<Array<IntImm>> GetArgsToMerge(const Map<Var, Buffer>& buffer_map,
798 const Array<Var>& params) {
799 std::unordered_map<const BufferNode*, Var> buffer_to_var{};
800 for (std::pair<Var, Buffer> var_buffer : buffer_map) {
801 buffer_to_var[var_buffer.second.as<BufferNode>()] = var_buffer.first;
802 }
803
804 std::unordered_map<const VarNode*, int> var_to_index{};
805 for (int i = 0; i < static_cast<int>(params.size()); ++i) {
806 var_to_index[params[i].as<VarNode>()] = i;
807 }
808
809 std::vector<Array<IntImm>> vector{};
810 for (std::pair<int, std::vector<Buffer>> index_vector : buffers_to_merge) {
811 std::vector<IntImm> indices{};
812 for (Buffer buffer : index_vector.second) {
813 const VarNode* var{buffer_to_var[buffer.as<BufferNode>()].as<VarNode>()};
814 IntImm index{DataType::Int(64), var_to_index[var]};
815 var_to_index.erase(var);
816 auto it = std::find_if(indices.begin(), indices.end(),
817 [&](IntImm value) { return value->value == index->value; });
818 if (it == indices.end()) {
819 indices.push_back(index);
820 }
821 }
822 vector.push_back(Array<IntImm>{indices});
823 }
824
825 for (std::pair<const VarNode*, int> var_index : var_to_index) {
826 vector.push_back(Array<IntImm>{IntImm(DataType::Int(64), var_index.second)});
827 }
828 std::sort(vector.begin(), vector.end(),
829 [](Array<IntImm> a, Array<IntImm> b) { return a[0]->value < b[0]->value; });
830 return vector;
831 }
832
833 Map<IntImm, Array<IntImm>> GetArgsToMergeWithoutArgsNotInConstDict(
834 const Array<Array<IntImm>>& args_to_merge, const Map<IntImm, runtime::NDArray>& const_dict) {
835 Map<IntImm, Array<IntImm>> new_args_to_merge{};
836 bool first_arg_found = false;
837 int64_t new_arg_key = 0; // the updated key of the merged const_dict
838 for (Array<IntImm> args : args_to_merge) {
839 IntImm key{args[0]};
840 auto it = std::find_if(const_dict.begin(), const_dict.end(),
841 [&](std::pair<tvm::IntImm, runtime::NDArray> pair) {
842 return pair.first->value == key->value;
843 });
844 if (it != const_dict.end()) {
845 if (first_arg_found == false) {
846 first_arg_found = true;
847 new_arg_key = key->value;
848 }
849 new_args_to_merge.Set(IntImm(DataType::Int(64), new_arg_key), args);
850 }
851 if (first_arg_found) {
852 new_arg_key++;
853 }
854 }
855 return new_args_to_merge;
856 }
857
858 Map<IntImm, runtime::NDArray> MakeNewConstDict(const Map<IntImm, Array<IntImm>>& args_to_merge,
859 Map<IntImm, runtime::NDArray> const_dict) {
860 Map<IntImm, runtime::NDArray> new_const_dict{};
861 if (args_to_merge.size() == 0) {
862 return new_const_dict;
863 }
864
865 for (auto const& elem : args_to_merge) {
866 IntImm key = elem.first;
867 Array<IntImm> args = elem.second;
868 int64_t size = 0;
869 for (IntImm arg : args) {
870 auto it = std::find_if(const_dict.begin(), const_dict.end(),
871 [&](auto pair) { return pair.first->value == arg->value; });
872 runtime::NDArray arg_constant{(*it).second};
873 size += runtime::GetDataSize(*arg_constant.operator->());
874 }
875
876 runtime::NDArray constant = runtime::NDArray::Empty({size}, DataType::UInt(8), {kDLCPU, 0});
877
878 size_t offset = 0;
879 for (IntImm arg : args) {
880 auto it = std::find_if(const_dict.begin(), const_dict.end(),
881 [&](auto pair) { return pair.first->value == arg->value; });
882 runtime::NDArray arg_constant{(*it).second};
883 size_t nbytes = runtime::GetDataSize(*arg_constant.operator->());
884 arg_constant.CopyToBytes(static_cast<uint8_t*>(constant->data) + offset, nbytes);
885 offset += nbytes;
886 }
887 new_const_dict.Set(key, constant);
888 }
889 return new_const_dict;
890 }
891};
892
893/*!
894 * \brief This pass looks for the constants used by each compute operator
895 * and merges them into a single buffer.
896 * Constants written to a buffer with local scope are not merged.
897 * \return tvm::transform::Pass
898 */
899tvm::transform::Pass MergeConstants() {
900 auto pass_func = [=](PrimFunc f, IRModule mod, tvm::transform::PassContext ctx) {
901 ICHECK(mod->GetGlobalVars().size() == 1 && mod->ContainGlobalVar("main"))
902 << "Expected a single primitive function called 'main'. Please run the "
903 "MergeConstants pass in conjunction with the LowerToTIR() pass.";
904 Optional<Map<IntImm, runtime::NDArray>> const_dict{
905 f->attrs.GetAttr("ethos-u.const_dict", Optional<Map<IntImm, runtime::NDArray>>{})};
906 ICHECK(const_dict) << "Expected a ethos-u.const_dict attribute";
907
908 MergeConstantsInfoExtractor::Info info{MergeConstantsInfoExtractor()(f)};
909 f = RemoveAllocatesMutator()(f);
910 return MergeConstantsMutator(info)(f, const_dict.value());
911 };
912 return tvm::tir::transform::CreatePrimFuncPass(pass_func, 0, "tir.contrib.ethos-u.MergeConstants",
913 {});
914}
915
916TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.MergeConstants").set_body_typed(MergeConstants);
917
918/*!
919 * \brief This pass removes the ethos-u.const_dict attribute
920 * \return tvm::transform::Pass
921 */
922class RemoveConstDictAttributeMutator : public StmtExprMutator {
923 public:
924 RemoveConstDictAttributeMutator() {}
925
926 PrimFunc operator()(PrimFunc main_func) {
927 return WithoutAttr(std::move(main_func), "ethos-u.const_dict");
928 }
929};
930
931tvm::transform::Pass RemoveConstDictAttribute() {
932 auto pass_func = [=](PrimFunc f, IRModule mod, tvm::transform::PassContext ctx) {
933 return RemoveConstDictAttributeMutator()(f);
934 };
935 return tvm::tir::transform::CreatePrimFuncPass(
936 pass_func, 0, "tir.contrib.ethos-u.RemoveConstDictAttribute", {});
937}
938
939TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.RemoveConstDictAttribute")
940 .set_body_typed(RemoveConstDictAttribute);
941
942} // namespace ethosu
943} // namespace contrib
944} // namespace tir
945} // namespace tvm
946