1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file infer_layout_utils.h
22 * \brief Utility functions to alter the layouts of operators or replace primitive operators with
23 other expressions. This pass can be used for computing convolution in
24 custom layouts or other general weight pre-transformation.
25 */
26
27#ifndef TVM_RELAY_TRANSFORMS_INFER_LAYOUT_UTILS_H_
28#define TVM_RELAY_TRANSFORMS_INFER_LAYOUT_UTILS_H_
29
30#include <tvm/relay/expr.h>
31#include <tvm/relay/op_attr_types.h>
32#include <tvm/tir/data_layout.h>
33
34#include <string>
35#include <tuple>
36#include <utility>
37
38#include "pattern_utils.h"
39
40namespace tvm {
41namespace relay {
42
43/*!
44 * \brief Returns a new layout where the subordinate factors are adjusted based on the tensor
45 * shape.
46 * \param old_layout The old layout before any transformation.
47 * \param old_shape The shape of the original tensor.
48 * \return The adjusted Layout.
49 */
50Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& old_layout,
51 const Array<tvm::PrimExpr>& old_shape);
52
53bool Isomorphic(const Layout& lhs, const Layout& rhs);
54
55/*!
56 * \brief Try transforming `old` in as the smae way as how`ref_old` is transformed to `ref_new`.
57 * `old` and `ref_old` are expected to describe two broadcastable tensors. Layout with fewer rank
58 * will be expanded. For example,
59 * if old = 'NW', ref_old = 'NC', ref_new = 'NC1c', then the result is 'NW1w';
60 * if old = 'W', ref_old = 'NC', ref_new = 'NC1c', then the result is 'NW1w'.
61 * When `old` and `ref_old` are isomorphic (same structure, only differ in naming), the transform
62 * is guaranteed to succeed, in which case the function is simply renaming the axes of `ref_new`
63 * to conform to `old`'s naming.
64 * \param old The layout to be transformed.
65 * \param ref_old The reference layout before transform.
66 * \param ref_new The reference layout after transform.
67 * \return The transformed layout.
68 */
69Layout TryTransformLike(const Layout& old, const Layout& ref_old, const Layout& ref_new);
70
71/*
72 * \brief An output structure to hold results from FInferCorrectLayout calls.
73 * \tparam input_layouts Inferred input layouts.
74 * \tparam output_layouts Inferred output layouts.
75 * \tparam new_attrs Updated attributes consistent with inferred layouts.
76 */
77class InferCorrectLayoutOutputNode : public Object {
78 public:
79 Array<Layout> input_layouts;
80 Array<Layout> output_layouts;
81 Attrs new_attrs;
82
83 void VisitAttrs(tvm::AttrVisitor* v) {
84 v->Visit("input_layouts", &input_layouts);
85 v->Visit("output_layouts", &output_layouts);
86 v->Visit("new_attrs", &new_attrs);
87 }
88
89 TVM_DECLARE_BASE_OBJECT_INFO(InferCorrectLayoutOutputNode, Object);
90
91 static constexpr const char* _type_key = "relay._transform.InferCorrectLayoutOutput";
92};
93
94class InferCorrectLayoutOutput : public ObjectRef {
95 public:
96 InferCorrectLayoutOutput(Array<Layout> input_layouts, Array<Layout> output_layouts,
97 Attrs new_attrs) {
98 auto n = make_object<InferCorrectLayoutOutputNode>();
99 n->input_layouts = std::move(input_layouts);
100 n->output_layouts = std::move(output_layouts);
101 n->new_attrs = std::move(new_attrs);
102 data_ = n;
103 }
104 TVM_DEFINE_OBJECT_REF_METHODS(InferCorrectLayoutOutput, ObjectRef, InferCorrectLayoutOutputNode);
105};
106
107/*!
108 * \brief Infer & correct function of node layout. See \p Layout for layout convention
109 * \param attrs The attribute of the node.
110 * \param new_in_layouts The layouts of input arguments after alter_op_layout.
111 * This can be undefined, which means we call this function before alternating
112 * any operators.
113 * \param old_in_layouts The layouts of input arguments before alter_op_layout.
114 * \param old_in_types The types of old input arguments.
115 * \return infer_layout_output Inferred layouts and updated attributes stored in
116 * InferCorrectLayoutOutput above.
117 */
118using FInferCorrectLayout = runtime::TypedPackedFunc<InferCorrectLayoutOutput(
119 const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
120 const Array<tvm::relay::Type>& old_in_types)>;
121
122inline InferCorrectLayoutOutput ElemwiseArbitraryLayout(
123 const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
124 const Array<tvm::relay::Type>& old_in_types) {
125 Layout ret;
126
127 if (new_in_layouts.defined()) {
128 ICHECK_GE(new_in_layouts.size(), 1);
129 ret = new_in_layouts[0];
130 } else {
131 for (size_t i = 0; i < old_in_layouts.size(); ++i) {
132 if (old_in_layouts[i].defined()) {
133 ret = old_in_layouts[i];
134 break;
135 }
136 }
137 }
138
139 return InferCorrectLayoutOutput(Array<Layout>(old_in_layouts.size(), ret), {ret}, attrs);
140}
141
142std::pair<Array<Layout>, Array<Layout>> BinaryBroadcastLayoutHelper(
143 const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
144 const Array<tvm::relay::Type>& old_in_types);
145
146/*! \brief Infer layout for binary broadcast operators */
147inline InferCorrectLayoutOutput BinaryBroadcastLayout(const Attrs& attrs,
148 const Array<Layout>& new_in_layouts,
149 const Array<Layout>& old_in_layouts,
150 const Array<tvm::relay::Type>& old_in_types) {
151 auto inferred_layout =
152 BinaryBroadcastLayoutHelper(attrs, new_in_layouts, old_in_layouts, old_in_types);
153 return InferCorrectLayoutOutput(inferred_layout.first, inferred_layout.second, attrs);
154}
155
156} // namespace relay
157} // namespace tvm
158
159#endif // TVM_RELAY_TRANSFORMS_INFER_LAYOUT_UTILS_H_
160