1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #include "create_primfunc.h" |
21 | |
22 | #include <tvm/arith/analyzer.h> |
23 | #include <tvm/ir/name_supply.h> |
24 | #include <tvm/runtime/registry.h> |
25 | #include <tvm/tir/data_type_rewriter.h> |
26 | #include <tvm/tir/function.h> |
27 | #include <tvm/tir/stmt_functor.h> |
28 | |
29 | #include <algorithm> |
30 | #include <set> |
31 | #include <unordered_map> |
32 | #include <unordered_set> |
33 | #include <utility> |
34 | #include <vector> |
35 | |
36 | #include "../../tir/ir/functor_common.h" |
37 | #include "../../tir/transforms/ir_utils.h" |
38 | #include "../schedule/graph.h" |
39 | |
40 | namespace tvm { |
41 | namespace tir { |
42 | |
43 | /*! \brief The helper mutator that transforms ProducerLoad to BufferLoad */ |
44 | class ProducerToBufferTransformer : public StmtExprMutator { |
45 | public: |
46 | explicit ProducerToBufferTransformer(const std::unordered_map<te::Tensor, Buffer>& tensor2buffers) |
47 | : tensor2buffers_(tensor2buffers) {} |
48 | |
49 | PrimExpr VisitExpr_(const ProducerLoadNode* op) final { |
50 | auto visited_op = Downcast<ProducerLoad>(StmtExprMutator::VisitExpr_(op)); |
51 | te::Tensor tensor = Downcast<te::Tensor>(visited_op->producer); |
52 | auto it = tensor2buffers_.find(tensor); |
53 | ICHECK(it != tensor2buffers_.end()) << "IndexError: Cannot find the tensor " << tensor; |
54 | const Buffer& buffer = it->second; |
55 | return BufferLoad(buffer, visited_op->indices); |
56 | } |
57 | |
58 | private: |
59 | /*! \brief The Map from Operations to buffers */ |
60 | const std::unordered_map<te::Tensor, Buffer>& tensor2buffers_; |
61 | }; |
62 | |
63 | /*! \brief The helper mutator to rewrite buffer and buffer var accessed by block body */ |
64 | class BufferSubstituter : public StmtExprMutator { |
65 | public: |
66 | explicit BufferSubstituter(const std::unordered_map<const VarNode*, PrimExpr>& var_map, |
67 | const std::unordered_map<const BufferNode*, Buffer>& buffer_map) |
68 | : var_map_(var_map), buffer_map_(buffer_map) {} |
69 | |
70 | PrimExpr VisitExpr_(const VarNode* op) final { |
71 | auto it = var_map_.find(op); |
72 | if (it != var_map_.end()) { |
73 | return it->second; |
74 | } |
75 | return StmtExprMutator::VisitExpr_(op); |
76 | } |
77 | |
78 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
79 | auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
80 | auto it = buffer_map_.find(load->buffer.get()); |
81 | if (it != buffer_map_.end()) { |
82 | return BufferLoad(it->second, load->indices, load->span); |
83 | } |
84 | return load; |
85 | } |
86 | |
87 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
88 | auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
89 | auto it = buffer_map_.find(store->buffer.get()); |
90 | if (it != buffer_map_.end()) { |
91 | return BufferStore(it->second, store->value, store->indices, store->span); |
92 | } |
93 | return store; |
94 | } |
95 | |
96 | private: |
97 | const std::unordered_map<const VarNode*, PrimExpr>& var_map_; |
98 | const std::unordered_map<const BufferNode*, Buffer>& buffer_map_; |
99 | }; |
100 | |
101 | /*! \brief Helper data structure to store information. */ |
102 | struct CreateFuncInfo { |
103 | /*! \brief The Tensor arg_list. */ |
104 | Array<te::Tensor> arg_list; |
105 | /*! \brief The map from each Tensor to its corresponding buffer. */ |
106 | std::unordered_map<te::Tensor, Buffer> tensor2buffers; |
107 | /*! \brief The transformer from ProducerLoad to BufferLoad. */ |
108 | ProducerToBufferTransformer transformer; |
109 | /*! \brief The buffers should be allocated at function root. */ |
110 | Array<Buffer> root_alloc; |
111 | /*! \brief The NameSupply to make block name unique. */ |
112 | NameSupply name_supply = NameSupply("" ); |
113 | |
114 | String FreshName(String base_name) { return name_supply->FreshName(base_name); } |
115 | |
116 | explicit CreateFuncInfo(Array<te::Tensor> arg_list) |
117 | : arg_list(std::move(arg_list)), transformer(tensor2buffers) {} |
118 | |
119 | bool IsArg(const te::Tensor& tensor) const { |
120 | return std::any_of(arg_list.begin(), arg_list.end(), |
121 | [&tensor](const te::Tensor& arg) { return tensor == arg; }); |
122 | } |
123 | }; |
124 | |
125 | class LayoutFreePlaceholdersNormalizer : public StmtMutator { |
126 | public: |
127 | PrimFunc Process(PrimFunc func) { |
128 | for (int i = 0, n = func->params.size(); i < n; ++i) { |
129 | if (const auto* v = func->params[i].as<VarNode>()) { |
130 | if (Optional<Buffer> buffer = func->buffer_map.Get(GetRef<Var>(v))) { |
131 | buffer2index_[buffer.value()] = i; |
132 | } |
133 | } |
134 | } |
135 | PrimFuncNode* f = func.CopyOnWrite(); |
136 | f->body = VisitStmt(std::move(f->body)); |
137 | if (this->layout_free_buffer_indices_.empty()) { |
138 | return func; |
139 | } |
140 | Array<Integer> indices; |
141 | indices.reserve(this->layout_free_buffer_indices_.size()); |
142 | for (int i : this->layout_free_buffer_indices_) { |
143 | indices.push_back(Integer(i)); |
144 | } |
145 | return WithAttr(std::move(func), tir::attr::layout_free_buffers, indices); |
146 | } |
147 | |
148 | Stmt VisitStmt_(const BlockNode* _block) final { |
149 | Block block = Downcast<Block>(StmtMutator::VisitStmt_(_block)); |
150 | BlockNode* n = block.CopyOnWrite(); |
151 | if (Optional<ObjectRef> ann = n->annotations.Get(topi_attr)) { |
152 | Array<Buffer> new_buffers; |
153 | for (Buffer buffer : Downcast<Array<Buffer>>(ann)) { |
154 | auto it = buffer2index_.find(buffer); |
155 | if (it != buffer2index_.end()) { |
156 | layout_free_buffer_indices_.insert(it->second); |
157 | } else { |
158 | new_buffers.push_back(buffer); |
159 | } |
160 | } |
161 | if (new_buffers.empty()) { |
162 | n->annotations.erase(topi_attr); |
163 | } else { |
164 | n->annotations.Set(topi_attr, new_buffers); |
165 | } |
166 | } |
167 | for (const String& attr : this->blocklist) { |
168 | auto it = n->annotations.find(attr); |
169 | if (it != n->annotations.end()) { |
170 | n->annotations.erase(attr); |
171 | } |
172 | } |
173 | return std::move(block); |
174 | } |
175 | |
176 | std::unordered_map<tir::Buffer, int, ObjectPtrHash, ObjectPtrEqual> buffer2index_; |
177 | std::set<int> layout_free_buffer_indices_; |
178 | String topi_attr = "layout_free_placeholders" ; |
179 | std::vector<String> blocklist = {"const_matrix" , "auto_scheduler_simplify_const_tensor_indices" , |
180 | "workload" }; |
181 | }; |
182 | |
183 | BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, |
184 | const Array<te::Tensor>& tensors, Array<PrimExpr> bindings, |
185 | PrimExpr expr_body, CreateFuncInfo* info, |
186 | arith::Analyzer* analyzer) { |
187 | // Step 1. Push_back data_par axis and reduce_axis into block_vars. |
188 | Array<IterVar> iter_vars; |
189 | std::unordered_map<const VarNode*, PrimExpr> var_map; |
190 | iter_vars.reserve(compute_op->axis.size() + compute_op->reduce_axis.size()); |
191 | auto f_push_block_vars = [&iter_vars, &var_map, &analyzer](const Array<IterVar>& iters) { |
192 | for (IterVar iter_var : iters) { |
193 | // Create new var |
194 | Var new_var("v_" + iter_var->var->name_hint, iter_var->var->dtype); |
195 | var_map[iter_var->var.get()] = new_var; |
196 | |
197 | PrimExpr dom_min = analyzer->Simplify(iter_var->dom->min); |
198 | PrimExpr dom_extent = analyzer->Simplify(iter_var->dom->extent); |
199 | iter_vars.push_back(IterVar(Range::FromMinExtent(dom_min, dom_extent), new_var, |
200 | iter_var->iter_type, iter_var->thread_tag, iter_var->span)); |
201 | } |
202 | }; |
203 | f_push_block_vars(compute_op->axis); |
204 | f_push_block_vars(compute_op->reduce_axis); |
205 | |
206 | // Step 2. |
207 | // - Declare buffers |
208 | // - Update `op2buffers` |
209 | // - Add the non-argument tensors to `alloc_buffer` of the root block |
210 | Array<Buffer> buffers; |
211 | for (const te::Tensor& tensor : tensors) { |
212 | Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global" ); |
213 | info->tensor2buffers[tensor] = buffer; |
214 | buffers.push_back(buffer); |
215 | |
216 | if (!info->IsArg(tensor)) { |
217 | info->root_alloc.push_back(info->tensor2buffers[tensor]); |
218 | } |
219 | } |
220 | |
221 | // Step 3. Calculate indices for BufferStore |
222 | Array<PrimExpr> indices; |
223 | indices.reserve(compute_op->axis.size()); |
224 | for (const IterVar& iter_var : compute_op->axis) { |
225 | auto it = var_map.find(iter_var->var.get()); |
226 | ICHECK(it != var_map.end()); |
227 | indices.push_back(it->second); |
228 | } |
229 | |
230 | // Step 4. Create block body. |
231 | String block_name{nullptr}; |
232 | Optional<Stmt> init = NullOpt; |
233 | Stmt body; |
234 | if (const auto* reduce = expr_body.as<ReduceNode>()) { |
235 | // Case 1. Reduce compute |
236 | block_name = info->FreshName(compute_op->name); |
237 | int n_buffers = buffers.size(); |
238 | |
239 | Array<PrimExpr> lhs; |
240 | Array<PrimExpr> rhs; |
241 | lhs.reserve(n_buffers); |
242 | rhs.reserve(n_buffers); |
243 | |
244 | // Make the LHS operands and RHS operands: |
245 | // - A LHS operand is the buffer storing the reduction result, with corresponding indices. |
246 | // - A RHS operand is the value to be reduced. |
247 | for (int i = 0; i < n_buffers; ++i) { |
248 | const PrimExpr& left = BufferLoad(buffers[i], indices); |
249 | const PrimExpr& right = |
250 | analyzer->Simplify(Substitute(info->transformer(reduce->source[i]), var_map)); |
251 | lhs.push_back(left); |
252 | rhs.push_back(right); |
253 | ICHECK_EQ(left->dtype, right->dtype); |
254 | } |
255 | |
256 | Array<Var> temp_vars; |
257 | Array<Stmt> body_stmts; |
258 | Array<Stmt> init_stmts; |
259 | temp_vars.reserve(n_buffers); |
260 | body_stmts.reserve(n_buffers); |
261 | init_stmts.reserve(n_buffers); |
262 | |
263 | // - When there is only one buffer, we directly create a BufferStore which stores "combiner(lhs, |
264 | // rhs)" into the target buffer position. |
265 | // - In case there are multiple buffers, to avoid incorrect results, we create some intermediate |
266 | // variables and use LetStmts to bind the variables with "combiner(lhs, rhs)". After that, we |
267 | // then store the value of the variables into the target buffer positions. |
268 | for (int i = 0; i < n_buffers; ++i) { |
269 | const Buffer& buffer = buffers[i]; |
270 | init_stmts.push_back(BufferStore(buffer, reduce->combiner->identity_element[i], indices)); |
271 | PrimExpr value{nullptr}; |
272 | if (n_buffers > 1) { |
273 | temp_vars.push_back(Var("v_" + buffer->name, PrimType(lhs[i].dtype()))); |
274 | value = temp_vars.back(); |
275 | } else { |
276 | value = reduce->combiner.get()->operator()(lhs, rhs)[i]; |
277 | } |
278 | body_stmts.push_back(BufferStore(buffer, value, indices)); |
279 | } |
280 | |
281 | init = SeqStmt::Flatten(init_stmts); |
282 | body = SeqStmt::Flatten(body_stmts); |
283 | if (n_buffers > 1) { |
284 | // When there are multiple buffers, we wrap the body with LetStmts. |
285 | for (int i = n_buffers - 1; i >= 0; --i) { |
286 | PrimExpr value = reduce->combiner.get()->operator()(lhs, rhs)[i]; |
287 | body = LetStmt(temp_vars[i], std::move(value), std::move(body)); |
288 | } |
289 | } |
290 | } else { |
291 | // Case 2. Data parallel compute |
292 | ICHECK_EQ(tensors.size(), 1); |
293 | block_name = info->FreshName(tensors[0]->GetNameHint()); |
294 | const PrimExpr& compute_body = Substitute(info->transformer(expr_body), var_map); |
295 | body = BufferStore(info->tensor2buffers[tensors[0]], analyzer->Simplify(compute_body), indices); |
296 | } |
297 | |
298 | // Step 5. Add script_parsing_detect_access attr for auto complete the whole IR. |
299 | Map<String, ObjectRef> annotations; |
300 | auto mutate_attr = [&info](const ObjectRef& value) -> ObjectRef { |
301 | if (const auto* tensor_value = value.as<te::TensorNode>()) { |
302 | return info->tensor2buffers.at(GetRef<te::Tensor>(tensor_value)); |
303 | } else { |
304 | return value; |
305 | } |
306 | }; |
307 | |
308 | for (const auto& pair : compute_op->attrs) { |
309 | const String& key = pair.first; |
310 | const ObjectRef& value = pair.second; |
311 | // TensorIR will not allow Tensor data structure |
312 | if (value->IsInstance<ArrayNode>()) { |
313 | const auto array_value = Downcast<Array<ObjectRef>>(value); |
314 | annotations.Set(key, array_value.Map(mutate_attr)); |
315 | } else { |
316 | annotations.Set(key, mutate_attr(value)); |
317 | } |
318 | } |
319 | // Set script_parsing_detect_access |
320 | annotations.Set(tir::attr::script_parsing_detect_access, IntImm(DataType::Int(32), 3)); |
321 | if (iter_vars.empty()) { |
322 | IterVar iter(Range::FromMinExtent(0, 1), Var("vi" , DataType::Int(32)), IterVarType::kDataPar); |
323 | PrimExpr binding(0); |
324 | iter_vars.push_back(iter); |
325 | bindings.push_back(binding); |
326 | } |
327 | |
328 | // Step 6. Create Block and BlockRealize. |
329 | return BlockRealize(/*iter_values=*/std::move(bindings), |
330 | /*predicate=*/Bool(true), |
331 | /*block=*/ |
332 | Block(/*iter_vars=*/std::move(iter_vars), |
333 | /*reads=*/{}, |
334 | /*writes=*/{}, |
335 | /*name_hint=*/block_name, |
336 | /*body=*/std::move(body), |
337 | /*init=*/std::move(init), |
338 | /*alloc_buffers=*/{}, |
339 | /*match_buffers=*/{}, |
340 | /*annotations=*/std::move(annotations))); |
341 | } |
342 | |
343 | Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* info, |
344 | arith::Analyzer* analyzer) { |
345 | // Step 1. Creating loop vars for block bindings. |
346 | Array<IterVar> axes = compute_op->axis; |
347 | axes.insert(axes.end(), compute_op->reduce_axis.begin(), compute_op->reduce_axis.end()); |
348 | |
349 | Array<PrimExpr> bindings = axes.Map([&](IterVar iter_var) -> PrimExpr { |
350 | int bits = std::max(iter_var->dom->min.dtype().bits(), iter_var->dom->extent.dtype().bits()); |
351 | return Var(iter_var->var->name_hint, runtime::DataType::Int(bits)); |
352 | }); |
353 | |
354 | // Step 2. Generate block bodies. |
355 | Array<Stmt> seq_stmt; |
356 | if (compute_op->body[0]->IsInstance<ReduceNode>()) { |
357 | auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) -> bool { |
358 | return a->combiner.same_as(b->combiner) && // |
359 | a->source.same_as(b->source) && // |
360 | a->axis.same_as(b->axis) && // |
361 | a->condition.same_as(b->condition) && // |
362 | ((a->init.empty() && b->init.empty()) || a->init.same_as(b->init)); |
363 | }; |
364 | |
365 | PrimExpr expr_body = compute_op->body[0]; |
366 | Array<te::Tensor> tensors = {compute_op.output(0)}; |
367 | const tir::ReduceNode* reduce = expr_body.as<tir::ReduceNode>(); |
368 | // specially handle reduction inline for multiplre reductions. |
369 | for (size_t k = 1; k < compute_op->body.size(); ++k) { |
370 | const tir::ReduceNode* reduce_ = compute_op->body[k].as<tir::ReduceNode>(); |
371 | ICHECK(reduce_); |
372 | ICHECK(f_reducer_equal(reduce_, reduce)) |
373 | << "The Reduce inputs of ComputeOp should have the same attribute except value_index" ; |
374 | tensors.push_back(compute_op.output(k)); |
375 | } |
376 | |
377 | seq_stmt.push_back(GenerateBlockFromTensors(compute_op, tensors, bindings, std::move(expr_body), |
378 | info, analyzer)); |
379 | } else { |
380 | for (int i = 0; i < compute_op->num_outputs(); ++i) { |
381 | const te::Tensor& tensor = compute_op.output(i); |
382 | PrimExpr expr_body = compute_op->body[i]; |
383 | seq_stmt.push_back(GenerateBlockFromTensors(compute_op, {tensor}, bindings, |
384 | std::move(expr_body), info, analyzer)); |
385 | } |
386 | } |
387 | |
388 | Stmt body = SeqStmt::Flatten(seq_stmt); |
389 | |
390 | // Step 3. Generate loop nesting. |
391 | for (size_t i = axes.size(); i > 0; --i) { |
392 | const IterVar& axis = axes[i - 1]; |
393 | PrimExpr dom_min = analyzer->Simplify(axis->dom->min); |
394 | PrimExpr dom_extent = analyzer->Simplify(axis->dom->extent); |
395 | const Var& loop_var = Downcast<Var>(bindings[i - 1]); |
396 | body = For(loop_var, dom_min, dom_extent, ForKind::kSerial, body); |
397 | } |
398 | |
399 | return body; |
400 | } |
401 | |
402 | Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* info) { |
403 | // Step 1. Check all inputs are visited before and update var_map. |
404 | std::unordered_map<const VarNode*, PrimExpr> var_map; |
405 | std::unordered_map<const BufferNode*, Buffer> input_buffer_map; |
406 | ICHECK_EQ(extern_op->inputs.size(), extern_op->input_placeholders.size()); |
407 | for (size_t i = 0; i < extern_op->inputs.size(); ++i) { |
408 | const Buffer& placeholder = extern_op->input_placeholders[i]; |
409 | const te::Tensor& input_tensor = extern_op->inputs[i]; |
410 | auto it = info->tensor2buffers.find(input_tensor); |
411 | ICHECK(it != info->tensor2buffers.end()); |
412 | var_map[placeholder->data.get()] = it->second->data; |
413 | input_buffer_map[placeholder.get()] = it->second; |
414 | } |
415 | |
416 | // Step 2. Update info with its output tensor and placeholder buffer. |
417 | ICHECK_EQ(extern_op->num_outputs(), extern_op->output_placeholders.size()); |
418 | for (int i = 0; i < extern_op->num_outputs(); ++i) { |
419 | const Buffer& placeholder = extern_op->output_placeholders[i]; |
420 | const te::Tensor& output_tensor = extern_op.output(i); |
421 | info->tensor2buffers[output_tensor] = placeholder; |
422 | if (!info->IsArg(output_tensor)) { |
423 | info->root_alloc.push_back(placeholder); |
424 | } |
425 | } |
426 | |
427 | // Step 3. Collect Access Region |
428 | Array<BufferRegion> reads, writes; |
429 | for (const te::Tensor& tensor : extern_op->inputs) { |
430 | // We have ICHECK before so it is not needed here. |
431 | reads.push_back(BufferRegion::FullRegion(info->tensor2buffers[tensor])); |
432 | } |
433 | for (const Buffer& buffer : extern_op->output_placeholders) { |
434 | writes.push_back(BufferRegion::FullRegion(buffer)); |
435 | } |
436 | |
437 | BufferSubstituter substituter(var_map, input_buffer_map); |
438 | Stmt body = substituter(extern_op->body); |
439 | |
440 | // Step 4. Generate opaque block as body. |
441 | return BlockRealize(/*iter_values=*/{}, |
442 | /*predicate=*/Bool(true), |
443 | /*block=*/ |
444 | Block(/*iter_vars=*/{}, |
445 | /*reads=*/std::move(reads), |
446 | /*writes=*/std::move(writes), |
447 | /*name_hint=*/info->FreshName(extern_op->name), |
448 | /*body=*/std::move(body), |
449 | /*init=*/NullOpt, |
450 | /*alloc_buffers=*/{}, |
451 | /*match_buffers=*/{}, |
452 | /*annotations=*/extern_op->attrs)); |
453 | } |
454 | |
455 | Array<te::Operation> CollectOrderedOps(const Array<te::Tensor>& arg_list) { |
456 | Array<te::Operation> arg_ops; |
457 | for (const te::Tensor& arg : arg_list) { |
458 | arg_ops.push_back(arg->op); |
459 | } |
460 | te::ReadGraph g = te::CreateReadGraph(arg_ops); |
461 | Array<te::Operation> order = te::PostDFSOrder(arg_ops, g); |
462 | |
463 | for (const te::Operation& op : order) { |
464 | if (!(op->IsInstance<te::PlaceholderOpNode>() || op->IsInstance<te::ComputeOpNode>() || |
465 | op->IsInstance<te::ExternOpNode>())) |
466 | LOG(FATAL) << "TypeError: Unsupported Operation: " << op->GetTypeKey() << ". " |
467 | << "Only te.placeholder and te.compute are allowed for now." ; |
468 | } |
469 | return order; |
470 | } |
471 | |
472 | void InitializeBufferBinds(const Array<te::Operation>& ordered_ops, CreateFuncInfo* info) { |
473 | // Process any TE operations which contain user defined buffers |
474 | for (const auto& op : ordered_ops) { |
475 | // Initialize the tensor2buffer binds map with buffers defined by the te.extern |
476 | if (const auto* extern_op = op.as<te::ExternOpNode>()) { |
477 | ICHECK_EQ(extern_op->inputs.size(), extern_op->input_placeholders.size()); |
478 | for (size_t i = 0; i < extern_op->inputs.size(); ++i) { |
479 | const te::Tensor& input = extern_op->inputs[i]; |
480 | const Buffer& buffer = extern_op->input_placeholders[i]; |
481 | info->tensor2buffers[input] = buffer; |
482 | } |
483 | } |
484 | } |
485 | } |
486 | |
487 | void RewriteStageToBlock(const te::Operation& op, CreateFuncInfo* info, Array<Stmt>* root_stmts, |
488 | arith::Analyzer* analyzer) { |
489 | if (const auto* placeholder = op.as<te::PlaceholderOpNode>()) { |
490 | // Case 1. PlaceholderOp (te.placeholder) |
491 | ICHECK_EQ(op->num_outputs(), 1); |
492 | const te::Tensor& tensor = op.output(0); |
493 | // Check op is in op list |
494 | ICHECK(info->IsArg(tensor)); |
495 | // Declare a buffer for any argument tensors without a pre-existing |
496 | // buffer declaration recorded in the tensor2buffer binds map |
497 | if (info->tensor2buffers.count(tensor) == 0) { |
498 | const Buffer& buffer = |
499 | decl_buffer(placeholder->shape, placeholder->dtype, placeholder->name, "global" ); |
500 | info->tensor2buffers[tensor] = buffer; |
501 | } |
502 | } else if (const auto* compute_op = op.as<te::ComputeOpNode>()) { |
503 | // Case 2. ComputeOp (te.compute) |
504 | root_stmts->push_back( |
505 | GenerateStmtFromCompute(GetRef<te::ComputeOp>(compute_op), info, analyzer)); |
506 | } else if (const auto extern_op = op.as<te::ExternOpNode>()) { |
507 | // Case 3. ExternOp (te.extern) |
508 | root_stmts->push_back(GenerateStmtFromExternOp(GetRef<te::ExternOp>(extern_op), info)); |
509 | } else { |
510 | ICHECK(false) << "TypeError: Unsupported Operation: " << op->GetTypeKey() << ". " |
511 | << "Only te.placeholder and te.compute are allowed for now." ; |
512 | } |
513 | } |
514 | |
515 | PrimFunc GenerateAndCompletePrimFunc(const Array<te::Tensor>& arg_list, |
516 | const Array<Stmt>& root_stmts, CreateFuncInfo* info) { |
517 | Array<Var> parameters; |
518 | Map<Var, Buffer> buffer_map; |
519 | for (const te::Tensor& tensor : arg_list) { |
520 | Var arg("var_" + tensor->GetNameHint(), PrimType(DataType::Handle())); |
521 | parameters.push_back(arg); |
522 | auto it = info->tensor2buffers.find(tensor); |
523 | ICHECK(it != info->tensor2buffers.end()); |
524 | buffer_map.Set(arg, it->second); |
525 | } |
526 | PrimFunc func = WithAttrs(PrimFunc(/*params=*/std::move(parameters), |
527 | /*body=*/SeqStmt::Flatten(root_stmts), |
528 | /*ret_type=*/VoidType(), |
529 | /*buffer_map=*/std::move(buffer_map)), |
530 | {{"global_symbol" , String("main" )}, {"tir.noalias" , Bool(true)}}); |
531 | const auto* complete = runtime::Registry::Get("script.Complete" ); |
532 | ICHECK(complete); |
533 | func = (*complete)(std::move(func), info->root_alloc); |
534 | return func; |
535 | } |
536 | |
537 | PrimFunc CreatePrimFuncWithConstants(const Array<te::Tensor>& arg_list, |
538 | const Array<runtime::NDArray>& constants, |
539 | std::optional<DataType> index_dtype_override) { |
540 | // Information used in CreatePrimFunc and its sub-functions. |
541 | CreateFuncInfo info(arg_list); |
542 | // Root body stmts. |
543 | Array<Stmt> root_stmts; |
544 | // Analyzer |
545 | arith::Analyzer analyzer; |
546 | |
547 | // Step 1. Create ordered array of operations and validate they are supported. |
548 | Array<te::Operation> order = CollectOrderedOps(arg_list); |
549 | |
550 | // Step 2. Initialize buffer binds map |
551 | InitializeBufferBinds(order, &info); |
552 | |
553 | // Step 3. Rewrite compute stages into blocks. |
554 | for (const te::Operation& op : order) { |
555 | RewriteStageToBlock(op, &info, &root_stmts, &analyzer); |
556 | } |
557 | |
558 | // Step 4. Create func and complete prim func. |
559 | auto func = GenerateAndCompletePrimFunc(arg_list, root_stmts, &info); |
560 | func = tir::BindParams(func, constants); |
561 | if (index_dtype_override.has_value()) { |
562 | func = IndexDataTypeNormalizer(index_dtype_override.value()).Rewrite(std::move(func)); |
563 | } |
564 | auto result = LayoutFreePlaceholdersNormalizer().Process(std::move(func)); |
565 | return result; |
566 | } |
567 | |
568 | PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list, |
569 | std::optional<DataType> index_dtype_override) { |
570 | return CreatePrimFuncWithConstants(arg_list, {}, index_dtype_override); |
571 | } |
572 | |
573 | TVM_REGISTER_GLOBAL("te.CreatePrimFunc" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
574 | Array<te::Tensor> arg_list = args[0]; |
575 | std::optional<DataType> index_dtype_override{std::nullopt}; |
576 | // Add conversion to make std::optional compatible with FFI. |
577 | ICHECK_EQ(args.size(), 2); |
578 | if (args[1].type_code() != kTVMNullptr) { |
579 | index_dtype_override = args[1].operator DataType(); |
580 | } |
581 | *ret = CreatePrimFunc(arg_list, index_dtype_override); |
582 | }); |
583 | |
584 | } // namespace tir |
585 | } // namespace tvm |
586 | |