1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/te/tensor.h
22 * \brief Dataflow tensor object
23 */
24#ifndef TVM_TE_TENSOR_H_
25#define TVM_TE_TENSOR_H_
26
27#include <tvm/arith/bound.h>
28#include <tvm/tir/expr.h>
29#include <tvm/tir/op.h>
30
31#include <string>
32#include <type_traits>
33#include <utility>
34#include <vector>
35
36namespace tvm {
37namespace te {
38
39using arith::IntSet;
40using namespace tvm::tir;
41
42// internal node container for Operation
43class OperationNode;
44class Tensor;
45
46/*! \brief Operation that produces tensors */
47class Operation : public ObjectRef {
48 public:
49 /*! \brief default constructor */
50 Operation() {}
51 explicit Operation(ObjectPtr<Object> n) : ObjectRef(n) {}
52 /*!
53 * \brief access the internal node container
54 * \return the pointer to the internal node container
55 */
56 inline const OperationNode* operator->() const;
57 /*!
58 * \brief get the i-th output of the operation.
59 * \param i the output index.
60 * \return The i-th output.
61 */
62 TVM_DLL Tensor output(size_t i) const;
63 /*! \brief specify container node */
64 using ContainerType = OperationNode;
65};
66
67/*! \brief Node to represent a tensor */
68class TensorNode : public DataProducerNode {
69 public:
70 /*! \brief The shape of the tensor */
71 Array<PrimExpr> shape;
72 /*! \brief data type in the content of the tensor */
73 DataType dtype;
74 /*! \brief the source operation, can be None */
75 Operation op;
76 /*! \brief the output index from source operation */
77 int value_index{0};
78 /*! \brief constructor */
79 TensorNode() {}
80
81 void VisitAttrs(AttrVisitor* v) {
82 v->Visit("shape", &shape);
83 v->Visit("dtype", &dtype);
84 v->Visit("op", &op);
85 v->Visit("value_index", &value_index);
86 }
87
88 Array<PrimExpr> GetShape() const final { return shape; }
89
90 DataType GetDataType() const final { return dtype; }
91
92 TVM_DLL String GetNameHint() const final;
93
94 static constexpr const char* _type_key = "Tensor";
95 TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, DataProducerNode);
96};
97
98/*!
99 * \brief Tensor structure representing a possible input,
100 * or intermediate computation result.
101 */
102class Tensor : public DataProducer {
103 private:
104 /*!
105 * \brief Helper for indexing operations into tensors
106 * \param indices The indices
107 * \param support_negative_indices Whether to normalize indices in the case of negative indices.
108 * \return the result expression representing tensor read.
109 */
110 inline PrimExpr IndexTensor(Array<PrimExpr> indices, bool support_negative_indices) const;
111
112 public:
113 TVM_DLL Tensor(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index);
114 /*!
115 * \brief check if two tensors equals each other.
116 * \param other tensor to be checked.
117 * \return whether the two tensors equals each other.
118 */
119 inline bool operator==(const Tensor& other) const;
120 /*!
121 * \brief check if two tensors are different.
122 * \param other tensor to be checked.
123 * \return whether the two tensors are different.
124 */
125 inline bool operator!=(const Tensor& other) const;
126 /*! \return The dimension of the tensor */
127 inline size_t ndim() const;
128 /*!
129 * \brief Take elements from the tensor
130 * \param args The indices
131 * \return the result expression representing tensor read.
132 */
133 template <typename... Args>
134 inline PrimExpr operator()(Args&&... args) const {
135 Array<PrimExpr> indices{std::forward<Args>(args)...};
136 return operator()(indices);
137 }
138 /*!
139 * \brief Take elements from the tensor
140 * \param indices the indices.
141 * \return the result expression representing tensor read.
142 */
143 TVM_DLL PrimExpr operator()(Array<PrimExpr> indices) const;
144 /*!
145 * \brief Take elements from the tensor
146 * \param indices the indices.
147 * \return the result expression representing tensor read.
148 */
149 TVM_DLL PrimExpr operator()(Array<Var> indices) const;
150 /*!
151 * \brief Take elements from the tensor with support for negative indices.
152 * \param args The indices
153 * \return the result expression representing tensor read.
154 */
155 template <typename... Args>
156 TVM_DLL PrimExpr IndexWithNegativeIndices(Args&&... args) const {
157 Array<PrimExpr> indices{std::forward<Args>(args)...};
158 return IndexWithNegativeIndices(indices);
159 }
160 /*!
161 * \brief Take elements from the tensor with support for negative indices.
162 * \param indices the indices.
163 * \return the result expression representing tensor read.
164 */
165 TVM_DLL PrimExpr IndexWithNegativeIndices(Array<PrimExpr> indices) const;
166 /*!
167 * \brief Take elements from the tensor with support for negative indices.
168 * \param indices the indices.
169 * \return the result expression representing tensor read.
170 */
171 TVM_DLL PrimExpr IndexWithNegativeIndices(Array<Var> indices) const;
172
173 /*!
174 * \brief data structure to represent a slice that fixes first k coordinates.
175 * This is used to enable syntax sugar of Tensor[x][y][z] to get the element.
176 */
177 class Slice {
178 public:
179 // construct via tensor and indices
180 Slice(const Tensor& tensor, std::vector<PrimExpr> indices)
181 : tensor_(tensor), indices_(indices) {}
182 /*!
183 * \brief get i-th slice from the current slice.
184 * \param i the index of the coordinate
185 * \return the subsequent slice.
186 */
187 inline Slice operator[](PrimExpr i) {
188 std::vector<PrimExpr> other = indices_;
189 other.emplace_back(i);
190 return Slice(tensor_, other);
191 }
192 /*!
193 * \brief Convert slice to expression.
194 * This is only valid when all the coordinates are fully specified.
195 * \return the corresponding expression of this slice.
196 */
197 inline operator PrimExpr() const { return tensor_(indices_); }
198
199 private:
200 const Tensor& tensor_;
201 std::vector<PrimExpr> indices_;
202 };
203 /*!
204 * \brief get i-th slice from the current Tensor.
205 * \param i the index of the coordinate
206 * \return the subsequent slice.
207 */
208 inline Slice operator[](PrimExpr i) const { return Slice(*this, {i}); }
209
210 TVM_DEFINE_OBJECT_REF_METHODS(Tensor, DataProducer, TensorNode);
211};
212
213// Implementations of inline functions
214inline size_t Tensor::ndim() const { return (*this)->shape.size(); }
215
216inline bool Tensor::operator==(const Tensor& other) const {
217 if (get() == other.get()) return true;
218 if (get() == nullptr || other.get() == nullptr) return false;
219 if ((*this)->op.defined() || other->op.defined()) {
220 return (*this)->op == other->op && (*this)->value_index == other->value_index;
221 } else {
222 return false;
223 }
224}
225
226inline bool Tensor::operator!=(const Tensor& other) const { return !(*this == other); }
227
228// macro to turn every operation of slice to expression
229#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \
230 inline PrimExpr operator Op(const Tensor::Slice& a) { return Op a.operator PrimExpr(); }
231
232#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \
233 template <typename T> \
234 inline PrimExpr operator Op(const Tensor::Slice& a, const T& b) { \
235 return a.operator PrimExpr() Op b; \
236 } \
237 template <typename T> \
238 inline PrimExpr operator Op(const T& a, const Tensor::Slice& b) { \
239 return a Op b.operator PrimExpr(); \
240 } \
241 inline PrimExpr operator Op(const Tensor::Slice& a, const Tensor::Slice& b) { \
242 return a.operator PrimExpr() Op b.operator PrimExpr(); \
243 }
244
245DEFINE_OVERLOAD_SLICE_UNARY_OP(!);
246DEFINE_OVERLOAD_SLICE_UNARY_OP(-);
247DEFINE_OVERLOAD_SLICE_BINARY_OP(+);
248DEFINE_OVERLOAD_SLICE_BINARY_OP(-);
249DEFINE_OVERLOAD_SLICE_BINARY_OP(*);
250DEFINE_OVERLOAD_SLICE_BINARY_OP(==);
251DEFINE_OVERLOAD_SLICE_BINARY_OP(<=);
252DEFINE_OVERLOAD_SLICE_BINARY_OP(>=);
253DEFINE_OVERLOAD_SLICE_BINARY_OP(!=);
254DEFINE_OVERLOAD_SLICE_BINARY_OP(&&);
255DEFINE_OVERLOAD_SLICE_BINARY_OP(||);
256DEFINE_OVERLOAD_SLICE_BINARY_OP(>>);
257DEFINE_OVERLOAD_SLICE_BINARY_OP(<<);
258DEFINE_OVERLOAD_SLICE_BINARY_OP(>); // NOLINT(*)
259DEFINE_OVERLOAD_SLICE_BINARY_OP(<); // NOLINT(*)
260
261} // namespace te
262} // namespace tvm
263
264namespace std {
265template <>
266struct hash<::tvm::te::Operation> : public ::tvm::ObjectPtrHash {};
267
268template <>
269struct hash<::tvm::te::Tensor> {
270 std::size_t operator()(const ::tvm::te::Tensor& k) const {
271 ::tvm::ObjectPtrHash hasher;
272 if (k.defined() && k->op.defined()) {
273 return hasher(k->op);
274 } else {
275 return hasher(k);
276 }
277 }
278};
279} // namespace std
280#endif // TVM_TE_TENSOR_H_
281