1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #include "./meta_schedule_layout_rewrite.h" |
21 | |
22 | #include <tvm/relay/attrs/transform.h> |
23 | #include <tvm/relay/expr_functor.h> |
24 | #include <tvm/relay/op_attr_types.h> |
25 | #include <tvm/relay/transform.h> |
26 | |
27 | #include <deque> |
28 | #include <mutex> |
29 | #include <vector> |
30 | |
31 | #include "../backend/te_compiler.h" |
32 | |
33 | namespace tvm { |
34 | namespace relay { |
35 | |
36 | class LayoutIndexQueue { |
37 | public: |
38 | static LayoutIndexQueue* Global() { |
39 | static LayoutIndexQueue inst; |
40 | return &inst; |
41 | } |
42 | |
43 | void Clear() { |
44 | std::lock_guard<std::mutex> lock(mutex_); |
45 | queue_.clear(); |
46 | } |
47 | |
48 | private: |
49 | friend class MetaScheduleLayoutRewriter; |
50 | std::mutex mutex_; |
51 | std::deque<tir::IndexMap> queue_; |
52 | }; |
53 | |
54 | void MetaScheduleLayoutRewriter::LayoutQueuePush(const tir::IndexMap& index_map) { |
55 | LayoutIndexQueue* self = LayoutIndexQueue::Global(); |
56 | { |
57 | std::lock_guard<std::mutex> lock(self->mutex_); |
58 | self->queue_.push_back(index_map); |
59 | } |
60 | } |
61 | |
62 | bool IsSupportedOp(const OpNode* op) { |
63 | static std::vector<std::string> target_ops{ |
64 | "nn.conv2d" , // |
65 | "nn.contrib_conv2d_winograd_without_weight_transform" , |
66 | "nn.conv3d" , |
67 | "nn.matmul" , |
68 | "nn.dense" , |
69 | "nn.batch_matmul" , |
70 | }; |
71 | return std::find(target_ops.begin(), target_ops.end(), op->name) != target_ops.end(); |
72 | } |
73 | |
74 | #define TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(Attr, AttrType, OriginalShape, Result) \ |
75 | if (const AttrType* ptr = Attr.as<AttrType>()) { \ |
76 | ObjectPtr<AttrType> n = make_object<AttrType>(*ptr); \ |
77 | n->meta_schedule_original_shape = OriginalShape; \ |
78 | Result = Attrs(n); \ |
79 | } |
80 | |
81 | // Mutate ops in a function |
82 | class MetaScheduleFuncMutator : public ExprMutator { |
83 | public: |
84 | explicit MetaScheduleFuncMutator(std::deque<tir::IndexMap>&& layout_queue) |
85 | : layout_queue_(std::move(layout_queue)) {} |
86 | |
87 | Expr VisitExpr_(const CallNode* call) { |
88 | Expr expr = ExprMutator::VisitExpr_(call); |
89 | if (layout_queue_.empty()) { |
90 | return expr; |
91 | } |
92 | if (const auto* call = expr.as<CallNode>()) { |
93 | if (const auto* op = call->op.as<OpNode>()) { |
94 | if (IsSupportedOp(op)) { |
95 | ICHECK_EQ(call->args.size(), 2); |
96 | tir::IndexMap index_map = layout_queue_.front(); |
97 | layout_queue_.pop_front(); |
98 | Array<PrimExpr> shape; |
99 | if (call->args[1]->IsInstance<VarNode>()) { |
100 | Var var = Downcast<Var>(call->args[1]); |
101 | shape = Downcast<TensorType>(var->type_annotation)->shape; |
102 | } else if (const ConstantNode* cnst = call->args[1].as<ConstantNode>()) { |
103 | shape = cnst->tensor_type()->shape; |
104 | } else { |
105 | LOG(FATAL) << "Unexpected input " << call->args[1]; |
106 | } |
107 | Attrs attrs{nullptr}; |
108 | TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, Conv2DAttrs, shape, attrs); |
109 | TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, Conv2DWinogradAttrs, shape, attrs); |
110 | TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, Conv3DAttrs, shape, attrs); |
111 | TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, MatmulAttrs, shape, attrs); |
112 | TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, DenseAttrs, shape, attrs); |
113 | TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, BatchMatmulAttrs, shape, attrs); |
114 | ICHECK(attrs.defined()) << "TypeError: Unknown attribute: " << call->attrs; |
115 | expr = Call(call->op, |
116 | {call->args[0], MakeMetaScheduleLayoutTransform(call->args[1], index_map)}, |
117 | attrs); |
118 | } |
119 | } |
120 | } |
121 | return expr; |
122 | } |
123 | |
124 | private: |
125 | std::deque<tir::IndexMap> layout_queue_; |
126 | }; |
127 | |
128 | #undef TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE |
129 | |
130 | Expr MetaScheduleLayoutRewriter::VisitExpr_(const CallNode* call) { |
131 | Expr expr = ExprMutator::VisitExpr_(call); |
132 | call = expr.as<CallNode>(); |
133 | if (call != nullptr) { |
134 | if (const auto* func = call->op.as<FunctionNode>()) { |
135 | LayoutIndexQueue* self = LayoutIndexQueue::Global(); |
136 | self->queue_.clear(); |
137 | tec::PrimFuncFor(GetRef<Function>(func), Target::Current()); |
138 | if (!self->queue_.empty()) { |
139 | std::deque<tir::IndexMap> queue = std::move(self->queue_); |
140 | self->queue_.clear(); |
141 | return MetaScheduleFuncMutator(std::move(queue)).VisitExpr(expr); |
142 | } |
143 | } |
144 | } |
145 | return expr; |
146 | } |
147 | |
148 | namespace transform { |
149 | |
150 | Pass MetaScheduleLayoutRewrite() { |
151 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
152 | [=](Function f, IRModule m, PassContext pc) -> Function { |
153 | return Downcast<Function>(MetaScheduleLayoutRewriter().Mutate(std::move(f))); |
154 | }; |
155 | return CreateFunctionPass(pass_func, 3, "MetaScheduleLayoutRewrite" , {"InferType" }); |
156 | } |
157 | |
158 | #define TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(Attrs, AttrType) \ |
159 | if (const auto* p = Attrs.as<AttrType>()) { \ |
160 | return p->meta_schedule_original_shape; \ |
161 | } |
162 | |
163 | TVM_REGISTER_GLOBAL("relay.attrs.get_meta_schedule_original_shape" ) |
164 | .set_body_typed([](const Attrs& attrs) -> Array<PrimExpr> { |
165 | TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, Conv2DAttrs); |
166 | TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, Conv2DWinogradAttrs); |
167 | TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, Conv3DAttrs); |
168 | TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, MatmulAttrs); |
169 | TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, DenseAttrs); |
170 | TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, BatchMatmulAttrs); |
171 | LOG(FATAL) << "TypeError: Unknown attribute: " << attrs; |
172 | throw; |
173 | }); |
174 | TVM_REGISTER_GLOBAL("relay._transform.MetaScheduleLayoutRewrite" ) |
175 | .set_body_typed(MetaScheduleLayoutRewrite); |
176 | |
177 | #undef TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE |
178 | |
179 | } // namespace transform |
180 | } // namespace relay |
181 | } // namespace tvm |
182 | |