1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file transform_layout.h |
23 | * \brief Common infrastructure for transforming the layouts. This is used for AlterOpLayout and |
24 | * ConvertLayout pass. */ |
25 | |
26 | #ifndef TVM_RELAY_TRANSFORMS_TRANSFORM_LAYOUT_H_ |
27 | #define TVM_RELAY_TRANSFORMS_TRANSFORM_LAYOUT_H_ |
28 | |
29 | #include <tvm/relay/expr.h> |
30 | #include <tvm/tir/data_layout.h> |
31 | |
32 | #include <string> |
33 | #include <tuple> |
34 | #include <unordered_map> |
35 | #include <utility> |
36 | #include <vector> |
37 | |
38 | #include "infer_layout_utils.h" |
39 | #include "pattern_utils.h" |
40 | |
41 | namespace tvm { |
42 | namespace relay { |
43 | |
44 | /*! |
45 | * \brief Memorizes layout transformations to reuse. |
46 | */ |
47 | class TransformMemorizerNode : public Object { |
48 | public: |
49 | /*! \brief The key for the memorizer map is (Expr, src_layout, dst_layout). */ |
50 | using TransformKey = std::tuple<const Object*, std::string, std::string>; |
51 | |
52 | struct key_hash : public std::function<std::size_t(TransformKey)> { |
53 | std::size_t operator()(const TransformKey& k) const { |
54 | return dmlc::HashCombine<std::string>( |
55 | dmlc::HashCombine<std::string>(std::hash<const Object*>()(std::get<0>(k)), |
56 | std::get<1>(k)), |
57 | (std::get<2>(k))); |
58 | } |
59 | }; |
60 | |
61 | /*! |
62 | * \brief Defines the call transformation for derived passes. The new layouts are defined by |
63 | * used for different targets using a packed func. |
64 | * \param ref_call The original call. |
65 | * \param new_attrs Updated attributes consistent with new layouts. |
66 | * \param new_args The traversed/recursed args to the call. |
67 | * \return The new Call after calling the packed func. |
68 | */ |
69 | virtual Call CallWithNewLayouts(const Call& ref_call, Attrs new_attrs, |
70 | const std::vector<Expr>& new_args) = 0; |
71 | |
72 | virtual Call CallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) { |
73 | return CallWithNewLayouts(ref_call, ref_call->attrs, new_args); |
74 | } |
75 | |
76 | /*! \brief The memorizer map. */ |
77 | std::unordered_map<TransformKey, Expr, key_hash> memo; |
78 | |
79 | static constexpr const char* _type_key = "relay.alter_op_layout.TransformMemorizerNode" ; |
80 | TVM_DECLARE_FINAL_OBJECT_INFO(TransformMemorizerNode, Object); |
81 | }; |
82 | |
83 | /*! |
84 | * \brief Container that transforms the layouts and memorizes them. |
85 | */ |
86 | class TransformMemorizer : public ObjectRef { |
87 | public: |
88 | TransformMemorizer() = default; |
89 | explicit TransformMemorizer(ObjectPtr<Object> n) : ObjectRef(n) {} |
90 | |
91 | TransformMemorizerNode* operator->() { |
92 | return static_cast<TransformMemorizerNode*>(get_mutable()); |
93 | } |
94 | |
95 | /* |
96 | * \brief Memorizes and transforms the layout. |
97 | * \param expr The initial expr. |
98 | * \param src_layout The source layout. |
99 | * \param dst_layout The dest layout. |
100 | * \return The new expr with the dst layout. |
101 | */ |
102 | Expr Transform(Expr raw, const Layout& src_layout, const Layout& dst_layout) { |
103 | if (src_layout.Equals(dst_layout)) { |
104 | return raw; |
105 | } |
106 | |
107 | std::tuple<const Object*, std::string, std::string> key = |
108 | std::make_tuple<>(raw.get(), src_layout.name(), dst_layout.name()); |
109 | auto& memo = operator->()->memo; |
110 | |
111 | auto iter = memo.find(key); |
112 | if (iter != memo.end()) { |
113 | return iter->second; |
114 | } else { |
115 | Expr transform = TransformHelper(raw, src_layout, dst_layout); |
116 | memo[key] = transform; |
117 | return transform; |
118 | } |
119 | } |
120 | |
121 | /* |
122 | * \brief Helper to transform the layouts. |
123 | * \param expr The initial expr. |
124 | * \param src_layout The source layout. |
125 | * \param dst_layout The dest layout. |
126 | * \return The new expr with the dst layout. |
127 | * \note It performs following 2 operations |
128 | * 1) If src_layout ndim is smaller then dst_layout, expand_dim is inserted to match the dim |
129 | * size. For example, src_layout = C, dst_layout = NCHW16c. The src is expanded to NHWC. |
130 | * 2) Call layout transform with new src layout. |
131 | */ |
132 | Expr TransformHelper(Expr raw, Layout src_layout, Layout dst_layout) { |
133 | if (src_layout.Equals(dst_layout)) { |
134 | return raw; |
135 | } |
136 | |
137 | // 1) Check if the shape lengths are different. If yes, expand dims. |
138 | Expr input_expr = raw; |
139 | Layout new_src_layout = src_layout; |
140 | if (src_layout.ndim_primal() < dst_layout.ndim_primal()) { |
141 | // If scalar, then no need of layout transformation as scalar can be broadcasted easily even |
142 | // if the other operand has a transformed layout. |
143 | if (input_expr->checked_type_.defined() && IsScalar(input_expr)) { |
144 | return raw; |
145 | } |
146 | int num_new_axis = dst_layout.ndim_primal() - src_layout.ndim_primal(); |
147 | new_src_layout = src_layout.ExpandPrimal(dst_layout); |
148 | input_expr = MakeExpandDims(input_expr, 0, num_new_axis); |
149 | if (new_src_layout.Equals(dst_layout)) { |
150 | return input_expr; |
151 | } |
152 | } |
153 | |
154 | // 2) Insert layout transform on the transformed src. |
155 | ICHECK(new_src_layout.defined() && dst_layout.defined()) |
156 | << "Cannot insert layout transform because there are undefined layouts" ; |
157 | ICHECK(tir::BijectiveLayout(new_src_layout, dst_layout).defined()) |
158 | << "Cannot insert layout transform because there are inconvertible layouts: " |
159 | << new_src_layout << " v.s. " << dst_layout; |
160 | return MakeLayoutTransform(input_expr, new_src_layout.name(), dst_layout.name()); |
161 | } |
162 | |
163 | using ContainerType = TransformMemorizerNode; |
164 | }; |
165 | |
166 | /* |
167 | * \brief TempExprNode during layout transform. Instance of this expr will be Realized to normal |
168 | * expr ultimately. |
169 | * \tparam TransformMemorizerT The derived TransformMemorizer type. |
170 | */ |
171 | template <class TransformMemorizerT> |
172 | class LayoutAlternatedExprNode : public TempExprNode { |
173 | public: |
174 | Expr value; |
175 | Layout old_layout; |
176 | Layout new_layout; |
177 | TransformMemorizerT memorizer; |
178 | |
179 | Expr Realize() const final { |
180 | // NOTE: use a copy to discard the "const" qualifier |
181 | TransformMemorizerT tmp_memorizer = memorizer; |
182 | // fallback to old layout |
183 | return tmp_memorizer.Transform(value, new_layout, old_layout); |
184 | } |
185 | |
186 | void VisitAttrs(AttrVisitor* v) { |
187 | v->Visit("value" , &value); |
188 | v->Visit("old_layout" , &old_layout); |
189 | v->Visit("new_layout" , &new_layout); |
190 | } |
191 | |
192 | static constexpr const char* _type_key = "relay.alter_op_layout.LayoutAlternatedExprNode" ; |
193 | TVM_DECLARE_FINAL_OBJECT_INFO(LayoutAlternatedExprNode, TempExprNode); |
194 | }; |
195 | |
196 | /*! |
197 | * \brief Container for the layout alternated expr. |
198 | * \tparam TransformMemorizerT The derived TransformMemorizer type. |
199 | */ |
200 | template <class TransformMemorizerT> |
201 | class LayoutAlternatedExpr : public ObjectRef { |
202 | public: |
203 | LayoutAlternatedExpr() {} |
204 | explicit LayoutAlternatedExpr(ObjectPtr<Object> n) : ObjectRef(n) {} |
205 | |
206 | LayoutAlternatedExprNode<TransformMemorizerT>* operator->() { |
207 | return static_cast<LayoutAlternatedExprNode<TransformMemorizerT>*>(get_mutable()); |
208 | } |
209 | |
210 | using ContainerType = LayoutAlternatedExprNode<TransformMemorizerT>; |
211 | }; |
212 | |
213 | /*! |
214 | * Call registered FInferCorrectLayout of an op. |
215 | * Parameters are the same as the parameters for FInferCorrectLayout |
216 | * Returns inferred_input_layout, inferred_output_layout, updated attributes, and a flag |
217 | * indicating whether or not layout conversion is successful. |
218 | */ |
219 | static inline std::tuple<InferCorrectLayoutOutput, bool> InferCorrectLayouts( |
220 | const Call& call, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts, |
221 | const Array<tvm::relay::Type>& old_in_types) { |
222 | static auto finfer_layout = Op::GetAttrMap<FInferCorrectLayout>("FInferCorrectLayout" ); |
223 | auto null_res = std::make_tuple( |
224 | InferCorrectLayoutOutput(Array<Layout>(nullptr), Array<Layout>(nullptr), Attrs(nullptr)), |
225 | false); |
226 | if (!call->op.as<OpNode>()) { |
227 | return null_res; |
228 | } |
229 | |
230 | Op op = Downcast<Op>(call->op); |
231 | if (finfer_layout.count(op)) { |
232 | auto out = finfer_layout[op](call->attrs, new_in_layouts, old_in_layouts, old_in_types); |
233 | for (auto inferred_layouts : {out->input_layouts, out->output_layouts}) { |
234 | for (auto layout : inferred_layouts) { |
235 | if (!layout.defined()) { // inference fails |
236 | return null_res; |
237 | } |
238 | } |
239 | } |
240 | return std::make_tuple(out, true); |
241 | } else { |
242 | return null_res; |
243 | } |
244 | } |
245 | |
246 | /* |
247 | * \brief Used with ForwardRewrite to transform the expr. The input args are same as |
248 | * FForwardRewrite. |
249 | * \param ref_call The reference old call type to be rewritten. |
250 | * We can make use of the op and type information. |
251 | * \param new_args The new arguments (some of them could be TempExpr). |
252 | * \param ctx Optional context information about ref_call. |
253 | * \tparam TransformMemorizerT The derived TransformMemorizer type. |
254 | * \return The rewriten result call, can also return nullptr, |
255 | * which indicate the rewriter should use the default fallback |
256 | * rule that realizes all its input and compose the call. |
257 | * |
258 | * \note The ctx can be used to provide extra information during transformation. The ctx is |
259 | * templated to reuse across AlterOpLayout and ConvertLayout pass. The steps are |
260 | * - Extract the original layouts. |
261 | * - Use ctx transformation to get a Call with new layouts - CallWithNewLayouts. |
262 | * - Extract the new layouts from the returned Call. |
263 | * - Transform the original call to reuse the new layouts using TransformMemorizer. |
264 | */ |
265 | template <class TransformMemorizerT> |
266 | Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) { |
267 | std::vector<LayoutAlternatedExpr<TransformMemorizerT>> inputs; |
268 | std::vector<Expr> normal_new_args; |
269 | |
270 | // NOTE: discard the "const" qualifier |
271 | // TransformMemorizer memorizer = Downcast<TransformMemorizer>(ctx); |
272 | // TransformMemorizerT* ctx_transformer = |
273 | // static_cast<TransformMemorizerT*>(memorizer.operator->()); |
274 | TransformMemorizerT memorizer = Downcast<TransformMemorizerT>(ctx); |
275 | |
276 | // fill incomplete state and flatten tuple |
277 | auto push_back_one_arg = [&inputs, memorizer](Expr arg) { |
278 | // We always expect LayoutAlternatedExpr<TransformMemorizerT>. |
279 | // This is used to convert the normal Expr to LayoutAlternatedExpr<TransformMemorizerT>. |
280 | if (const LayoutAlternatedExprNode<TransformMemorizerT>* inp = |
281 | arg.as<LayoutAlternatedExprNode<TransformMemorizerT>>()) { |
282 | inputs.push_back(GetRef<LayoutAlternatedExpr<TransformMemorizerT>>(inp)); |
283 | return inp->value; |
284 | } else { |
285 | auto inode = make_object<LayoutAlternatedExprNode<TransformMemorizerT>>(); |
286 | inode->value = arg; |
287 | inode->memorizer = memorizer; |
288 | inputs.push_back(LayoutAlternatedExpr<TransformMemorizerT>(inode)); |
289 | return arg; |
290 | } |
291 | }; |
292 | |
293 | for (auto new_arg : new_args) { |
294 | // NOTE: do not support nested tuple |
295 | if (new_arg->IsInstance<TupleNode>()) { |
296 | Tuple tuple_new_arg = Downcast<Tuple>(new_arg); |
297 | Array<Expr> fields; |
298 | fields.reserve(tuple_new_arg->fields.size()); |
299 | for (auto x : tuple_new_arg->fields) { |
300 | Expr tmp = push_back_one_arg(x); |
301 | fields.push_back(tmp); |
302 | } |
303 | normal_new_args.push_back(WithFields(tuple_new_arg, fields)); |
304 | } else { |
305 | Expr tmp = push_back_one_arg(new_arg); |
306 | normal_new_args.push_back(tmp); |
307 | } |
308 | } |
309 | |
310 | // If there is no FInferCorrectLayout for the type, then we just assume the layout is correct. |
311 | static auto finfer_layout = Op::GetAttrMap<FInferCorrectLayout>("FInferCorrectLayout" ); |
312 | if (Op::HasAttrMap("FTVMAlterOpLayout" )) { |
313 | static auto falter_layout = Op::GetAttrMap<FTVMAlterOpLayout>("FTVMAlterOpLayout" ); |
314 | if (ref_call->op.as<OpNode>()) { |
315 | Op op = Downcast<Op>(ref_call->op); |
316 | if (falter_layout.count(op) && !finfer_layout.count(op)) { |
317 | return memorizer->CallWithNewLayouts(ref_call, normal_new_args); |
318 | } |
319 | } |
320 | } |
321 | |
322 | // old_prd, new_prd = state[inputs] |
323 | // different ops can view a tensor with different layouts, e.g. conv_1->transpose(H, W)->conv_2 |
324 | // transpose view its output having NCWH layout, but conv_2 still views it as NCHW to operate |
325 | // old_prd, new_prd: the input layouts from the perspective of the producer (transpose) |
326 | // old_cur, new_cur: the input layouts from the perspective of the current node (conv_2) |
327 | // old_prd->new_prd tells how producer changed the layout |
328 | // old_cur->new_cur tells what change the current node wants to see |
329 | // No layout transforms are needed when they mean the same (NCHW->NCHW4c == NCWH->NCWH4c) |
330 | |
331 | // The workflow: |
332 | // 1. Run InferCorrectLayouts(NULL, old_prd) to get old_cur |
333 | // 2. Run InferCorrectLayouts(new_prd, old_prd) to get new_cur and rewrite the current op |
334 | |
335 | Array<Layout> old_prd, old_cur, old_out, new_prd, new_out, new_cur; |
336 | for (auto inp : inputs) { |
337 | old_prd.push_back(inp->old_layout); |
338 | new_prd.push_back(inp->new_layout); |
339 | } |
340 | |
341 | // Collect input types to pass on to Infer Correct Layout. |
342 | tvm::Array<tvm::relay::Type> types; |
343 | for (auto arg : ref_call->args) { |
344 | types.push_back(arg->checked_type()); |
345 | } |
346 | |
347 | bool success = false; |
348 | InferCorrectLayoutOutput infer_out; |
349 | std::tie(infer_out, success) = |
350 | InferCorrectLayouts(ref_call, Array<Layout>(nullptr), old_prd, types); |
351 | old_cur = infer_out->input_layouts; |
352 | old_out = infer_out->output_layouts; |
353 | if (!success) { |
354 | return Expr(nullptr); |
355 | } |
356 | ICHECK_EQ(old_cur.size(), new_prd.size()); |
357 | |
358 | // for backward compatibility of InferCorrectLayouts |
359 | Array<Layout> new_prd_inferred = new_prd; |
360 | // if new_prd_inferred == 'undef': new_prd_inferred = old_cur |
361 | for (size_t i = 0; i < new_prd_inferred.size(); ++i) { |
362 | if (!new_prd_inferred[i].defined()) { |
363 | new_prd_inferred.Set(i, old_cur[i]); |
364 | } |
365 | } |
366 | Array<Layout> old_prd_inferred = old_prd; |
367 | // if old_prd_inferred == 'undef': old_prd_inferred = old_cur |
368 | for (size_t i = 0; i < old_prd_inferred.size(); ++i) { |
369 | if (!old_prd_inferred[i].defined()) { |
370 | old_prd_inferred.Set(i, old_cur[i]); |
371 | } |
372 | } |
373 | |
374 | // new_op = alter(op) |
375 | Call new_call = memorizer->CallWithNewLayouts(ref_call, infer_out->new_attrs, normal_new_args); |
376 | |
377 | // new_cur, new_out = op.infer(new_prd) |
378 | if (new_call->op->IsInstance<OpNode>()) { |
379 | success = false; |
380 | std::tie(infer_out, success) = |
381 | InferCorrectLayouts(new_call, new_prd_inferred, old_prd_inferred, types); |
382 | new_cur = infer_out->input_layouts; |
383 | new_out = infer_out->output_layouts; |
384 | if (!success) { |
385 | return Expr(nullptr); |
386 | } |
387 | } else { |
388 | return Expr(nullptr); |
389 | } |
390 | |
391 | ICHECK_EQ(new_out.size(), old_out.size()) |
392 | << "The number of output nodes should keep the same during alter_op_layout" ; |
393 | ICHECK_EQ(new_prd.size(), new_cur.size()) |
394 | << "The number of input nodes should keep the same during alter_op_layout" ; |
395 | |
396 | auto transform_layout = [&memorizer](Expr arg_item, const Layout& old_prd, const Layout& old_cur, |
397 | const Layout& new_prd, const Layout& new_cur) { |
398 | if (old_cur.Equals(old_prd)) { // the two transforms can be fused to one |
399 | arg_item = memorizer.Transform(arg_item, new_prd, new_cur); |
400 | } else { |
401 | if (old_prd.defined()) arg_item = memorizer.Transform(arg_item, new_prd, old_prd); |
402 | arg_item = memorizer.Transform(arg_item, old_cur, new_cur); |
403 | } |
404 | return arg_item; |
405 | }; |
406 | |
407 | DLOG(INFO) << "Transforming layout for `" << ref_call->op << "`" ; |
408 | DLOG(INFO) << " old_prd=" << old_prd; |
409 | DLOG(INFO) << " new_prd=" << new_prd; |
410 | DLOG(INFO) << " old_cur=" << old_cur; |
411 | DLOG(INFO) << " new_cur=" << new_cur; |
412 | |
413 | // if (new_prd != new_cur): insert transform (new_prd -> new_cur) |
414 | Array<Expr> transformed_args; |
415 | size_t pt = 0; |
416 | for (auto arg : new_call->args) { |
417 | if (arg->IsInstance<TupleNode>()) { // unflatten tuple |
418 | Tuple tuple_arg = Downcast<Tuple>(arg); |
419 | Array<Expr> transformed_tuple_arg; |
420 | transformed_tuple_arg.reserve(tuple_arg->fields.size()); |
421 | for (auto arg_item : tuple_arg->fields) { |
422 | transformed_tuple_arg.push_back( |
423 | transform_layout(arg_item, old_prd[pt], old_cur[pt], new_prd[pt], new_cur[pt])); |
424 | pt++; |
425 | } |
426 | transformed_args.push_back(WithFields(tuple_arg, transformed_tuple_arg)); |
427 | } else { |
428 | transformed_args.push_back( |
429 | transform_layout(arg, old_prd[pt], old_cur[pt], new_prd[pt], new_cur[pt])); |
430 | pt++; |
431 | } |
432 | } |
433 | ICHECK_EQ(pt, inputs.size()); |
434 | |
435 | // state[node] = (old_out, new_out) |
436 | // (handle tuple output) |
437 | if (ref_call->checked_type()->IsInstance<TupleTypeNode>()) { |
438 | Expr tuple_output = Call(new_call->op, transformed_args, infer_out->new_attrs); |
439 | Array<Expr> fields; |
440 | for (size_t i = 0; i < new_out.size(); ++i) { |
441 | auto rnode = make_object<LayoutAlternatedExprNode<TransformMemorizerT>>(); |
442 | rnode->value = TupleGetItem(tuple_output, i); |
443 | rnode->old_layout = old_out[i]; |
444 | rnode->new_layout = new_out[i]; |
445 | rnode->memorizer = memorizer; |
446 | fields.push_back(Expr(rnode)); |
447 | } |
448 | return Tuple(fields); |
449 | } else { |
450 | auto rnode = make_object<LayoutAlternatedExprNode<TransformMemorizerT>>(); |
451 | ICHECK_EQ(new_out.size(), 1); |
452 | rnode->value = Call(new_call->op, transformed_args, infer_out->new_attrs, {}, ref_call->span); |
453 | rnode->old_layout = old_out[0]; |
454 | rnode->new_layout = new_out[0]; |
455 | rnode->memorizer = memorizer; |
456 | return Expr(rnode); |
457 | } |
458 | } |
459 | |
460 | } // namespace relay |
461 | } // namespace tvm |
462 | |
463 | #endif // TVM_RELAY_TRANSFORMS_TRANSFORM_LAYOUT_H_ |
464 | |