1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file remap_thread_axis.cc |
22 | */ |
23 | #include <tvm/runtime/registry.h> |
24 | #include <tvm/tir/expr.h> |
25 | #include <tvm/tir/stmt_functor.h> |
26 | #include <tvm/tir/transform.h> |
27 | |
28 | #include <unordered_map> |
29 | |
30 | namespace tvm { |
31 | namespace tir { |
32 | |
33 | // Mutator to change the read pattern |
34 | class ThreadAxisRewriter : private StmtExprMutator { |
35 | public: |
36 | explicit ThreadAxisRewriter(const std::unordered_map<std::string, IterVar>& tmap) : tmap_(tmap) {} |
37 | |
38 | Stmt Rewrite(Stmt stmt) { return operator()(std::move(stmt)); } |
39 | |
40 | private: |
41 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
42 | if (op->attr_key == attr::thread_extent) { |
43 | IterVar iv = Downcast<IterVar>(op->node); |
44 | ICHECK_NE(iv->thread_tag.length(), 0U); |
45 | auto it = tmap_.find(iv->thread_tag); |
46 | if (it != tmap_.end()) { |
47 | const IterVar& new_iv = it->second; |
48 | const VarNode* v = iv->var.get(); |
49 | if (!vmap_.count(v)) { |
50 | vmap_[v] = new_iv->var; |
51 | } else { |
52 | ICHECK(vmap_[v].same_as(new_iv->var)); |
53 | } |
54 | Stmt body = this->VisitStmt(op->body); |
55 | return AttrStmt(new_iv, op->attr_key, op->value, body); |
56 | } |
57 | } |
58 | return StmtExprMutator::VisitStmt_(op); |
59 | } |
60 | |
61 | PrimExpr VisitExpr_(const VarNode* op) final { |
62 | auto it = vmap_.find(op); |
63 | if (it != vmap_.end()) return it->second; |
64 | return StmtExprMutator::VisitExpr_(op); |
65 | } |
66 | // The thread map |
67 | const std::unordered_map<std::string, IterVar>& tmap_; |
68 | // variable map |
69 | std::unordered_map<const VarNode*, Var> vmap_; |
70 | }; |
71 | |
72 | PrimFunc RemapThreadAxis(PrimFunc&& f, Map<runtime::String, IterVar> thread_map) { |
73 | std::unordered_map<std::string, IterVar> tmap; |
74 | for (const auto& kv : thread_map) { |
75 | tmap[kv.first] = kv.second; |
76 | } |
77 | |
78 | auto opt_thread_axis = f->GetAttr<Array<IterVar>>(tir::attr::kDeviceThreadAxis); |
79 | ICHECK(opt_thread_axis != nullptr) << "Require attribute " << tir::attr::kDeviceThreadAxis; |
80 | auto thread_axis = opt_thread_axis.value(); |
81 | auto* n = f.CopyOnWrite(); |
82 | |
83 | // replace the thread axis |
84 | for (size_t i = 0; i < thread_axis.size(); ++i) { |
85 | auto it = tmap.find(thread_axis[i]->thread_tag); |
86 | if (it != tmap.end()) { |
87 | thread_axis.Set(i, it->second); |
88 | } |
89 | } |
90 | n->body = ThreadAxisRewriter(tmap).Rewrite(std::move(n->body)); |
91 | return WithAttr(std::move(f), tir::attr::kDeviceThreadAxis, thread_axis); |
92 | } |
93 | |
94 | namespace transform { |
95 | |
96 | Pass RemapThreadAxis(Map<runtime::String, IterVar> thread_map) { |
97 | auto pass_func = [thread_map](PrimFunc f, IRModule m, PassContext ctx) { |
98 | return RemapThreadAxis(std::move(f), thread_map); |
99 | }; |
100 | return CreatePrimFuncPass(pass_func, 0, "tir.RemapThreadAxis" , {}); |
101 | } |
102 | |
103 | TVM_REGISTER_GLOBAL("tir.transform.RemapThreadAxis" ).set_body_typed(RemapThreadAxis); |
104 | |
105 | } // namespace transform |
106 | } // namespace tir |
107 | } // namespace tvm |
108 | |