1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/tir/data_layout.h |
22 | * \brief Layout expression to describe the data organization of a tensor. |
23 | * And BijectiveLayout to mapping two data layouts between each other. |
24 | */ |
25 | #ifndef TVM_TIR_DATA_LAYOUT_H_ |
26 | #define TVM_TIR_DATA_LAYOUT_H_ |
27 | |
28 | #include <tvm/tir/expr.h> |
29 | #include <tvm/tir/op.h> |
30 | |
31 | #include <algorithm> |
32 | #include <sstream> |
33 | #include <string> |
34 | #include <utility> |
35 | #include <vector> |
36 | |
37 | namespace tvm { |
38 | namespace tir { |
39 | |
40 | class Layout; |
41 | |
42 | class LayoutAxis { |
43 | public: |
44 | static const LayoutAxis& Get(const char name); |
45 | |
46 | // Get the singleton LayoutAxis using itvar->var->name_hint |
47 | static const LayoutAxis& Get(const tir::IterVar& itvar); |
48 | |
49 | // Get the singleton LayoutAxis using name[0] (size of name must be 1). |
50 | static const LayoutAxis& Get(const std::string& name); |
51 | |
52 | inline bool IsPrimal() const { return name_ >= 'A' && name_ <= 'Z'; } |
53 | inline std::string name() const { return std::string(1, name_); } |
54 | |
55 | // if current axis is primal, switch the axis to its subordinate one, |
56 | // else switch to the primal. |
57 | inline const LayoutAxis& ToDual() const { |
58 | if (name_ >= 'A' && name_ <= 'Z') { |
59 | return LayoutAxis::Get(name_ - 'A' + 'a'); |
60 | } else { |
61 | return LayoutAxis::Get(name_ - 'a' + 'A'); |
62 | } |
63 | } |
64 | |
65 | // return the primal axis. If it is already primal, return itself. |
66 | const LayoutAxis& ToPrimal() const { return IsPrimal() ? *this : ToDual(); } |
67 | |
68 | // return the subordinate axis. If it is already subordinate, return itself. |
69 | const LayoutAxis& ToSubordinate() const { return IsPrimal() ? ToDual() : *this; } |
70 | |
71 | inline bool operator==(const LayoutAxis& rhs) const { return name_ == rhs.name_; } |
72 | |
73 | friend std::ostream& operator<<(std::ostream& os, const LayoutAxis& l) { |
74 | os << l.name(); |
75 | return os; |
76 | } |
77 | |
78 | private: |
79 | static const LayoutAxis UPPER_CASE[]; |
80 | static const LayoutAxis LOWER_CASE[]; |
81 | LayoutAxis(const LayoutAxis&); |
82 | LayoutAxis& operator=(const LayoutAxis&); |
83 | explicit LayoutAxis(const char name) : name_(name) {} |
84 | |
85 | const char name_; |
86 | }; |
87 | |
88 | /*! |
89 | * \brief Layout is to describe how data is organized within an N-dimention tensor. |
90 | * It is composed of upper cases, lower cases and numbers, |
91 | * where upper case indicates a primal axis and |
92 | * the corresponding lower case with factor size indicates the subordinate axis. |
93 | * For example, NCHW16c can describe a 5-D tensor of |
94 | * [batch_size, channel, height, width, channel_block]. |
95 | * Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel). |
96 | * Layout for scalar is defined, while both its name and axes have size 0. |
97 | */ |
98 | class LayoutNode : public Object { |
99 | public: |
100 | /*! \brief string representation of layout, "" for scalar. */ |
101 | String name; |
102 | /*! \brief specify each axis of the layout, |
103 | * in which the variable name is the name of the axis. |
104 | * The IterVar's extent indicates the size of the axis, |
105 | * it is a variable for a primal axis, but a constant for a subordinate axis. |
106 | * Empty for scalar's layout. |
107 | */ |
108 | Array<tir::IterVar> axes; |
109 | |
110 | void VisitAttrs(AttrVisitor* v) { |
111 | v->Visit("name" , &name); |
112 | v->Visit("axes" , &axes); |
113 | } |
114 | |
115 | static constexpr const char* _type_key = "tir.Layout" ; |
116 | TVM_DECLARE_FINAL_OBJECT_INFO(LayoutNode, Object); |
117 | }; |
118 | |
119 | /*! |
120 | * \brief Managed reference to LayoutNode |
121 | * \sa LayoutNode |
122 | */ |
123 | class Layout : public ObjectRef { |
124 | public: |
125 | explicit Layout(const Array<tir::IterVar>& axes); |
126 | |
127 | /*! \brief construct from a string */ |
128 | Layout(const tvm::String& name) : Layout(name.operator std::string()) {} // NOLINT(*) |
129 | |
130 | /*! \brief construct from a string */ |
131 | Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*) |
132 | |
133 | /*! |
134 | * \brief construct from a string. |
135 | * \param name input in layout convention: |
136 | * upper case indicates a dimension and |
137 | * the corresponding lower case with factor size |
138 | * indicates the split dimension. |
139 | * return undefined layout if "__undef__" is passed. |
140 | * \param dtype The dtype of generated axes vars in the returned layout. |
141 | * It is required to be integer type. |
142 | */ |
143 | TVM_DLL Layout(const std::string& name, DataType dtype = DataType::Int(32)); // NOLINT(*) |
144 | |
145 | /*! |
146 | * \brief access the internal node container |
147 | * \return the pointer to the internal node container |
148 | */ |
149 | LayoutNode* operator->() { return static_cast<LayoutNode*>(get_mutable()); } |
150 | |
151 | /*! |
152 | * \brief Return an undefined layout. |
153 | * \return a (global) undefined layout. |
154 | */ |
155 | static const Layout& Undef() { |
156 | static Layout undef; |
157 | return undef; |
158 | } |
159 | |
160 | /*! |
161 | * \brief Returns a sub-layout which is the portion of the object |
162 | * that starts at dimension \p pos and spans \p len dimensions |
163 | * (or until the end of the layout, whichever comes first). |
164 | * \param pos The start position. |
165 | * \param len The length of the sub-layout. if 0, return layout of scalar |
166 | * \return A newly constructed Layout object. |
167 | */ |
168 | Layout SubLayout(size_t pos, size_t len) const; |
169 | |
170 | /*! |
171 | * \brief Split \p axis by \p size and put the sub-axis to position \p target_pos. |
172 | * \param axis The source axis to be split. It must be a primal-axis; |
173 | * \param target_pos The target position of the newly split subordinate-axis. |
174 | * \param factor size of the sub-dimension. |
175 | * \return A newly constructed Layout object. |
176 | */ |
177 | Layout Split(const LayoutAxis& axis, size_t target_pos, int32_t factor) const; |
178 | |
179 | /*! \return number of dimensions */ |
180 | inline size_t ndim() const { |
181 | if (!defined()) return 0; |
182 | return operator->()->axes.size(); |
183 | } |
184 | |
185 | /*! \return number of super dimensions */ |
186 | inline size_t ndim_primal() const { |
187 | if (!defined()) return 0; |
188 | size_t ct = 0; |
189 | for (auto x : operator->()->axes) { |
190 | if (LayoutAxis::Get(x).IsPrimal()) { |
191 | ct++; |
192 | } |
193 | } |
194 | return ct; |
195 | } |
196 | |
197 | /*! |
198 | * \brief Returns a new layout where the dims have been expanded to match the primal dimensions. |
199 | * \param dst_layout The dst layout to which current layout has to be expanded. |
200 | * \return The expanded Layout. |
201 | */ |
202 | inline Layout ExpandPrimal(const Layout& dst_layout) { |
203 | Layout new_src_layout; |
204 | // 1) Find the axis which are missing in the current layout. Make them the prefix. |
205 | std::string new_src_layout_str = "" ; |
206 | for (auto dst_axis : dst_layout->axes) { |
207 | if (LayoutAxis::Get(dst_axis).IsPrimal()) { |
208 | if (!this->Contains(LayoutAxis::Get(dst_axis))) { |
209 | new_src_layout_str += dst_axis->var->name_hint; |
210 | } |
211 | } |
212 | } |
213 | // 2) Now, add the primal axis of the current layout. |
214 | new_src_layout_str += this->name(); |
215 | new_src_layout = Layout(new_src_layout_str); |
216 | return new_src_layout; |
217 | } |
218 | |
219 | /*! |
220 | * \brief return the index of the input axis. |
221 | * If it is not found in the layout or the layout is undefined, |
222 | * return -1. |
223 | * \param axis the input axis. |
224 | * \return the index or -1 if not found. |
225 | */ |
226 | inline int32_t IndexOf(const LayoutAxis& axis) const { |
227 | if (!this->defined()) return -1; |
228 | const auto axes = operator->()->axes; |
229 | for (size_t i = 0; i < axes.size(); ++i) { |
230 | if (axes[i]->var->name_hint == axis.name()) return static_cast<int32_t>(i); |
231 | } |
232 | return -1; |
233 | } |
234 | |
235 | /*! |
236 | * \brief Get the factor size of the subordinate axis. |
237 | * \param axis the input primal-axis or subordinate-axis. |
238 | * \return the size of the subordinate-axis of \p axis (if \p axis is a primal-axis), |
239 | * or the size of \p axis itself (if \p axis is a subordinate-axis). |
240 | * Return -1 if \p axis is not in the layout the layout is undefined. |
241 | */ |
242 | int32_t FactorOf(const LayoutAxis& axis) const; |
243 | |
244 | /*! |
245 | * \brief Whether the layout contains an axis. |
246 | * \param axis axis to be checked. |
247 | * \return Whether the layout contains the axis. |
248 | */ |
249 | bool Contains(const LayoutAxis& axis) const { |
250 | if (!defined()) return false; |
251 | for (const tir::IterVar var : operator->()->axes) { |
252 | if (var->var->name_hint == axis.name()) { |
253 | return true; |
254 | } |
255 | } |
256 | return false; |
257 | } |
258 | |
259 | const LayoutAxis& operator[](int32_t i) const { |
260 | ICHECK(defined()) << "Try to access axis from an undefined layout." ; |
261 | int32_t index = i < 0 ? static_cast<int32_t>(ndim() + i) : i; |
262 | ICHECK(index >= 0 && static_cast<size_t>(index) < ndim()) << "Invalid index " << i; |
263 | const tir::IterVar axis = operator->()->axes[index]; |
264 | return LayoutAxis::Get(axis); |
265 | } |
266 | |
267 | /*! \return the string description of the layout */ |
268 | inline std::string name() const { |
269 | if (!defined()) return "__undef__" ; |
270 | return operator->()->name; |
271 | } |
272 | |
273 | /*! |
274 | * \brief Whether the two layouts are equal. |
275 | * \param rhs Another layout. |
276 | * \return whether the two layouts are equal. |
277 | */ |
278 | inline bool Equals(const Layout& rhs) const { return name() == rhs.name(); } |
279 | |
280 | /*! |
281 | * \brief allow output string of layout to ostream |
282 | * \param os the output stream |
283 | * \param l the layout |
284 | * \return the ostream |
285 | */ |
286 | friend std::ostream& operator<<(std::ostream& os, const Layout& l) { |
287 | os << l.name(); |
288 | return os; |
289 | } |
290 | |
291 | TVM_DEFINE_OBJECT_REF_METHODS(Layout, ObjectRef, LayoutNode); |
292 | }; |
293 | |
294 | // Internal node container BijectiveLayout |
295 | class BijectiveLayoutNode : public Object { |
296 | public: |
297 | /*! \brief Describes how source axes can be mapped to the destination axes, |
298 | * e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n |
299 | */ |
300 | Array<PrimExpr> index_forward_rule; |
301 | /*! \brief Describes how destination axes can be mapped to the source axes */ |
302 | Array<PrimExpr> index_backward_rule; |
303 | /*! \brief Describes how source shapes can be mapped to the destination shapes */ |
304 | Array<PrimExpr> shape_forward_rule; |
305 | /*! \brief Describes how destination shapes can be mapped to the source shapes */ |
306 | Array<PrimExpr> shape_backward_rule; |
307 | |
308 | /*! \brief The source layout */ |
309 | Layout src_layout; |
310 | /*! \brief The destination layout */ |
311 | Layout dst_layout; |
312 | |
313 | void VisitAttrs(AttrVisitor* v) { |
314 | v->Visit("src_layout" , &src_layout); |
315 | v->Visit("dst_layout" , &dst_layout); |
316 | v->Visit("index_forward_rule" , &index_forward_rule); |
317 | v->Visit("index_backward_rule" , &index_backward_rule); |
318 | v->Visit("shape_forward_rule" , &shape_forward_rule); |
319 | v->Visit("shape_backward_rule" , &shape_backward_rule); |
320 | } |
321 | |
322 | static constexpr const char* _type_key = "tir.BijectiveLayout" ; |
323 | TVM_DECLARE_FINAL_OBJECT_INFO(BijectiveLayoutNode, Object); |
324 | }; |
325 | |
326 | /*! |
327 | * \brief Bijective function mapping for data layout transformation. |
328 | * Given two Layout, BijectiveLayout build and store the mapping rules, |
329 | * provides API to transform N-dimention tensor from the source indices (i0, i1, .., im) |
330 | * to the destination indices (j0, j1, .., jm). |
331 | */ |
332 | class BijectiveLayout : public ObjectRef { |
333 | public: |
334 | /*! |
335 | * \brief The constructor |
336 | * \param src_layout The source layout |
337 | * \param dst_layout The destination layout |
338 | */ |
339 | TVM_DLL BijectiveLayout(Layout src_layout, Layout dst_layout); |
340 | |
341 | // Given the source shape, infer the destination shape. |
342 | TVM_DLL Array<PrimExpr> ForwardShape(const Array<PrimExpr>& shape) const; |
343 | // Given the destination shape, recover the source shape. |
344 | TVM_DLL Array<PrimExpr> BackwardShape(const Array<PrimExpr>& dst_shape) const; |
345 | // Given the destination indices, infer the destination indices. |
346 | TVM_DLL Array<PrimExpr> ForwardIndex(const Array<PrimExpr>& index) const; |
347 | // Given the destination indices, recover the source indices. |
348 | TVM_DLL Array<PrimExpr> BackwardIndex(const Array<PrimExpr>& dst_index) const; |
349 | |
350 | TVM_DEFINE_OBJECT_REF_METHODS(BijectiveLayout, ObjectRef, BijectiveLayoutNode); |
351 | }; |
352 | |
353 | } // namespace tir |
354 | } // namespace tvm |
355 | |
356 | #endif // TVM_TIR_DATA_LAYOUT_H_ |
357 | |