1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler_layout_rewrite.h
22 * \brief Rewrite the layout of "layout free" tensors (e.g., the weight tensors in
23 * conv2d and dense layers) according to the tile structure generated by the auto-scheduler.
24 */
25
26#include "auto_scheduler_layout_rewrite.h"
27
28#include <tvm/relay/attrs/transform.h>
29#include <tvm/relay/expr_functor.h>
30#include <tvm/relay/op_attr_types.h>
31#include <tvm/relay/transform.h>
32
33#include <deque>
34#include <functional>
35#include <vector>
36
37#include "../backend/te_compiler.h"
38#include "pattern_utils.h"
39
40namespace tvm {
41namespace relay {
42
43// Two global variables for receiving layout information from python
44std::deque<std::string> AutoSchedulerLayoutRewriter::global_ori_layouts_queue;
45std::deque<std::string> AutoSchedulerLayoutRewriter::global_new_layouts_queue;
46
47// Copy an Attrs but with a new auto_scheduler_rewritten_layout filed.
48template <typename T>
49Attrs CopyAttrsWithNewLayout(const T* ptr, const std::string& layout) {
50 auto n = make_object<T>(*ptr);
51 n->auto_scheduler_rewritten_layout = layout;
52 return Attrs(n);
53}
54
55// Mutate ops in a function
56class FuncMutator : public ExprMutator {
57 public:
58 FuncMutator(const std::deque<std::string>& ori_layouts_queue,
59 const std::deque<std::string>& new_layouts_queue)
60 : ExprMutator(),
61 ori_layouts_queue_(ori_layouts_queue),
62 new_layouts_queue_(new_layouts_queue) {}
63
64 Expr VisitExpr_(const CallNode* n) {
65 auto new_n = ExprMutator::VisitExpr_(n);
66
67 const auto* call = new_n.as<CallNode>();
68 if (call && call->op.as<OpNode>() &&
69 (std::find(target_ops_.begin(), target_ops_.end(), n->op.as<OpNode>()->name) !=
70 target_ops_.end()) &&
71 !ori_layouts_queue_.empty() && !new_layouts_queue_.empty()) {
72 // Pop a new layout from the queue
73 const std::string ori_layout = ori_layouts_queue_.front();
74 const std::string new_layout = new_layouts_queue_.front();
75 ori_layouts_queue_.pop_front();
76 new_layouts_queue_.pop_front();
77
78 // Insert a new op to do layout transform. (This will be simplified by FoldConstant later).
79 Expr updated_kernel = MakeAutoSchedulerLayoutTransform(call->args[1], ori_layout, new_layout);
80 Array<Expr> updated_args = {call->args[0], updated_kernel};
81
82 // Update the attrs
83 Attrs updated_attrs;
84 if (auto pattr = call->attrs.as<Conv2DAttrs>()) {
85 updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
86 } else if (auto pattr = call->attrs.as<Conv2DWinogradAttrs>()) {
87 updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
88 } else if (auto pattr = call->attrs.as<Conv3DAttrs>()) {
89 updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
90 } else if (auto pattr = call->attrs.as<MatmulAttrs>()) {
91 updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
92 } else if (auto pattr = call->attrs.as<DenseAttrs>()) {
93 updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
94 } else if (auto pattr = call->attrs.as<BatchMatmulAttrs>()) {
95 updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
96 } else {
97 LOG(FATAL) << "Unhandled attribute: " << call->attrs;
98 }
99 new_n = Call(call->op, updated_args, updated_attrs);
100 }
101 return new_n;
102 }
103
104 private:
105 std::deque<std::string> ori_layouts_queue_;
106 std::deque<std::string> new_layouts_queue_;
107
108 std::vector<std::string> target_ops_{
109 "nn.conv2d", "nn.conv3d", "nn.contrib_conv2d_winograd_without_weight_transform",
110 "nn.matmul", "nn.dense", "nn.batch_matmul"};
111};
112
113Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) {
114 auto new_n = ExprMutator::VisitExpr_(n);
115
116 if (const auto* call = new_n.as<CallNode>()) {
117 if (const auto* func = call->op.as<FunctionNode>()) {
118 global_ori_layouts_queue.clear();
119 global_new_layouts_queue.clear();
120
121 // Use ScheduleGetter to call python lower functions.
122 // This is used to get the layout transform information.
123 // The layout transformation will be recorded to global_ori_layout_queue
124 // and global_new_layouts_queue in ComputeDAG::RewriteLayout.
125 auto f = runtime::Registry::Get("auto_scheduler.enter_layout_rewrite");
126 CHECK(f) << "Could not find auto_scheduler.enter_layout_rewrite function.";
127 (*f)();
128
129 tec::PrimFuncFor(GetRef<Function>(func), Target::Current());
130
131 f = runtime::Registry::Get("auto_scheduler.exit_layout_rewrite");
132 CHECK(f) << "Could not find ansor.exit_layout_rewrite function.";
133 (*f)();
134
135 // Mutate the called function
136 if (!global_ori_layouts_queue.empty() && !global_new_layouts_queue.empty()) {
137 auto ret = FuncMutator(global_ori_layouts_queue, global_new_layouts_queue).VisitExpr(new_n);
138 return ret;
139 }
140 }
141 }
142
143 return new_n;
144}
145
146Expr AutoSchedulerLayoutRewrite(const Expr& expr) {
147 return AutoSchedulerLayoutRewriter().Mutate(expr);
148}
149
150namespace transform {
151
152Pass AutoSchedulerLayoutRewrite() {
153 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
154 [=](Function f, IRModule m, PassContext pc) {
155 return Downcast<Function>(relay::AutoSchedulerLayoutRewrite(f));
156 };
157 return CreateFunctionPass(pass_func, 3, "AutoSchedulerLayoutRewrite", {"InferType"});
158}
159
160TVM_REGISTER_GLOBAL("relay._transform.AutoSchedulerLayoutRewrite")
161 .set_body_typed(AutoSchedulerLayoutRewrite);
162
163TVM_REGISTER_GLOBAL("relay.attrs.get_auto_scheduler_rewritten_layout")
164 .set_body_typed([](const Attrs& attrs) {
165 if (attrs->IsInstance<Conv2DAttrs>()) {
166 return attrs.as<Conv2DAttrs>()->auto_scheduler_rewritten_layout;
167 } else if (attrs->IsInstance<Conv2DWinogradAttrs>()) {
168 return attrs.as<Conv2DWinogradAttrs>()->auto_scheduler_rewritten_layout;
169 } else if (attrs->IsInstance<Conv3DAttrs>()) {
170 return attrs.as<Conv3DAttrs>()->auto_scheduler_rewritten_layout;
171 } else if (attrs->IsInstance<MatmulAttrs>()) {
172 return attrs.as<MatmulAttrs>()->auto_scheduler_rewritten_layout;
173 } else if (attrs->IsInstance<DenseAttrs>()) {
174 return attrs.as<DenseAttrs>()->auto_scheduler_rewritten_layout;
175 } else if (attrs->IsInstance<BatchMatmulAttrs>()) {
176 return attrs.as<BatchMatmulAttrs>()->auto_scheduler_rewritten_layout;
177 } else {
178 LOG(FATAL) << "Unhandled attribute: " << attrs;
179 }
180 return tvm::String();
181 });
182
183} // namespace transform
184
185} // namespace relay
186} // namespace tvm
187