1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \brief Scan Operator. |
22 | * \file scan_op.cc |
23 | */ |
24 | #include <tvm/runtime/registry.h> |
25 | #include <tvm/te/operation.h> |
26 | #include <tvm/tir/expr.h> |
27 | |
28 | #include "../schedule/graph.h" |
29 | #include "op_utils.h" |
30 | |
31 | namespace tvm { |
32 | namespace te { |
33 | using namespace tir; |
34 | |
35 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
36 | .set_dispatch<ScanOpNode>([](const ObjectRef& node, ReprPrinter* p) { |
37 | auto* op = static_cast<const ScanOpNode*>(node.get()); |
38 | p->stream << "scan(" << op->name << ", " << op << ")" ; |
39 | }); |
40 | TVM_REGISTER_NODE_TYPE(ScanOpNode); |
41 | |
42 | int ScanOpNode::num_outputs() const { return static_cast<int>(update.size()); } |
43 | Array<IterVar> ScanOpNode::root_iter_vars() const { |
44 | Array<IterVar> ret{scan_axis}; |
45 | for (IterVar iv : spatial_axis_) { |
46 | ret.push_back(iv); |
47 | } |
48 | return ret; |
49 | } |
50 | |
51 | DataType ScanOpNode::output_dtype(size_t i) const { return update[i]->dtype; } |
52 | |
53 | Array<PrimExpr> ScanOpNode::output_shape(size_t i) const { |
54 | ICHECK_LT(i, state_placeholder.size()); |
55 | return state_placeholder[i]->shape; |
56 | } |
57 | |
58 | ScanOp::ScanOp(std::string name, std::string tag, Map<String, ObjectRef> attrs, IterVar axis, |
59 | Array<Tensor> init, Array<Tensor> update, Array<Tensor> state_placeholder, |
60 | Array<Tensor> inputs) { |
61 | if (!attrs.defined()) { |
62 | attrs = Map<String, ObjectRef>(); |
63 | } |
64 | auto n = make_object<ScanOpNode>(); |
65 | ICHECK_EQ(init.size(), update.size()); |
66 | ICHECK_EQ(init.size(), state_placeholder.size()); |
67 | arith::Analyzer analyzer; |
68 | auto prove_equal = [&](PrimExpr lhs, PrimExpr rhs) { |
69 | return is_zero(analyzer.Simplify(lhs - rhs)); |
70 | }; |
71 | |
72 | for (size_t i = 0; i < init.size(); ++i) { |
73 | ICHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype); |
74 | ICHECK_EQ(init[i]->dtype, update[i]->dtype); |
75 | ICHECK(prove_equal(init[i]->shape[0], axis->dom->min)) |
76 | << "init.shape[0] need to match scan_axis.dom.min" ; |
77 | ICHECK(prove_equal(state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent)) |
78 | << "state_placeholder.shape[0] need to match" |
79 | << " scan_axis.dom.min + scan_axis.dom.extent" ; |
80 | ICHECK_EQ(state_placeholder[i].ndim(), init[i].ndim()) |
81 | << "The dimension of init need to match state_placeholder" ; |
82 | ICHECK_EQ(update[i].ndim(), state_placeholder[i].ndim()) |
83 | << "The update.ndim need to be state_placeholder.ndim - 1" ; |
84 | for (size_t k = 0; k < update[i].ndim(); ++k) { |
85 | ICHECK(prove_equal(update[i]->shape[k], state_placeholder[i]->shape[k])); |
86 | if (k != 0) { |
87 | // setup spatial axis |
88 | std::ostringstream spatial_name; |
89 | spatial_name << name << ".out" << i << ".i" << k; |
90 | n->spatial_axis_.push_back(IterVar(Range::FromMinExtent(0, update[i]->shape[k]), |
91 | Var(spatial_name.str()), kOpaque)); |
92 | } |
93 | } |
94 | |
95 | for (size_t k = 1; k < init[i].ndim(); ++k) { |
96 | ICHECK(prove_equal(init[i]->shape[k], state_placeholder[i]->shape[k])); |
97 | } |
98 | } |
99 | n->name = std::move(name); |
100 | n->tag = std::move(tag); |
101 | n->attrs = std::move(attrs); |
102 | n->scan_axis = std::move(axis); |
103 | n->init = std::move(init); |
104 | n->update = std::move(update); |
105 | n->state_placeholder = std::move(state_placeholder); |
106 | n->inputs = std::move(inputs); |
107 | data_ = std::move(n); |
108 | } |
109 | |
110 | TVM_REGISTER_GLOBAL("te.ScanOp" ) |
111 | .set_body_typed([](std::string name, std::string tag, Map<String, ObjectRef> attrs, |
112 | IterVar axis, Array<Tensor> init, Array<Tensor> update, |
113 | Array<Tensor> state_placeholder, Array<Tensor> inputs) { |
114 | return ScanOp(name, tag, attrs, axis, init, update, state_placeholder, inputs); |
115 | }); |
116 | |
117 | Array<Tensor> scan(Array<Tensor> init, Array<Tensor> update, Array<Tensor> state_placeholder, |
118 | Array<Tensor> inputs, std::string name, std::string tag, |
119 | Map<String, ObjectRef> attrs) { |
120 | IterVar scan_axis = |
121 | IterVar(Range::FromMinExtent(init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]), |
122 | Var(name + ".idx" ), kOrdered); |
123 | Operation op = ScanOp(name, tag, attrs, scan_axis, init, update, state_placeholder, inputs); |
124 | Array<Tensor> res; |
125 | for (int i = 0; i < op->num_outputs(); ++i) { |
126 | res.push_back(op.output(i)); |
127 | } |
128 | return res; |
129 | } |
130 | |
131 | Array<Tensor> ScanOpNode::InputTensors() const { |
132 | Array<Tensor> ret; |
133 | for (Tensor t : init) { |
134 | ret.push_back(t); |
135 | } |
136 | for (Tensor t : update) { |
137 | ret.push_back(t); |
138 | } |
139 | return ret; |
140 | } |
141 | |
142 | Operation ScanOpNode::ReplaceInputs(const Operation& self, |
143 | const std::unordered_map<Tensor, Tensor>& rmap) const { |
144 | ICHECK_EQ(self.operator->(), this); |
145 | auto n = make_object<ScanOpNode>(*this); |
146 | for (size_t i = 0; i < n->init.size(); ++i) { |
147 | if (rmap.count(n->init[i])) { |
148 | n->init.Set(i, rmap.at(n->init[i])); |
149 | } |
150 | if (rmap.count(n->update[i])) { |
151 | n->update.Set(i, rmap.at(n->update[i])); |
152 | } |
153 | } |
154 | if (!n->init.same_as(init) || !n->update.same_as(update)) { |
155 | return Operation(n); |
156 | } else { |
157 | return self; |
158 | } |
159 | } |
160 | |
161 | void ScanOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, |
162 | const std::unordered_map<const VarNode*, IntSet>& dom_map, |
163 | std::unordered_map<Tensor, TensorDom>* out_dom_map) const { |
164 | ICHECK_EQ(self.operator->(), this); |
165 | for (size_t i = 0, sp_idx = 0; i < this->init.size(); ++i) { |
166 | TensorDom* init_dom = nullptr; |
167 | TensorDom* update_dom = nullptr; |
168 | if (out_dom_map->count(this->init[i])) { |
169 | init_dom = &out_dom_map->at(this->init[i]); |
170 | } |
171 | if (out_dom_map->count(this->update[i])) { |
172 | update_dom = &out_dom_map->at(this->update[i]); |
173 | } |
174 | // first dimension, always needed. |
175 | if (init_dom) { |
176 | init_dom->data[0].push_back( |
177 | IntSet::FromRange(Range::FromMinExtent(0, this->init[i]->shape[0]))); |
178 | } |
179 | if (update_dom) { |
180 | update_dom->data[0].push_back(dom_map.at(this->scan_axis->var.get())); |
181 | } |
182 | // The update dimensions |
183 | for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) { |
184 | IterVar sp_ax = this->spatial_axis_[sp_idx]; |
185 | if (init_dom) { |
186 | init_dom->data[k].push_back(dom_map.at(sp_ax->var.get())); |
187 | } |
188 | if (update_dom) { |
189 | update_dom->data[k].push_back(dom_map.at(sp_ax->var.get())); |
190 | } |
191 | } |
192 | } |
193 | } |
194 | |
195 | void ScanOpNode::GatherBound(const Operation& self, |
196 | const std::unordered_map<Tensor, TensorDom>& tensor_dom, |
197 | std::unordered_map<IterVar, Range>* out_dom_map) const { |
198 | ICHECK_EQ(self.operator->(), this); |
199 | ICHECK(!out_dom_map->count(this->scan_axis)); |
200 | std::vector<Tensor> output(this->num_outputs()); |
201 | for (size_t i = 0; i < output.size(); ++i) { |
202 | output[i] = self.output(i); |
203 | } |
204 | // Update for time axis. |
205 | std::vector<IntSet> time_dom; |
206 | for (size_t i = 0; i < output.size(); ++i) { |
207 | const TensorDom& d = tensor_dom.at(output[i]); |
208 | time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end()); |
209 | } |
210 | ICHECK(!out_dom_map->count(this->scan_axis)); |
211 | arith::Analyzer analyzer; |
212 | Range sdom = this->scan_axis->dom; |
213 | Range r = arith::Union(time_dom).CoverRange(sdom); |
214 | (*out_dom_map)[this->scan_axis] = |
215 | Range::FromMinExtent(sdom->min, analyzer.Simplify(r->extent + r->min - sdom->min)); |
216 | Map<IterVar, PrimExpr> fix_pt = ScanFixPointAnalysis(self); |
217 | // Update for spatial axis. |
218 | size_t sp_idx = 0; |
219 | for (size_t i = 0; i < output.size(); ++i) { |
220 | const TensorDom& d = tensor_dom.at(output[i]); |
221 | for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) { |
222 | IterVar sp_ax = this->spatial_axis_[sp_idx]; |
223 | ICHECK(!out_dom_map->count(sp_ax)); |
224 | ICHECK(fix_pt.count(sp_ax)); |
225 | if (fix_pt[sp_ax].as<tir::IntImmNode>()->value) { |
226 | // fix point, we can slice it. |
227 | (*out_dom_map)[sp_ax] = arith::Union(d.data[k]).CoverRange(sp_ax->dom); |
228 | } else { |
229 | // not a fix point, need to include everything. |
230 | (*out_dom_map)[sp_ax] = sp_ax->dom; |
231 | } |
232 | } |
233 | } |
234 | } |
235 | |
236 | Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map, |
237 | const Stmt& body, String storage_scope) const { |
238 | arith::Analyzer analyzer; |
239 | ICHECK_EQ(stage->op.get(), this); |
240 | Range sdom = dom_map.at(this->scan_axis); |
241 | Range tdom = Range::FromMinExtent(0, analyzer.Simplify(sdom->extent + sdom->min)); |
242 | Stmt ret = body; |
243 | size_t sp_idx = 0; |
244 | for (size_t i = 0; i < update.size(); ++i) { |
245 | Tensor t = stage->op.output(i); |
246 | ICHECK_EQ(static_cast<size_t>(t->value_index), i); |
247 | Region bounds; |
248 | bounds.push_back(tdom); |
249 | for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) { |
250 | IterVar sp_ax = this->spatial_axis_[sp_idx]; |
251 | bounds.push_back(dom_map.at(sp_ax)); |
252 | } |
253 | ret = tir::ProducerRealize(t, bounds, const_true(), ret, storage_scope); |
254 | } |
255 | return ret; |
256 | } |
257 | |
258 | Stmt ScanOpNode::BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map, |
259 | bool debug_keep_trivial_loop) const { |
260 | ICHECK_EQ(stage->op.operator->(), this); |
261 | Stmt provide = |
262 | AttrStmt(stage->op, tir::attr::scan_update_scope, this->scan_axis->var, Evaluate(0)); |
263 | Stmt init = AttrStmt(stage->op, tir::attr::scan_init_scope, 0, Evaluate(0)); |
264 | size_t begin_scan = 0; |
265 | for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { |
266 | if (stage->leaf_iter_vars[i]->iter_type == kThreadIndex) { |
267 | ICHECK_EQ(begin_scan, i); |
268 | begin_scan = i + 1; |
269 | } |
270 | } |
271 | std::unordered_map<IterVar, PrimExpr> vmap; |
272 | std::unordered_set<IterVar> empty; |
273 | auto nest = MakeLoopNest(stage, dom_map, 0, false, empty, &vmap, debug_keep_trivial_loop); |
274 | nest[begin_scan].push_back(init); |
275 | nest.push_back(MakeIfNest(MakeBoundCheck(stage, dom_map, vmap, false, empty))); |
276 | return MergeNest(nest, provide); |
277 | } |
278 | } // namespace te |
279 | } // namespace tvm |
280 | |