1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_scheduler_layout_rewrite.h |
22 | * \brief Rewrite the layout of "layout free" tensors (e.g., the weight tensors in |
23 | * conv2d and dense layers) according to the tile structure generated by the auto-scheduler. |
24 | */ |
25 | |
26 | #include "auto_scheduler_layout_rewrite.h" |
27 | |
28 | #include <tvm/relay/attrs/transform.h> |
29 | #include <tvm/relay/expr_functor.h> |
30 | #include <tvm/relay/op_attr_types.h> |
31 | #include <tvm/relay/transform.h> |
32 | |
33 | #include <deque> |
34 | #include <functional> |
35 | #include <vector> |
36 | |
37 | #include "../backend/te_compiler.h" |
38 | #include "pattern_utils.h" |
39 | |
40 | namespace tvm { |
41 | namespace relay { |
42 | |
43 | // Two global variables for receiving layout information from python |
44 | std::deque<std::string> AutoSchedulerLayoutRewriter::global_ori_layouts_queue; |
45 | std::deque<std::string> AutoSchedulerLayoutRewriter::global_new_layouts_queue; |
46 | |
47 | // Copy an Attrs but with a new auto_scheduler_rewritten_layout filed. |
48 | template <typename T> |
49 | Attrs CopyAttrsWithNewLayout(const T* ptr, const std::string& layout) { |
50 | auto n = make_object<T>(*ptr); |
51 | n->auto_scheduler_rewritten_layout = layout; |
52 | return Attrs(n); |
53 | } |
54 | |
55 | // Mutate ops in a function |
56 | class FuncMutator : public ExprMutator { |
57 | public: |
58 | FuncMutator(const std::deque<std::string>& ori_layouts_queue, |
59 | const std::deque<std::string>& new_layouts_queue) |
60 | : ExprMutator(), |
61 | ori_layouts_queue_(ori_layouts_queue), |
62 | new_layouts_queue_(new_layouts_queue) {} |
63 | |
64 | Expr VisitExpr_(const CallNode* n) { |
65 | auto new_n = ExprMutator::VisitExpr_(n); |
66 | |
67 | const auto* call = new_n.as<CallNode>(); |
68 | if (call && call->op.as<OpNode>() && |
69 | (std::find(target_ops_.begin(), target_ops_.end(), n->op.as<OpNode>()->name) != |
70 | target_ops_.end()) && |
71 | !ori_layouts_queue_.empty() && !new_layouts_queue_.empty()) { |
72 | // Pop a new layout from the queue |
73 | const std::string ori_layout = ori_layouts_queue_.front(); |
74 | const std::string new_layout = new_layouts_queue_.front(); |
75 | ori_layouts_queue_.pop_front(); |
76 | new_layouts_queue_.pop_front(); |
77 | |
78 | // Insert a new op to do layout transform. (This will be simplified by FoldConstant later). |
79 | Expr updated_kernel = MakeAutoSchedulerLayoutTransform(call->args[1], ori_layout, new_layout); |
80 | Array<Expr> updated_args = {call->args[0], updated_kernel}; |
81 | |
82 | // Update the attrs |
83 | Attrs updated_attrs; |
84 | if (auto pattr = call->attrs.as<Conv2DAttrs>()) { |
85 | updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); |
86 | } else if (auto pattr = call->attrs.as<Conv2DWinogradAttrs>()) { |
87 | updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); |
88 | } else if (auto pattr = call->attrs.as<Conv3DAttrs>()) { |
89 | updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); |
90 | } else if (auto pattr = call->attrs.as<MatmulAttrs>()) { |
91 | updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); |
92 | } else if (auto pattr = call->attrs.as<DenseAttrs>()) { |
93 | updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); |
94 | } else if (auto pattr = call->attrs.as<BatchMatmulAttrs>()) { |
95 | updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); |
96 | } else { |
97 | LOG(FATAL) << "Unhandled attribute: " << call->attrs; |
98 | } |
99 | new_n = Call(call->op, updated_args, updated_attrs); |
100 | } |
101 | return new_n; |
102 | } |
103 | |
104 | private: |
105 | std::deque<std::string> ori_layouts_queue_; |
106 | std::deque<std::string> new_layouts_queue_; |
107 | |
108 | std::vector<std::string> target_ops_{ |
109 | "nn.conv2d" , "nn.conv3d" , "nn.contrib_conv2d_winograd_without_weight_transform" , |
110 | "nn.matmul" , "nn.dense" , "nn.batch_matmul" }; |
111 | }; |
112 | |
113 | Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) { |
114 | auto new_n = ExprMutator::VisitExpr_(n); |
115 | |
116 | if (const auto* call = new_n.as<CallNode>()) { |
117 | if (const auto* func = call->op.as<FunctionNode>()) { |
118 | global_ori_layouts_queue.clear(); |
119 | global_new_layouts_queue.clear(); |
120 | |
121 | // Use ScheduleGetter to call python lower functions. |
122 | // This is used to get the layout transform information. |
123 | // The layout transformation will be recorded to global_ori_layout_queue |
124 | // and global_new_layouts_queue in ComputeDAG::RewriteLayout. |
125 | auto f = runtime::Registry::Get("auto_scheduler.enter_layout_rewrite" ); |
126 | CHECK(f) << "Could not find auto_scheduler.enter_layout_rewrite function." ; |
127 | (*f)(); |
128 | |
129 | tec::PrimFuncFor(GetRef<Function>(func), Target::Current()); |
130 | |
131 | f = runtime::Registry::Get("auto_scheduler.exit_layout_rewrite" ); |
132 | CHECK(f) << "Could not find ansor.exit_layout_rewrite function." ; |
133 | (*f)(); |
134 | |
135 | // Mutate the called function |
136 | if (!global_ori_layouts_queue.empty() && !global_new_layouts_queue.empty()) { |
137 | auto ret = FuncMutator(global_ori_layouts_queue, global_new_layouts_queue).VisitExpr(new_n); |
138 | return ret; |
139 | } |
140 | } |
141 | } |
142 | |
143 | return new_n; |
144 | } |
145 | |
146 | Expr AutoSchedulerLayoutRewrite(const Expr& expr) { |
147 | return AutoSchedulerLayoutRewriter().Mutate(expr); |
148 | } |
149 | |
150 | namespace transform { |
151 | |
152 | Pass AutoSchedulerLayoutRewrite() { |
153 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
154 | [=](Function f, IRModule m, PassContext pc) { |
155 | return Downcast<Function>(relay::AutoSchedulerLayoutRewrite(f)); |
156 | }; |
157 | return CreateFunctionPass(pass_func, 3, "AutoSchedulerLayoutRewrite" , {"InferType" }); |
158 | } |
159 | |
160 | TVM_REGISTER_GLOBAL("relay._transform.AutoSchedulerLayoutRewrite" ) |
161 | .set_body_typed(AutoSchedulerLayoutRewrite); |
162 | |
163 | TVM_REGISTER_GLOBAL("relay.attrs.get_auto_scheduler_rewritten_layout" ) |
164 | .set_body_typed([](const Attrs& attrs) { |
165 | if (attrs->IsInstance<Conv2DAttrs>()) { |
166 | return attrs.as<Conv2DAttrs>()->auto_scheduler_rewritten_layout; |
167 | } else if (attrs->IsInstance<Conv2DWinogradAttrs>()) { |
168 | return attrs.as<Conv2DWinogradAttrs>()->auto_scheduler_rewritten_layout; |
169 | } else if (attrs->IsInstance<Conv3DAttrs>()) { |
170 | return attrs.as<Conv3DAttrs>()->auto_scheduler_rewritten_layout; |
171 | } else if (attrs->IsInstance<MatmulAttrs>()) { |
172 | return attrs.as<MatmulAttrs>()->auto_scheduler_rewritten_layout; |
173 | } else if (attrs->IsInstance<DenseAttrs>()) { |
174 | return attrs.as<DenseAttrs>()->auto_scheduler_rewritten_layout; |
175 | } else if (attrs->IsInstance<BatchMatmulAttrs>()) { |
176 | return attrs.as<BatchMatmulAttrs>()->auto_scheduler_rewritten_layout; |
177 | } else { |
178 | LOG(FATAL) << "Unhandled attribute: " << attrs; |
179 | } |
180 | return tvm::String(); |
181 | }); |
182 | |
183 | } // namespace transform |
184 | |
185 | } // namespace relay |
186 | } // namespace tvm |
187 | |