1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/te/operation.h
22 * \brief Operation node can generate one or multiple Tensors
23 */
24#ifndef TVM_TE_OPERATION_H_
25#define TVM_TE_OPERATION_H_
26
27#include <tvm/arith/analyzer.h>
28#include <tvm/te/schedule.h>
29#include <tvm/te/tensor.h>
30#include <tvm/tir/buffer.h>
31#include <tvm/tir/expr.h>
32#include <tvm/tir/op.h>
33
34#include <string>
35#include <unordered_map>
36#include <vector>
37
38namespace tvm {
39/*! \brief Tensor expression language DSL. */
40namespace te {
41
42/*!
43 * \brief Temporary data structure to store union
44 * of bounds of each axis of Tensor.
45 */
46struct TensorDom {
47 // constructor
48 explicit TensorDom(int ndim) : data(ndim) {}
49 /*! \brief The domain data */
50 std::vector<std::vector<IntSet>> data;
51};
52
53/*!
54 * \brief Base class of all operation nodes
55 */
56class TVM_DLL OperationNode : public Object {
57 public:
58 /*! \brief optional name of the operation */
59 std::string name;
60 /*! \brief optional tag of the operation */
61 std::string tag;
62 /*! \brief additional attributes of the operation*/
63 Map<String, ObjectRef> attrs;
64 // virtual destructor.
65 virtual ~OperationNode() {}
66 /*! \return number of outputs */
67 virtual int num_outputs() const = 0;
68 /*!
69 * \return The list of iteration variable at root
70 * \note root_iter_vars decides the shape of the outputs.
71 */
72 virtual Array<IterVar> root_iter_vars() const = 0;
73 /*!
74 * \brief Get data type. i-th output tensor.
75 * \param i The output index.
76 * \return type of i-th output.
77 */
78 virtual DataType output_dtype(size_t i) const = 0;
79 /*!
80 * \brief Get shape of i-th output tensor.
81 * \param i The output index.
82 * \return shape of i-th output.
83 */
84 virtual Array<PrimExpr> output_shape(size_t i) const = 0;
85 /*!
86 * \brief List all the input Tensors.
87 * \return List of input tensors.
88 */
89 virtual Array<Tensor> InputTensors() const = 0;
90 /*!
91 * \brief Replace the input of the operation by pattern specified by rmap.
92 *
93 * \param self The reference to self.
94 * \param rmap The replacement map.
95 * \return self if nothing is replaced, otherwise return replaced op.
96 */
97 virtual Operation ReplaceInputs(const Operation& self,
98 const std::unordered_map<Tensor, Tensor>& rmap) const = 0;
99 /*!
100 * \brief Propagate the bounds to inputs
101 * \param self The reference to self.
102 * \param analyzer The analyzer to be used in the function.
103 * \param dom_map the domain map of Variables(corresponds to root_iter_vars)
104 * \param out_dom_map The output domain.
105 * The function is only asked to fill the bounds for Tensors that
106 * is already in the out_dom_map
107 */
108 virtual void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
109 const std::unordered_map<const VarNode*, IntSet>& dom_map,
110 std::unordered_map<Tensor, TensorDom>* out_dom_map) const = 0;
111 /*!
112 * \brief Gather the bound from output tensor.
113 * Set the range of each root_iter_vars in the op to out_dom_map
114 *
115 * \param self The reference to self.
116 * \param tensor_dom Domain map of Tensor->access set of each dimension.
117 * \param out_dom_map The output domain map of each IterVar to be setted.
118 */
119 virtual void GatherBound(const Operation& self,
120 const std::unordered_map<Tensor, TensorDom>& tensor_dom,
121 std::unordered_map<IterVar, Range>* out_dom_map) const = 0;
122 /*!
123 * \brief Build the Realize statement that realizes
124 * the op's output tensors.
125 * \param stage the op's stage.
126 * \param realize_map The realization domain map of the operators.
127 * \param body The body that is going to get
128 * \param storage_scope The storage scope associated with this realization
129 * \return A realization statement that wraps body.
130 */
131 virtual Stmt BuildRealize(const Stage& stage,
132 const std::unordered_map<IterVar, Range>& realize_map, const Stmt& body,
133 String storage_scope = "") const = 0;
134 /*!
135 * \brief Build the statement that provide the output tensors.
136 * \param stage The schedule stage of the op.
137 * \param dom_map The domain map of all iteration domains.
138 * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
139 * \return A statement that add production and wraps consumer.
140 */
141 virtual Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
142 bool debug_keep_trivial_loop) const = 0;
143
144 static constexpr const char* _type_key = "Operation";
145
146 TVM_DECLARE_BASE_OBJECT_INFO(OperationNode, Object);
147};
148
149/*!
150 * \brief A placeholder op represents an input placeholder.
151 */
152class PlaceholderOpNode : public OperationNode {
153 public:
154 /*! \brief The shape of the input */
155 Array<PrimExpr> shape;
156 /*! \brief The data type of the input. */
157 DataType dtype;
158 // override behavior.
159 int num_outputs() const final;
160 Array<IterVar> root_iter_vars() const final;
161 DataType output_dtype(size_t i) const final;
162 Array<PrimExpr> output_shape(size_t i) const final;
163 Array<Tensor> InputTensors() const final;
164 Operation ReplaceInputs(const Operation& self,
165 const std::unordered_map<Tensor, Tensor>& rmap) const final;
166 void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
167 const std::unordered_map<const VarNode*, IntSet>& dom_map,
168 std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
169 void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
170 std::unordered_map<IterVar, Range>* out_dom_map) const final;
171 Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
172 const Stmt& body, String storage_scope = "") const final;
173 Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
174 bool debug_keep_trivial_loop) const final;
175
176 void VisitAttrs(AttrVisitor* v) {
177 v->Visit("name", &name);
178 v->Visit("tag", &tag);
179 v->Visit("attrs", &attrs);
180 v->Visit("shape", &shape);
181 v->Visit("dtype", &dtype);
182 }
183
184 static constexpr const char* _type_key = "PlaceholderOp";
185 TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode);
186};
187
188/*!
189 * \brief Managed reference to PlaceholderOpNode
190 * \sa PlaceholderOpNode
191 */
192class PlaceholderOp : public Operation {
193 public:
194 TVM_DLL PlaceholderOp(std::string name, Array<PrimExpr> shape, DataType dtype);
195
196 TVM_DEFINE_OBJECT_REF_METHODS(PlaceholderOp, Operation, PlaceholderOpNode);
197};
198
199/*!
200 * \brief A Compute op that compute a tensor on certain domain.
201 * This is the base class for ComputeOp (operating on a scalar at a time) and
202 * TensorComputeOp (operating on a TensorSlice at a time)
203 */
204class TVM_DLL BaseComputeOpNode : public OperationNode {
205 public:
206 /*! \brief IterVar on each axis */
207 Array<IterVar> axis;
208 /*! \brief IterVar on each reduction axis, if the body is a Reduce */
209 Array<IterVar> reduce_axis;
210 // override functions
211 Array<IterVar> root_iter_vars() const final;
212 Array<PrimExpr> output_shape(size_t idx) const final;
213 void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
214 std::unordered_map<IterVar, Range>* out_dom_map) const final;
215 Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
216 const Stmt& body, String storage_scope = "") const final;
217 virtual size_t num_schedulable_dims() const = 0;
218
219 static constexpr const char* _type_key = "BaseComputeOp";
220 TVM_DECLARE_BASE_OBJECT_INFO(BaseComputeOpNode, OperationNode);
221};
222
223/*!
224 * \brief A Compute op that compute a tensor on certain domain.
225 */
226class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
227 public:
228 /*! \brief the compute expression */
229 Array<PrimExpr> body;
230 /*! \brief constructor */
231 ComputeOpNode() {}
232 // override functions
233 int num_outputs() const final;
234 DataType output_dtype(size_t i) const final;
235 Array<Tensor> InputTensors() const final;
236 Operation ReplaceInputs(const Operation& self,
237 const std::unordered_map<Tensor, Tensor>& rmap) const final;
238 void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
239 const std::unordered_map<const VarNode*, IntSet>& dom_map,
240 std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
241 Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
242 bool debug_keep_trivial_loop) const final;
243 size_t num_schedulable_dims() const final;
244
245 void VisitAttrs(AttrVisitor* v) {
246 v->Visit("name", &name);
247 v->Visit("tag", &tag);
248 v->Visit("attrs", &attrs);
249 v->Visit("axis", &axis);
250 v->Visit("reduce_axis", &reduce_axis);
251 v->Visit("body", &body);
252 }
253
254 static constexpr const char* _type_key = "ComputeOp";
255 TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode);
256};
257
258/*!
259 * \brief Managed reference to ComputeOpNode
260 * \sa ComputeOpNode
261 */
262class ComputeOp : public Operation {
263 public:
264 TVM_DLL ComputeOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
265 Array<IterVar> axis, Array<PrimExpr> body);
266
267 TVM_DEFINE_OBJECT_REF_METHODS(ComputeOp, Operation, ComputeOpNode);
268 TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeOpNode);
269};
270
271/*!
272 * \brief A TenorCompute op that compute a tensor with an tensor intrinsic.
273 */
274class TensorComputeOpNode : public BaseComputeOpNode {
275 public:
276 /*! \brief number of axes that can be scheduled */
277 int schedulable_ndim;
278 /*! \brief TensorIntrin used to compute */
279 TensorIntrin intrin;
280 /*! \brief input tensors of intrin */
281 Array<Tensor> inputs;
282 /*! \brief region of input tensors */
283 Array<Region> input_regions;
284 /*! \brief scalar expression inputs */
285 Array<PrimExpr> scalar_inputs;
286 /*! \brief constructor */
287 TensorComputeOpNode() {}
288 // override functions
289 int num_outputs() const final;
290 DataType output_dtype(size_t i) const final;
291 Array<Tensor> InputTensors() const final;
292 Operation ReplaceInputs(const Operation& self,
293 const std::unordered_map<Tensor, Tensor>& rmap) const final;
294 void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
295 const std::unordered_map<const VarNode*, IntSet>& dom_map,
296 std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
297 Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
298 bool debug_keep_trivial_loop) const final;
299 size_t num_schedulable_dims() const final;
300
301 void VisitAttrs(AttrVisitor* v) {
302 v->Visit("name", &name);
303 v->Visit("tag", &tag);
304 v->Visit("axis", &axis);
305 v->Visit("reduce_axis", &reduce_axis);
306 v->Visit("schedulable_ndim", &schedulable_ndim);
307 v->Visit("intrin", &intrin);
308 v->Visit("inputs", &inputs);
309 v->Visit("input_regions", &input_regions);
310 v->Visit("scalar_inputs", &scalar_inputs);
311 }
312
313 static constexpr const char* _type_key = "TensorComputeOp";
314 TVM_DECLARE_FINAL_OBJECT_INFO(TensorComputeOpNode, BaseComputeOpNode);
315};
316
317/*!
318 * \brief Managed reference to TensorComputeOpNode
319 * \sa TensorComputeOpNode
320 */
321class TensorComputeOp : public Operation {
322 public:
323 TVM_DLL TensorComputeOp(std::string name, std::string tag, Array<IterVar> axis,
324 Array<IterVar> reduce_axis, int schedulable_ndim, TensorIntrin intrin,
325 Array<Tensor> tensors, Array<Region> regions,
326 Array<PrimExpr> scalar_inputs);
327
328 TVM_DEFINE_OBJECT_REF_METHODS(TensorComputeOp, Operation, TensorComputeOpNode);
329};
330
331/*!
332 * \brief Symbolic scan.
333 */
334class ScanOpNode : public OperationNode {
335 public:
336 /*! \brief IterVar to scan over */
337 IterVar scan_axis;
338 /*! \brief the initialization tensors */
339 Array<Tensor> init;
340 /*! \brief the update function represented by tensor */
341 Array<Tensor> update;
342 /*! \brief The placeholder to refer as states in update. */
343 Array<Tensor> state_placeholder;
344 /*!
345 * \brief the inputs to the scan, these are optionally provided
346 * But they can be helpful to provide hints to speedup get of scan body.
347 */
348 Array<Tensor> inputs;
349 /*!
350 * \brief Spatial axis to indicate spatial dimension of each output.
351 * They corresponds to flattened spatial axis of the outputs.
352 *
353 * [output[0].axis[1], output[0].axis[2]... output[k].axis[j]...]
354 * These are auxiliary data structure for storing result of bound inference.
355 * They do not corresponds to splittable iterations, thus the name comes
356 * with underscore.
357 */
358 Array<IterVar> spatial_axis_;
359 /*! \brief constructor */
360 ScanOpNode() {}
361 // override behavior.
362 int num_outputs() const final;
363 Array<IterVar> root_iter_vars() const final;
364 DataType output_dtype(size_t i) const final;
365 Array<PrimExpr> output_shape(size_t i) const final;
366 Array<Tensor> InputTensors() const final;
367 Operation ReplaceInputs(const Operation& self,
368 const std::unordered_map<Tensor, Tensor>& rmap) const final;
369 void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
370 const std::unordered_map<const VarNode*, IntSet>& dom_map,
371 std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
372 void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
373 std::unordered_map<IterVar, Range>* out_dom_map) const final;
374 Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
375 const Stmt& body, String storage_scope = "") const final;
376 Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
377 bool debug_keep_trivial_loop) const final;
378
379 void VisitAttrs(AttrVisitor* v) {
380 v->Visit("name", &name);
381 v->Visit("tag", &tag);
382 v->Visit("attrs", &attrs);
383 v->Visit("scan_axis", &scan_axis);
384 v->Visit("init", &init);
385 v->Visit("update", &update);
386 v->Visit("state_placeholder", &state_placeholder);
387 v->Visit("inputs", &inputs);
388 v->Visit("spatial_axis_", &spatial_axis_);
389 }
390
391 static constexpr const char* _type_key = "ScanOp";
392 TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode);
393};
394
395/*!
396 * \brief Managed reference to ScanOpNode
397 * \sa ScanOpNode
398 */
399class ScanOp : public Operation {
400 public:
401 TVM_DLL ScanOp(std::string name, std::string tag, Map<String, ObjectRef> attrs, IterVar axis,
402 Array<Tensor> init, Array<Tensor> update, Array<Tensor> state_placeholder,
403 Array<Tensor> input);
404
405 TVM_DEFINE_OBJECT_REF_METHODS(ScanOp, Operation, ScanOpNode);
406};
407
408/*!
409 * \brief External computation that cannot be splitted.
410 */
411class ExternOpNode : public OperationNode {
412 public:
413 /*! \brief The input tensors */
414 Array<Tensor> inputs;
415 /*! \brief Symbolic placeholder representation of inputs */
416 Array<Buffer> input_placeholders;
417 /*! \brief Symbolic placeholder representation of outputs */
418 Array<Buffer> output_placeholders;
419 /*! \brief the statement that generates the computation. */
420 Stmt body;
421
422 /*! \brief constructor */
423 ExternOpNode() {}
424 // override functions
425 int num_outputs() const final;
426 Array<IterVar> root_iter_vars() const final;
427 DataType output_dtype(size_t i) const final;
428 Array<PrimExpr> output_shape(size_t i) const final;
429 Array<Tensor> InputTensors() const final;
430 Operation ReplaceInputs(const Operation& self,
431 const std::unordered_map<Tensor, Tensor>& rmap) const final;
432 void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
433 const std::unordered_map<const VarNode*, IntSet>& dom_map,
434 std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
435 void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
436 std::unordered_map<IterVar, Range>* out_dom_map) const final;
437 Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
438 const Stmt& body, String storage_scope = "") const final;
439 Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
440 bool debug_keep_trivial_loop) const final;
441
442 void VisitAttrs(AttrVisitor* v) {
443 v->Visit("name", &name);
444 v->Visit("tag", &tag);
445 v->Visit("attrs", &attrs);
446 v->Visit("inputs", &inputs);
447 v->Visit("input_placeholders", &input_placeholders);
448 v->Visit("output_placeholders", &output_placeholders);
449 v->Visit("body", &body);
450 }
451
452 static constexpr const char* _type_key = "ExternOp";
453 TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode);
454};
455
456/*!
457 * \brief Managed reference to ExternOpNode
458 * \sa ExternOpNode
459 */
460class ExternOp : public Operation {
461 public:
462 TVM_DLL ExternOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
463 Array<Tensor> inputs, Array<Buffer> input_placeholders,
464 Array<Buffer> output_placeholders, Stmt body);
465
466 TVM_DEFINE_OBJECT_REF_METHODS(ExternOp, Operation, ExternOpNode);
467};
468
469/*!
470 * \brief A computation operator that generated by hybrid script.
471 */
472class HybridOpNode : public OperationNode {
473 public:
474 /*! \brief The input tensors */
475 Array<Tensor> inputs;
476 /*! \brief Symbolic placeholder representation of outputs */
477 Array<Tensor> outputs;
478 /*! \brief The axis of iterations */
479 Array<IterVar> axis;
480 /*! \brief the statement that generates the computation. This is
481 * slightly different from the body in ExternOpNode. All the output
482 * tensors keep its own name specified by users in the script.
483 * However, when compilation, these tensors will be placed by those
484 * actual output tensors. */
485 Stmt body;
486
487 /*! \brief constructor */
488 HybridOpNode() {}
489 // override functions
490 int num_outputs() const final;
491 Array<IterVar> root_iter_vars() const final;
492 DataType output_dtype(size_t i) const final;
493 Array<PrimExpr> output_shape(size_t i) const final;
494 Array<Tensor> InputTensors() const final;
495 Operation ReplaceInputs(const Operation& self,
496 const std::unordered_map<Tensor, Tensor>& rmap) const final;
497 void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
498 const std::unordered_map<const VarNode*, IntSet>& dom_map,
499 std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
500 void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
501 std::unordered_map<IterVar, Range>* out_dom_map) const final;
502 Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
503 const Stmt& body, String storage_scope = "") const final;
504 Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
505 bool debug_keep_trivial_loop) const final;
506
507 void VisitAttrs(AttrVisitor* v) {
508 v->Visit("name", &name);
509 v->Visit("tag", &tag);
510 v->Visit("attrs", &attrs);
511 v->Visit("inputs", &inputs);
512 v->Visit("outputs", &outputs);
513 v->Visit("axis", &axis);
514 v->Visit("body", &body);
515 }
516
517 static constexpr const char* _type_key = "HybridOp";
518 TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode);
519};
520
521/*!
522 * \brief Managed reference to HybridOpNode
523 * \sa HybridOpNode
524 */
525class HybridOp : public Operation {
526 public:
527 TVM_DLL HybridOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
528 Array<Tensor> inputs, Array<Tensor> outputs, Stmt body);
529
530 TVM_DEFINE_OBJECT_REF_METHODS(HybridOp, Operation, HybridOpNode);
531};
532
533/*!
534 * \brief Construct a new Var expression
535 * \param name_hint The name hint for the expression
536 * \param t The type of the expression
537 */
538TVM_DLL Var var(std::string name_hint, DataType t = DataType::Int(32));
539
540/*!
541 * \brief Create a new IterVar that represents an axis in thread.
542 *
543 * \param dom Optional, domain of the thread axis.
544 * \param tag The thread tag of the axis.
545 */
546TVM_DLL IterVar thread_axis(Range dom, std::string tag);
547
548/*!
549 * \brief Create a new IterVar for reduction operations.
550 *
551 * \param dom The domain of the reduction axis.
552 * \param name The name of the reduction axis.
553 */
554TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv");
555
556/*! \brief The compute function to specify the input source of a Tensor */
557using FCompute = std::function<PrimExpr(const Array<Var>& i)>;
558
559/*! \brief The compute function to specify the inputs source of Tensors */
560using FBatchCompute = std::function<Array<PrimExpr>(const Array<Var>& i)>;
561
562/*!
563 * \brief create a place holder tensor.
564 * \param shape The shape of the tensor.
565 * \param dtype the data type of the tensor.
566 * \param name The name of the Tensor.
567 */
568TVM_DLL Tensor placeholder(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
569 std::string name = "placeholder");
570
571/*!
572 * \brief Construct a new tensor by computing over shape,
573 * using the computation rule: result_tensor[axis] = fcompute(axis)
574 * \param shape Shape of the tensor.
575 * \param fcompute The compute function to create the tensor.
576 * \param name The optional name of the tensor.
577 * \param tag The optional tag of the tensor.
578 * \param attrs Optional additional attributes of the compute.
579 */
580TVM_DLL Tensor compute(Array<PrimExpr> shape, FCompute fcompute, std::string name = "tensor",
581 std::string tag = "", Map<String, ObjectRef> attrs = {});
582
583/*!
584 * \brief Construct a new tensor by computing over shape,
585 * using the computation rule: result_tensor[axis] = fcompute(axis)
586 * \param shape Shape of the tensor.
587 * \param fcompute The compute function to create the tensors.
588 * \param name The optional name of the tensor.
589 * \param tag The optional tag of the tensor.
590 * \param attrs Optional additional attributes of the compute.
591 */
592TVM_DLL Array<Tensor> compute(Array<PrimExpr> shape, FBatchCompute fcompute,
593 std::string name = "tensor", std::string tag = "",
594 Map<String, ObjectRef> attrs = {});
595
596/*!
597 * \brief Construct new tensors by scan.
598 *
599 * \param init The intialize tensor of first K steps.
600 * \param update The update tensor indicated the updated result after each timestamp.
601 * \param state_placeholder The placeholder for the states.
602 * \param inputs The inputs to the scan body, this is optional,
603 * but recommended to provide concrete information about scan body.
604 * \param name The optional name of the tensor.
605 * \param tag The optional tag of the tensor.
606 * \param attrs Optional additional attributes of the compute.
607 */
608TVM_DLL Array<Tensor> scan(Array<Tensor> init, Array<Tensor> update,
609 Array<Tensor> state_placeholder, Array<Tensor> inputs = Array<Tensor>(),
610 std::string name = "scan", std::string tag = "",
611 Map<String, ObjectRef> attrs = {});
612
613// same as compute, specialized for different fcompute function
614inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var)> f,
615 std::string name = "tensor", std::string tag = "",
616 Map<String, ObjectRef> attrs = {}) {
617 FCompute fc = [f](const Array<Var>& i) { return f(i[0]); };
618 return compute(shape, fc, name, tag, attrs);
619}
620inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var, Var)> f,
621 std::string name = "tensor", std::string tag = "",
622 Map<String, ObjectRef> attrs = {}) {
623 FCompute fc = [f](const Array<Var>& i) { return f(i[0], i[1]); };
624 return compute(shape, fc, name, tag, attrs);
625}
626inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var, Var, Var)> f,
627 std::string name = "tensor", std::string tag = "",
628 Map<String, ObjectRef> attrs = {}) {
629 FCompute fc = [f](const Array<Var>& i) { return f(i[0], i[1], i[2]); };
630 return compute(shape, fc, name, tag, attrs);
631}
632inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var, Var, Var, Var)> f,
633 std::string name = "tensor", std::string tag = "",
634 Map<String, ObjectRef> attrs = {}) {
635 FCompute fc = [f](const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
636 return compute(shape, fc, name, tag, attrs);
637}
638
639// inline function.
640inline const OperationNode* Operation::operator->() const {
641 return static_cast<const OperationNode*>(get());
642}
643} // namespace te
644} // namespace tvm
645#endif // TVM_TE_OPERATION_H_
646