1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Scan Operator.
22 * \file scan_op.cc
23 */
24#include <tvm/runtime/registry.h>
25#include <tvm/te/operation.h>
26#include <tvm/tir/expr.h>
27
28#include "../schedule/graph.h"
29#include "op_utils.h"
30
31namespace tvm {
32namespace te {
33using namespace tir;
34
35TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
36 .set_dispatch<ScanOpNode>([](const ObjectRef& node, ReprPrinter* p) {
37 auto* op = static_cast<const ScanOpNode*>(node.get());
38 p->stream << "scan(" << op->name << ", " << op << ")";
39 });
40TVM_REGISTER_NODE_TYPE(ScanOpNode);
41
42int ScanOpNode::num_outputs() const { return static_cast<int>(update.size()); }
43Array<IterVar> ScanOpNode::root_iter_vars() const {
44 Array<IterVar> ret{scan_axis};
45 for (IterVar iv : spatial_axis_) {
46 ret.push_back(iv);
47 }
48 return ret;
49}
50
51DataType ScanOpNode::output_dtype(size_t i) const { return update[i]->dtype; }
52
53Array<PrimExpr> ScanOpNode::output_shape(size_t i) const {
54 ICHECK_LT(i, state_placeholder.size());
55 return state_placeholder[i]->shape;
56}
57
58ScanOp::ScanOp(std::string name, std::string tag, Map<String, ObjectRef> attrs, IterVar axis,
59 Array<Tensor> init, Array<Tensor> update, Array<Tensor> state_placeholder,
60 Array<Tensor> inputs) {
61 if (!attrs.defined()) {
62 attrs = Map<String, ObjectRef>();
63 }
64 auto n = make_object<ScanOpNode>();
65 ICHECK_EQ(init.size(), update.size());
66 ICHECK_EQ(init.size(), state_placeholder.size());
67 arith::Analyzer analyzer;
68 auto prove_equal = [&](PrimExpr lhs, PrimExpr rhs) {
69 return is_zero(analyzer.Simplify(lhs - rhs));
70 };
71
72 for (size_t i = 0; i < init.size(); ++i) {
73 ICHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype);
74 ICHECK_EQ(init[i]->dtype, update[i]->dtype);
75 ICHECK(prove_equal(init[i]->shape[0], axis->dom->min))
76 << "init.shape[0] need to match scan_axis.dom.min";
77 ICHECK(prove_equal(state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent))
78 << "state_placeholder.shape[0] need to match"
79 << " scan_axis.dom.min + scan_axis.dom.extent";
80 ICHECK_EQ(state_placeholder[i].ndim(), init[i].ndim())
81 << "The dimension of init need to match state_placeholder";
82 ICHECK_EQ(update[i].ndim(), state_placeholder[i].ndim())
83 << "The update.ndim need to be state_placeholder.ndim - 1";
84 for (size_t k = 0; k < update[i].ndim(); ++k) {
85 ICHECK(prove_equal(update[i]->shape[k], state_placeholder[i]->shape[k]));
86 if (k != 0) {
87 // setup spatial axis
88 std::ostringstream spatial_name;
89 spatial_name << name << ".out" << i << ".i" << k;
90 n->spatial_axis_.push_back(IterVar(Range::FromMinExtent(0, update[i]->shape[k]),
91 Var(spatial_name.str()), kOpaque));
92 }
93 }
94
95 for (size_t k = 1; k < init[i].ndim(); ++k) {
96 ICHECK(prove_equal(init[i]->shape[k], state_placeholder[i]->shape[k]));
97 }
98 }
99 n->name = std::move(name);
100 n->tag = std::move(tag);
101 n->attrs = std::move(attrs);
102 n->scan_axis = std::move(axis);
103 n->init = std::move(init);
104 n->update = std::move(update);
105 n->state_placeholder = std::move(state_placeholder);
106 n->inputs = std::move(inputs);
107 data_ = std::move(n);
108}
109
110TVM_REGISTER_GLOBAL("te.ScanOp")
111 .set_body_typed([](std::string name, std::string tag, Map<String, ObjectRef> attrs,
112 IterVar axis, Array<Tensor> init, Array<Tensor> update,
113 Array<Tensor> state_placeholder, Array<Tensor> inputs) {
114 return ScanOp(name, tag, attrs, axis, init, update, state_placeholder, inputs);
115 });
116
117Array<Tensor> scan(Array<Tensor> init, Array<Tensor> update, Array<Tensor> state_placeholder,
118 Array<Tensor> inputs, std::string name, std::string tag,
119 Map<String, ObjectRef> attrs) {
120 IterVar scan_axis =
121 IterVar(Range::FromMinExtent(init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
122 Var(name + ".idx"), kOrdered);
123 Operation op = ScanOp(name, tag, attrs, scan_axis, init, update, state_placeholder, inputs);
124 Array<Tensor> res;
125 for (int i = 0; i < op->num_outputs(); ++i) {
126 res.push_back(op.output(i));
127 }
128 return res;
129}
130
131Array<Tensor> ScanOpNode::InputTensors() const {
132 Array<Tensor> ret;
133 for (Tensor t : init) {
134 ret.push_back(t);
135 }
136 for (Tensor t : update) {
137 ret.push_back(t);
138 }
139 return ret;
140}
141
142Operation ScanOpNode::ReplaceInputs(const Operation& self,
143 const std::unordered_map<Tensor, Tensor>& rmap) const {
144 ICHECK_EQ(self.operator->(), this);
145 auto n = make_object<ScanOpNode>(*this);
146 for (size_t i = 0; i < n->init.size(); ++i) {
147 if (rmap.count(n->init[i])) {
148 n->init.Set(i, rmap.at(n->init[i]));
149 }
150 if (rmap.count(n->update[i])) {
151 n->update.Set(i, rmap.at(n->update[i]));
152 }
153 }
154 if (!n->init.same_as(init) || !n->update.same_as(update)) {
155 return Operation(n);
156 } else {
157 return self;
158 }
159}
160
161void ScanOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
162 const std::unordered_map<const VarNode*, IntSet>& dom_map,
163 std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
164 ICHECK_EQ(self.operator->(), this);
165 for (size_t i = 0, sp_idx = 0; i < this->init.size(); ++i) {
166 TensorDom* init_dom = nullptr;
167 TensorDom* update_dom = nullptr;
168 if (out_dom_map->count(this->init[i])) {
169 init_dom = &out_dom_map->at(this->init[i]);
170 }
171 if (out_dom_map->count(this->update[i])) {
172 update_dom = &out_dom_map->at(this->update[i]);
173 }
174 // first dimension, always needed.
175 if (init_dom) {
176 init_dom->data[0].push_back(
177 IntSet::FromRange(Range::FromMinExtent(0, this->init[i]->shape[0])));
178 }
179 if (update_dom) {
180 update_dom->data[0].push_back(dom_map.at(this->scan_axis->var.get()));
181 }
182 // The update dimensions
183 for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
184 IterVar sp_ax = this->spatial_axis_[sp_idx];
185 if (init_dom) {
186 init_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
187 }
188 if (update_dom) {
189 update_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
190 }
191 }
192 }
193}
194
195void ScanOpNode::GatherBound(const Operation& self,
196 const std::unordered_map<Tensor, TensorDom>& tensor_dom,
197 std::unordered_map<IterVar, Range>* out_dom_map) const {
198 ICHECK_EQ(self.operator->(), this);
199 ICHECK(!out_dom_map->count(this->scan_axis));
200 std::vector<Tensor> output(this->num_outputs());
201 for (size_t i = 0; i < output.size(); ++i) {
202 output[i] = self.output(i);
203 }
204 // Update for time axis.
205 std::vector<IntSet> time_dom;
206 for (size_t i = 0; i < output.size(); ++i) {
207 const TensorDom& d = tensor_dom.at(output[i]);
208 time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end());
209 }
210 ICHECK(!out_dom_map->count(this->scan_axis));
211 arith::Analyzer analyzer;
212 Range sdom = this->scan_axis->dom;
213 Range r = arith::Union(time_dom).CoverRange(sdom);
214 (*out_dom_map)[this->scan_axis] =
215 Range::FromMinExtent(sdom->min, analyzer.Simplify(r->extent + r->min - sdom->min));
216 Map<IterVar, PrimExpr> fix_pt = ScanFixPointAnalysis(self);
217 // Update for spatial axis.
218 size_t sp_idx = 0;
219 for (size_t i = 0; i < output.size(); ++i) {
220 const TensorDom& d = tensor_dom.at(output[i]);
221 for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
222 IterVar sp_ax = this->spatial_axis_[sp_idx];
223 ICHECK(!out_dom_map->count(sp_ax));
224 ICHECK(fix_pt.count(sp_ax));
225 if (fix_pt[sp_ax].as<tir::IntImmNode>()->value) {
226 // fix point, we can slice it.
227 (*out_dom_map)[sp_ax] = arith::Union(d.data[k]).CoverRange(sp_ax->dom);
228 } else {
229 // not a fix point, need to include everything.
230 (*out_dom_map)[sp_ax] = sp_ax->dom;
231 }
232 }
233 }
234}
235
236Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
237 const Stmt& body, String storage_scope) const {
238 arith::Analyzer analyzer;
239 ICHECK_EQ(stage->op.get(), this);
240 Range sdom = dom_map.at(this->scan_axis);
241 Range tdom = Range::FromMinExtent(0, analyzer.Simplify(sdom->extent + sdom->min));
242 Stmt ret = body;
243 size_t sp_idx = 0;
244 for (size_t i = 0; i < update.size(); ++i) {
245 Tensor t = stage->op.output(i);
246 ICHECK_EQ(static_cast<size_t>(t->value_index), i);
247 Region bounds;
248 bounds.push_back(tdom);
249 for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
250 IterVar sp_ax = this->spatial_axis_[sp_idx];
251 bounds.push_back(dom_map.at(sp_ax));
252 }
253 ret = tir::ProducerRealize(t, bounds, const_true(), ret, storage_scope);
254 }
255 return ret;
256}
257
258Stmt ScanOpNode::BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
259 bool debug_keep_trivial_loop) const {
260 ICHECK_EQ(stage->op.operator->(), this);
261 Stmt provide =
262 AttrStmt(stage->op, tir::attr::scan_update_scope, this->scan_axis->var, Evaluate(0));
263 Stmt init = AttrStmt(stage->op, tir::attr::scan_init_scope, 0, Evaluate(0));
264 size_t begin_scan = 0;
265 for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
266 if (stage->leaf_iter_vars[i]->iter_type == kThreadIndex) {
267 ICHECK_EQ(begin_scan, i);
268 begin_scan = i + 1;
269 }
270 }
271 std::unordered_map<IterVar, PrimExpr> vmap;
272 std::unordered_set<IterVar> empty;
273 auto nest = MakeLoopNest(stage, dom_map, 0, false, empty, &vmap, debug_keep_trivial_loop);
274 nest[begin_scan].push_back(init);
275 nest.push_back(MakeIfNest(MakeBoundCheck(stage, dom_map, vmap, false, empty)));
276 return MergeNest(nest, provide);
277}
278} // namespace te
279} // namespace tvm
280