1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file transform_layout.h
23 * \brief Common infrastructure for transforming the layouts. This is used for AlterOpLayout and
24 * ConvertLayout pass. */
25
26#ifndef TVM_RELAY_TRANSFORMS_TRANSFORM_LAYOUT_H_
27#define TVM_RELAY_TRANSFORMS_TRANSFORM_LAYOUT_H_
28
29#include <tvm/relay/expr.h>
30#include <tvm/tir/data_layout.h>
31
32#include <string>
33#include <tuple>
34#include <unordered_map>
35#include <utility>
36#include <vector>
37
38#include "infer_layout_utils.h"
39#include "pattern_utils.h"
40
41namespace tvm {
42namespace relay {
43
44/*!
45 * \brief Memorizes layout transformations to reuse.
46 */
47class TransformMemorizerNode : public Object {
48 public:
49 /*! \brief The key for the memorizer map is (Expr, src_layout, dst_layout). */
50 using TransformKey = std::tuple<const Object*, std::string, std::string>;
51
52 struct key_hash : public std::function<std::size_t(TransformKey)> {
53 std::size_t operator()(const TransformKey& k) const {
54 return dmlc::HashCombine<std::string>(
55 dmlc::HashCombine<std::string>(std::hash<const Object*>()(std::get<0>(k)),
56 std::get<1>(k)),
57 (std::get<2>(k)));
58 }
59 };
60
61 /*!
62 * \brief Defines the call transformation for derived passes. The new layouts are defined by
63 * used for different targets using a packed func.
64 * \param ref_call The original call.
65 * \param new_attrs Updated attributes consistent with new layouts.
66 * \param new_args The traversed/recursed args to the call.
67 * \return The new Call after calling the packed func.
68 */
69 virtual Call CallWithNewLayouts(const Call& ref_call, Attrs new_attrs,
70 const std::vector<Expr>& new_args) = 0;
71
72 virtual Call CallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) {
73 return CallWithNewLayouts(ref_call, ref_call->attrs, new_args);
74 }
75
76 /*! \brief The memorizer map. */
77 std::unordered_map<TransformKey, Expr, key_hash> memo;
78
79 static constexpr const char* _type_key = "relay.alter_op_layout.TransformMemorizerNode";
80 TVM_DECLARE_FINAL_OBJECT_INFO(TransformMemorizerNode, Object);
81};
82
83/*!
84 * \brief Container that transforms the layouts and memorizes them.
85 */
86class TransformMemorizer : public ObjectRef {
87 public:
88 TransformMemorizer() = default;
89 explicit TransformMemorizer(ObjectPtr<Object> n) : ObjectRef(n) {}
90
91 TransformMemorizerNode* operator->() {
92 return static_cast<TransformMemorizerNode*>(get_mutable());
93 }
94
95 /*
96 * \brief Memorizes and transforms the layout.
97 * \param expr The initial expr.
98 * \param src_layout The source layout.
99 * \param dst_layout The dest layout.
100 * \return The new expr with the dst layout.
101 */
102 Expr Transform(Expr raw, const Layout& src_layout, const Layout& dst_layout) {
103 if (src_layout.Equals(dst_layout)) {
104 return raw;
105 }
106
107 std::tuple<const Object*, std::string, std::string> key =
108 std::make_tuple<>(raw.get(), src_layout.name(), dst_layout.name());
109 auto& memo = operator->()->memo;
110
111 auto iter = memo.find(key);
112 if (iter != memo.end()) {
113 return iter->second;
114 } else {
115 Expr transform = TransformHelper(raw, src_layout, dst_layout);
116 memo[key] = transform;
117 return transform;
118 }
119 }
120
121 /*
122 * \brief Helper to transform the layouts.
123 * \param expr The initial expr.
124 * \param src_layout The source layout.
125 * \param dst_layout The dest layout.
126 * \return The new expr with the dst layout.
127 * \note It performs following 2 operations
128 * 1) If src_layout ndim is smaller then dst_layout, expand_dim is inserted to match the dim
129 * size. For example, src_layout = C, dst_layout = NCHW16c. The src is expanded to NHWC.
130 * 2) Call layout transform with new src layout.
131 */
132 Expr TransformHelper(Expr raw, Layout src_layout, Layout dst_layout) {
133 if (src_layout.Equals(dst_layout)) {
134 return raw;
135 }
136
137 // 1) Check if the shape lengths are different. If yes, expand dims.
138 Expr input_expr = raw;
139 Layout new_src_layout = src_layout;
140 if (src_layout.ndim_primal() < dst_layout.ndim_primal()) {
141 // If scalar, then no need of layout transformation as scalar can be broadcasted easily even
142 // if the other operand has a transformed layout.
143 if (input_expr->checked_type_.defined() && IsScalar(input_expr)) {
144 return raw;
145 }
146 int num_new_axis = dst_layout.ndim_primal() - src_layout.ndim_primal();
147 new_src_layout = src_layout.ExpandPrimal(dst_layout);
148 input_expr = MakeExpandDims(input_expr, 0, num_new_axis);
149 if (new_src_layout.Equals(dst_layout)) {
150 return input_expr;
151 }
152 }
153
154 // 2) Insert layout transform on the transformed src.
155 ICHECK(new_src_layout.defined() && dst_layout.defined())
156 << "Cannot insert layout transform because there are undefined layouts";
157 ICHECK(tir::BijectiveLayout(new_src_layout, dst_layout).defined())
158 << "Cannot insert layout transform because there are inconvertible layouts: "
159 << new_src_layout << " v.s. " << dst_layout;
160 return MakeLayoutTransform(input_expr, new_src_layout.name(), dst_layout.name());
161 }
162
163 using ContainerType = TransformMemorizerNode;
164};
165
166/*
167 * \brief TempExprNode during layout transform. Instance of this expr will be Realized to normal
168 * expr ultimately.
169 * \tparam TransformMemorizerT The derived TransformMemorizer type.
170 */
171template <class TransformMemorizerT>
172class LayoutAlternatedExprNode : public TempExprNode {
173 public:
174 Expr value;
175 Layout old_layout;
176 Layout new_layout;
177 TransformMemorizerT memorizer;
178
179 Expr Realize() const final {
180 // NOTE: use a copy to discard the "const" qualifier
181 TransformMemorizerT tmp_memorizer = memorizer;
182 // fallback to old layout
183 return tmp_memorizer.Transform(value, new_layout, old_layout);
184 }
185
186 void VisitAttrs(AttrVisitor* v) {
187 v->Visit("value", &value);
188 v->Visit("old_layout", &old_layout);
189 v->Visit("new_layout", &new_layout);
190 }
191
192 static constexpr const char* _type_key = "relay.alter_op_layout.LayoutAlternatedExprNode";
193 TVM_DECLARE_FINAL_OBJECT_INFO(LayoutAlternatedExprNode, TempExprNode);
194};
195
196/*!
197 * \brief Container for the layout alternated expr.
198 * \tparam TransformMemorizerT The derived TransformMemorizer type.
199 */
200template <class TransformMemorizerT>
201class LayoutAlternatedExpr : public ObjectRef {
202 public:
203 LayoutAlternatedExpr() {}
204 explicit LayoutAlternatedExpr(ObjectPtr<Object> n) : ObjectRef(n) {}
205
206 LayoutAlternatedExprNode<TransformMemorizerT>* operator->() {
207 return static_cast<LayoutAlternatedExprNode<TransformMemorizerT>*>(get_mutable());
208 }
209
210 using ContainerType = LayoutAlternatedExprNode<TransformMemorizerT>;
211};
212
213/*!
214 * Call registered FInferCorrectLayout of an op.
215 * Parameters are the same as the parameters for FInferCorrectLayout
216 * Returns inferred_input_layout, inferred_output_layout, updated attributes, and a flag
217 * indicating whether or not layout conversion is successful.
218 */
219static inline std::tuple<InferCorrectLayoutOutput, bool> InferCorrectLayouts(
220 const Call& call, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
221 const Array<tvm::relay::Type>& old_in_types) {
222 static auto finfer_layout = Op::GetAttrMap<FInferCorrectLayout>("FInferCorrectLayout");
223 auto null_res = std::make_tuple(
224 InferCorrectLayoutOutput(Array<Layout>(nullptr), Array<Layout>(nullptr), Attrs(nullptr)),
225 false);
226 if (!call->op.as<OpNode>()) {
227 return null_res;
228 }
229
230 Op op = Downcast<Op>(call->op);
231 if (finfer_layout.count(op)) {
232 auto out = finfer_layout[op](call->attrs, new_in_layouts, old_in_layouts, old_in_types);
233 for (auto inferred_layouts : {out->input_layouts, out->output_layouts}) {
234 for (auto layout : inferred_layouts) {
235 if (!layout.defined()) { // inference fails
236 return null_res;
237 }
238 }
239 }
240 return std::make_tuple(out, true);
241 } else {
242 return null_res;
243 }
244}
245
246/*
247 * \brief Used with ForwardRewrite to transform the expr. The input args are same as
248 * FForwardRewrite.
249 * \param ref_call The reference old call type to be rewritten.
250 * We can make use of the op and type information.
251 * \param new_args The new arguments (some of them could be TempExpr).
252 * \param ctx Optional context information about ref_call.
253 * \tparam TransformMemorizerT The derived TransformMemorizer type.
254 * \return The rewriten result call, can also return nullptr,
255 * which indicate the rewriter should use the default fallback
256 * rule that realizes all its input and compose the call.
257 *
258 * \note The ctx can be used to provide extra information during transformation. The ctx is
259 * templated to reuse across AlterOpLayout and ConvertLayout pass. The steps are
260 * - Extract the original layouts.
261 * - Use ctx transformation to get a Call with new layouts - CallWithNewLayouts.
262 * - Extract the new layouts from the returned Call.
263 * - Transform the original call to reuse the new layouts using TransformMemorizer.
264 */
265template <class TransformMemorizerT>
266Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
267 std::vector<LayoutAlternatedExpr<TransformMemorizerT>> inputs;
268 std::vector<Expr> normal_new_args;
269
270 // NOTE: discard the "const" qualifier
271 // TransformMemorizer memorizer = Downcast<TransformMemorizer>(ctx);
272 // TransformMemorizerT* ctx_transformer =
273 // static_cast<TransformMemorizerT*>(memorizer.operator->());
274 TransformMemorizerT memorizer = Downcast<TransformMemorizerT>(ctx);
275
276 // fill incomplete state and flatten tuple
277 auto push_back_one_arg = [&inputs, memorizer](Expr arg) {
278 // We always expect LayoutAlternatedExpr<TransformMemorizerT>.
279 // This is used to convert the normal Expr to LayoutAlternatedExpr<TransformMemorizerT>.
280 if (const LayoutAlternatedExprNode<TransformMemorizerT>* inp =
281 arg.as<LayoutAlternatedExprNode<TransformMemorizerT>>()) {
282 inputs.push_back(GetRef<LayoutAlternatedExpr<TransformMemorizerT>>(inp));
283 return inp->value;
284 } else {
285 auto inode = make_object<LayoutAlternatedExprNode<TransformMemorizerT>>();
286 inode->value = arg;
287 inode->memorizer = memorizer;
288 inputs.push_back(LayoutAlternatedExpr<TransformMemorizerT>(inode));
289 return arg;
290 }
291 };
292
293 for (auto new_arg : new_args) {
294 // NOTE: do not support nested tuple
295 if (new_arg->IsInstance<TupleNode>()) {
296 Tuple tuple_new_arg = Downcast<Tuple>(new_arg);
297 Array<Expr> fields;
298 fields.reserve(tuple_new_arg->fields.size());
299 for (auto x : tuple_new_arg->fields) {
300 Expr tmp = push_back_one_arg(x);
301 fields.push_back(tmp);
302 }
303 normal_new_args.push_back(WithFields(tuple_new_arg, fields));
304 } else {
305 Expr tmp = push_back_one_arg(new_arg);
306 normal_new_args.push_back(tmp);
307 }
308 }
309
310 // If there is no FInferCorrectLayout for the type, then we just assume the layout is correct.
311 static auto finfer_layout = Op::GetAttrMap<FInferCorrectLayout>("FInferCorrectLayout");
312 if (Op::HasAttrMap("FTVMAlterOpLayout")) {
313 static auto falter_layout = Op::GetAttrMap<FTVMAlterOpLayout>("FTVMAlterOpLayout");
314 if (ref_call->op.as<OpNode>()) {
315 Op op = Downcast<Op>(ref_call->op);
316 if (falter_layout.count(op) && !finfer_layout.count(op)) {
317 return memorizer->CallWithNewLayouts(ref_call, normal_new_args);
318 }
319 }
320 }
321
322 // old_prd, new_prd = state[inputs]
323 // different ops can view a tensor with different layouts, e.g. conv_1->transpose(H, W)->conv_2
324 // transpose view its output having NCWH layout, but conv_2 still views it as NCHW to operate
325 // old_prd, new_prd: the input layouts from the perspective of the producer (transpose)
326 // old_cur, new_cur: the input layouts from the perspective of the current node (conv_2)
327 // old_prd->new_prd tells how producer changed the layout
328 // old_cur->new_cur tells what change the current node wants to see
329 // No layout transforms are needed when they mean the same (NCHW->NCHW4c == NCWH->NCWH4c)
330
331 // The workflow:
332 // 1. Run InferCorrectLayouts(NULL, old_prd) to get old_cur
333 // 2. Run InferCorrectLayouts(new_prd, old_prd) to get new_cur and rewrite the current op
334
335 Array<Layout> old_prd, old_cur, old_out, new_prd, new_out, new_cur;
336 for (auto inp : inputs) {
337 old_prd.push_back(inp->old_layout);
338 new_prd.push_back(inp->new_layout);
339 }
340
341 // Collect input types to pass on to Infer Correct Layout.
342 tvm::Array<tvm::relay::Type> types;
343 for (auto arg : ref_call->args) {
344 types.push_back(arg->checked_type());
345 }
346
347 bool success = false;
348 InferCorrectLayoutOutput infer_out;
349 std::tie(infer_out, success) =
350 InferCorrectLayouts(ref_call, Array<Layout>(nullptr), old_prd, types);
351 old_cur = infer_out->input_layouts;
352 old_out = infer_out->output_layouts;
353 if (!success) {
354 return Expr(nullptr);
355 }
356 ICHECK_EQ(old_cur.size(), new_prd.size());
357
358 // for backward compatibility of InferCorrectLayouts
359 Array<Layout> new_prd_inferred = new_prd;
360 // if new_prd_inferred == 'undef': new_prd_inferred = old_cur
361 for (size_t i = 0; i < new_prd_inferred.size(); ++i) {
362 if (!new_prd_inferred[i].defined()) {
363 new_prd_inferred.Set(i, old_cur[i]);
364 }
365 }
366 Array<Layout> old_prd_inferred = old_prd;
367 // if old_prd_inferred == 'undef': old_prd_inferred = old_cur
368 for (size_t i = 0; i < old_prd_inferred.size(); ++i) {
369 if (!old_prd_inferred[i].defined()) {
370 old_prd_inferred.Set(i, old_cur[i]);
371 }
372 }
373
374 // new_op = alter(op)
375 Call new_call = memorizer->CallWithNewLayouts(ref_call, infer_out->new_attrs, normal_new_args);
376
377 // new_cur, new_out = op.infer(new_prd)
378 if (new_call->op->IsInstance<OpNode>()) {
379 success = false;
380 std::tie(infer_out, success) =
381 InferCorrectLayouts(new_call, new_prd_inferred, old_prd_inferred, types);
382 new_cur = infer_out->input_layouts;
383 new_out = infer_out->output_layouts;
384 if (!success) {
385 return Expr(nullptr);
386 }
387 } else {
388 return Expr(nullptr);
389 }
390
391 ICHECK_EQ(new_out.size(), old_out.size())
392 << "The number of output nodes should keep the same during alter_op_layout";
393 ICHECK_EQ(new_prd.size(), new_cur.size())
394 << "The number of input nodes should keep the same during alter_op_layout";
395
396 auto transform_layout = [&memorizer](Expr arg_item, const Layout& old_prd, const Layout& old_cur,
397 const Layout& new_prd, const Layout& new_cur) {
398 if (old_cur.Equals(old_prd)) { // the two transforms can be fused to one
399 arg_item = memorizer.Transform(arg_item, new_prd, new_cur);
400 } else {
401 if (old_prd.defined()) arg_item = memorizer.Transform(arg_item, new_prd, old_prd);
402 arg_item = memorizer.Transform(arg_item, old_cur, new_cur);
403 }
404 return arg_item;
405 };
406
407 DLOG(INFO) << "Transforming layout for `" << ref_call->op << "`";
408 DLOG(INFO) << " old_prd=" << old_prd;
409 DLOG(INFO) << " new_prd=" << new_prd;
410 DLOG(INFO) << " old_cur=" << old_cur;
411 DLOG(INFO) << " new_cur=" << new_cur;
412
413 // if (new_prd != new_cur): insert transform (new_prd -> new_cur)
414 Array<Expr> transformed_args;
415 size_t pt = 0;
416 for (auto arg : new_call->args) {
417 if (arg->IsInstance<TupleNode>()) { // unflatten tuple
418 Tuple tuple_arg = Downcast<Tuple>(arg);
419 Array<Expr> transformed_tuple_arg;
420 transformed_tuple_arg.reserve(tuple_arg->fields.size());
421 for (auto arg_item : tuple_arg->fields) {
422 transformed_tuple_arg.push_back(
423 transform_layout(arg_item, old_prd[pt], old_cur[pt], new_prd[pt], new_cur[pt]));
424 pt++;
425 }
426 transformed_args.push_back(WithFields(tuple_arg, transformed_tuple_arg));
427 } else {
428 transformed_args.push_back(
429 transform_layout(arg, old_prd[pt], old_cur[pt], new_prd[pt], new_cur[pt]));
430 pt++;
431 }
432 }
433 ICHECK_EQ(pt, inputs.size());
434
435 // state[node] = (old_out, new_out)
436 // (handle tuple output)
437 if (ref_call->checked_type()->IsInstance<TupleTypeNode>()) {
438 Expr tuple_output = Call(new_call->op, transformed_args, infer_out->new_attrs);
439 Array<Expr> fields;
440 for (size_t i = 0; i < new_out.size(); ++i) {
441 auto rnode = make_object<LayoutAlternatedExprNode<TransformMemorizerT>>();
442 rnode->value = TupleGetItem(tuple_output, i);
443 rnode->old_layout = old_out[i];
444 rnode->new_layout = new_out[i];
445 rnode->memorizer = memorizer;
446 fields.push_back(Expr(rnode));
447 }
448 return Tuple(fields);
449 } else {
450 auto rnode = make_object<LayoutAlternatedExprNode<TransformMemorizerT>>();
451 ICHECK_EQ(new_out.size(), 1);
452 rnode->value = Call(new_call->op, transformed_args, infer_out->new_attrs, {}, ref_call->span);
453 rnode->old_layout = old_out[0];
454 rnode->new_layout = new_out[0];
455 rnode->memorizer = memorizer;
456 return Expr(rnode);
457 }
458}
459
460} // namespace relay
461} // namespace tvm
462
463#endif // TVM_RELAY_TRANSFORMS_TRANSFORM_LAYOUT_H_
464