1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/te/tensor.h |
22 | * \brief Dataflow tensor object |
23 | */ |
24 | #ifndef TVM_TE_TENSOR_H_ |
25 | #define TVM_TE_TENSOR_H_ |
26 | |
27 | #include <tvm/arith/bound.h> |
28 | #include <tvm/tir/expr.h> |
29 | #include <tvm/tir/op.h> |
30 | |
31 | #include <string> |
32 | #include <type_traits> |
33 | #include <utility> |
34 | #include <vector> |
35 | |
36 | namespace tvm { |
37 | namespace te { |
38 | |
39 | using arith::IntSet; |
40 | using namespace tvm::tir; |
41 | |
42 | // internal node container for Operation |
43 | class OperationNode; |
44 | class Tensor; |
45 | |
46 | /*! \brief Operation that produces tensors */ |
47 | class Operation : public ObjectRef { |
48 | public: |
49 | /*! \brief default constructor */ |
50 | Operation() {} |
51 | explicit Operation(ObjectPtr<Object> n) : ObjectRef(n) {} |
52 | /*! |
53 | * \brief access the internal node container |
54 | * \return the pointer to the internal node container |
55 | */ |
56 | inline const OperationNode* operator->() const; |
57 | /*! |
58 | * \brief get the i-th output of the operation. |
59 | * \param i the output index. |
60 | * \return The i-th output. |
61 | */ |
62 | TVM_DLL Tensor output(size_t i) const; |
63 | /*! \brief specify container node */ |
64 | using ContainerType = OperationNode; |
65 | }; |
66 | |
67 | /*! \brief Node to represent a tensor */ |
68 | class TensorNode : public DataProducerNode { |
69 | public: |
70 | /*! \brief The shape of the tensor */ |
71 | Array<PrimExpr> shape; |
72 | /*! \brief data type in the content of the tensor */ |
73 | DataType dtype; |
74 | /*! \brief the source operation, can be None */ |
75 | Operation op; |
76 | /*! \brief the output index from source operation */ |
77 | int value_index{0}; |
78 | /*! \brief constructor */ |
79 | TensorNode() {} |
80 | |
81 | void VisitAttrs(AttrVisitor* v) { |
82 | v->Visit("shape" , &shape); |
83 | v->Visit("dtype" , &dtype); |
84 | v->Visit("op" , &op); |
85 | v->Visit("value_index" , &value_index); |
86 | } |
87 | |
88 | Array<PrimExpr> GetShape() const final { return shape; } |
89 | |
90 | DataType GetDataType() const final { return dtype; } |
91 | |
92 | TVM_DLL String GetNameHint() const final; |
93 | |
94 | static constexpr const char* _type_key = "Tensor" ; |
95 | TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, DataProducerNode); |
96 | }; |
97 | |
98 | /*! |
99 | * \brief Tensor structure representing a possible input, |
100 | * or intermediate computation result. |
101 | */ |
102 | class Tensor : public DataProducer { |
103 | private: |
104 | /*! |
105 | * \brief Helper for indexing operations into tensors |
106 | * \param indices The indices |
107 | * \param support_negative_indices Whether to normalize indices in the case of negative indices. |
108 | * \return the result expression representing tensor read. |
109 | */ |
110 | inline PrimExpr IndexTensor(Array<PrimExpr> indices, bool support_negative_indices) const; |
111 | |
112 | public: |
113 | TVM_DLL Tensor(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index); |
114 | /*! |
115 | * \brief check if two tensors equals each other. |
116 | * \param other tensor to be checked. |
117 | * \return whether the two tensors equals each other. |
118 | */ |
119 | inline bool operator==(const Tensor& other) const; |
120 | /*! |
121 | * \brief check if two tensors are different. |
122 | * \param other tensor to be checked. |
123 | * \return whether the two tensors are different. |
124 | */ |
125 | inline bool operator!=(const Tensor& other) const; |
126 | /*! \return The dimension of the tensor */ |
127 | inline size_t ndim() const; |
128 | /*! |
129 | * \brief Take elements from the tensor |
130 | * \param args The indices |
131 | * \return the result expression representing tensor read. |
132 | */ |
133 | template <typename... Args> |
134 | inline PrimExpr operator()(Args&&... args) const { |
135 | Array<PrimExpr> indices{std::forward<Args>(args)...}; |
136 | return operator()(indices); |
137 | } |
138 | /*! |
139 | * \brief Take elements from the tensor |
140 | * \param indices the indices. |
141 | * \return the result expression representing tensor read. |
142 | */ |
143 | TVM_DLL PrimExpr operator()(Array<PrimExpr> indices) const; |
144 | /*! |
145 | * \brief Take elements from the tensor |
146 | * \param indices the indices. |
147 | * \return the result expression representing tensor read. |
148 | */ |
149 | TVM_DLL PrimExpr operator()(Array<Var> indices) const; |
150 | /*! |
151 | * \brief Take elements from the tensor with support for negative indices. |
152 | * \param args The indices |
153 | * \return the result expression representing tensor read. |
154 | */ |
155 | template <typename... Args> |
156 | TVM_DLL PrimExpr IndexWithNegativeIndices(Args&&... args) const { |
157 | Array<PrimExpr> indices{std::forward<Args>(args)...}; |
158 | return IndexWithNegativeIndices(indices); |
159 | } |
160 | /*! |
161 | * \brief Take elements from the tensor with support for negative indices. |
162 | * \param indices the indices. |
163 | * \return the result expression representing tensor read. |
164 | */ |
165 | TVM_DLL PrimExpr IndexWithNegativeIndices(Array<PrimExpr> indices) const; |
166 | /*! |
167 | * \brief Take elements from the tensor with support for negative indices. |
168 | * \param indices the indices. |
169 | * \return the result expression representing tensor read. |
170 | */ |
171 | TVM_DLL PrimExpr IndexWithNegativeIndices(Array<Var> indices) const; |
172 | |
173 | /*! |
174 | * \brief data structure to represent a slice that fixes first k coordinates. |
175 | * This is used to enable syntax sugar of Tensor[x][y][z] to get the element. |
176 | */ |
177 | class Slice { |
178 | public: |
179 | // construct via tensor and indices |
180 | Slice(const Tensor& tensor, std::vector<PrimExpr> indices) |
181 | : tensor_(tensor), indices_(indices) {} |
182 | /*! |
183 | * \brief get i-th slice from the current slice. |
184 | * \param i the index of the coordinate |
185 | * \return the subsequent slice. |
186 | */ |
187 | inline Slice operator[](PrimExpr i) { |
188 | std::vector<PrimExpr> other = indices_; |
189 | other.emplace_back(i); |
190 | return Slice(tensor_, other); |
191 | } |
192 | /*! |
193 | * \brief Convert slice to expression. |
194 | * This is only valid when all the coordinates are fully specified. |
195 | * \return the corresponding expression of this slice. |
196 | */ |
197 | inline operator PrimExpr() const { return tensor_(indices_); } |
198 | |
199 | private: |
200 | const Tensor& tensor_; |
201 | std::vector<PrimExpr> indices_; |
202 | }; |
203 | /*! |
204 | * \brief get i-th slice from the current Tensor. |
205 | * \param i the index of the coordinate |
206 | * \return the subsequent slice. |
207 | */ |
208 | inline Slice operator[](PrimExpr i) const { return Slice(*this, {i}); } |
209 | |
210 | TVM_DEFINE_OBJECT_REF_METHODS(Tensor, DataProducer, TensorNode); |
211 | }; |
212 | |
213 | // Implementations of inline functions |
214 | inline size_t Tensor::ndim() const { return (*this)->shape.size(); } |
215 | |
216 | inline bool Tensor::operator==(const Tensor& other) const { |
217 | if (get() == other.get()) return true; |
218 | if (get() == nullptr || other.get() == nullptr) return false; |
219 | if ((*this)->op.defined() || other->op.defined()) { |
220 | return (*this)->op == other->op && (*this)->value_index == other->value_index; |
221 | } else { |
222 | return false; |
223 | } |
224 | } |
225 | |
226 | inline bool Tensor::operator!=(const Tensor& other) const { return !(*this == other); } |
227 | |
228 | // macro to turn every operation of slice to expression |
229 | #define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \ |
230 | inline PrimExpr operator Op(const Tensor::Slice& a) { return Op a.operator PrimExpr(); } |
231 | |
232 | #define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \ |
233 | template <typename T> \ |
234 | inline PrimExpr operator Op(const Tensor::Slice& a, const T& b) { \ |
235 | return a.operator PrimExpr() Op b; \ |
236 | } \ |
237 | template <typename T> \ |
238 | inline PrimExpr operator Op(const T& a, const Tensor::Slice& b) { \ |
239 | return a Op b.operator PrimExpr(); \ |
240 | } \ |
241 | inline PrimExpr operator Op(const Tensor::Slice& a, const Tensor::Slice& b) { \ |
242 | return a.operator PrimExpr() Op b.operator PrimExpr(); \ |
243 | } |
244 | |
245 | DEFINE_OVERLOAD_SLICE_UNARY_OP(!); |
246 | DEFINE_OVERLOAD_SLICE_UNARY_OP(-); |
247 | DEFINE_OVERLOAD_SLICE_BINARY_OP(+); |
248 | DEFINE_OVERLOAD_SLICE_BINARY_OP(-); |
249 | DEFINE_OVERLOAD_SLICE_BINARY_OP(*); |
250 | DEFINE_OVERLOAD_SLICE_BINARY_OP(==); |
251 | DEFINE_OVERLOAD_SLICE_BINARY_OP(<=); |
252 | DEFINE_OVERLOAD_SLICE_BINARY_OP(>=); |
253 | DEFINE_OVERLOAD_SLICE_BINARY_OP(!=); |
254 | DEFINE_OVERLOAD_SLICE_BINARY_OP(&&); |
255 | DEFINE_OVERLOAD_SLICE_BINARY_OP(||); |
256 | DEFINE_OVERLOAD_SLICE_BINARY_OP(>>); |
257 | DEFINE_OVERLOAD_SLICE_BINARY_OP(<<); |
258 | DEFINE_OVERLOAD_SLICE_BINARY_OP(>); // NOLINT(*) |
259 | DEFINE_OVERLOAD_SLICE_BINARY_OP(<); // NOLINT(*) |
260 | |
261 | } // namespace te |
262 | } // namespace tvm |
263 | |
264 | namespace std { |
265 | template <> |
266 | struct hash<::tvm::te::Operation> : public ::tvm::ObjectPtrHash {}; |
267 | |
268 | template <> |
269 | struct hash<::tvm::te::Tensor> { |
270 | std::size_t operator()(const ::tvm::te::Tensor& k) const { |
271 | ::tvm::ObjectPtrHash hasher; |
272 | if (k.defined() && k->op.defined()) { |
273 | return hasher(k->op); |
274 | } else { |
275 | return hasher(k); |
276 | } |
277 | } |
278 | }; |
279 | } // namespace std |
280 | #endif // TVM_TE_TENSOR_H_ |
281 | |