1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tensor.cc |
22 | */ |
23 | #include <tvm/runtime/registry.h> |
24 | #include <tvm/te/operation.h> |
25 | #include <tvm/te/tensor.h> |
26 | #include <tvm/te/tensor_intrin.h> |
27 | |
28 | #include <memory> |
29 | |
30 | namespace tvm { |
31 | namespace te { |
32 | |
33 | IterVar thread_axis(Range dom, std::string tag) { |
34 | return IterVar(dom, Var(tag, dom.defined() ? dom->extent.dtype() : DataType::Int(32)), |
35 | kThreadIndex, tag); |
36 | } |
37 | |
38 | IterVar reduce_axis(Range dom, std::string name) { |
39 | return IterVar(dom, Var(name, dom->extent.dtype()), kCommReduce); |
40 | } |
41 | |
42 | Var var(std::string name_hint, DataType t) { return Var(name_hint, t); } |
43 | |
44 | // Tensor |
45 | inline PrimExpr Tensor::IndexTensor(Array<PrimExpr> indices, bool support_negative_indices) const { |
46 | Array<PrimExpr> shape = (*this)->shape; |
47 | |
48 | if (shape.size() != 0) { |
49 | ICHECK_EQ(shape.size(), indices.size()) |
50 | << "Tensor dimension mismatch in read " |
51 | << "ndim = " << ndim() << ", indices.size=" << indices.size(); |
52 | } |
53 | |
54 | if (support_negative_indices) { |
55 | for (size_t i = 0; i < shape.size(); i++) { |
56 | PrimExpr new_index = |
57 | Select(indices[i] < make_const(indices[i]->dtype, 0), indices[i] + shape[i], indices[i]); |
58 | indices.Set(i, new_index); |
59 | } |
60 | } |
61 | return ProducerLoad((*this), indices); |
62 | } |
63 | |
64 | PrimExpr Tensor::operator()(Array<Var> indices) const { |
65 | Array<PrimExpr> arr(indices.begin(), indices.end()); |
66 | return operator()(arr); |
67 | } |
68 | |
69 | PrimExpr Tensor::operator()(Array<PrimExpr> indices) const { return IndexTensor(indices, false); } |
70 | |
71 | PrimExpr Tensor::IndexWithNegativeIndices(Array<Var> indices) const { |
72 | Array<PrimExpr> arr(indices.begin(), indices.end()); |
73 | return IndexWithNegativeIndices(arr); |
74 | } |
75 | |
76 | PrimExpr Tensor::IndexWithNegativeIndices(Array<PrimExpr> indices) const { |
77 | return IndexTensor(indices, true); |
78 | } |
79 | |
80 | String TensorNode::GetNameHint() const { |
81 | return op->num_outputs() == 1 ? op->name : (op->name + ".v" + std::to_string(value_index)); |
82 | } |
83 | |
84 | Tensor Operation::output(size_t i) const { |
85 | auto node = make_object<TensorNode>(); |
86 | node->op = *this; |
87 | node->value_index = i; |
88 | node->dtype = (*this)->output_dtype(i); |
89 | node->shape = (*this)->output_shape(i); |
90 | return Tensor(node); |
91 | } |
92 | |
93 | Tensor::Tensor(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index) { |
94 | auto n = make_object<TensorNode>(); |
95 | n->shape = std::move(shape); |
96 | n->dtype = dtype; |
97 | n->op = op; |
98 | n->value_index = value_index; |
99 | data_ = std::move(n); |
100 | } |
101 | |
102 | TVM_REGISTER_GLOBAL("te.Tensor" ) |
103 | .set_body_typed([](Array<PrimExpr> shape, DataType dtype, Operation op, int value_index) { |
104 | return Tensor(shape, dtype, op, value_index); |
105 | }); |
106 | |
107 | TVM_REGISTER_NODE_TYPE(TensorNode); |
108 | |
109 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
110 | .set_dispatch<TensorNode>([](const ObjectRef& node, ReprPrinter* p) { |
111 | auto* t = static_cast<const TensorNode*>(node.get()); |
112 | p->stream << "Tensor(shape=" << t->shape << ", op.name=" << t->op->name << ')'; |
113 | }); |
114 | |
115 | // TensorIntrin |
116 | TensorIntrin::TensorIntrin(std::string name, Operation op, Array<Tensor> inputs, |
117 | Array<Buffer> buffers, Array<Var> scalar_params, Stmt body, |
118 | Stmt reduce_init, Stmt reduce_update) { |
119 | auto n = make_object<TensorIntrinNode>(); |
120 | n->name = std::move(name); |
121 | n->op = std::move(op); |
122 | n->inputs = std::move(inputs); |
123 | n->buffers = std::move(buffers); |
124 | n->scalar_params = std::move(scalar_params); |
125 | n->body = std::move(body); |
126 | n->reduce_init = std::move(reduce_init); |
127 | n->reduce_update = std::move(reduce_update); |
128 | data_ = std::move(n); |
129 | } |
130 | |
131 | TVM_REGISTER_GLOBAL("te.TensorIntrin" ) |
132 | .set_body_typed([](std::string name, Operation op, Array<Tensor> inputs, Array<Buffer> buffers, |
133 | Array<Var> scalar_params, Stmt body, Stmt reduce_init, Stmt reduce_update) { |
134 | return TensorIntrin(name, op, inputs, buffers, scalar_params, body, reduce_init, |
135 | reduce_update); |
136 | }); |
137 | |
138 | TVM_REGISTER_NODE_TYPE(TensorIntrinNode); |
139 | |
140 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
141 | .set_dispatch<TensorIntrinNode>([](const ObjectRef& node, ReprPrinter* p) { |
142 | auto* op = static_cast<const TensorIntrinNode*>(node.get()); |
143 | p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")" ; |
144 | }); |
145 | |
146 | // TensorIntrinCall |
147 | TensorIntrinCall::TensorIntrinCall(TensorIntrin intrin, Array<Tensor> tensors, |
148 | Array<Region> regions, Array<IterVar> reduce_axis, |
149 | Array<PrimExpr> scalar_inputs) { |
150 | auto n = make_object<TensorIntrinCallNode>(); |
151 | n->intrin = std::move(intrin); |
152 | n->tensors = std::move(tensors); |
153 | n->regions = std::move(regions); |
154 | n->reduce_axis = std::move(reduce_axis); |
155 | n->scalar_inputs = std::move(scalar_inputs); |
156 | data_ = std::move(n); |
157 | } |
158 | |
159 | TVM_REGISTER_GLOBAL("te.TensorIntrinCall" ) |
160 | .set_body_typed([](TensorIntrin intrin, Array<Tensor> tensors, Array<Region> regions, |
161 | Array<IterVar> reduce_axis, Array<PrimExpr> scalar_inputs) { |
162 | return TensorIntrinCall(intrin, tensors, regions, reduce_axis, scalar_inputs); |
163 | }); |
164 | |
165 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
166 | .set_dispatch<TensorIntrinCallNode>([](const ObjectRef& node, ReprPrinter* p) { |
167 | auto* n = static_cast<const TensorIntrinCallNode*>(node.get()); |
168 | p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")" ; |
169 | }); |
170 | |
171 | TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode); |
172 | |
173 | // Other tensor ops. |
174 | TVM_REGISTER_GLOBAL("te.TensorEqual" ).set_body_method(&Tensor::operator==); |
175 | |
176 | TVM_REGISTER_GLOBAL("te.TensorHash" ).set_body_typed([](Tensor tensor) -> int64_t { |
177 | return static_cast<int64_t>(std::hash<Tensor>()(tensor)); |
178 | }); |
179 | |
180 | TVM_REGISTER_GLOBAL("te.OpGetOutput" ).set_body_typed([](Operation op, int64_t output) { |
181 | return op.output(static_cast<size_t>(output)); |
182 | }); |
183 | |
184 | TVM_REGISTER_GLOBAL("te.OpNumOutputs" ).set_body_method<Operation>(&OperationNode::num_outputs); |
185 | |
186 | TVM_REGISTER_GLOBAL("te.OpInputTensors" ).set_body_method<Operation>(&OperationNode::InputTensors); |
187 | |
188 | } // namespace te |
189 | } // namespace tvm |
190 | |