1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#include "infer_layout_utils.h"
21
22#include <tvm/relay/expr.h>
23#include <tvm/relay/op_attr_types.h>
24#include <tvm/tir/data_layout.h>
25
26#include <map>
27#include <string>
28#include <tuple>
29#include <utility>
30#include <vector>
31
32#include "pattern_utils.h"
33#include "tvm/runtime/logging.h"
34
35namespace tvm {
36namespace relay {
37
38Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& old_layout,
39 const Array<tvm::PrimExpr>& old_shape) {
40 // For each subordinate axis
41 // 1) Find the corresponding dual axis.
42 // 2) Find the Index of this dual axis in old_layout.
43 // 3) Find the shape of the that axis in old_shape.
44 // 4) a) Adjust factor to 1, if that shape is 1. b) Else retain the factor.
45 DLOG(INFO) << "AdjustSubordinateFactors"
46 << "src_layout: " << src_layout << " old_layout: " << old_layout
47 << " old_shape: " << old_shape << std::endl;
48 std::string new_layout;
49 for (auto axis : src_layout->axes) {
50 if (!LayoutAxis::Get(axis).IsPrimal()) {
51 bool is_shape_one = false;
52 // 1) Find the corresponding dual axis
53 const auto& dual_axis = LayoutAxis::Get(axis).ToPrimal();
54
55 // 2) Find the index of this dual axis in old_layout
56 int old_axis = old_layout.IndexOf(dual_axis);
57
58 if (old_axis == -1) {
59 new_layout += "1";
60 is_shape_one = true;
61 } else {
62 // 3) Find the shape of this index in old_shape
63 auto shape_val = old_shape[old_axis];
64
65 // 4) a) Check if this shape element is 1.
66 if (auto* shape_int = shape_val.as<IntImmNode>()) {
67 // We can treat 1 as broadcast only if axis was not split before
68 if (shape_int->value == 1 && old_layout.IndexOf(LayoutAxis::Get(axis)) == -1) {
69 new_layout += "1";
70 is_shape_one = true;
71 }
72 }
73 }
74
75 // 4) b) If shape is not 1, retain the factor.
76 if (!is_shape_one) {
77 auto new_shape_val = src_layout.FactorOf(dual_axis);
78 new_layout += std::to_string(new_shape_val);
79 }
80 }
81 new_layout += LayoutAxis::Get(axis).name();
82 }
83 return new_layout != "" ? Layout(new_layout)
84 : Layout("H").SubLayout(0, 0); // hack to create a scalar layout
85}
86
87bool Isomorphic(const Layout& lhs, const Layout& rhs) {
88 DLOG(INFO) << "Isomorphic: "
89 << "lhs: " << lhs << " rhs: " << rhs << std::endl;
90 ICHECK(lhs.defined());
91 ICHECK(rhs.defined());
92 if (lhs->axes.size() != rhs->axes.size()) return false;
93 std::map<std::string, std::string> map_to, map_back;
94 for (size_t i = 0; i < lhs->axes.size(); ++i) {
95 auto& lhs_axis = LayoutAxis::Get(lhs->axes[i]);
96 auto& rhs_axis = LayoutAxis::Get(rhs->axes[i]);
97 std::string name_lhs = lhs_axis.name();
98 std::string name_rhs = rhs_axis.name();
99 if (lhs_axis.IsPrimal() != rhs_axis.IsPrimal()) return false;
100
101 auto it = map_to.find(name_lhs);
102 if (it == map_to.end())
103 map_to[name_lhs] = name_rhs;
104 else if (it->second != name_rhs)
105 return false;
106
107 it = map_back.find(name_rhs);
108 if (it == map_back.end())
109 map_back[name_rhs] = name_lhs;
110 else if (it->second != name_lhs)
111 return false;
112 if (!lhs_axis.IsPrimal() && lhs.FactorOf(lhs_axis) != rhs.FactorOf(rhs_axis)) return false;
113 }
114 return true;
115}
116
117Layout TryTransformLike(const Layout& old, const Layout& ref_old, const Layout& ref_new) {
118 DLOG(INFO) << "transform_layout: old = " << old << ", ref_new = " << ref_new
119 << ", ref_old = " << ref_old << std::endl;
120 ICHECK(ref_old.defined());
121 ICHECK(ref_new.defined());
122 ICHECK(old.defined());
123
124 { // check if old and ref_old are similar enough such that it's
125 // compatible for the transform ref_old -> ref_new
126 const Layout& large = ref_old.ndim() > old.ndim() ? ref_old : old;
127 const Layout& small = large == ref_old ? old : ref_old;
128 Layout large_sublayout = large.SubLayout(large.ndim() - small.ndim(), small.ndim()),
129 rest_sublayout = large.SubLayout(0, large.ndim() - small.ndim());
130 bool orthorgonal = true;
131 for (auto i : rest_sublayout->axes)
132 if (large_sublayout.IndexOf(LayoutAxis::Get(i).ToPrimal()) != -1 ||
133 large_sublayout.IndexOf(LayoutAxis::Get(i).ToSubordinate()) != -1) {
134 orthorgonal = false;
135 break;
136 }
137 if (!orthorgonal || !Isomorphic(large_sublayout, small))
138 return Layout::Undef(); // For now this case is not supported.
139 }
140
141 // `old` is compatible. Now learn the axis name mapping between `old` and `ref_old`
142 if (old.ndim() == 0) return old; // an optmization for scalar: no-op
143 std::vector<int> mapping(26, -1);
144 std::vector<bool> used(26, false);
145
146 auto find_unused = [&](char preference) -> char {
147 if (!used[preference - 'A']) return preference; // preference unused
148 for (int i = 0; i < 26; ++i)
149 if (!used[i]) return 'A' + i;
150 LOG(FATAL) << "All letters are used";
151 };
152
153 for (int j = old->axes.size() - 1, i = ref_old->axes.size() - 1; j >= 0; --i, --j) {
154 char name_ref = LayoutAxis::Get(ref_old->axes[i]).ToPrimal().name()[0];
155 char name = LayoutAxis::Get(old->axes[j]).ToPrimal().name()[0];
156 mapping[name_ref - 'A'] = name - 'A';
157 used[name - 'A'] = true;
158 }
159
160 for (int i = ref_old->axes.size() - 1; i >= 0; --i) {
161 char name_ref = LayoutAxis::Get(ref_old->axes[i]).ToPrimal().name()[0];
162 int name = mapping[name_ref - 'A'];
163 if (name == -1) {
164 mapping[name_ref - 'A'] = find_unused(name_ref) - 'A';
165 used[mapping[name_ref - 'A']] = true;
166 }
167 }
168
169 // apply the mapping to rename `ref_new`
170 std::string new_layout;
171 for (auto c : std::string(ref_new->name)) {
172 if (c >= 'A' && c <= 'Z') {
173 ICHECK(mapping[c - 'A'] != -1);
174 new_layout += mapping[c - 'A'] + 'A';
175 } else if (c >= 'a' && c <= 'z') {
176 ICHECK(mapping[c - 'a'] != -1);
177 new_layout += mapping[c - 'a'] + 'a';
178 } else {
179 new_layout += c;
180 }
181 }
182
183 DLOG(INFO) << "new_layout = " << new_layout << std::endl;
184 return Layout(new_layout);
185}
186
187std::pair<Array<Layout>, Array<Layout>> BinaryBroadcastLayoutHelper(
188 const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
189 const Array<tvm::relay::Type>& old_in_types) {
190 // Two steps. Step (2) only executes if the function is called after rewrite.
191 // (1) infer input layouts before rewrite
192 // (2) if some input layouts are changed by its producer after rewrite, rewrite the other
193 // layout to make sure it's changed in the same way, so that they are still broadcastable.
194 Array<Layout> layouts;
195 Array<Array<IndexExpr>> old_in_shapes;
196 for (auto old_in_t : old_in_types) {
197 ICHECK(old_in_t.as<TensorTypeNode>());
198 old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
199 }
200 int old_large_idx = old_in_shapes[0].size() >= old_in_shapes[1].size() ? 0 : 1;
201
202 layouts.Assign(old_in_layouts.begin(), old_in_layouts.end());
203 // always operate on the original layouts first for consistency
204
205 std::pair<Array<Layout>, Array<Layout>> out,
206 out_default{{Layout::Undef(), Layout::Undef()}, {Layout::Undef()}};
207
208 if (!layouts[0].defined() && !layouts[1].defined()) {
209 // both undefined, infer fails
210 out = out_default;
211 } else if (!layouts[0].defined() || !layouts[1].defined()) {
212 // only one is defined, use shape information to help infer
213 int defined_idx = layouts[0].defined() ? 0 : 1;
214 int undef_idx = 1 - defined_idx;
215
216 if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) {
217 // TODO(lazycal): handle the case when the sublayout contains subcoordinate of factor one but
218 // the other tensor has the corresponding dimension size other than one.
219 // E.g. defined's shape = [x, x, x, x, 1] in NCHW1c and undefined's shape = [3]
220 layouts.Set(undef_idx, layouts[defined_idx].SubLayout(old_in_shapes[defined_idx].size() -
221 old_in_shapes[undef_idx].size(),
222 old_in_shapes[undef_idx].size()));
223 out = {layouts, {layouts[defined_idx]}};
224 } else {
225 // only know the tensor with smaller dimensions,
226 // so we cannot infer the final broadcasted output.
227 // fails in this case.
228 out = out_default;
229 }
230 } else {
231 // when both are defined, return the larger one
232 out = {layouts, {layouts[old_large_idx]}};
233 }
234 if (!new_in_layouts.defined()) return out;
235 // Step (2) rewrite the layouts to make them broadcastable again.
236 Layout ret = new_in_layouts[old_large_idx];
237 int large_idx = new_in_layouts[0].ndim_primal() >= new_in_layouts[1].ndim_primal() ? 0 : 1;
238 int small_idx = 1 - large_idx;
239 // start adjusting
240
241 // Apply a greedy strategy that always transform the small layout in the same way as the
242 // large layout is transformed, if possible.
243 Layout tgt_layout =
244 TryTransformLike(layouts[small_idx], layouts[large_idx], new_in_layouts[large_idx]);
245 if (!tgt_layout.defined()) return out_default; // fallback
246
247 // Support scenarios where original operands were of type [N, H, W, C] and [N, H, W, 1]
248 // In this case, we might have NCHW16c coming for 1 operand. However, the other operand does
249 // not have enough C dimension. To reuse broadcasting, we would want to use NCHW1c for the
250 // second operand. The following section of code walks through the layouts and shapes to
251 // perform that operation.
252 // a in NCHWC16c
253 // b in NHW1
254 // b = layout_transform(b) from NHW1 -> NCHW1c
255 // add(a, b)
256 auto old_small_shape = old_in_shapes[small_idx];
257 auto old_small_layout = layouts[small_idx];
258 auto new_small_layout = AdjustSubordinateFactors(tgt_layout, old_small_layout, old_small_shape);
259 layouts.Set(large_idx, new_in_layouts[large_idx]);
260 layouts.Set(small_idx, new_small_layout);
261 return {layouts, {ret}};
262}
263
264} // namespace relay
265} // namespace tvm
266