1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/tir/data_layout.h
22 * \brief Layout expression to describe the data organization of a tensor.
23 * And BijectiveLayout to mapping two data layouts between each other.
24 */
25#ifndef TVM_TIR_DATA_LAYOUT_H_
26#define TVM_TIR_DATA_LAYOUT_H_
27
28#include <tvm/tir/expr.h>
29#include <tvm/tir/op.h>
30
31#include <algorithm>
32#include <sstream>
33#include <string>
34#include <utility>
35#include <vector>
36
37namespace tvm {
38namespace tir {
39
40class Layout;
41
42class LayoutAxis {
43 public:
44 static const LayoutAxis& Get(const char name);
45
46 // Get the singleton LayoutAxis using itvar->var->name_hint
47 static const LayoutAxis& Get(const tir::IterVar& itvar);
48
49 // Get the singleton LayoutAxis using name[0] (size of name must be 1).
50 static const LayoutAxis& Get(const std::string& name);
51
52 inline bool IsPrimal() const { return name_ >= 'A' && name_ <= 'Z'; }
53 inline std::string name() const { return std::string(1, name_); }
54
55 // if current axis is primal, switch the axis to its subordinate one,
56 // else switch to the primal.
57 inline const LayoutAxis& ToDual() const {
58 if (name_ >= 'A' && name_ <= 'Z') {
59 return LayoutAxis::Get(name_ - 'A' + 'a');
60 } else {
61 return LayoutAxis::Get(name_ - 'a' + 'A');
62 }
63 }
64
65 // return the primal axis. If it is already primal, return itself.
66 const LayoutAxis& ToPrimal() const { return IsPrimal() ? *this : ToDual(); }
67
68 // return the subordinate axis. If it is already subordinate, return itself.
69 const LayoutAxis& ToSubordinate() const { return IsPrimal() ? ToDual() : *this; }
70
71 inline bool operator==(const LayoutAxis& rhs) const { return name_ == rhs.name_; }
72
73 friend std::ostream& operator<<(std::ostream& os, const LayoutAxis& l) {
74 os << l.name();
75 return os;
76 }
77
78 private:
79 static const LayoutAxis UPPER_CASE[];
80 static const LayoutAxis LOWER_CASE[];
81 LayoutAxis(const LayoutAxis&);
82 LayoutAxis& operator=(const LayoutAxis&);
83 explicit LayoutAxis(const char name) : name_(name) {}
84
85 const char name_;
86};
87
88/*!
89 * \brief Layout is to describe how data is organized within an N-dimention tensor.
90 * It is composed of upper cases, lower cases and numbers,
91 * where upper case indicates a primal axis and
92 * the corresponding lower case with factor size indicates the subordinate axis.
93 * For example, NCHW16c can describe a 5-D tensor of
94 * [batch_size, channel, height, width, channel_block].
95 * Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
96 * Layout for scalar is defined, while both its name and axes have size 0.
97 */
98class LayoutNode : public Object {
99 public:
100 /*! \brief string representation of layout, "" for scalar. */
101 String name;
102 /*! \brief specify each axis of the layout,
103 * in which the variable name is the name of the axis.
104 * The IterVar's extent indicates the size of the axis,
105 * it is a variable for a primal axis, but a constant for a subordinate axis.
106 * Empty for scalar's layout.
107 */
108 Array<tir::IterVar> axes;
109
110 void VisitAttrs(AttrVisitor* v) {
111 v->Visit("name", &name);
112 v->Visit("axes", &axes);
113 }
114
115 static constexpr const char* _type_key = "tir.Layout";
116 TVM_DECLARE_FINAL_OBJECT_INFO(LayoutNode, Object);
117};
118
119/*!
120 * \brief Managed reference to LayoutNode
121 * \sa LayoutNode
122 */
123class Layout : public ObjectRef {
124 public:
125 explicit Layout(const Array<tir::IterVar>& axes);
126
127 /*! \brief construct from a string */
128 Layout(const tvm::String& name) : Layout(name.operator std::string()) {} // NOLINT(*)
129
130 /*! \brief construct from a string */
131 Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*)
132
133 /*!
134 * \brief construct from a string.
135 * \param name input in layout convention:
136 * upper case indicates a dimension and
137 * the corresponding lower case with factor size
138 * indicates the split dimension.
139 * return undefined layout if "__undef__" is passed.
140 * \param dtype The dtype of generated axes vars in the returned layout.
141 * It is required to be integer type.
142 */
143 TVM_DLL Layout(const std::string& name, DataType dtype = DataType::Int(32)); // NOLINT(*)
144
145 /*!
146 * \brief access the internal node container
147 * \return the pointer to the internal node container
148 */
149 LayoutNode* operator->() { return static_cast<LayoutNode*>(get_mutable()); }
150
151 /*!
152 * \brief Return an undefined layout.
153 * \return a (global) undefined layout.
154 */
155 static const Layout& Undef() {
156 static Layout undef;
157 return undef;
158 }
159
160 /*!
161 * \brief Returns a sub-layout which is the portion of the object
162 * that starts at dimension \p pos and spans \p len dimensions
163 * (or until the end of the layout, whichever comes first).
164 * \param pos The start position.
165 * \param len The length of the sub-layout. if 0, return layout of scalar
166 * \return A newly constructed Layout object.
167 */
168 Layout SubLayout(size_t pos, size_t len) const;
169
170 /*!
171 * \brief Split \p axis by \p size and put the sub-axis to position \p target_pos.
172 * \param axis The source axis to be split. It must be a primal-axis;
173 * \param target_pos The target position of the newly split subordinate-axis.
174 * \param factor size of the sub-dimension.
175 * \return A newly constructed Layout object.
176 */
177 Layout Split(const LayoutAxis& axis, size_t target_pos, int32_t factor) const;
178
179 /*! \return number of dimensions */
180 inline size_t ndim() const {
181 if (!defined()) return 0;
182 return operator->()->axes.size();
183 }
184
185 /*! \return number of super dimensions */
186 inline size_t ndim_primal() const {
187 if (!defined()) return 0;
188 size_t ct = 0;
189 for (auto x : operator->()->axes) {
190 if (LayoutAxis::Get(x).IsPrimal()) {
191 ct++;
192 }
193 }
194 return ct;
195 }
196
197 /*!
198 * \brief Returns a new layout where the dims have been expanded to match the primal dimensions.
199 * \param dst_layout The dst layout to which current layout has to be expanded.
200 * \return The expanded Layout.
201 */
202 inline Layout ExpandPrimal(const Layout& dst_layout) {
203 Layout new_src_layout;
204 // 1) Find the axis which are missing in the current layout. Make them the prefix.
205 std::string new_src_layout_str = "";
206 for (auto dst_axis : dst_layout->axes) {
207 if (LayoutAxis::Get(dst_axis).IsPrimal()) {
208 if (!this->Contains(LayoutAxis::Get(dst_axis))) {
209 new_src_layout_str += dst_axis->var->name_hint;
210 }
211 }
212 }
213 // 2) Now, add the primal axis of the current layout.
214 new_src_layout_str += this->name();
215 new_src_layout = Layout(new_src_layout_str);
216 return new_src_layout;
217 }
218
219 /*!
220 * \brief return the index of the input axis.
221 * If it is not found in the layout or the layout is undefined,
222 * return -1.
223 * \param axis the input axis.
224 * \return the index or -1 if not found.
225 */
226 inline int32_t IndexOf(const LayoutAxis& axis) const {
227 if (!this->defined()) return -1;
228 const auto axes = operator->()->axes;
229 for (size_t i = 0; i < axes.size(); ++i) {
230 if (axes[i]->var->name_hint == axis.name()) return static_cast<int32_t>(i);
231 }
232 return -1;
233 }
234
235 /*!
236 * \brief Get the factor size of the subordinate axis.
237 * \param axis the input primal-axis or subordinate-axis.
238 * \return the size of the subordinate-axis of \p axis (if \p axis is a primal-axis),
239 * or the size of \p axis itself (if \p axis is a subordinate-axis).
240 * Return -1 if \p axis is not in the layout the layout is undefined.
241 */
242 int32_t FactorOf(const LayoutAxis& axis) const;
243
244 /*!
245 * \brief Whether the layout contains an axis.
246 * \param axis axis to be checked.
247 * \return Whether the layout contains the axis.
248 */
249 bool Contains(const LayoutAxis& axis) const {
250 if (!defined()) return false;
251 for (const tir::IterVar var : operator->()->axes) {
252 if (var->var->name_hint == axis.name()) {
253 return true;
254 }
255 }
256 return false;
257 }
258
259 const LayoutAxis& operator[](int32_t i) const {
260 ICHECK(defined()) << "Try to access axis from an undefined layout.";
261 int32_t index = i < 0 ? static_cast<int32_t>(ndim() + i) : i;
262 ICHECK(index >= 0 && static_cast<size_t>(index) < ndim()) << "Invalid index " << i;
263 const tir::IterVar axis = operator->()->axes[index];
264 return LayoutAxis::Get(axis);
265 }
266
267 /*! \return the string description of the layout */
268 inline std::string name() const {
269 if (!defined()) return "__undef__";
270 return operator->()->name;
271 }
272
273 /*!
274 * \brief Whether the two layouts are equal.
275 * \param rhs Another layout.
276 * \return whether the two layouts are equal.
277 */
278 inline bool Equals(const Layout& rhs) const { return name() == rhs.name(); }
279
280 /*!
281 * \brief allow output string of layout to ostream
282 * \param os the output stream
283 * \param l the layout
284 * \return the ostream
285 */
286 friend std::ostream& operator<<(std::ostream& os, const Layout& l) {
287 os << l.name();
288 return os;
289 }
290
291 TVM_DEFINE_OBJECT_REF_METHODS(Layout, ObjectRef, LayoutNode);
292};
293
294// Internal node container BijectiveLayout
295class BijectiveLayoutNode : public Object {
296 public:
297 /*! \brief Describes how source axes can be mapped to the destination axes,
298 * e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n
299 */
300 Array<PrimExpr> index_forward_rule;
301 /*! \brief Describes how destination axes can be mapped to the source axes */
302 Array<PrimExpr> index_backward_rule;
303 /*! \brief Describes how source shapes can be mapped to the destination shapes */
304 Array<PrimExpr> shape_forward_rule;
305 /*! \brief Describes how destination shapes can be mapped to the source shapes */
306 Array<PrimExpr> shape_backward_rule;
307
308 /*! \brief The source layout */
309 Layout src_layout;
310 /*! \brief The destination layout */
311 Layout dst_layout;
312
313 void VisitAttrs(AttrVisitor* v) {
314 v->Visit("src_layout", &src_layout);
315 v->Visit("dst_layout", &dst_layout);
316 v->Visit("index_forward_rule", &index_forward_rule);
317 v->Visit("index_backward_rule", &index_backward_rule);
318 v->Visit("shape_forward_rule", &shape_forward_rule);
319 v->Visit("shape_backward_rule", &shape_backward_rule);
320 }
321
322 static constexpr const char* _type_key = "tir.BijectiveLayout";
323 TVM_DECLARE_FINAL_OBJECT_INFO(BijectiveLayoutNode, Object);
324};
325
326/*!
327 * \brief Bijective function mapping for data layout transformation.
328 * Given two Layout, BijectiveLayout build and store the mapping rules,
329 * provides API to transform N-dimention tensor from the source indices (i0, i1, .., im)
330 * to the destination indices (j0, j1, .., jm).
331 */
332class BijectiveLayout : public ObjectRef {
333 public:
334 /*!
335 * \brief The constructor
336 * \param src_layout The source layout
337 * \param dst_layout The destination layout
338 */
339 TVM_DLL BijectiveLayout(Layout src_layout, Layout dst_layout);
340
341 // Given the source shape, infer the destination shape.
342 TVM_DLL Array<PrimExpr> ForwardShape(const Array<PrimExpr>& shape) const;
343 // Given the destination shape, recover the source shape.
344 TVM_DLL Array<PrimExpr> BackwardShape(const Array<PrimExpr>& dst_shape) const;
345 // Given the destination indices, infer the destination indices.
346 TVM_DLL Array<PrimExpr> ForwardIndex(const Array<PrimExpr>& index) const;
347 // Given the destination indices, recover the source indices.
348 TVM_DLL Array<PrimExpr> BackwardIndex(const Array<PrimExpr>& dst_index) const;
349
350 TVM_DEFINE_OBJECT_REF_METHODS(BijectiveLayout, ObjectRef, BijectiveLayoutNode);
351};
352
353} // namespace tir
354} // namespace tvm
355
356#endif // TVM_TIR_DATA_LAYOUT_H_
357