1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Tensor Compute Op.
22 * \file tensor_compute_op.cc
23 */
24#include <tvm/arith/analyzer.h>
25#include <tvm/runtime/registry.h>
26#include <tvm/te/operation.h>
27#include <tvm/tir/builtin.h>
28#include <tvm/tir/expr.h>
29#include <tvm/tir/stmt_functor.h>
30
31#include <unordered_set>
32
33#include "./compute_op.h"
34#include "./op_utils.h"
35
36namespace tvm {
37namespace te {
38using namespace tir;
39// TensorComputeOpNode
40TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
41 .set_dispatch<TensorComputeOpNode>([](const ObjectRef& node, ReprPrinter* p) {
42 auto* op = static_cast<const TensorComputeOpNode*>(node.get());
43 p->stream << "tensor_compute_op(" << op->name << ", " << op << ")";
44 });
45
46TVM_REGISTER_NODE_TYPE(TensorComputeOpNode);
47
48int TensorComputeOpNode::num_outputs() const {
49 return static_cast<int>(this->intrin->buffers.size() - this->inputs.size());
50}
51
52DataType TensorComputeOpNode::output_dtype(size_t i) const {
53 return this->intrin->buffers[this->inputs.size() + i]->dtype;
54}
55
56TensorComputeOp::TensorComputeOp(std::string name, std::string tag, Array<IterVar> axis,
57 Array<IterVar> reduce_axis, int schedulable_ndim,
58 TensorIntrin intrin, Array<Tensor> tensors, Array<Region> regions,
59 Array<PrimExpr> scalar_inputs) {
60 auto n = make_object<TensorComputeOpNode>();
61 n->name = std::move(name);
62 n->tag = std::move(tag);
63 n->axis = std::move(axis);
64 n->reduce_axis = std::move(reduce_axis);
65 n->schedulable_ndim = std::move(schedulable_ndim);
66 n->intrin = std::move(intrin);
67 n->inputs = std::move(tensors);
68 n->input_regions = std::move(regions);
69 n->scalar_inputs = std::move(scalar_inputs);
70 data_ = std::move(n);
71}
72
73TVM_REGISTER_GLOBAL("te.TensorComputeOp")
74 .set_body_typed([](std::string name, std::string tag, Array<IterVar> axis,
75 Array<IterVar> reduce_axis, int schedulable_ndim, TensorIntrin intrin,
76 Array<Tensor> tensors, Array<Region> regions,
77 Array<PrimExpr> scalar_inputs) {
78 return TensorComputeOp(name, tag, axis, reduce_axis, schedulable_ndim, intrin, tensors,
79 regions, scalar_inputs);
80 });
81
82Array<Tensor> TensorComputeOpNode::InputTensors() const { return inputs; }
83
84Operation TensorComputeOpNode::ReplaceInputs(const Operation& self,
85 const std::unordered_map<Tensor, Tensor>& rmap) const {
86 ICHECK_EQ(self.operator->(), this);
87 auto n = make_object<TensorComputeOpNode>(*this);
88 auto intrin = make_object<TensorIntrinNode>(*(this->intrin.operator->()));
89 intrin->body = ReplaceTensor(this->intrin->body, rmap);
90 if (intrin->reduce_init.defined()) {
91 intrin->reduce_init = ReplaceTensor(this->intrin->reduce_init, rmap);
92 }
93 if (intrin->reduce_update.defined()) {
94 intrin->reduce_update = ReplaceTensor(this->intrin->reduce_update, rmap);
95 }
96 for (size_t i = 0; i < n->inputs.size(); ++i) {
97 Tensor t = n->inputs[i];
98 if (rmap.count(t)) {
99 n->inputs.Set(i, rmap.at(t));
100 }
101 }
102
103 if (intrin->body.same_as(n->intrin->body) &&
104 intrin->reduce_init.same_as(n->intrin->reduce_init) &&
105 intrin->reduce_update.same_as(n->intrin->reduce_update) && inputs.same_as(n->inputs)) {
106 return self;
107 } else {
108 n->intrin = TensorIntrin(intrin);
109 return Operation(n);
110 }
111}
112
113void TensorComputeOpNode::PropBoundToInputs(
114 const Operation& self, arith::Analyzer* analyzer,
115 const std::unordered_map<const VarNode*, IntSet>& dom_map,
116 std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
117 for (size_t i = 0; i < this->inputs.size(); ++i) {
118 Tensor t = this->inputs[i];
119 Region region = input_regions[i];
120
121 auto it = out_dom_map->find(t);
122 if (it == out_dom_map->end()) continue;
123 TensorDom& dom = it->second;
124 for (size_t j = 0; j < t.ndim(); ++j) {
125 dom.data[j].emplace_back(EvalSet(region[j], dom_map));
126 }
127 }
128}
129
130size_t TensorComputeOpNode::num_schedulable_dims() const { return schedulable_ndim; }
131
132Stmt TensorComputeOpNode::BuildProvide(const Stage& stage,
133 const std::unordered_map<IterVar, Range>& dom_map,
134 bool debug_keep_trivial_loop) const {
135 ICHECK_EQ(stage->op.operator->(), this);
136
137 // Start bind data.
138 Stmt nop = Evaluate(0);
139 std::vector<Stmt> input_bind_nest, output_bind_nest;
140 Array<Tensor> inputs = this->InputTensors();
141
142 // input binding
143 size_t num_inputs = inputs.size();
144 for (size_t i = 0; i < num_inputs; ++i) {
145 Tensor tensor = inputs[i];
146 Region region = this->input_regions[i];
147 Buffer buffer = this->intrin->buffers[i];
148 Array<ObjectRef> bind_spec{buffer, tensor};
149
150 Array<PrimExpr> tuple;
151 for (size_t i = 0; i < region.size(); ++i) {
152 tuple.push_back(region[i]->min);
153 tuple.push_back(region[i]->extent);
154 }
155 input_bind_nest.emplace_back(
156 AttrStmt(bind_spec, tir::attr::buffer_bind_scope,
157 Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), nop));
158 }
159
160 // output binding
161 for (int i = 0; i < this->num_outputs(); ++i) {
162 Tensor tensor = stage->op.output(i);
163 Buffer buffer = this->intrin->buffers[num_inputs + i];
164 Array<ObjectRef> bind_spec{buffer, tensor};
165
166 Array<PrimExpr> tuple;
167 for (size_t i = 0; i < this->axis.size(); ++i) {
168 auto ivar = this->axis[i];
169 if (i < static_cast<size_t>(this->schedulable_ndim)) {
170 tuple.push_back(ivar->var);
171 tuple.push_back(1);
172 } else {
173 Range dom = ivar->dom;
174 tuple.push_back(dom->min);
175 tuple.push_back(dom->extent);
176 }
177 }
178
179 output_bind_nest.emplace_back(
180 AttrStmt(bind_spec, tir::attr::buffer_bind_scope,
181 Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), nop));
182 }
183
184 // Check variable remap
185 std::unordered_map<const VarNode*, PrimExpr> vmap;
186 tir::ArgBinder binder(&vmap);
187
188 // Map the expressions passed in the call to the TensorIntrin, to the placeholder
189 // variables
190 Array<PrimExpr> user_expr = this->scalar_inputs;
191 Array<Var> scalar_params = this->intrin->scalar_params;
192 Array<PrimExpr> sp_expr;
193 for (auto sp : scalar_params) {
194 PrimExpr esp = sp;
195 sp_expr.push_back(esp);
196 }
197 ICHECK_EQ(sp_expr.size(), user_expr.size());
198 // TODO(jdavies-huawei): what name should be used here?
199 binder.BindArray(sp_expr, user_expr, this->name);
200
201 size_t tloc = stage->leaf_iter_vars.size();
202 ComputeLoopNest n = ComputeLoopNest::Create(this, stage, dom_map, debug_keep_trivial_loop);
203
204 if (this->reduce_axis.size() == 0) {
205 std::vector<std::vector<Stmt>> nest(n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
206 nest.emplace_back(MakeIfNest(n.main_predicates));
207 ICHECK_EQ(n.init_predicates.size(), 0U);
208 ICHECK(this->intrin->body.defined())
209 << "Normal store op for intrin " << this << " is not defined";
210 Stmt body = MergeNest(output_bind_nest, this->intrin->body);
211 body = MergeNest(input_bind_nest, body);
212 body = tir::Substitute(body, vmap);
213 body = MergeNest(binder.asserts(), body);
214 body = te::Substitute(body, n.main_vmap);
215 Stmt ret = MergeNest(nest, body);
216 return ret;
217 } else {
218 // Need to split reduction
219 ICHECK(this->intrin->reduce_update.defined()) << "Reduction update op is not defined";
220 // Need init and update steps
221 ICHECK_NE(this->reduce_axis.size(), 0U);
222 std::vector<std::vector<Stmt>> common(n.main_nest.begin(),
223 n.main_nest.begin() + n.num_common_loop + 1);
224 std::vector<std::vector<Stmt>> update_nest(n.main_nest.begin() + n.num_common_loop + 1,
225 n.main_nest.begin() + tloc + 1);
226 update_nest.emplace_back(MakeIfNest(n.main_predicates));
227
228 if (this->intrin->reduce_init.defined()) {
229 // init nest
230 std::vector<std::vector<Stmt>> init_nest(n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
231 init_nest.emplace_back(MakeIfNest(n.init_predicates));
232 Stmt init = MergeNest(output_bind_nest, this->intrin->reduce_init);
233 init = te::Substitute(init, n.init_vmap);
234 init = MergeNest(init_nest, init);
235 // The update
236 Stmt update = MergeNest(output_bind_nest, this->intrin->reduce_update);
237 update = MergeNest(input_bind_nest, update);
238 update = tir::Substitute(update, vmap);
239 update = MergeNest(binder.asserts(), update);
240 update = te::Substitute(update, n.main_vmap);
241 update = MergeNest(update_nest, update);
242 return MergeNest(common, SeqStmt::Flatten(init, update));
243 } else {
244 // When init op is not available, use body op for reset in the first iter.
245 ICHECK(this->intrin->body.defined()) << "Normal body op is not defined";
246 Stmt update =
247 TransformUpdate(stage, dom_map, n, this->intrin->body, this->intrin->reduce_update);
248 update = MergeNest(output_bind_nest, update);
249 update = MergeNest(input_bind_nest, update);
250 update = tir::Substitute(update, vmap);
251 update = MergeNest(binder.asserts(), update);
252 update = te::Substitute(update, n.main_vmap);
253 update = MergeNest(update_nest, update);
254 return MergeNest(common, update);
255 }
256 }
257}
258} // namespace te
259} // namespace tvm
260