1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/tir/buffer.h |
22 | * \brief Symbolic n-dimensional array, to represent a memory buffer. |
23 | */ |
24 | #ifndef TVM_TIR_BUFFER_H_ |
25 | #define TVM_TIR_BUFFER_H_ |
26 | |
27 | #include <tvm/ir/expr.h> |
28 | #include <tvm/runtime/container/array.h> |
29 | #include <tvm/runtime/container/string.h> |
30 | #include <tvm/tir/var.h> |
31 | |
32 | #include <string> |
33 | |
34 | namespace tvm { |
35 | namespace tir { |
36 | |
37 | // forward declare Stmt |
38 | class Stmt; |
39 | |
40 | /*! \brief buffer type */ |
41 | enum BufferType : int { |
42 | kDefault = 1, |
43 | // Maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1. |
44 | kAutoBroadcast = 2, |
45 | }; |
46 | |
47 | /*! \brief Node to represent a buffer */ |
48 | class BufferNode : public Object { |
49 | public: |
50 | // Data fields. |
51 | /*! |
52 | * \brief The pointer to the head of the data |
53 | * \sa data_alignment The alignment of data in bytes. |
54 | */ |
55 | Var data; |
56 | /*! \brief data type in the content of the tensor */ |
57 | DataType dtype; |
58 | /*! \brief The type of the buffer prior to flattening |
59 | * |
60 | * This contains the shape as it is accessed by |
61 | * BufferLoad/BufferStore nodes, and used by the low-level code |
62 | * generators. |
63 | */ |
64 | Array<PrimExpr> shape; |
65 | /*! |
66 | * \brief Separators between input axes when generating flattened output axes |
67 | * |
68 | * For buffers representing flat 1-d memory (e.g. any buffer in |
69 | * RAM), this should be an empty array. For buffers representing |
70 | * non-flat memory, each entry in axis_separators should be the |
71 | * first input axis that is part of a new flattened axis. |
72 | */ |
73 | Array<IntImm> axis_separators; |
74 | /*! |
75 | * \brief The strides of each dimension |
76 | * This can be an empty array, indicating array is contiguous |
77 | */ |
78 | Array<PrimExpr> strides; |
79 | /*! \brief The offset in terms of number of dtype elements (including lanes) */ |
80 | PrimExpr elem_offset; |
81 | // Meta data |
82 | /*! \brief optional name of the buffer */ |
83 | String name; |
84 | /*! \brief Alignment requirement of data pointer in bytes. */ |
85 | int data_alignment; |
86 | /*! |
87 | * \brief Factor of elem_offset field, |
88 | * elem_offset is guaranteed to be multiple of offset_factor. |
89 | */ |
90 | int offset_factor; |
91 | /*! \brief buffer type */ |
92 | BufferType buffer_type; |
93 | /*! |
94 | * \brief Span that points to the original source code. |
95 | * Reserved debug information. |
96 | */ |
97 | mutable Span span; |
98 | /*! \brief constructor */ |
99 | BufferNode() {} |
100 | |
101 | void VisitAttrs(AttrVisitor* v) { |
102 | v->Visit("data" , &data); |
103 | v->Visit("dtype" , &dtype); |
104 | v->Visit("shape" , &shape); |
105 | v->Visit("strides" , &strides); |
106 | v->Visit("axis_separators" , &axis_separators); |
107 | v->Visit("elem_offset" , &elem_offset); |
108 | v->Visit("name" , &name); |
109 | v->Visit("data_alignment" , &data_alignment); |
110 | v->Visit("offset_factor" , &offset_factor); |
111 | v->Visit("buffer_type" , &buffer_type); |
112 | v->Visit("span" , &span); |
113 | } |
114 | |
115 | bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const { |
116 | // Use DefEqual as buffer can define variables in its semantics, |
117 | // skip name as name is not important. |
118 | return equal.DefEqual(data, other->data) && equal(dtype, other->dtype) && |
119 | equal.DefEqual(shape, other->shape) && equal.DefEqual(strides, other->strides) && |
120 | equal.DefEqual(axis_separators, other->axis_separators) && |
121 | equal.DefEqual(elem_offset, other->elem_offset) && |
122 | equal(data_alignment, other->data_alignment) && equal(buffer_type, other->buffer_type); |
123 | } |
124 | |
125 | void SHashReduce(SHashReducer hash_reduce) const { |
126 | hash_reduce.DefHash(data); |
127 | hash_reduce(dtype); |
128 | hash_reduce.DefHash(shape); |
129 | hash_reduce.DefHash(strides); |
130 | hash_reduce.DefHash(elem_offset); |
131 | hash_reduce.DefHash(axis_separators); |
132 | hash_reduce(data_alignment); |
133 | hash_reduce(buffer_type); |
134 | } |
135 | |
136 | /*! \return preferred index type for this buffer node */ |
137 | DataType DefaultIndexType() const { |
138 | return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32); |
139 | } |
140 | |
141 | /*! \brief Determine the offset in the buffer of the given index. |
142 | * |
143 | * Returns the buffer offset, in number of elements of type dtype, |
144 | * without adjusting for number of lanes. (e.g. The number of |
145 | * float16x4 elements in a buffer of type float16x4.) |
146 | */ |
147 | Array<PrimExpr> ElemOffset(Array<PrimExpr> index) const; |
148 | |
149 | static constexpr const char* _type_key = "tir.Buffer" ; |
150 | static constexpr const bool _type_has_method_sequal_reduce = true; |
151 | static constexpr const bool _type_has_method_shash_reduce = true; |
152 | TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object); |
153 | }; |
154 | |
155 | /*! |
156 | * \brief Buffer is a symbolic n-darray structure. |
157 | * It is a composition of primitive symbolic types, |
158 | * used to specify the memory layout of the Tensor used in program input. |
159 | */ |
160 | class Buffer : public ObjectRef { |
161 | public: |
162 | // User can specify data_alignment and offset_factor to be 0 |
163 | // A default value will be picked. |
164 | TVM_DLL Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides, |
165 | PrimExpr elem_offset, String name, int data_alignment, int offset_factor, |
166 | BufferType buffer_type, Array<IntImm> axis_separators = {}, Span span = Span()); |
167 | |
168 | /*! |
169 | * \brief Return a new buffer that is equivalent with current one |
170 | * but always add stride field. |
171 | * \return The strided version of the buffer. |
172 | */ |
173 | TVM_DLL Buffer MakeStrideView() const; |
174 | /*! |
175 | * \brief Make a new symbolic buffer representing a slice of the buffer. |
176 | * \param begins The beginning position of each dimension. |
177 | * \param extents The extent of each dimension. |
178 | * \note This function will make target buffer as compact as possible. |
179 | * If stride is not needed in the slice, it won't be presented |
180 | * \return the result buffer. |
181 | */ |
182 | TVM_DLL Buffer MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const; |
183 | /*! |
184 | * \brief Get access ptr to the entire buffer. |
185 | * \param access_mask The access mask |
186 | * \param ptr_type The type of the pointer. |
187 | * \param content_lanes The number of lanes for the (data) type. |
188 | * \param offset The offset of ptr. |
189 | * \param input_extent The extent of ptr. |
190 | */ |
191 | TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(), |
192 | int content_lanes = 1, PrimExpr offset = IntImm(DataType::Int(32), 0), |
193 | Optional<PrimExpr> input_extent = NullOpt) const; |
194 | /*! |
195 | * \brief Create an Expr that does a vector load at begin index. |
196 | * \param begin The beginning index |
197 | * \param dtype The data type to be loaded. |
198 | */ |
199 | TVM_DLL PrimExpr vload(Array<PrimExpr> begin, DataType dtype) const; |
200 | /*! |
201 | * \brief Create a Stmt that does a vector store at begin index. |
202 | * \param begin The beginning index |
203 | * \param value The value to be stored. |
204 | */ |
205 | TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value) const; |
206 | |
207 | /*! |
208 | * \brief Get a flattened version of the buffer |
209 | */ |
210 | Buffer GetFlattenedBuffer() const; |
211 | |
212 | /*! \brief Determine the offset in the buffer of the given index. |
213 | * |
214 | * Returns the buffer offset, in number of elements of type dtype, |
215 | * without adjusting for number of lanes. (e.g. The number of |
216 | * float16x4 elements in a buffer of type float16x4.) |
217 | */ |
218 | Array<PrimExpr> OffsetOf(Array<PrimExpr> index) const; |
219 | |
220 | /*! |
221 | * \brief Return the storage scope associated with this buffer. |
222 | */ |
223 | TVM_DLL String scope() const; |
224 | |
225 | TVM_DEFINE_OBJECT_REF_METHODS(Buffer, ObjectRef, BufferNode); |
226 | TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferNode); |
227 | }; |
228 | |
229 | /*! |
230 | * \brief Construct a new buffer given shape, and dtype. |
231 | * \param shape The shape of the buffer, |
232 | * \param dtype The content data type. |
233 | * \param name The name of the buffer |
234 | * \param storage_scope The storage scope associated with this buffer |
235 | * \param axis_separators Divisions defining the groups of axes that will be flattened together. |
236 | * \param span The location of this object in the source code. |
237 | * \return The created buffer. |
238 | * \sa Buffer for complete constructor. |
239 | */ |
240 | TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32), |
241 | String name = "buffer" , String storage_scope = "" , |
242 | Array<IntImm> axis_separators = {}, Span span = Span()); |
243 | |
244 | /*! |
245 | * \brief Base node for data producers. |
246 | * |
247 | * A DataProducer stores necessary information(e.g. a tensor expression) to produce |
248 | * a multi-dimensional array. The stored information is opaque to the TIR. |
249 | * DataProducer can appear in high-level DSLs that are built on top of the TIR. |
250 | * |
251 | * A valid TIR PrimFunc should not contain any DataProducer, high level DSLs should lower |
252 | * all DataProducers to Buffers before TIR transformations. |
253 | * |
254 | * \sa tvm::te::Tensor |
255 | */ |
256 | class DataProducerNode : public Object { |
257 | public: |
258 | /*! \brief destructor. */ |
259 | virtual ~DataProducerNode() {} |
260 | /*! |
261 | * \brief Get the shape of the result. |
262 | * \return The shape. |
263 | */ |
264 | virtual Array<PrimExpr> GetShape() const = 0; |
265 | /*! |
266 | * \brief Get the data type of the result. |
267 | * \return The data type. |
268 | */ |
269 | virtual DataType GetDataType() const = 0; |
270 | /*! |
271 | * \brief Get the name hint of the data producer. |
272 | * \return The data type. |
273 | */ |
274 | virtual String GetNameHint() const = 0; |
275 | |
276 | bool SEqualReduce(const DataProducerNode* other, SEqualReducer equal) const { |
277 | // because buffer producer is opaque, we just do pointer equality. |
278 | return this == other; |
279 | } |
280 | |
281 | void SHashReduce(SHashReducer hash_reduce) const {} |
282 | |
283 | static constexpr const char* _type_key = "tir.DataProducer" ; |
284 | static constexpr const bool _type_has_method_sequal_reduce = true; |
285 | static constexpr const bool _type_has_method_shash_reduce = true; |
286 | TVM_DECLARE_BASE_OBJECT_INFO(DataProducerNode, Object); |
287 | }; |
288 | |
289 | /*! |
290 | * \brief Managed reference to DataProducerNode. |
291 | * \sa DataProducerNode |
292 | */ |
293 | class DataProducer : public ObjectRef { |
294 | public: |
295 | TVM_DEFINE_OBJECT_REF_METHODS(DataProducer, ObjectRef, DataProducerNode); |
296 | }; |
297 | |
298 | /*! |
299 | * \brief Creates TIR Buffer for provided parameters |
300 | * \param shape shape of the buffer |
301 | * \param dtype data type |
302 | * \param name buffer name |
303 | * \param data_alignment alignment requirement of data pointer in bytes |
304 | * \param offset_factor Factor of elem_offset field, elem_offset is guaranteed to be |
305 | * multiple of offset_factor |
306 | User can specify data_alignment and offset_factor to be 0 |
307 | * A default value will be picked. |
308 | * \param compact If the statement has already bound to a compact buffer. |
309 | * \param memory_scope memory scope of the buffer |
310 | */ |
311 | TVM_DLL tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype, |
312 | std::string name, int data_alignment, |
313 | int offset_factor, bool compact, |
314 | std::string memory_scope = "" ); |
315 | } // namespace tir |
316 | } // namespace tvm |
317 | #endif // TVM_TIR_BUFFER_H_ |
318 | |