1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file remap_thread_axis.cc
22 */
23#include <tvm/runtime/registry.h>
24#include <tvm/tir/expr.h>
25#include <tvm/tir/stmt_functor.h>
26#include <tvm/tir/transform.h>
27
28#include <unordered_map>
29
30namespace tvm {
31namespace tir {
32
33// Mutator to change the read pattern
34class ThreadAxisRewriter : private StmtExprMutator {
35 public:
36 explicit ThreadAxisRewriter(const std::unordered_map<std::string, IterVar>& tmap) : tmap_(tmap) {}
37
38 Stmt Rewrite(Stmt stmt) { return operator()(std::move(stmt)); }
39
40 private:
41 Stmt VisitStmt_(const AttrStmtNode* op) final {
42 if (op->attr_key == attr::thread_extent) {
43 IterVar iv = Downcast<IterVar>(op->node);
44 ICHECK_NE(iv->thread_tag.length(), 0U);
45 auto it = tmap_.find(iv->thread_tag);
46 if (it != tmap_.end()) {
47 const IterVar& new_iv = it->second;
48 const VarNode* v = iv->var.get();
49 if (!vmap_.count(v)) {
50 vmap_[v] = new_iv->var;
51 } else {
52 ICHECK(vmap_[v].same_as(new_iv->var));
53 }
54 Stmt body = this->VisitStmt(op->body);
55 return AttrStmt(new_iv, op->attr_key, op->value, body);
56 }
57 }
58 return StmtExprMutator::VisitStmt_(op);
59 }
60
61 PrimExpr VisitExpr_(const VarNode* op) final {
62 auto it = vmap_.find(op);
63 if (it != vmap_.end()) return it->second;
64 return StmtExprMutator::VisitExpr_(op);
65 }
66 // The thread map
67 const std::unordered_map<std::string, IterVar>& tmap_;
68 // variable map
69 std::unordered_map<const VarNode*, Var> vmap_;
70};
71
72PrimFunc RemapThreadAxis(PrimFunc&& f, Map<runtime::String, IterVar> thread_map) {
73 std::unordered_map<std::string, IterVar> tmap;
74 for (const auto& kv : thread_map) {
75 tmap[kv.first] = kv.second;
76 }
77
78 auto opt_thread_axis = f->GetAttr<Array<IterVar>>(tir::attr::kDeviceThreadAxis);
79 ICHECK(opt_thread_axis != nullptr) << "Require attribute " << tir::attr::kDeviceThreadAxis;
80 auto thread_axis = opt_thread_axis.value();
81 auto* n = f.CopyOnWrite();
82
83 // replace the thread axis
84 for (size_t i = 0; i < thread_axis.size(); ++i) {
85 auto it = tmap.find(thread_axis[i]->thread_tag);
86 if (it != tmap.end()) {
87 thread_axis.Set(i, it->second);
88 }
89 }
90 n->body = ThreadAxisRewriter(tmap).Rewrite(std::move(n->body));
91 return WithAttr(std::move(f), tir::attr::kDeviceThreadAxis, thread_axis);
92}
93
94namespace transform {
95
96Pass RemapThreadAxis(Map<runtime::String, IterVar> thread_map) {
97 auto pass_func = [thread_map](PrimFunc f, IRModule m, PassContext ctx) {
98 return RemapThreadAxis(std::move(f), thread_map);
99 };
100 return CreatePrimFuncPass(pass_func, 0, "tir.RemapThreadAxis", {});
101}
102
103TVM_REGISTER_GLOBAL("tir.transform.RemapThreadAxis").set_body_typed(RemapThreadAxis);
104
105} // namespace transform
106} // namespace tir
107} // namespace tvm
108