1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#include "./meta_schedule_layout_rewrite.h"
21
22#include <tvm/relay/attrs/transform.h>
23#include <tvm/relay/expr_functor.h>
24#include <tvm/relay/op_attr_types.h>
25#include <tvm/relay/transform.h>
26
27#include <deque>
28#include <mutex>
29#include <vector>
30
31#include "../backend/te_compiler.h"
32
33namespace tvm {
34namespace relay {
35
36class LayoutIndexQueue {
37 public:
38 static LayoutIndexQueue* Global() {
39 static LayoutIndexQueue inst;
40 return &inst;
41 }
42
43 void Clear() {
44 std::lock_guard<std::mutex> lock(mutex_);
45 queue_.clear();
46 }
47
48 private:
49 friend class MetaScheduleLayoutRewriter;
50 std::mutex mutex_;
51 std::deque<tir::IndexMap> queue_;
52};
53
54void MetaScheduleLayoutRewriter::LayoutQueuePush(const tir::IndexMap& index_map) {
55 LayoutIndexQueue* self = LayoutIndexQueue::Global();
56 {
57 std::lock_guard<std::mutex> lock(self->mutex_);
58 self->queue_.push_back(index_map);
59 }
60}
61
62bool IsSupportedOp(const OpNode* op) {
63 static std::vector<std::string> target_ops{
64 "nn.conv2d", //
65 "nn.contrib_conv2d_winograd_without_weight_transform",
66 "nn.conv3d",
67 "nn.matmul",
68 "nn.dense",
69 "nn.batch_matmul",
70 };
71 return std::find(target_ops.begin(), target_ops.end(), op->name) != target_ops.end();
72}
73
74#define TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(Attr, AttrType, OriginalShape, Result) \
75 if (const AttrType* ptr = Attr.as<AttrType>()) { \
76 ObjectPtr<AttrType> n = make_object<AttrType>(*ptr); \
77 n->meta_schedule_original_shape = OriginalShape; \
78 Result = Attrs(n); \
79 }
80
81// Mutate ops in a function
82class MetaScheduleFuncMutator : public ExprMutator {
83 public:
84 explicit MetaScheduleFuncMutator(std::deque<tir::IndexMap>&& layout_queue)
85 : layout_queue_(std::move(layout_queue)) {}
86
87 Expr VisitExpr_(const CallNode* call) {
88 Expr expr = ExprMutator::VisitExpr_(call);
89 if (layout_queue_.empty()) {
90 return expr;
91 }
92 if (const auto* call = expr.as<CallNode>()) {
93 if (const auto* op = call->op.as<OpNode>()) {
94 if (IsSupportedOp(op)) {
95 ICHECK_EQ(call->args.size(), 2);
96 tir::IndexMap index_map = layout_queue_.front();
97 layout_queue_.pop_front();
98 Array<PrimExpr> shape;
99 if (call->args[1]->IsInstance<VarNode>()) {
100 Var var = Downcast<Var>(call->args[1]);
101 shape = Downcast<TensorType>(var->type_annotation)->shape;
102 } else if (const ConstantNode* cnst = call->args[1].as<ConstantNode>()) {
103 shape = cnst->tensor_type()->shape;
104 } else {
105 LOG(FATAL) << "Unexpected input " << call->args[1];
106 }
107 Attrs attrs{nullptr};
108 TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, Conv2DAttrs, shape, attrs);
109 TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, Conv2DWinogradAttrs, shape, attrs);
110 TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, Conv3DAttrs, shape, attrs);
111 TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, MatmulAttrs, shape, attrs);
112 TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, DenseAttrs, shape, attrs);
113 TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, BatchMatmulAttrs, shape, attrs);
114 ICHECK(attrs.defined()) << "TypeError: Unknown attribute: " << call->attrs;
115 expr = Call(call->op,
116 {call->args[0], MakeMetaScheduleLayoutTransform(call->args[1], index_map)},
117 attrs);
118 }
119 }
120 }
121 return expr;
122 }
123
124 private:
125 std::deque<tir::IndexMap> layout_queue_;
126};
127
128#undef TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE
129
130Expr MetaScheduleLayoutRewriter::VisitExpr_(const CallNode* call) {
131 Expr expr = ExprMutator::VisitExpr_(call);
132 call = expr.as<CallNode>();
133 if (call != nullptr) {
134 if (const auto* func = call->op.as<FunctionNode>()) {
135 LayoutIndexQueue* self = LayoutIndexQueue::Global();
136 self->queue_.clear();
137 tec::PrimFuncFor(GetRef<Function>(func), Target::Current());
138 if (!self->queue_.empty()) {
139 std::deque<tir::IndexMap> queue = std::move(self->queue_);
140 self->queue_.clear();
141 return MetaScheduleFuncMutator(std::move(queue)).VisitExpr(expr);
142 }
143 }
144 }
145 return expr;
146}
147
148namespace transform {
149
150Pass MetaScheduleLayoutRewrite() {
151 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
152 [=](Function f, IRModule m, PassContext pc) -> Function {
153 return Downcast<Function>(MetaScheduleLayoutRewriter().Mutate(std::move(f)));
154 };
155 return CreateFunctionPass(pass_func, 3, "MetaScheduleLayoutRewrite", {"InferType"});
156}
157
158#define TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(Attrs, AttrType) \
159 if (const auto* p = Attrs.as<AttrType>()) { \
160 return p->meta_schedule_original_shape; \
161 }
162
163TVM_REGISTER_GLOBAL("relay.attrs.get_meta_schedule_original_shape")
164 .set_body_typed([](const Attrs& attrs) -> Array<PrimExpr> {
165 TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, Conv2DAttrs);
166 TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, Conv2DWinogradAttrs);
167 TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, Conv3DAttrs);
168 TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, MatmulAttrs);
169 TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, DenseAttrs);
170 TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, BatchMatmulAttrs);
171 LOG(FATAL) << "TypeError: Unknown attribute: " << attrs;
172 throw;
173 });
174TVM_REGISTER_GLOBAL("relay._transform.MetaScheduleLayoutRewrite")
175 .set_body_typed(MetaScheduleLayoutRewrite);
176
177#undef TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE
178
179} // namespace transform
180} // namespace relay
181} // namespace tvm
182