1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/tir/buffer.h
22 * \brief Symbolic n-dimensional array, to represent a memory buffer.
23 */
24#ifndef TVM_TIR_BUFFER_H_
25#define TVM_TIR_BUFFER_H_
26
27#include <tvm/ir/expr.h>
28#include <tvm/runtime/container/array.h>
29#include <tvm/runtime/container/string.h>
30#include <tvm/tir/var.h>
31
32#include <string>
33
34namespace tvm {
35namespace tir {
36
37// forward declare Stmt
38class Stmt;
39
40/*! \brief buffer type */
41enum BufferType : int {
42 kDefault = 1,
43 // Maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1.
44 kAutoBroadcast = 2,
45};
46
47/*! \brief Node to represent a buffer */
48class BufferNode : public Object {
49 public:
50 // Data fields.
51 /*!
52 * \brief The pointer to the head of the data
53 * \sa data_alignment The alignment of data in bytes.
54 */
55 Var data;
56 /*! \brief data type in the content of the tensor */
57 DataType dtype;
58 /*! \brief The type of the buffer prior to flattening
59 *
60 * This contains the shape as it is accessed by
61 * BufferLoad/BufferStore nodes, and used by the low-level code
62 * generators.
63 */
64 Array<PrimExpr> shape;
65 /*!
66 * \brief Separators between input axes when generating flattened output axes
67 *
68 * For buffers representing flat 1-d memory (e.g. any buffer in
69 * RAM), this should be an empty array. For buffers representing
70 * non-flat memory, each entry in axis_separators should be the
71 * first input axis that is part of a new flattened axis.
72 */
73 Array<IntImm> axis_separators;
74 /*!
75 * \brief The strides of each dimension
76 * This can be an empty array, indicating array is contiguous
77 */
78 Array<PrimExpr> strides;
79 /*! \brief The offset in terms of number of dtype elements (including lanes) */
80 PrimExpr elem_offset;
81 // Meta data
82 /*! \brief optional name of the buffer */
83 String name;
84 /*! \brief Alignment requirement of data pointer in bytes. */
85 int data_alignment;
86 /*!
87 * \brief Factor of elem_offset field,
88 * elem_offset is guaranteed to be multiple of offset_factor.
89 */
90 int offset_factor;
91 /*! \brief buffer type */
92 BufferType buffer_type;
93 /*!
94 * \brief Span that points to the original source code.
95 * Reserved debug information.
96 */
97 mutable Span span;
98 /*! \brief constructor */
99 BufferNode() {}
100
101 void VisitAttrs(AttrVisitor* v) {
102 v->Visit("data", &data);
103 v->Visit("dtype", &dtype);
104 v->Visit("shape", &shape);
105 v->Visit("strides", &strides);
106 v->Visit("axis_separators", &axis_separators);
107 v->Visit("elem_offset", &elem_offset);
108 v->Visit("name", &name);
109 v->Visit("data_alignment", &data_alignment);
110 v->Visit("offset_factor", &offset_factor);
111 v->Visit("buffer_type", &buffer_type);
112 v->Visit("span", &span);
113 }
114
115 bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const {
116 // Use DefEqual as buffer can define variables in its semantics,
117 // skip name as name is not important.
118 return equal.DefEqual(data, other->data) && equal(dtype, other->dtype) &&
119 equal.DefEqual(shape, other->shape) && equal.DefEqual(strides, other->strides) &&
120 equal.DefEqual(axis_separators, other->axis_separators) &&
121 equal.DefEqual(elem_offset, other->elem_offset) &&
122 equal(data_alignment, other->data_alignment) && equal(buffer_type, other->buffer_type);
123 }
124
125 void SHashReduce(SHashReducer hash_reduce) const {
126 hash_reduce.DefHash(data);
127 hash_reduce(dtype);
128 hash_reduce.DefHash(shape);
129 hash_reduce.DefHash(strides);
130 hash_reduce.DefHash(elem_offset);
131 hash_reduce.DefHash(axis_separators);
132 hash_reduce(data_alignment);
133 hash_reduce(buffer_type);
134 }
135
136 /*! \return preferred index type for this buffer node */
137 DataType DefaultIndexType() const {
138 return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
139 }
140
141 /*! \brief Determine the offset in the buffer of the given index.
142 *
143 * Returns the buffer offset, in number of elements of type dtype,
144 * without adjusting for number of lanes. (e.g. The number of
145 * float16x4 elements in a buffer of type float16x4.)
146 */
147 Array<PrimExpr> ElemOffset(Array<PrimExpr> index) const;
148
149 static constexpr const char* _type_key = "tir.Buffer";
150 static constexpr const bool _type_has_method_sequal_reduce = true;
151 static constexpr const bool _type_has_method_shash_reduce = true;
152 TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object);
153};
154
155/*!
156 * \brief Buffer is a symbolic n-darray structure.
157 * It is a composition of primitive symbolic types,
158 * used to specify the memory layout of the Tensor used in program input.
159 */
160class Buffer : public ObjectRef {
161 public:
162 // User can specify data_alignment and offset_factor to be 0
163 // A default value will be picked.
164 TVM_DLL Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
165 PrimExpr elem_offset, String name, int data_alignment, int offset_factor,
166 BufferType buffer_type, Array<IntImm> axis_separators = {}, Span span = Span());
167
168 /*!
169 * \brief Return a new buffer that is equivalent with current one
170 * but always add stride field.
171 * \return The strided version of the buffer.
172 */
173 TVM_DLL Buffer MakeStrideView() const;
174 /*!
175 * \brief Make a new symbolic buffer representing a slice of the buffer.
176 * \param begins The beginning position of each dimension.
177 * \param extents The extent of each dimension.
178 * \note This function will make target buffer as compact as possible.
179 * If stride is not needed in the slice, it won't be presented
180 * \return the result buffer.
181 */
182 TVM_DLL Buffer MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const;
183 /*!
184 * \brief Get access ptr to the entire buffer.
185 * \param access_mask The access mask
186 * \param ptr_type The type of the pointer.
187 * \param content_lanes The number of lanes for the (data) type.
188 * \param offset The offset of ptr.
189 * \param input_extent The extent of ptr.
190 */
191 TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(),
192 int content_lanes = 1, PrimExpr offset = IntImm(DataType::Int(32), 0),
193 Optional<PrimExpr> input_extent = NullOpt) const;
194 /*!
195 * \brief Create an Expr that does a vector load at begin index.
196 * \param begin The beginning index
197 * \param dtype The data type to be loaded.
198 */
199 TVM_DLL PrimExpr vload(Array<PrimExpr> begin, DataType dtype) const;
200 /*!
201 * \brief Create a Stmt that does a vector store at begin index.
202 * \param begin The beginning index
203 * \param value The value to be stored.
204 */
205 TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value) const;
206
207 /*!
208 * \brief Get a flattened version of the buffer
209 */
210 Buffer GetFlattenedBuffer() const;
211
212 /*! \brief Determine the offset in the buffer of the given index.
213 *
214 * Returns the buffer offset, in number of elements of type dtype,
215 * without adjusting for number of lanes. (e.g. The number of
216 * float16x4 elements in a buffer of type float16x4.)
217 */
218 Array<PrimExpr> OffsetOf(Array<PrimExpr> index) const;
219
220 /*!
221 * \brief Return the storage scope associated with this buffer.
222 */
223 TVM_DLL String scope() const;
224
225 TVM_DEFINE_OBJECT_REF_METHODS(Buffer, ObjectRef, BufferNode);
226 TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferNode);
227};
228
229/*!
230 * \brief Construct a new buffer given shape, and dtype.
231 * \param shape The shape of the buffer,
232 * \param dtype The content data type.
233 * \param name The name of the buffer
234 * \param storage_scope The storage scope associated with this buffer
235 * \param axis_separators Divisions defining the groups of axes that will be flattened together.
236 * \param span The location of this object in the source code.
237 * \return The created buffer.
238 * \sa Buffer for complete constructor.
239 */
240TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
241 String name = "buffer", String storage_scope = "",
242 Array<IntImm> axis_separators = {}, Span span = Span());
243
244/*!
245 * \brief Base node for data producers.
246 *
247 * A DataProducer stores necessary information(e.g. a tensor expression) to produce
248 * a multi-dimensional array. The stored information is opaque to the TIR.
249 * DataProducer can appear in high-level DSLs that are built on top of the TIR.
250 *
251 * A valid TIR PrimFunc should not contain any DataProducer, high level DSLs should lower
252 * all DataProducers to Buffers before TIR transformations.
253 *
254 * \sa tvm::te::Tensor
255 */
256class DataProducerNode : public Object {
257 public:
258 /*! \brief destructor. */
259 virtual ~DataProducerNode() {}
260 /*!
261 * \brief Get the shape of the result.
262 * \return The shape.
263 */
264 virtual Array<PrimExpr> GetShape() const = 0;
265 /*!
266 * \brief Get the data type of the result.
267 * \return The data type.
268 */
269 virtual DataType GetDataType() const = 0;
270 /*!
271 * \brief Get the name hint of the data producer.
272 * \return The data type.
273 */
274 virtual String GetNameHint() const = 0;
275
276 bool SEqualReduce(const DataProducerNode* other, SEqualReducer equal) const {
277 // because buffer producer is opaque, we just do pointer equality.
278 return this == other;
279 }
280
281 void SHashReduce(SHashReducer hash_reduce) const {}
282
283 static constexpr const char* _type_key = "tir.DataProducer";
284 static constexpr const bool _type_has_method_sequal_reduce = true;
285 static constexpr const bool _type_has_method_shash_reduce = true;
286 TVM_DECLARE_BASE_OBJECT_INFO(DataProducerNode, Object);
287};
288
289/*!
290 * \brief Managed reference to DataProducerNode.
291 * \sa DataProducerNode
292 */
293class DataProducer : public ObjectRef {
294 public:
295 TVM_DEFINE_OBJECT_REF_METHODS(DataProducer, ObjectRef, DataProducerNode);
296};
297
298/*!
299 * \brief Creates TIR Buffer for provided parameters
300 * \param shape shape of the buffer
301 * \param dtype data type
302 * \param name buffer name
303 * \param data_alignment alignment requirement of data pointer in bytes
304 * \param offset_factor Factor of elem_offset field, elem_offset is guaranteed to be
305 * multiple of offset_factor
306 User can specify data_alignment and offset_factor to be 0
307 * A default value will be picked.
308 * \param compact If the statement has already bound to a compact buffer.
309 * \param memory_scope memory scope of the buffer
310 */
311TVM_DLL tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype,
312 std::string name, int data_alignment,
313 int offset_factor, bool compact,
314 std::string memory_scope = "");
315} // namespace tir
316} // namespace tvm
317#endif // TVM_TIR_BUFFER_H_
318