1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tensor.cc
22 */
23#include <tvm/runtime/registry.h>
24#include <tvm/te/operation.h>
25#include <tvm/te/tensor.h>
26#include <tvm/te/tensor_intrin.h>
27
28#include <memory>
29
30namespace tvm {
31namespace te {
32
33IterVar thread_axis(Range dom, std::string tag) {
34 return IterVar(dom, Var(tag, dom.defined() ? dom->extent.dtype() : DataType::Int(32)),
35 kThreadIndex, tag);
36}
37
38IterVar reduce_axis(Range dom, std::string name) {
39 return IterVar(dom, Var(name, dom->extent.dtype()), kCommReduce);
40}
41
42Var var(std::string name_hint, DataType t) { return Var(name_hint, t); }
43
44// Tensor
45inline PrimExpr Tensor::IndexTensor(Array<PrimExpr> indices, bool support_negative_indices) const {
46 Array<PrimExpr> shape = (*this)->shape;
47
48 if (shape.size() != 0) {
49 ICHECK_EQ(shape.size(), indices.size())
50 << "Tensor dimension mismatch in read "
51 << "ndim = " << ndim() << ", indices.size=" << indices.size();
52 }
53
54 if (support_negative_indices) {
55 for (size_t i = 0; i < shape.size(); i++) {
56 PrimExpr new_index =
57 Select(indices[i] < make_const(indices[i]->dtype, 0), indices[i] + shape[i], indices[i]);
58 indices.Set(i, new_index);
59 }
60 }
61 return ProducerLoad((*this), indices);
62}
63
64PrimExpr Tensor::operator()(Array<Var> indices) const {
65 Array<PrimExpr> arr(indices.begin(), indices.end());
66 return operator()(arr);
67}
68
69PrimExpr Tensor::operator()(Array<PrimExpr> indices) const { return IndexTensor(indices, false); }
70
71PrimExpr Tensor::IndexWithNegativeIndices(Array<Var> indices) const {
72 Array<PrimExpr> arr(indices.begin(), indices.end());
73 return IndexWithNegativeIndices(arr);
74}
75
76PrimExpr Tensor::IndexWithNegativeIndices(Array<PrimExpr> indices) const {
77 return IndexTensor(indices, true);
78}
79
80String TensorNode::GetNameHint() const {
81 return op->num_outputs() == 1 ? op->name : (op->name + ".v" + std::to_string(value_index));
82}
83
84Tensor Operation::output(size_t i) const {
85 auto node = make_object<TensorNode>();
86 node->op = *this;
87 node->value_index = i;
88 node->dtype = (*this)->output_dtype(i);
89 node->shape = (*this)->output_shape(i);
90 return Tensor(node);
91}
92
93Tensor::Tensor(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index) {
94 auto n = make_object<TensorNode>();
95 n->shape = std::move(shape);
96 n->dtype = dtype;
97 n->op = op;
98 n->value_index = value_index;
99 data_ = std::move(n);
100}
101
102TVM_REGISTER_GLOBAL("te.Tensor")
103 .set_body_typed([](Array<PrimExpr> shape, DataType dtype, Operation op, int value_index) {
104 return Tensor(shape, dtype, op, value_index);
105 });
106
107TVM_REGISTER_NODE_TYPE(TensorNode);
108
109TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
110 .set_dispatch<TensorNode>([](const ObjectRef& node, ReprPrinter* p) {
111 auto* t = static_cast<const TensorNode*>(node.get());
112 p->stream << "Tensor(shape=" << t->shape << ", op.name=" << t->op->name << ')';
113 });
114
115// TensorIntrin
116TensorIntrin::TensorIntrin(std::string name, Operation op, Array<Tensor> inputs,
117 Array<Buffer> buffers, Array<Var> scalar_params, Stmt body,
118 Stmt reduce_init, Stmt reduce_update) {
119 auto n = make_object<TensorIntrinNode>();
120 n->name = std::move(name);
121 n->op = std::move(op);
122 n->inputs = std::move(inputs);
123 n->buffers = std::move(buffers);
124 n->scalar_params = std::move(scalar_params);
125 n->body = std::move(body);
126 n->reduce_init = std::move(reduce_init);
127 n->reduce_update = std::move(reduce_update);
128 data_ = std::move(n);
129}
130
131TVM_REGISTER_GLOBAL("te.TensorIntrin")
132 .set_body_typed([](std::string name, Operation op, Array<Tensor> inputs, Array<Buffer> buffers,
133 Array<Var> scalar_params, Stmt body, Stmt reduce_init, Stmt reduce_update) {
134 return TensorIntrin(name, op, inputs, buffers, scalar_params, body, reduce_init,
135 reduce_update);
136 });
137
138TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
139
140TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
141 .set_dispatch<TensorIntrinNode>([](const ObjectRef& node, ReprPrinter* p) {
142 auto* op = static_cast<const TensorIntrinNode*>(node.get());
143 p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")";
144 });
145
146// TensorIntrinCall
147TensorIntrinCall::TensorIntrinCall(TensorIntrin intrin, Array<Tensor> tensors,
148 Array<Region> regions, Array<IterVar> reduce_axis,
149 Array<PrimExpr> scalar_inputs) {
150 auto n = make_object<TensorIntrinCallNode>();
151 n->intrin = std::move(intrin);
152 n->tensors = std::move(tensors);
153 n->regions = std::move(regions);
154 n->reduce_axis = std::move(reduce_axis);
155 n->scalar_inputs = std::move(scalar_inputs);
156 data_ = std::move(n);
157}
158
159TVM_REGISTER_GLOBAL("te.TensorIntrinCall")
160 .set_body_typed([](TensorIntrin intrin, Array<Tensor> tensors, Array<Region> regions,
161 Array<IterVar> reduce_axis, Array<PrimExpr> scalar_inputs) {
162 return TensorIntrinCall(intrin, tensors, regions, reduce_axis, scalar_inputs);
163 });
164
165TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
166 .set_dispatch<TensorIntrinCallNode>([](const ObjectRef& node, ReprPrinter* p) {
167 auto* n = static_cast<const TensorIntrinCallNode*>(node.get());
168 p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")";
169 });
170
171TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode);
172
173// Other tensor ops.
174TVM_REGISTER_GLOBAL("te.TensorEqual").set_body_method(&Tensor::operator==);
175
176TVM_REGISTER_GLOBAL("te.TensorHash").set_body_typed([](Tensor tensor) -> int64_t {
177 return static_cast<int64_t>(std::hash<Tensor>()(tensor));
178});
179
180TVM_REGISTER_GLOBAL("te.OpGetOutput").set_body_typed([](Operation op, int64_t output) {
181 return op.output(static_cast<size_t>(output));
182});
183
184TVM_REGISTER_GLOBAL("te.OpNumOutputs").set_body_method<Operation>(&OperationNode::num_outputs);
185
186TVM_REGISTER_GLOBAL("te.OpInputTensors").set_body_method<Operation>(&OperationNode::InputTensors);
187
188} // namespace te
189} // namespace tvm
190