1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \brief Tensor Compute Op. |
22 | * \file tensor_compute_op.cc |
23 | */ |
24 | #include <tvm/arith/analyzer.h> |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/te/operation.h> |
27 | #include <tvm/tir/builtin.h> |
28 | #include <tvm/tir/expr.h> |
29 | #include <tvm/tir/stmt_functor.h> |
30 | |
31 | #include <unordered_set> |
32 | |
33 | #include "./compute_op.h" |
34 | #include "./op_utils.h" |
35 | |
36 | namespace tvm { |
37 | namespace te { |
38 | using namespace tir; |
39 | // TensorComputeOpNode |
40 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
41 | .set_dispatch<TensorComputeOpNode>([](const ObjectRef& node, ReprPrinter* p) { |
42 | auto* op = static_cast<const TensorComputeOpNode*>(node.get()); |
43 | p->stream << "tensor_compute_op(" << op->name << ", " << op << ")" ; |
44 | }); |
45 | |
46 | TVM_REGISTER_NODE_TYPE(TensorComputeOpNode); |
47 | |
48 | int TensorComputeOpNode::num_outputs() const { |
49 | return static_cast<int>(this->intrin->buffers.size() - this->inputs.size()); |
50 | } |
51 | |
52 | DataType TensorComputeOpNode::output_dtype(size_t i) const { |
53 | return this->intrin->buffers[this->inputs.size() + i]->dtype; |
54 | } |
55 | |
56 | TensorComputeOp::TensorComputeOp(std::string name, std::string tag, Array<IterVar> axis, |
57 | Array<IterVar> reduce_axis, int schedulable_ndim, |
58 | TensorIntrin intrin, Array<Tensor> tensors, Array<Region> regions, |
59 | Array<PrimExpr> scalar_inputs) { |
60 | auto n = make_object<TensorComputeOpNode>(); |
61 | n->name = std::move(name); |
62 | n->tag = std::move(tag); |
63 | n->axis = std::move(axis); |
64 | n->reduce_axis = std::move(reduce_axis); |
65 | n->schedulable_ndim = std::move(schedulable_ndim); |
66 | n->intrin = std::move(intrin); |
67 | n->inputs = std::move(tensors); |
68 | n->input_regions = std::move(regions); |
69 | n->scalar_inputs = std::move(scalar_inputs); |
70 | data_ = std::move(n); |
71 | } |
72 | |
73 | TVM_REGISTER_GLOBAL("te.TensorComputeOp" ) |
74 | .set_body_typed([](std::string name, std::string tag, Array<IterVar> axis, |
75 | Array<IterVar> reduce_axis, int schedulable_ndim, TensorIntrin intrin, |
76 | Array<Tensor> tensors, Array<Region> regions, |
77 | Array<PrimExpr> scalar_inputs) { |
78 | return TensorComputeOp(name, tag, axis, reduce_axis, schedulable_ndim, intrin, tensors, |
79 | regions, scalar_inputs); |
80 | }); |
81 | |
82 | Array<Tensor> TensorComputeOpNode::InputTensors() const { return inputs; } |
83 | |
84 | Operation TensorComputeOpNode::ReplaceInputs(const Operation& self, |
85 | const std::unordered_map<Tensor, Tensor>& rmap) const { |
86 | ICHECK_EQ(self.operator->(), this); |
87 | auto n = make_object<TensorComputeOpNode>(*this); |
88 | auto intrin = make_object<TensorIntrinNode>(*(this->intrin.operator->())); |
89 | intrin->body = ReplaceTensor(this->intrin->body, rmap); |
90 | if (intrin->reduce_init.defined()) { |
91 | intrin->reduce_init = ReplaceTensor(this->intrin->reduce_init, rmap); |
92 | } |
93 | if (intrin->reduce_update.defined()) { |
94 | intrin->reduce_update = ReplaceTensor(this->intrin->reduce_update, rmap); |
95 | } |
96 | for (size_t i = 0; i < n->inputs.size(); ++i) { |
97 | Tensor t = n->inputs[i]; |
98 | if (rmap.count(t)) { |
99 | n->inputs.Set(i, rmap.at(t)); |
100 | } |
101 | } |
102 | |
103 | if (intrin->body.same_as(n->intrin->body) && |
104 | intrin->reduce_init.same_as(n->intrin->reduce_init) && |
105 | intrin->reduce_update.same_as(n->intrin->reduce_update) && inputs.same_as(n->inputs)) { |
106 | return self; |
107 | } else { |
108 | n->intrin = TensorIntrin(intrin); |
109 | return Operation(n); |
110 | } |
111 | } |
112 | |
113 | void TensorComputeOpNode::PropBoundToInputs( |
114 | const Operation& self, arith::Analyzer* analyzer, |
115 | const std::unordered_map<const VarNode*, IntSet>& dom_map, |
116 | std::unordered_map<Tensor, TensorDom>* out_dom_map) const { |
117 | for (size_t i = 0; i < this->inputs.size(); ++i) { |
118 | Tensor t = this->inputs[i]; |
119 | Region region = input_regions[i]; |
120 | |
121 | auto it = out_dom_map->find(t); |
122 | if (it == out_dom_map->end()) continue; |
123 | TensorDom& dom = it->second; |
124 | for (size_t j = 0; j < t.ndim(); ++j) { |
125 | dom.data[j].emplace_back(EvalSet(region[j], dom_map)); |
126 | } |
127 | } |
128 | } |
129 | |
130 | size_t TensorComputeOpNode::num_schedulable_dims() const { return schedulable_ndim; } |
131 | |
132 | Stmt TensorComputeOpNode::BuildProvide(const Stage& stage, |
133 | const std::unordered_map<IterVar, Range>& dom_map, |
134 | bool debug_keep_trivial_loop) const { |
135 | ICHECK_EQ(stage->op.operator->(), this); |
136 | |
137 | // Start bind data. |
138 | Stmt nop = Evaluate(0); |
139 | std::vector<Stmt> input_bind_nest, output_bind_nest; |
140 | Array<Tensor> inputs = this->InputTensors(); |
141 | |
142 | // input binding |
143 | size_t num_inputs = inputs.size(); |
144 | for (size_t i = 0; i < num_inputs; ++i) { |
145 | Tensor tensor = inputs[i]; |
146 | Region region = this->input_regions[i]; |
147 | Buffer buffer = this->intrin->buffers[i]; |
148 | Array<ObjectRef> bind_spec{buffer, tensor}; |
149 | |
150 | Array<PrimExpr> tuple; |
151 | for (size_t i = 0; i < region.size(); ++i) { |
152 | tuple.push_back(region[i]->min); |
153 | tuple.push_back(region[i]->extent); |
154 | } |
155 | input_bind_nest.emplace_back( |
156 | AttrStmt(bind_spec, tir::attr::buffer_bind_scope, |
157 | Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), nop)); |
158 | } |
159 | |
160 | // output binding |
161 | for (int i = 0; i < this->num_outputs(); ++i) { |
162 | Tensor tensor = stage->op.output(i); |
163 | Buffer buffer = this->intrin->buffers[num_inputs + i]; |
164 | Array<ObjectRef> bind_spec{buffer, tensor}; |
165 | |
166 | Array<PrimExpr> tuple; |
167 | for (size_t i = 0; i < this->axis.size(); ++i) { |
168 | auto ivar = this->axis[i]; |
169 | if (i < static_cast<size_t>(this->schedulable_ndim)) { |
170 | tuple.push_back(ivar->var); |
171 | tuple.push_back(1); |
172 | } else { |
173 | Range dom = ivar->dom; |
174 | tuple.push_back(dom->min); |
175 | tuple.push_back(dom->extent); |
176 | } |
177 | } |
178 | |
179 | output_bind_nest.emplace_back( |
180 | AttrStmt(bind_spec, tir::attr::buffer_bind_scope, |
181 | Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), nop)); |
182 | } |
183 | |
184 | // Check variable remap |
185 | std::unordered_map<const VarNode*, PrimExpr> vmap; |
186 | tir::ArgBinder binder(&vmap); |
187 | |
188 | // Map the expressions passed in the call to the TensorIntrin, to the placeholder |
189 | // variables |
190 | Array<PrimExpr> user_expr = this->scalar_inputs; |
191 | Array<Var> scalar_params = this->intrin->scalar_params; |
192 | Array<PrimExpr> sp_expr; |
193 | for (auto sp : scalar_params) { |
194 | PrimExpr esp = sp; |
195 | sp_expr.push_back(esp); |
196 | } |
197 | ICHECK_EQ(sp_expr.size(), user_expr.size()); |
198 | // TODO(jdavies-huawei): what name should be used here? |
199 | binder.BindArray(sp_expr, user_expr, this->name); |
200 | |
201 | size_t tloc = stage->leaf_iter_vars.size(); |
202 | ComputeLoopNest n = ComputeLoopNest::Create(this, stage, dom_map, debug_keep_trivial_loop); |
203 | |
204 | if (this->reduce_axis.size() == 0) { |
205 | std::vector<std::vector<Stmt>> nest(n.main_nest.begin(), n.main_nest.begin() + tloc + 1); |
206 | nest.emplace_back(MakeIfNest(n.main_predicates)); |
207 | ICHECK_EQ(n.init_predicates.size(), 0U); |
208 | ICHECK(this->intrin->body.defined()) |
209 | << "Normal store op for intrin " << this << " is not defined" ; |
210 | Stmt body = MergeNest(output_bind_nest, this->intrin->body); |
211 | body = MergeNest(input_bind_nest, body); |
212 | body = tir::Substitute(body, vmap); |
213 | body = MergeNest(binder.asserts(), body); |
214 | body = te::Substitute(body, n.main_vmap); |
215 | Stmt ret = MergeNest(nest, body); |
216 | return ret; |
217 | } else { |
218 | // Need to split reduction |
219 | ICHECK(this->intrin->reduce_update.defined()) << "Reduction update op is not defined" ; |
220 | // Need init and update steps |
221 | ICHECK_NE(this->reduce_axis.size(), 0U); |
222 | std::vector<std::vector<Stmt>> common(n.main_nest.begin(), |
223 | n.main_nest.begin() + n.num_common_loop + 1); |
224 | std::vector<std::vector<Stmt>> update_nest(n.main_nest.begin() + n.num_common_loop + 1, |
225 | n.main_nest.begin() + tloc + 1); |
226 | update_nest.emplace_back(MakeIfNest(n.main_predicates)); |
227 | |
228 | if (this->intrin->reduce_init.defined()) { |
229 | // init nest |
230 | std::vector<std::vector<Stmt>> init_nest(n.init_nest.begin(), n.init_nest.begin() + tloc + 1); |
231 | init_nest.emplace_back(MakeIfNest(n.init_predicates)); |
232 | Stmt init = MergeNest(output_bind_nest, this->intrin->reduce_init); |
233 | init = te::Substitute(init, n.init_vmap); |
234 | init = MergeNest(init_nest, init); |
235 | // The update |
236 | Stmt update = MergeNest(output_bind_nest, this->intrin->reduce_update); |
237 | update = MergeNest(input_bind_nest, update); |
238 | update = tir::Substitute(update, vmap); |
239 | update = MergeNest(binder.asserts(), update); |
240 | update = te::Substitute(update, n.main_vmap); |
241 | update = MergeNest(update_nest, update); |
242 | return MergeNest(common, SeqStmt::Flatten(init, update)); |
243 | } else { |
244 | // When init op is not available, use body op for reset in the first iter. |
245 | ICHECK(this->intrin->body.defined()) << "Normal body op is not defined" ; |
246 | Stmt update = |
247 | TransformUpdate(stage, dom_map, n, this->intrin->body, this->intrin->reduce_update); |
248 | update = MergeNest(output_bind_nest, update); |
249 | update = MergeNest(input_bind_nest, update); |
250 | update = tir::Substitute(update, vmap); |
251 | update = MergeNest(binder.asserts(), update); |
252 | update = te::Substitute(update, n.main_vmap); |
253 | update = MergeNest(update_nest, update); |
254 | return MergeNest(common, update); |
255 | } |
256 | } |
257 | } |
258 | } // namespace te |
259 | } // namespace tvm |
260 | |