1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file infer_layout_utils.h |
22 | * \brief Utility functions to alter the layouts of operators or replace primitive operators with |
23 | other expressions. This pass can be used for computing convolution in |
24 | custom layouts or other general weight pre-transformation. |
25 | */ |
26 | |
27 | #ifndef TVM_RELAY_TRANSFORMS_INFER_LAYOUT_UTILS_H_ |
28 | #define TVM_RELAY_TRANSFORMS_INFER_LAYOUT_UTILS_H_ |
29 | |
30 | #include <tvm/relay/expr.h> |
31 | #include <tvm/relay/op_attr_types.h> |
32 | #include <tvm/tir/data_layout.h> |
33 | |
34 | #include <string> |
35 | #include <tuple> |
36 | #include <utility> |
37 | |
38 | #include "pattern_utils.h" |
39 | |
40 | namespace tvm { |
41 | namespace relay { |
42 | |
43 | /*! |
44 | * \brief Returns a new layout where the subordinate factors are adjusted based on the tensor |
45 | * shape. |
46 | * \param old_layout The old layout before any transformation. |
47 | * \param old_shape The shape of the original tensor. |
48 | * \return The adjusted Layout. |
49 | */ |
50 | Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& old_layout, |
51 | const Array<tvm::PrimExpr>& old_shape); |
52 | |
53 | bool Isomorphic(const Layout& lhs, const Layout& rhs); |
54 | |
55 | /*! |
56 | * \brief Try transforming `old` in as the smae way as how`ref_old` is transformed to `ref_new`. |
57 | * `old` and `ref_old` are expected to describe two broadcastable tensors. Layout with fewer rank |
58 | * will be expanded. For example, |
59 | * if old = 'NW', ref_old = 'NC', ref_new = 'NC1c', then the result is 'NW1w'; |
60 | * if old = 'W', ref_old = 'NC', ref_new = 'NC1c', then the result is 'NW1w'. |
61 | * When `old` and `ref_old` are isomorphic (same structure, only differ in naming), the transform |
62 | * is guaranteed to succeed, in which case the function is simply renaming the axes of `ref_new` |
63 | * to conform to `old`'s naming. |
64 | * \param old The layout to be transformed. |
65 | * \param ref_old The reference layout before transform. |
66 | * \param ref_new The reference layout after transform. |
67 | * \return The transformed layout. |
68 | */ |
69 | Layout TryTransformLike(const Layout& old, const Layout& ref_old, const Layout& ref_new); |
70 | |
71 | /* |
72 | * \brief An output structure to hold results from FInferCorrectLayout calls. |
73 | * \tparam input_layouts Inferred input layouts. |
74 | * \tparam output_layouts Inferred output layouts. |
75 | * \tparam new_attrs Updated attributes consistent with inferred layouts. |
76 | */ |
77 | class InferCorrectLayoutOutputNode : public Object { |
78 | public: |
79 | Array<Layout> input_layouts; |
80 | Array<Layout> output_layouts; |
81 | Attrs new_attrs; |
82 | |
83 | void VisitAttrs(tvm::AttrVisitor* v) { |
84 | v->Visit("input_layouts" , &input_layouts); |
85 | v->Visit("output_layouts" , &output_layouts); |
86 | v->Visit("new_attrs" , &new_attrs); |
87 | } |
88 | |
89 | TVM_DECLARE_BASE_OBJECT_INFO(InferCorrectLayoutOutputNode, Object); |
90 | |
91 | static constexpr const char* _type_key = "relay._transform.InferCorrectLayoutOutput" ; |
92 | }; |
93 | |
94 | class InferCorrectLayoutOutput : public ObjectRef { |
95 | public: |
96 | InferCorrectLayoutOutput(Array<Layout> input_layouts, Array<Layout> output_layouts, |
97 | Attrs new_attrs) { |
98 | auto n = make_object<InferCorrectLayoutOutputNode>(); |
99 | n->input_layouts = std::move(input_layouts); |
100 | n->output_layouts = std::move(output_layouts); |
101 | n->new_attrs = std::move(new_attrs); |
102 | data_ = n; |
103 | } |
104 | TVM_DEFINE_OBJECT_REF_METHODS(InferCorrectLayoutOutput, ObjectRef, InferCorrectLayoutOutputNode); |
105 | }; |
106 | |
107 | /*! |
108 | * \brief Infer & correct function of node layout. See \p Layout for layout convention |
109 | * \param attrs The attribute of the node. |
110 | * \param new_in_layouts The layouts of input arguments after alter_op_layout. |
111 | * This can be undefined, which means we call this function before alternating |
112 | * any operators. |
113 | * \param old_in_layouts The layouts of input arguments before alter_op_layout. |
114 | * \param old_in_types The types of old input arguments. |
115 | * \return infer_layout_output Inferred layouts and updated attributes stored in |
116 | * InferCorrectLayoutOutput above. |
117 | */ |
118 | using FInferCorrectLayout = runtime::TypedPackedFunc<InferCorrectLayoutOutput( |
119 | const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts, |
120 | const Array<tvm::relay::Type>& old_in_types)>; |
121 | |
122 | inline InferCorrectLayoutOutput ElemwiseArbitraryLayout( |
123 | const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts, |
124 | const Array<tvm::relay::Type>& old_in_types) { |
125 | Layout ret; |
126 | |
127 | if (new_in_layouts.defined()) { |
128 | ICHECK_GE(new_in_layouts.size(), 1); |
129 | ret = new_in_layouts[0]; |
130 | } else { |
131 | for (size_t i = 0; i < old_in_layouts.size(); ++i) { |
132 | if (old_in_layouts[i].defined()) { |
133 | ret = old_in_layouts[i]; |
134 | break; |
135 | } |
136 | } |
137 | } |
138 | |
139 | return InferCorrectLayoutOutput(Array<Layout>(old_in_layouts.size(), ret), {ret}, attrs); |
140 | } |
141 | |
142 | std::pair<Array<Layout>, Array<Layout>> BinaryBroadcastLayoutHelper( |
143 | const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts, |
144 | const Array<tvm::relay::Type>& old_in_types); |
145 | |
146 | /*! \brief Infer layout for binary broadcast operators */ |
147 | inline InferCorrectLayoutOutput BinaryBroadcastLayout(const Attrs& attrs, |
148 | const Array<Layout>& new_in_layouts, |
149 | const Array<Layout>& old_in_layouts, |
150 | const Array<tvm::relay::Type>& old_in_types) { |
151 | auto inferred_layout = |
152 | BinaryBroadcastLayoutHelper(attrs, new_in_layouts, old_in_layouts, old_in_types); |
153 | return InferCorrectLayoutOutput(inferred_layout.first, inferred_layout.second, attrs); |
154 | } |
155 | |
156 | } // namespace relay |
157 | } // namespace tvm |
158 | |
159 | #endif // TVM_RELAY_TRANSFORMS_INFER_LAYOUT_UTILS_H_ |
160 | |