1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/te/tensor_intrin.h |
22 | * \brief Tensor intrinsic operations. |
23 | */ |
24 | #ifndef TVM_TE_TENSOR_INTRIN_H_ |
25 | #define TVM_TE_TENSOR_INTRIN_H_ |
26 | |
27 | #include <tvm/te/tensor.h> |
28 | #include <tvm/tir/buffer.h> |
29 | |
30 | #include <string> |
31 | |
32 | namespace tvm { |
33 | namespace te { |
34 | |
35 | /*! \brief Node to represent a Tensor intrinsic operator */ |
36 | class TensorIntrinNode : public Object { |
37 | public: |
38 | /*! \brief The name of the intrinsic */ |
39 | std::string name; |
40 | /*! \brief The operation this intrinsics is carrying out */ |
41 | Operation op; |
42 | /*! \brief List of inputs of operator, placeholder in postdfs order */ |
43 | Array<Tensor> inputs; |
44 | /*! |
45 | * \brief Symbolic buffers of each output/input tensor |
46 | * buffers[0:len(inputs)] are buffers of the inputs. |
47 | * buffers[len(inputs):] are buffers of each output. |
48 | * |
49 | * \note When a field in Buffer is Var, it means we can be flexible |
50 | * wrt that field and Var can occur in body. |
51 | * When it is a constant, it means we can only take data in that shape. |
52 | */ |
53 | Array<Buffer> buffers; |
54 | /*! \brief List of scalar variables, used in body. These placeholders |
55 | * will be bound to expressions passed in when the TensorIntrin is called |
56 | * from a TensorComputeOp. |
57 | */ |
58 | Array<Var> scalar_params; |
59 | /*! \brief The normal statement to execute the intrinsic */ |
60 | Stmt body; |
61 | /*! |
62 | * \brief Special statement for reduction op, can be None |
63 | * reset the value of output buffer to identity value. |
64 | */ |
65 | Stmt reduce_init; |
66 | /*! |
67 | * \brief Special statement for reduction op, can be None |
68 | * Reduce: do a reduction of current output buffer with the result. |
69 | */ |
70 | Stmt reduce_update; |
71 | /*! \brief constructor */ |
72 | TensorIntrinNode() {} |
73 | |
74 | void VisitAttrs(AttrVisitor* v) { |
75 | v->Visit("name" , &name); |
76 | v->Visit("op" , &op); |
77 | v->Visit("inputs" , &inputs); |
78 | v->Visit("buffers" , &buffers); |
79 | v->Visit("scalar_params" , &scalar_params); |
80 | v->Visit("body" , &body); |
81 | v->Visit("reduce_init" , &reduce_init); |
82 | v->Visit("reduce_update" , &reduce_update); |
83 | } |
84 | |
85 | static constexpr const char* _type_key = "TensorIntrin" ; |
86 | TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object); |
87 | }; |
88 | |
89 | /*! |
90 | * \brief Managed reference to TensorIntrinNode |
91 | * \sa TensorIntrinNode |
92 | */ |
93 | class TensorIntrin : public ObjectRef { |
94 | public: |
95 | TVM_DLL TensorIntrin(std::string name, Operation op, Array<Tensor> inputs, Array<Buffer> buffers, |
96 | Array<Var> scalar_params, Stmt body, Stmt reduce_init, Stmt reduce_update); |
97 | |
98 | TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode); |
99 | }; |
100 | |
101 | class TensorIntrinCallNode : public Object { |
102 | public: |
103 | /*! \brief the tensor intrinsic */ |
104 | TensorIntrin intrin; |
105 | /*! \brief input tensors of the intrinsic */ |
106 | Array<Tensor> tensors; |
107 | /*! \brief regions of input tensors */ |
108 | Array<Region> regions; |
109 | |
110 | /*! |
111 | * \brief IterVar on each reduction axis, if the |
112 | * intrin will use the reduce axis |
113 | */ |
114 | Array<IterVar> reduce_axis; |
115 | |
116 | /*! \brief scalar expression inputs */ |
117 | Array<PrimExpr> scalar_inputs; |
118 | |
119 | void VisitAttrs(AttrVisitor* v) { |
120 | v->Visit("intrin" , &intrin); |
121 | v->Visit("tensors" , &tensors); |
122 | v->Visit("regions" , ®ions); |
123 | v->Visit("reduce_axis" , &reduce_axis); |
124 | v->Visit("scalar_inputs" , &scalar_inputs); |
125 | } |
126 | |
127 | static constexpr const char* _type_key = "TensorIntrinCall" ; |
128 | TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinCallNode, Object); |
129 | }; |
130 | |
131 | /*! |
132 | * \brief Managed reference to TensorIntrinCallNode |
133 | * \sa TensorIntrinCallNode |
134 | */ |
135 | class TensorIntrinCall : public ObjectRef { |
136 | public: |
137 | TVM_DLL TensorIntrinCall(TensorIntrin intrin, Array<Tensor> tensors, Array<Region> regions, |
138 | Array<IterVar> reduce_axis, Array<PrimExpr> scalar_inputs); |
139 | |
140 | TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrinCall, ObjectRef, TensorIntrinCallNode); |
141 | }; |
142 | |
143 | } // namespace te |
144 | } // namespace tvm |
145 | #endif // TVM_TE_TENSOR_INTRIN_H_ |
146 | |