1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #include "infer_layout_utils.h" |
21 | |
22 | #include <tvm/relay/expr.h> |
23 | #include <tvm/relay/op_attr_types.h> |
24 | #include <tvm/tir/data_layout.h> |
25 | |
26 | #include <map> |
27 | #include <string> |
28 | #include <tuple> |
29 | #include <utility> |
30 | #include <vector> |
31 | |
32 | #include "pattern_utils.h" |
33 | #include "tvm/runtime/logging.h" |
34 | |
35 | namespace tvm { |
36 | namespace relay { |
37 | |
38 | Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& old_layout, |
39 | const Array<tvm::PrimExpr>& old_shape) { |
40 | // For each subordinate axis |
41 | // 1) Find the corresponding dual axis. |
42 | // 2) Find the Index of this dual axis in old_layout. |
43 | // 3) Find the shape of the that axis in old_shape. |
44 | // 4) a) Adjust factor to 1, if that shape is 1. b) Else retain the factor. |
45 | DLOG(INFO) << "AdjustSubordinateFactors" |
46 | << "src_layout: " << src_layout << " old_layout: " << old_layout |
47 | << " old_shape: " << old_shape << std::endl; |
48 | std::string new_layout; |
49 | for (auto axis : src_layout->axes) { |
50 | if (!LayoutAxis::Get(axis).IsPrimal()) { |
51 | bool is_shape_one = false; |
52 | // 1) Find the corresponding dual axis |
53 | const auto& dual_axis = LayoutAxis::Get(axis).ToPrimal(); |
54 | |
55 | // 2) Find the index of this dual axis in old_layout |
56 | int old_axis = old_layout.IndexOf(dual_axis); |
57 | |
58 | if (old_axis == -1) { |
59 | new_layout += "1" ; |
60 | is_shape_one = true; |
61 | } else { |
62 | // 3) Find the shape of this index in old_shape |
63 | auto shape_val = old_shape[old_axis]; |
64 | |
65 | // 4) a) Check if this shape element is 1. |
66 | if (auto* shape_int = shape_val.as<IntImmNode>()) { |
67 | // We can treat 1 as broadcast only if axis was not split before |
68 | if (shape_int->value == 1 && old_layout.IndexOf(LayoutAxis::Get(axis)) == -1) { |
69 | new_layout += "1" ; |
70 | is_shape_one = true; |
71 | } |
72 | } |
73 | } |
74 | |
75 | // 4) b) If shape is not 1, retain the factor. |
76 | if (!is_shape_one) { |
77 | auto new_shape_val = src_layout.FactorOf(dual_axis); |
78 | new_layout += std::to_string(new_shape_val); |
79 | } |
80 | } |
81 | new_layout += LayoutAxis::Get(axis).name(); |
82 | } |
83 | return new_layout != "" ? Layout(new_layout) |
84 | : Layout("H" ).SubLayout(0, 0); // hack to create a scalar layout |
85 | } |
86 | |
87 | bool Isomorphic(const Layout& lhs, const Layout& rhs) { |
88 | DLOG(INFO) << "Isomorphic: " |
89 | << "lhs: " << lhs << " rhs: " << rhs << std::endl; |
90 | ICHECK(lhs.defined()); |
91 | ICHECK(rhs.defined()); |
92 | if (lhs->axes.size() != rhs->axes.size()) return false; |
93 | std::map<std::string, std::string> map_to, map_back; |
94 | for (size_t i = 0; i < lhs->axes.size(); ++i) { |
95 | auto& lhs_axis = LayoutAxis::Get(lhs->axes[i]); |
96 | auto& rhs_axis = LayoutAxis::Get(rhs->axes[i]); |
97 | std::string name_lhs = lhs_axis.name(); |
98 | std::string name_rhs = rhs_axis.name(); |
99 | if (lhs_axis.IsPrimal() != rhs_axis.IsPrimal()) return false; |
100 | |
101 | auto it = map_to.find(name_lhs); |
102 | if (it == map_to.end()) |
103 | map_to[name_lhs] = name_rhs; |
104 | else if (it->second != name_rhs) |
105 | return false; |
106 | |
107 | it = map_back.find(name_rhs); |
108 | if (it == map_back.end()) |
109 | map_back[name_rhs] = name_lhs; |
110 | else if (it->second != name_lhs) |
111 | return false; |
112 | if (!lhs_axis.IsPrimal() && lhs.FactorOf(lhs_axis) != rhs.FactorOf(rhs_axis)) return false; |
113 | } |
114 | return true; |
115 | } |
116 | |
117 | Layout TryTransformLike(const Layout& old, const Layout& ref_old, const Layout& ref_new) { |
118 | DLOG(INFO) << "transform_layout: old = " << old << ", ref_new = " << ref_new |
119 | << ", ref_old = " << ref_old << std::endl; |
120 | ICHECK(ref_old.defined()); |
121 | ICHECK(ref_new.defined()); |
122 | ICHECK(old.defined()); |
123 | |
124 | { // check if old and ref_old are similar enough such that it's |
125 | // compatible for the transform ref_old -> ref_new |
126 | const Layout& large = ref_old.ndim() > old.ndim() ? ref_old : old; |
127 | const Layout& small = large == ref_old ? old : ref_old; |
128 | Layout large_sublayout = large.SubLayout(large.ndim() - small.ndim(), small.ndim()), |
129 | rest_sublayout = large.SubLayout(0, large.ndim() - small.ndim()); |
130 | bool orthorgonal = true; |
131 | for (auto i : rest_sublayout->axes) |
132 | if (large_sublayout.IndexOf(LayoutAxis::Get(i).ToPrimal()) != -1 || |
133 | large_sublayout.IndexOf(LayoutAxis::Get(i).ToSubordinate()) != -1) { |
134 | orthorgonal = false; |
135 | break; |
136 | } |
137 | if (!orthorgonal || !Isomorphic(large_sublayout, small)) |
138 | return Layout::Undef(); // For now this case is not supported. |
139 | } |
140 | |
141 | // `old` is compatible. Now learn the axis name mapping between `old` and `ref_old` |
142 | if (old.ndim() == 0) return old; // an optmization for scalar: no-op |
143 | std::vector<int> mapping(26, -1); |
144 | std::vector<bool> used(26, false); |
145 | |
146 | auto find_unused = [&](char preference) -> char { |
147 | if (!used[preference - 'A']) return preference; // preference unused |
148 | for (int i = 0; i < 26; ++i) |
149 | if (!used[i]) return 'A' + i; |
150 | LOG(FATAL) << "All letters are used" ; |
151 | }; |
152 | |
153 | for (int j = old->axes.size() - 1, i = ref_old->axes.size() - 1; j >= 0; --i, --j) { |
154 | char name_ref = LayoutAxis::Get(ref_old->axes[i]).ToPrimal().name()[0]; |
155 | char name = LayoutAxis::Get(old->axes[j]).ToPrimal().name()[0]; |
156 | mapping[name_ref - 'A'] = name - 'A'; |
157 | used[name - 'A'] = true; |
158 | } |
159 | |
160 | for (int i = ref_old->axes.size() - 1; i >= 0; --i) { |
161 | char name_ref = LayoutAxis::Get(ref_old->axes[i]).ToPrimal().name()[0]; |
162 | int name = mapping[name_ref - 'A']; |
163 | if (name == -1) { |
164 | mapping[name_ref - 'A'] = find_unused(name_ref) - 'A'; |
165 | used[mapping[name_ref - 'A']] = true; |
166 | } |
167 | } |
168 | |
169 | // apply the mapping to rename `ref_new` |
170 | std::string new_layout; |
171 | for (auto c : std::string(ref_new->name)) { |
172 | if (c >= 'A' && c <= 'Z') { |
173 | ICHECK(mapping[c - 'A'] != -1); |
174 | new_layout += mapping[c - 'A'] + 'A'; |
175 | } else if (c >= 'a' && c <= 'z') { |
176 | ICHECK(mapping[c - 'a'] != -1); |
177 | new_layout += mapping[c - 'a'] + 'a'; |
178 | } else { |
179 | new_layout += c; |
180 | } |
181 | } |
182 | |
183 | DLOG(INFO) << "new_layout = " << new_layout << std::endl; |
184 | return Layout(new_layout); |
185 | } |
186 | |
187 | std::pair<Array<Layout>, Array<Layout>> BinaryBroadcastLayoutHelper( |
188 | const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts, |
189 | const Array<tvm::relay::Type>& old_in_types) { |
190 | // Two steps. Step (2) only executes if the function is called after rewrite. |
191 | // (1) infer input layouts before rewrite |
192 | // (2) if some input layouts are changed by its producer after rewrite, rewrite the other |
193 | // layout to make sure it's changed in the same way, so that they are still broadcastable. |
194 | Array<Layout> layouts; |
195 | Array<Array<IndexExpr>> old_in_shapes; |
196 | for (auto old_in_t : old_in_types) { |
197 | ICHECK(old_in_t.as<TensorTypeNode>()); |
198 | old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape); |
199 | } |
200 | int old_large_idx = old_in_shapes[0].size() >= old_in_shapes[1].size() ? 0 : 1; |
201 | |
202 | layouts.Assign(old_in_layouts.begin(), old_in_layouts.end()); |
203 | // always operate on the original layouts first for consistency |
204 | |
205 | std::pair<Array<Layout>, Array<Layout>> out, |
206 | out_default{{Layout::Undef(), Layout::Undef()}, {Layout::Undef()}}; |
207 | |
208 | if (!layouts[0].defined() && !layouts[1].defined()) { |
209 | // both undefined, infer fails |
210 | out = out_default; |
211 | } else if (!layouts[0].defined() || !layouts[1].defined()) { |
212 | // only one is defined, use shape information to help infer |
213 | int defined_idx = layouts[0].defined() ? 0 : 1; |
214 | int undef_idx = 1 - defined_idx; |
215 | |
216 | if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) { |
217 | // TODO(lazycal): handle the case when the sublayout contains subcoordinate of factor one but |
218 | // the other tensor has the corresponding dimension size other than one. |
219 | // E.g. defined's shape = [x, x, x, x, 1] in NCHW1c and undefined's shape = [3] |
220 | layouts.Set(undef_idx, layouts[defined_idx].SubLayout(old_in_shapes[defined_idx].size() - |
221 | old_in_shapes[undef_idx].size(), |
222 | old_in_shapes[undef_idx].size())); |
223 | out = {layouts, {layouts[defined_idx]}}; |
224 | } else { |
225 | // only know the tensor with smaller dimensions, |
226 | // so we cannot infer the final broadcasted output. |
227 | // fails in this case. |
228 | out = out_default; |
229 | } |
230 | } else { |
231 | // when both are defined, return the larger one |
232 | out = {layouts, {layouts[old_large_idx]}}; |
233 | } |
234 | if (!new_in_layouts.defined()) return out; |
235 | // Step (2) rewrite the layouts to make them broadcastable again. |
236 | Layout ret = new_in_layouts[old_large_idx]; |
237 | int large_idx = new_in_layouts[0].ndim_primal() >= new_in_layouts[1].ndim_primal() ? 0 : 1; |
238 | int small_idx = 1 - large_idx; |
239 | // start adjusting |
240 | |
241 | // Apply a greedy strategy that always transform the small layout in the same way as the |
242 | // large layout is transformed, if possible. |
243 | Layout tgt_layout = |
244 | TryTransformLike(layouts[small_idx], layouts[large_idx], new_in_layouts[large_idx]); |
245 | if (!tgt_layout.defined()) return out_default; // fallback |
246 | |
247 | // Support scenarios where original operands were of type [N, H, W, C] and [N, H, W, 1] |
248 | // In this case, we might have NCHW16c coming for 1 operand. However, the other operand does |
249 | // not have enough C dimension. To reuse broadcasting, we would want to use NCHW1c for the |
250 | // second operand. The following section of code walks through the layouts and shapes to |
251 | // perform that operation. |
252 | // a in NCHWC16c |
253 | // b in NHW1 |
254 | // b = layout_transform(b) from NHW1 -> NCHW1c |
255 | // add(a, b) |
256 | auto old_small_shape = old_in_shapes[small_idx]; |
257 | auto old_small_layout = layouts[small_idx]; |
258 | auto new_small_layout = AdjustSubordinateFactors(tgt_layout, old_small_layout, old_small_shape); |
259 | layouts.Set(large_idx, new_in_layouts[large_idx]); |
260 | layouts.Set(small_idx, new_small_layout); |
261 | return {layouts, {ret}}; |
262 | } |
263 | |
264 | } // namespace relay |
265 | } // namespace tvm |
266 | |