1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/te/operation.h |
22 | * \brief Operation node can generate one or multiple Tensors |
23 | */ |
24 | #ifndef TVM_TE_OPERATION_H_ |
25 | #define TVM_TE_OPERATION_H_ |
26 | |
27 | #include <tvm/arith/analyzer.h> |
28 | #include <tvm/te/schedule.h> |
29 | #include <tvm/te/tensor.h> |
30 | #include <tvm/tir/buffer.h> |
31 | #include <tvm/tir/expr.h> |
32 | #include <tvm/tir/op.h> |
33 | |
34 | #include <string> |
35 | #include <unordered_map> |
36 | #include <vector> |
37 | |
38 | namespace tvm { |
39 | /*! \brief Tensor expression language DSL. */ |
40 | namespace te { |
41 | |
42 | /*! |
43 | * \brief Temporary data structure to store union |
44 | * of bounds of each axis of Tensor. |
45 | */ |
46 | struct TensorDom { |
47 | // constructor |
48 | explicit TensorDom(int ndim) : data(ndim) {} |
49 | /*! \brief The domain data */ |
50 | std::vector<std::vector<IntSet>> data; |
51 | }; |
52 | |
53 | /*! |
54 | * \brief Base class of all operation nodes |
55 | */ |
56 | class TVM_DLL OperationNode : public Object { |
57 | public: |
58 | /*! \brief optional name of the operation */ |
59 | std::string name; |
60 | /*! \brief optional tag of the operation */ |
61 | std::string tag; |
62 | /*! \brief additional attributes of the operation*/ |
63 | Map<String, ObjectRef> attrs; |
64 | // virtual destructor. |
65 | virtual ~OperationNode() {} |
66 | /*! \return number of outputs */ |
67 | virtual int num_outputs() const = 0; |
68 | /*! |
69 | * \return The list of iteration variable at root |
70 | * \note root_iter_vars decides the shape of the outputs. |
71 | */ |
72 | virtual Array<IterVar> root_iter_vars() const = 0; |
73 | /*! |
74 | * \brief Get data type. i-th output tensor. |
75 | * \param i The output index. |
76 | * \return type of i-th output. |
77 | */ |
78 | virtual DataType output_dtype(size_t i) const = 0; |
79 | /*! |
80 | * \brief Get shape of i-th output tensor. |
81 | * \param i The output index. |
82 | * \return shape of i-th output. |
83 | */ |
84 | virtual Array<PrimExpr> output_shape(size_t i) const = 0; |
85 | /*! |
86 | * \brief List all the input Tensors. |
87 | * \return List of input tensors. |
88 | */ |
89 | virtual Array<Tensor> InputTensors() const = 0; |
90 | /*! |
91 | * \brief Replace the input of the operation by pattern specified by rmap. |
92 | * |
93 | * \param self The reference to self. |
94 | * \param rmap The replacement map. |
95 | * \return self if nothing is replaced, otherwise return replaced op. |
96 | */ |
97 | virtual Operation ReplaceInputs(const Operation& self, |
98 | const std::unordered_map<Tensor, Tensor>& rmap) const = 0; |
99 | /*! |
100 | * \brief Propagate the bounds to inputs |
101 | * \param self The reference to self. |
102 | * \param analyzer The analyzer to be used in the function. |
103 | * \param dom_map the domain map of Variables(corresponds to root_iter_vars) |
104 | * \param out_dom_map The output domain. |
105 | * The function is only asked to fill the bounds for Tensors that |
106 | * is already in the out_dom_map |
107 | */ |
108 | virtual void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, |
109 | const std::unordered_map<const VarNode*, IntSet>& dom_map, |
110 | std::unordered_map<Tensor, TensorDom>* out_dom_map) const = 0; |
111 | /*! |
112 | * \brief Gather the bound from output tensor. |
113 | * Set the range of each root_iter_vars in the op to out_dom_map |
114 | * |
115 | * \param self The reference to self. |
116 | * \param tensor_dom Domain map of Tensor->access set of each dimension. |
117 | * \param out_dom_map The output domain map of each IterVar to be setted. |
118 | */ |
119 | virtual void GatherBound(const Operation& self, |
120 | const std::unordered_map<Tensor, TensorDom>& tensor_dom, |
121 | std::unordered_map<IterVar, Range>* out_dom_map) const = 0; |
122 | /*! |
123 | * \brief Build the Realize statement that realizes |
124 | * the op's output tensors. |
125 | * \param stage the op's stage. |
126 | * \param realize_map The realization domain map of the operators. |
127 | * \param body The body that is going to get |
128 | * \param storage_scope The storage scope associated with this realization |
129 | * \return A realization statement that wraps body. |
130 | */ |
131 | virtual Stmt BuildRealize(const Stage& stage, |
132 | const std::unordered_map<IterVar, Range>& realize_map, const Stmt& body, |
133 | String storage_scope = "" ) const = 0; |
134 | /*! |
135 | * \brief Build the statement that provide the output tensors. |
136 | * \param stage The schedule stage of the op. |
137 | * \param dom_map The domain map of all iteration domains. |
138 | * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 |
139 | * \return A statement that add production and wraps consumer. |
140 | */ |
141 | virtual Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map, |
142 | bool debug_keep_trivial_loop) const = 0; |
143 | |
144 | static constexpr const char* _type_key = "Operation" ; |
145 | |
146 | TVM_DECLARE_BASE_OBJECT_INFO(OperationNode, Object); |
147 | }; |
148 | |
149 | /*! |
150 | * \brief A placeholder op represents an input placeholder. |
151 | */ |
152 | class PlaceholderOpNode : public OperationNode { |
153 | public: |
154 | /*! \brief The shape of the input */ |
155 | Array<PrimExpr> shape; |
156 | /*! \brief The data type of the input. */ |
157 | DataType dtype; |
158 | // override behavior. |
159 | int num_outputs() const final; |
160 | Array<IterVar> root_iter_vars() const final; |
161 | DataType output_dtype(size_t i) const final; |
162 | Array<PrimExpr> output_shape(size_t i) const final; |
163 | Array<Tensor> InputTensors() const final; |
164 | Operation ReplaceInputs(const Operation& self, |
165 | const std::unordered_map<Tensor, Tensor>& rmap) const final; |
166 | void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, |
167 | const std::unordered_map<const VarNode*, IntSet>& dom_map, |
168 | std::unordered_map<Tensor, TensorDom>* out_dom_map) const final; |
169 | void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom, |
170 | std::unordered_map<IterVar, Range>* out_dom_map) const final; |
171 | Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map, |
172 | const Stmt& body, String storage_scope = "" ) const final; |
173 | Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map, |
174 | bool debug_keep_trivial_loop) const final; |
175 | |
176 | void VisitAttrs(AttrVisitor* v) { |
177 | v->Visit("name" , &name); |
178 | v->Visit("tag" , &tag); |
179 | v->Visit("attrs" , &attrs); |
180 | v->Visit("shape" , &shape); |
181 | v->Visit("dtype" , &dtype); |
182 | } |
183 | |
184 | static constexpr const char* _type_key = "PlaceholderOp" ; |
185 | TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode); |
186 | }; |
187 | |
188 | /*! |
189 | * \brief Managed reference to PlaceholderOpNode |
190 | * \sa PlaceholderOpNode |
191 | */ |
192 | class PlaceholderOp : public Operation { |
193 | public: |
194 | TVM_DLL PlaceholderOp(std::string name, Array<PrimExpr> shape, DataType dtype); |
195 | |
196 | TVM_DEFINE_OBJECT_REF_METHODS(PlaceholderOp, Operation, PlaceholderOpNode); |
197 | }; |
198 | |
199 | /*! |
200 | * \brief A Compute op that compute a tensor on certain domain. |
201 | * This is the base class for ComputeOp (operating on a scalar at a time) and |
202 | * TensorComputeOp (operating on a TensorSlice at a time) |
203 | */ |
204 | class TVM_DLL BaseComputeOpNode : public OperationNode { |
205 | public: |
206 | /*! \brief IterVar on each axis */ |
207 | Array<IterVar> axis; |
208 | /*! \brief IterVar on each reduction axis, if the body is a Reduce */ |
209 | Array<IterVar> reduce_axis; |
210 | // override functions |
211 | Array<IterVar> root_iter_vars() const final; |
212 | Array<PrimExpr> output_shape(size_t idx) const final; |
213 | void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom, |
214 | std::unordered_map<IterVar, Range>* out_dom_map) const final; |
215 | Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map, |
216 | const Stmt& body, String storage_scope = "" ) const final; |
217 | virtual size_t num_schedulable_dims() const = 0; |
218 | |
219 | static constexpr const char* _type_key = "BaseComputeOp" ; |
220 | TVM_DECLARE_BASE_OBJECT_INFO(BaseComputeOpNode, OperationNode); |
221 | }; |
222 | |
223 | /*! |
224 | * \brief A Compute op that compute a tensor on certain domain. |
225 | */ |
226 | class TVM_DLL ComputeOpNode : public BaseComputeOpNode { |
227 | public: |
228 | /*! \brief the compute expression */ |
229 | Array<PrimExpr> body; |
230 | /*! \brief constructor */ |
231 | ComputeOpNode() {} |
232 | // override functions |
233 | int num_outputs() const final; |
234 | DataType output_dtype(size_t i) const final; |
235 | Array<Tensor> InputTensors() const final; |
236 | Operation ReplaceInputs(const Operation& self, |
237 | const std::unordered_map<Tensor, Tensor>& rmap) const final; |
238 | void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, |
239 | const std::unordered_map<const VarNode*, IntSet>& dom_map, |
240 | std::unordered_map<Tensor, TensorDom>* out_dom_map) const final; |
241 | Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map, |
242 | bool debug_keep_trivial_loop) const final; |
243 | size_t num_schedulable_dims() const final; |
244 | |
245 | void VisitAttrs(AttrVisitor* v) { |
246 | v->Visit("name" , &name); |
247 | v->Visit("tag" , &tag); |
248 | v->Visit("attrs" , &attrs); |
249 | v->Visit("axis" , &axis); |
250 | v->Visit("reduce_axis" , &reduce_axis); |
251 | v->Visit("body" , &body); |
252 | } |
253 | |
254 | static constexpr const char* _type_key = "ComputeOp" ; |
255 | TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode); |
256 | }; |
257 | |
258 | /*! |
259 | * \brief Managed reference to ComputeOpNode |
260 | * \sa ComputeOpNode |
261 | */ |
262 | class ComputeOp : public Operation { |
263 | public: |
264 | TVM_DLL ComputeOp(std::string name, std::string tag, Map<String, ObjectRef> attrs, |
265 | Array<IterVar> axis, Array<PrimExpr> body); |
266 | |
267 | TVM_DEFINE_OBJECT_REF_METHODS(ComputeOp, Operation, ComputeOpNode); |
268 | TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeOpNode); |
269 | }; |
270 | |
271 | /*! |
272 | * \brief A TenorCompute op that compute a tensor with an tensor intrinsic. |
273 | */ |
274 | class TensorComputeOpNode : public BaseComputeOpNode { |
275 | public: |
276 | /*! \brief number of axes that can be scheduled */ |
277 | int schedulable_ndim; |
278 | /*! \brief TensorIntrin used to compute */ |
279 | TensorIntrin intrin; |
280 | /*! \brief input tensors of intrin */ |
281 | Array<Tensor> inputs; |
282 | /*! \brief region of input tensors */ |
283 | Array<Region> input_regions; |
284 | /*! \brief scalar expression inputs */ |
285 | Array<PrimExpr> scalar_inputs; |
286 | /*! \brief constructor */ |
287 | TensorComputeOpNode() {} |
288 | // override functions |
289 | int num_outputs() const final; |
290 | DataType output_dtype(size_t i) const final; |
291 | Array<Tensor> InputTensors() const final; |
292 | Operation ReplaceInputs(const Operation& self, |
293 | const std::unordered_map<Tensor, Tensor>& rmap) const final; |
294 | void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, |
295 | const std::unordered_map<const VarNode*, IntSet>& dom_map, |
296 | std::unordered_map<Tensor, TensorDom>* out_dom_map) const final; |
297 | Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map, |
298 | bool debug_keep_trivial_loop) const final; |
299 | size_t num_schedulable_dims() const final; |
300 | |
301 | void VisitAttrs(AttrVisitor* v) { |
302 | v->Visit("name" , &name); |
303 | v->Visit("tag" , &tag); |
304 | v->Visit("axis" , &axis); |
305 | v->Visit("reduce_axis" , &reduce_axis); |
306 | v->Visit("schedulable_ndim" , &schedulable_ndim); |
307 | v->Visit("intrin" , &intrin); |
308 | v->Visit("inputs" , &inputs); |
309 | v->Visit("input_regions" , &input_regions); |
310 | v->Visit("scalar_inputs" , &scalar_inputs); |
311 | } |
312 | |
313 | static constexpr const char* _type_key = "TensorComputeOp" ; |
314 | TVM_DECLARE_FINAL_OBJECT_INFO(TensorComputeOpNode, BaseComputeOpNode); |
315 | }; |
316 | |
317 | /*! |
318 | * \brief Managed reference to TensorComputeOpNode |
319 | * \sa TensorComputeOpNode |
320 | */ |
321 | class TensorComputeOp : public Operation { |
322 | public: |
323 | TVM_DLL TensorComputeOp(std::string name, std::string tag, Array<IterVar> axis, |
324 | Array<IterVar> reduce_axis, int schedulable_ndim, TensorIntrin intrin, |
325 | Array<Tensor> tensors, Array<Region> regions, |
326 | Array<PrimExpr> scalar_inputs); |
327 | |
328 | TVM_DEFINE_OBJECT_REF_METHODS(TensorComputeOp, Operation, TensorComputeOpNode); |
329 | }; |
330 | |
331 | /*! |
332 | * \brief Symbolic scan. |
333 | */ |
334 | class ScanOpNode : public OperationNode { |
335 | public: |
336 | /*! \brief IterVar to scan over */ |
337 | IterVar scan_axis; |
338 | /*! \brief the initialization tensors */ |
339 | Array<Tensor> init; |
340 | /*! \brief the update function represented by tensor */ |
341 | Array<Tensor> update; |
342 | /*! \brief The placeholder to refer as states in update. */ |
343 | Array<Tensor> state_placeholder; |
344 | /*! |
345 | * \brief the inputs to the scan, these are optionally provided |
346 | * But they can be helpful to provide hints to speedup get of scan body. |
347 | */ |
348 | Array<Tensor> inputs; |
349 | /*! |
350 | * \brief Spatial axis to indicate spatial dimension of each output. |
351 | * They corresponds to flattened spatial axis of the outputs. |
352 | * |
353 | * [output[0].axis[1], output[0].axis[2]... output[k].axis[j]...] |
354 | * These are auxiliary data structure for storing result of bound inference. |
355 | * They do not corresponds to splittable iterations, thus the name comes |
356 | * with underscore. |
357 | */ |
358 | Array<IterVar> spatial_axis_; |
359 | /*! \brief constructor */ |
360 | ScanOpNode() {} |
361 | // override behavior. |
362 | int num_outputs() const final; |
363 | Array<IterVar> root_iter_vars() const final; |
364 | DataType output_dtype(size_t i) const final; |
365 | Array<PrimExpr> output_shape(size_t i) const final; |
366 | Array<Tensor> InputTensors() const final; |
367 | Operation ReplaceInputs(const Operation& self, |
368 | const std::unordered_map<Tensor, Tensor>& rmap) const final; |
369 | void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, |
370 | const std::unordered_map<const VarNode*, IntSet>& dom_map, |
371 | std::unordered_map<Tensor, TensorDom>* out_dom_map) const final; |
372 | void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom, |
373 | std::unordered_map<IterVar, Range>* out_dom_map) const final; |
374 | Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map, |
375 | const Stmt& body, String storage_scope = "" ) const final; |
376 | Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map, |
377 | bool debug_keep_trivial_loop) const final; |
378 | |
379 | void VisitAttrs(AttrVisitor* v) { |
380 | v->Visit("name" , &name); |
381 | v->Visit("tag" , &tag); |
382 | v->Visit("attrs" , &attrs); |
383 | v->Visit("scan_axis" , &scan_axis); |
384 | v->Visit("init" , &init); |
385 | v->Visit("update" , &update); |
386 | v->Visit("state_placeholder" , &state_placeholder); |
387 | v->Visit("inputs" , &inputs); |
388 | v->Visit("spatial_axis_" , &spatial_axis_); |
389 | } |
390 | |
391 | static constexpr const char* _type_key = "ScanOp" ; |
392 | TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode); |
393 | }; |
394 | |
395 | /*! |
396 | * \brief Managed reference to ScanOpNode |
397 | * \sa ScanOpNode |
398 | */ |
399 | class ScanOp : public Operation { |
400 | public: |
401 | TVM_DLL ScanOp(std::string name, std::string tag, Map<String, ObjectRef> attrs, IterVar axis, |
402 | Array<Tensor> init, Array<Tensor> update, Array<Tensor> state_placeholder, |
403 | Array<Tensor> input); |
404 | |
405 | TVM_DEFINE_OBJECT_REF_METHODS(ScanOp, Operation, ScanOpNode); |
406 | }; |
407 | |
408 | /*! |
409 | * \brief External computation that cannot be splitted. |
410 | */ |
411 | class ExternOpNode : public OperationNode { |
412 | public: |
413 | /*! \brief The input tensors */ |
414 | Array<Tensor> inputs; |
415 | /*! \brief Symbolic placeholder representation of inputs */ |
416 | Array<Buffer> input_placeholders; |
417 | /*! \brief Symbolic placeholder representation of outputs */ |
418 | Array<Buffer> output_placeholders; |
419 | /*! \brief the statement that generates the computation. */ |
420 | Stmt body; |
421 | |
422 | /*! \brief constructor */ |
423 | ExternOpNode() {} |
424 | // override functions |
425 | int num_outputs() const final; |
426 | Array<IterVar> root_iter_vars() const final; |
427 | DataType output_dtype(size_t i) const final; |
428 | Array<PrimExpr> output_shape(size_t i) const final; |
429 | Array<Tensor> InputTensors() const final; |
430 | Operation ReplaceInputs(const Operation& self, |
431 | const std::unordered_map<Tensor, Tensor>& rmap) const final; |
432 | void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, |
433 | const std::unordered_map<const VarNode*, IntSet>& dom_map, |
434 | std::unordered_map<Tensor, TensorDom>* out_dom_map) const final; |
435 | void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom, |
436 | std::unordered_map<IterVar, Range>* out_dom_map) const final; |
437 | Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map, |
438 | const Stmt& body, String storage_scope = "" ) const final; |
439 | Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map, |
440 | bool debug_keep_trivial_loop) const final; |
441 | |
442 | void VisitAttrs(AttrVisitor* v) { |
443 | v->Visit("name" , &name); |
444 | v->Visit("tag" , &tag); |
445 | v->Visit("attrs" , &attrs); |
446 | v->Visit("inputs" , &inputs); |
447 | v->Visit("input_placeholders" , &input_placeholders); |
448 | v->Visit("output_placeholders" , &output_placeholders); |
449 | v->Visit("body" , &body); |
450 | } |
451 | |
452 | static constexpr const char* _type_key = "ExternOp" ; |
453 | TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode); |
454 | }; |
455 | |
456 | /*! |
457 | * \brief Managed reference to ExternOpNode |
458 | * \sa ExternOpNode |
459 | */ |
460 | class ExternOp : public Operation { |
461 | public: |
462 | TVM_DLL ExternOp(std::string name, std::string tag, Map<String, ObjectRef> attrs, |
463 | Array<Tensor> inputs, Array<Buffer> input_placeholders, |
464 | Array<Buffer> output_placeholders, Stmt body); |
465 | |
466 | TVM_DEFINE_OBJECT_REF_METHODS(ExternOp, Operation, ExternOpNode); |
467 | }; |
468 | |
469 | /*! |
470 | * \brief A computation operator that generated by hybrid script. |
471 | */ |
472 | class HybridOpNode : public OperationNode { |
473 | public: |
474 | /*! \brief The input tensors */ |
475 | Array<Tensor> inputs; |
476 | /*! \brief Symbolic placeholder representation of outputs */ |
477 | Array<Tensor> outputs; |
478 | /*! \brief The axis of iterations */ |
479 | Array<IterVar> axis; |
480 | /*! \brief the statement that generates the computation. This is |
481 | * slightly different from the body in ExternOpNode. All the output |
482 | * tensors keep its own name specified by users in the script. |
483 | * However, when compilation, these tensors will be placed by those |
484 | * actual output tensors. */ |
485 | Stmt body; |
486 | |
487 | /*! \brief constructor */ |
488 | HybridOpNode() {} |
489 | // override functions |
490 | int num_outputs() const final; |
491 | Array<IterVar> root_iter_vars() const final; |
492 | DataType output_dtype(size_t i) const final; |
493 | Array<PrimExpr> output_shape(size_t i) const final; |
494 | Array<Tensor> InputTensors() const final; |
495 | Operation ReplaceInputs(const Operation& self, |
496 | const std::unordered_map<Tensor, Tensor>& rmap) const final; |
497 | void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, |
498 | const std::unordered_map<const VarNode*, IntSet>& dom_map, |
499 | std::unordered_map<Tensor, TensorDom>* out_dom_map) const final; |
500 | void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom, |
501 | std::unordered_map<IterVar, Range>* out_dom_map) const final; |
502 | Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map, |
503 | const Stmt& body, String storage_scope = "" ) const final; |
504 | Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map, |
505 | bool debug_keep_trivial_loop) const final; |
506 | |
507 | void VisitAttrs(AttrVisitor* v) { |
508 | v->Visit("name" , &name); |
509 | v->Visit("tag" , &tag); |
510 | v->Visit("attrs" , &attrs); |
511 | v->Visit("inputs" , &inputs); |
512 | v->Visit("outputs" , &outputs); |
513 | v->Visit("axis" , &axis); |
514 | v->Visit("body" , &body); |
515 | } |
516 | |
517 | static constexpr const char* _type_key = "HybridOp" ; |
518 | TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode); |
519 | }; |
520 | |
521 | /*! |
522 | * \brief Managed reference to HybridOpNode |
523 | * \sa HybridOpNode |
524 | */ |
525 | class HybridOp : public Operation { |
526 | public: |
527 | TVM_DLL HybridOp(std::string name, std::string tag, Map<String, ObjectRef> attrs, |
528 | Array<Tensor> inputs, Array<Tensor> outputs, Stmt body); |
529 | |
530 | TVM_DEFINE_OBJECT_REF_METHODS(HybridOp, Operation, HybridOpNode); |
531 | }; |
532 | |
533 | /*! |
534 | * \brief Construct a new Var expression |
535 | * \param name_hint The name hint for the expression |
536 | * \param t The type of the expression |
537 | */ |
538 | TVM_DLL Var var(std::string name_hint, DataType t = DataType::Int(32)); |
539 | |
540 | /*! |
541 | * \brief Create a new IterVar that represents an axis in thread. |
542 | * |
543 | * \param dom Optional, domain of the thread axis. |
544 | * \param tag The thread tag of the axis. |
545 | */ |
546 | TVM_DLL IterVar thread_axis(Range dom, std::string tag); |
547 | |
548 | /*! |
549 | * \brief Create a new IterVar for reduction operations. |
550 | * |
551 | * \param dom The domain of the reduction axis. |
552 | * \param name The name of the reduction axis. |
553 | */ |
554 | TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv" ); |
555 | |
556 | /*! \brief The compute function to specify the input source of a Tensor */ |
557 | using FCompute = std::function<PrimExpr(const Array<Var>& i)>; |
558 | |
559 | /*! \brief The compute function to specify the inputs source of Tensors */ |
560 | using FBatchCompute = std::function<Array<PrimExpr>(const Array<Var>& i)>; |
561 | |
562 | /*! |
563 | * \brief create a place holder tensor. |
564 | * \param shape The shape of the tensor. |
565 | * \param dtype the data type of the tensor. |
566 | * \param name The name of the Tensor. |
567 | */ |
568 | TVM_DLL Tensor placeholder(Array<PrimExpr> shape, DataType dtype = DataType::Float(32), |
569 | std::string name = "placeholder" ); |
570 | |
571 | /*! |
572 | * \brief Construct a new tensor by computing over shape, |
573 | * using the computation rule: result_tensor[axis] = fcompute(axis) |
574 | * \param shape Shape of the tensor. |
575 | * \param fcompute The compute function to create the tensor. |
576 | * \param name The optional name of the tensor. |
577 | * \param tag The optional tag of the tensor. |
578 | * \param attrs Optional additional attributes of the compute. |
579 | */ |
580 | TVM_DLL Tensor compute(Array<PrimExpr> shape, FCompute fcompute, std::string name = "tensor" , |
581 | std::string tag = "" , Map<String, ObjectRef> attrs = {}); |
582 | |
583 | /*! |
584 | * \brief Construct a new tensor by computing over shape, |
585 | * using the computation rule: result_tensor[axis] = fcompute(axis) |
586 | * \param shape Shape of the tensor. |
587 | * \param fcompute The compute function to create the tensors. |
588 | * \param name The optional name of the tensor. |
589 | * \param tag The optional tag of the tensor. |
590 | * \param attrs Optional additional attributes of the compute. |
591 | */ |
592 | TVM_DLL Array<Tensor> compute(Array<PrimExpr> shape, FBatchCompute fcompute, |
593 | std::string name = "tensor" , std::string tag = "" , |
594 | Map<String, ObjectRef> attrs = {}); |
595 | |
596 | /*! |
597 | * \brief Construct new tensors by scan. |
598 | * |
599 | * \param init The intialize tensor of first K steps. |
600 | * \param update The update tensor indicated the updated result after each timestamp. |
601 | * \param state_placeholder The placeholder for the states. |
602 | * \param inputs The inputs to the scan body, this is optional, |
603 | * but recommended to provide concrete information about scan body. |
604 | * \param name The optional name of the tensor. |
605 | * \param tag The optional tag of the tensor. |
606 | * \param attrs Optional additional attributes of the compute. |
607 | */ |
608 | TVM_DLL Array<Tensor> scan(Array<Tensor> init, Array<Tensor> update, |
609 | Array<Tensor> state_placeholder, Array<Tensor> inputs = Array<Tensor>(), |
610 | std::string name = "scan" , std::string tag = "" , |
611 | Map<String, ObjectRef> attrs = {}); |
612 | |
613 | // same as compute, specialized for different fcompute function |
614 | inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var)> f, |
615 | std::string name = "tensor" , std::string tag = "" , |
616 | Map<String, ObjectRef> attrs = {}) { |
617 | FCompute fc = [f](const Array<Var>& i) { return f(i[0]); }; |
618 | return compute(shape, fc, name, tag, attrs); |
619 | } |
620 | inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var, Var)> f, |
621 | std::string name = "tensor" , std::string tag = "" , |
622 | Map<String, ObjectRef> attrs = {}) { |
623 | FCompute fc = [f](const Array<Var>& i) { return f(i[0], i[1]); }; |
624 | return compute(shape, fc, name, tag, attrs); |
625 | } |
626 | inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var, Var, Var)> f, |
627 | std::string name = "tensor" , std::string tag = "" , |
628 | Map<String, ObjectRef> attrs = {}) { |
629 | FCompute fc = [f](const Array<Var>& i) { return f(i[0], i[1], i[2]); }; |
630 | return compute(shape, fc, name, tag, attrs); |
631 | } |
632 | inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var, Var, Var, Var)> f, |
633 | std::string name = "tensor" , std::string tag = "" , |
634 | Map<String, ObjectRef> attrs = {}) { |
635 | FCompute fc = [f](const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); }; |
636 | return compute(shape, fc, name, tag, attrs); |
637 | } |
638 | |
639 | // inline function. |
640 | inline const OperationNode* Operation::operator->() const { |
641 | return static_cast<const OperationNode*>(get()); |
642 | } |
643 | } // namespace te |
644 | } // namespace tvm |
645 | #endif // TVM_TE_OPERATION_H_ |
646 | |