1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#include "create_primfunc.h"
21
22#include <tvm/arith/analyzer.h>
23#include <tvm/ir/name_supply.h>
24#include <tvm/runtime/registry.h>
25#include <tvm/tir/data_type_rewriter.h>
26#include <tvm/tir/function.h>
27#include <tvm/tir/stmt_functor.h>
28
29#include <algorithm>
30#include <set>
31#include <unordered_map>
32#include <unordered_set>
33#include <utility>
34#include <vector>
35
36#include "../../tir/ir/functor_common.h"
37#include "../../tir/transforms/ir_utils.h"
38#include "../schedule/graph.h"
39
40namespace tvm {
41namespace tir {
42
43/*! \brief The helper mutator that transforms ProducerLoad to BufferLoad */
44class ProducerToBufferTransformer : public StmtExprMutator {
45 public:
46 explicit ProducerToBufferTransformer(const std::unordered_map<te::Tensor, Buffer>& tensor2buffers)
47 : tensor2buffers_(tensor2buffers) {}
48
49 PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
50 auto visited_op = Downcast<ProducerLoad>(StmtExprMutator::VisitExpr_(op));
51 te::Tensor tensor = Downcast<te::Tensor>(visited_op->producer);
52 auto it = tensor2buffers_.find(tensor);
53 ICHECK(it != tensor2buffers_.end()) << "IndexError: Cannot find the tensor " << tensor;
54 const Buffer& buffer = it->second;
55 return BufferLoad(buffer, visited_op->indices);
56 }
57
58 private:
59 /*! \brief The Map from Operations to buffers */
60 const std::unordered_map<te::Tensor, Buffer>& tensor2buffers_;
61};
62
63/*! \brief The helper mutator to rewrite buffer and buffer var accessed by block body */
64class BufferSubstituter : public StmtExprMutator {
65 public:
66 explicit BufferSubstituter(const std::unordered_map<const VarNode*, PrimExpr>& var_map,
67 const std::unordered_map<const BufferNode*, Buffer>& buffer_map)
68 : var_map_(var_map), buffer_map_(buffer_map) {}
69
70 PrimExpr VisitExpr_(const VarNode* op) final {
71 auto it = var_map_.find(op);
72 if (it != var_map_.end()) {
73 return it->second;
74 }
75 return StmtExprMutator::VisitExpr_(op);
76 }
77
78 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
79 auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
80 auto it = buffer_map_.find(load->buffer.get());
81 if (it != buffer_map_.end()) {
82 return BufferLoad(it->second, load->indices, load->span);
83 }
84 return load;
85 }
86
87 Stmt VisitStmt_(const BufferStoreNode* op) final {
88 auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
89 auto it = buffer_map_.find(store->buffer.get());
90 if (it != buffer_map_.end()) {
91 return BufferStore(it->second, store->value, store->indices, store->span);
92 }
93 return store;
94 }
95
96 private:
97 const std::unordered_map<const VarNode*, PrimExpr>& var_map_;
98 const std::unordered_map<const BufferNode*, Buffer>& buffer_map_;
99};
100
101/*! \brief Helper data structure to store information. */
102struct CreateFuncInfo {
103 /*! \brief The Tensor arg_list. */
104 Array<te::Tensor> arg_list;
105 /*! \brief The map from each Tensor to its corresponding buffer. */
106 std::unordered_map<te::Tensor, Buffer> tensor2buffers;
107 /*! \brief The transformer from ProducerLoad to BufferLoad. */
108 ProducerToBufferTransformer transformer;
109 /*! \brief The buffers should be allocated at function root. */
110 Array<Buffer> root_alloc;
111 /*! \brief The NameSupply to make block name unique. */
112 NameSupply name_supply = NameSupply("");
113
114 String FreshName(String base_name) { return name_supply->FreshName(base_name); }
115
116 explicit CreateFuncInfo(Array<te::Tensor> arg_list)
117 : arg_list(std::move(arg_list)), transformer(tensor2buffers) {}
118
119 bool IsArg(const te::Tensor& tensor) const {
120 return std::any_of(arg_list.begin(), arg_list.end(),
121 [&tensor](const te::Tensor& arg) { return tensor == arg; });
122 }
123};
124
125class LayoutFreePlaceholdersNormalizer : public StmtMutator {
126 public:
127 PrimFunc Process(PrimFunc func) {
128 for (int i = 0, n = func->params.size(); i < n; ++i) {
129 if (const auto* v = func->params[i].as<VarNode>()) {
130 if (Optional<Buffer> buffer = func->buffer_map.Get(GetRef<Var>(v))) {
131 buffer2index_[buffer.value()] = i;
132 }
133 }
134 }
135 PrimFuncNode* f = func.CopyOnWrite();
136 f->body = VisitStmt(std::move(f->body));
137 if (this->layout_free_buffer_indices_.empty()) {
138 return func;
139 }
140 Array<Integer> indices;
141 indices.reserve(this->layout_free_buffer_indices_.size());
142 for (int i : this->layout_free_buffer_indices_) {
143 indices.push_back(Integer(i));
144 }
145 return WithAttr(std::move(func), tir::attr::layout_free_buffers, indices);
146 }
147
148 Stmt VisitStmt_(const BlockNode* _block) final {
149 Block block = Downcast<Block>(StmtMutator::VisitStmt_(_block));
150 BlockNode* n = block.CopyOnWrite();
151 if (Optional<ObjectRef> ann = n->annotations.Get(topi_attr)) {
152 Array<Buffer> new_buffers;
153 for (Buffer buffer : Downcast<Array<Buffer>>(ann)) {
154 auto it = buffer2index_.find(buffer);
155 if (it != buffer2index_.end()) {
156 layout_free_buffer_indices_.insert(it->second);
157 } else {
158 new_buffers.push_back(buffer);
159 }
160 }
161 if (new_buffers.empty()) {
162 n->annotations.erase(topi_attr);
163 } else {
164 n->annotations.Set(topi_attr, new_buffers);
165 }
166 }
167 for (const String& attr : this->blocklist) {
168 auto it = n->annotations.find(attr);
169 if (it != n->annotations.end()) {
170 n->annotations.erase(attr);
171 }
172 }
173 return std::move(block);
174 }
175
176 std::unordered_map<tir::Buffer, int, ObjectPtrHash, ObjectPtrEqual> buffer2index_;
177 std::set<int> layout_free_buffer_indices_;
178 String topi_attr = "layout_free_placeholders";
179 std::vector<String> blocklist = {"const_matrix", "auto_scheduler_simplify_const_tensor_indices",
180 "workload"};
181};
182
183BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
184 const Array<te::Tensor>& tensors, Array<PrimExpr> bindings,
185 PrimExpr expr_body, CreateFuncInfo* info,
186 arith::Analyzer* analyzer) {
187 // Step 1. Push_back data_par axis and reduce_axis into block_vars.
188 Array<IterVar> iter_vars;
189 std::unordered_map<const VarNode*, PrimExpr> var_map;
190 iter_vars.reserve(compute_op->axis.size() + compute_op->reduce_axis.size());
191 auto f_push_block_vars = [&iter_vars, &var_map, &analyzer](const Array<IterVar>& iters) {
192 for (IterVar iter_var : iters) {
193 // Create new var
194 Var new_var("v_" + iter_var->var->name_hint, iter_var->var->dtype);
195 var_map[iter_var->var.get()] = new_var;
196
197 PrimExpr dom_min = analyzer->Simplify(iter_var->dom->min);
198 PrimExpr dom_extent = analyzer->Simplify(iter_var->dom->extent);
199 iter_vars.push_back(IterVar(Range::FromMinExtent(dom_min, dom_extent), new_var,
200 iter_var->iter_type, iter_var->thread_tag, iter_var->span));
201 }
202 };
203 f_push_block_vars(compute_op->axis);
204 f_push_block_vars(compute_op->reduce_axis);
205
206 // Step 2.
207 // - Declare buffers
208 // - Update `op2buffers`
209 // - Add the non-argument tensors to `alloc_buffer` of the root block
210 Array<Buffer> buffers;
211 for (const te::Tensor& tensor : tensors) {
212 Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global");
213 info->tensor2buffers[tensor] = buffer;
214 buffers.push_back(buffer);
215
216 if (!info->IsArg(tensor)) {
217 info->root_alloc.push_back(info->tensor2buffers[tensor]);
218 }
219 }
220
221 // Step 3. Calculate indices for BufferStore
222 Array<PrimExpr> indices;
223 indices.reserve(compute_op->axis.size());
224 for (const IterVar& iter_var : compute_op->axis) {
225 auto it = var_map.find(iter_var->var.get());
226 ICHECK(it != var_map.end());
227 indices.push_back(it->second);
228 }
229
230 // Step 4. Create block body.
231 String block_name{nullptr};
232 Optional<Stmt> init = NullOpt;
233 Stmt body;
234 if (const auto* reduce = expr_body.as<ReduceNode>()) {
235 // Case 1. Reduce compute
236 block_name = info->FreshName(compute_op->name);
237 int n_buffers = buffers.size();
238
239 Array<PrimExpr> lhs;
240 Array<PrimExpr> rhs;
241 lhs.reserve(n_buffers);
242 rhs.reserve(n_buffers);
243
244 // Make the LHS operands and RHS operands:
245 // - A LHS operand is the buffer storing the reduction result, with corresponding indices.
246 // - A RHS operand is the value to be reduced.
247 for (int i = 0; i < n_buffers; ++i) {
248 const PrimExpr& left = BufferLoad(buffers[i], indices);
249 const PrimExpr& right =
250 analyzer->Simplify(Substitute(info->transformer(reduce->source[i]), var_map));
251 lhs.push_back(left);
252 rhs.push_back(right);
253 ICHECK_EQ(left->dtype, right->dtype);
254 }
255
256 Array<Var> temp_vars;
257 Array<Stmt> body_stmts;
258 Array<Stmt> init_stmts;
259 temp_vars.reserve(n_buffers);
260 body_stmts.reserve(n_buffers);
261 init_stmts.reserve(n_buffers);
262
263 // - When there is only one buffer, we directly create a BufferStore which stores "combiner(lhs,
264 // rhs)" into the target buffer position.
265 // - In case there are multiple buffers, to avoid incorrect results, we create some intermediate
266 // variables and use LetStmts to bind the variables with "combiner(lhs, rhs)". After that, we
267 // then store the value of the variables into the target buffer positions.
268 for (int i = 0; i < n_buffers; ++i) {
269 const Buffer& buffer = buffers[i];
270 init_stmts.push_back(BufferStore(buffer, reduce->combiner->identity_element[i], indices));
271 PrimExpr value{nullptr};
272 if (n_buffers > 1) {
273 temp_vars.push_back(Var("v_" + buffer->name, PrimType(lhs[i].dtype())));
274 value = temp_vars.back();
275 } else {
276 value = reduce->combiner.get()->operator()(lhs, rhs)[i];
277 }
278 body_stmts.push_back(BufferStore(buffer, value, indices));
279 }
280
281 init = SeqStmt::Flatten(init_stmts);
282 body = SeqStmt::Flatten(body_stmts);
283 if (n_buffers > 1) {
284 // When there are multiple buffers, we wrap the body with LetStmts.
285 for (int i = n_buffers - 1; i >= 0; --i) {
286 PrimExpr value = reduce->combiner.get()->operator()(lhs, rhs)[i];
287 body = LetStmt(temp_vars[i], std::move(value), std::move(body));
288 }
289 }
290 } else {
291 // Case 2. Data parallel compute
292 ICHECK_EQ(tensors.size(), 1);
293 block_name = info->FreshName(tensors[0]->GetNameHint());
294 const PrimExpr& compute_body = Substitute(info->transformer(expr_body), var_map);
295 body = BufferStore(info->tensor2buffers[tensors[0]], analyzer->Simplify(compute_body), indices);
296 }
297
298 // Step 5. Add script_parsing_detect_access attr for auto complete the whole IR.
299 Map<String, ObjectRef> annotations;
300 auto mutate_attr = [&info](const ObjectRef& value) -> ObjectRef {
301 if (const auto* tensor_value = value.as<te::TensorNode>()) {
302 return info->tensor2buffers.at(GetRef<te::Tensor>(tensor_value));
303 } else {
304 return value;
305 }
306 };
307
308 for (const auto& pair : compute_op->attrs) {
309 const String& key = pair.first;
310 const ObjectRef& value = pair.second;
311 // TensorIR will not allow Tensor data structure
312 if (value->IsInstance<ArrayNode>()) {
313 const auto array_value = Downcast<Array<ObjectRef>>(value);
314 annotations.Set(key, array_value.Map(mutate_attr));
315 } else {
316 annotations.Set(key, mutate_attr(value));
317 }
318 }
319 // Set script_parsing_detect_access
320 annotations.Set(tir::attr::script_parsing_detect_access, IntImm(DataType::Int(32), 3));
321 if (iter_vars.empty()) {
322 IterVar iter(Range::FromMinExtent(0, 1), Var("vi", DataType::Int(32)), IterVarType::kDataPar);
323 PrimExpr binding(0);
324 iter_vars.push_back(iter);
325 bindings.push_back(binding);
326 }
327
328 // Step 6. Create Block and BlockRealize.
329 return BlockRealize(/*iter_values=*/std::move(bindings),
330 /*predicate=*/Bool(true),
331 /*block=*/
332 Block(/*iter_vars=*/std::move(iter_vars),
333 /*reads=*/{},
334 /*writes=*/{},
335 /*name_hint=*/block_name,
336 /*body=*/std::move(body),
337 /*init=*/std::move(init),
338 /*alloc_buffers=*/{},
339 /*match_buffers=*/{},
340 /*annotations=*/std::move(annotations)));
341}
342
343Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* info,
344 arith::Analyzer* analyzer) {
345 // Step 1. Creating loop vars for block bindings.
346 Array<IterVar> axes = compute_op->axis;
347 axes.insert(axes.end(), compute_op->reduce_axis.begin(), compute_op->reduce_axis.end());
348
349 Array<PrimExpr> bindings = axes.Map([&](IterVar iter_var) -> PrimExpr {
350 int bits = std::max(iter_var->dom->min.dtype().bits(), iter_var->dom->extent.dtype().bits());
351 return Var(iter_var->var->name_hint, runtime::DataType::Int(bits));
352 });
353
354 // Step 2. Generate block bodies.
355 Array<Stmt> seq_stmt;
356 if (compute_op->body[0]->IsInstance<ReduceNode>()) {
357 auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) -> bool {
358 return a->combiner.same_as(b->combiner) && //
359 a->source.same_as(b->source) && //
360 a->axis.same_as(b->axis) && //
361 a->condition.same_as(b->condition) && //
362 ((a->init.empty() && b->init.empty()) || a->init.same_as(b->init));
363 };
364
365 PrimExpr expr_body = compute_op->body[0];
366 Array<te::Tensor> tensors = {compute_op.output(0)};
367 const tir::ReduceNode* reduce = expr_body.as<tir::ReduceNode>();
368 // specially handle reduction inline for multiplre reductions.
369 for (size_t k = 1; k < compute_op->body.size(); ++k) {
370 const tir::ReduceNode* reduce_ = compute_op->body[k].as<tir::ReduceNode>();
371 ICHECK(reduce_);
372 ICHECK(f_reducer_equal(reduce_, reduce))
373 << "The Reduce inputs of ComputeOp should have the same attribute except value_index";
374 tensors.push_back(compute_op.output(k));
375 }
376
377 seq_stmt.push_back(GenerateBlockFromTensors(compute_op, tensors, bindings, std::move(expr_body),
378 info, analyzer));
379 } else {
380 for (int i = 0; i < compute_op->num_outputs(); ++i) {
381 const te::Tensor& tensor = compute_op.output(i);
382 PrimExpr expr_body = compute_op->body[i];
383 seq_stmt.push_back(GenerateBlockFromTensors(compute_op, {tensor}, bindings,
384 std::move(expr_body), info, analyzer));
385 }
386 }
387
388 Stmt body = SeqStmt::Flatten(seq_stmt);
389
390 // Step 3. Generate loop nesting.
391 for (size_t i = axes.size(); i > 0; --i) {
392 const IterVar& axis = axes[i - 1];
393 PrimExpr dom_min = analyzer->Simplify(axis->dom->min);
394 PrimExpr dom_extent = analyzer->Simplify(axis->dom->extent);
395 const Var& loop_var = Downcast<Var>(bindings[i - 1]);
396 body = For(loop_var, dom_min, dom_extent, ForKind::kSerial, body);
397 }
398
399 return body;
400}
401
402Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* info) {
403 // Step 1. Check all inputs are visited before and update var_map.
404 std::unordered_map<const VarNode*, PrimExpr> var_map;
405 std::unordered_map<const BufferNode*, Buffer> input_buffer_map;
406 ICHECK_EQ(extern_op->inputs.size(), extern_op->input_placeholders.size());
407 for (size_t i = 0; i < extern_op->inputs.size(); ++i) {
408 const Buffer& placeholder = extern_op->input_placeholders[i];
409 const te::Tensor& input_tensor = extern_op->inputs[i];
410 auto it = info->tensor2buffers.find(input_tensor);
411 ICHECK(it != info->tensor2buffers.end());
412 var_map[placeholder->data.get()] = it->second->data;
413 input_buffer_map[placeholder.get()] = it->second;
414 }
415
416 // Step 2. Update info with its output tensor and placeholder buffer.
417 ICHECK_EQ(extern_op->num_outputs(), extern_op->output_placeholders.size());
418 for (int i = 0; i < extern_op->num_outputs(); ++i) {
419 const Buffer& placeholder = extern_op->output_placeholders[i];
420 const te::Tensor& output_tensor = extern_op.output(i);
421 info->tensor2buffers[output_tensor] = placeholder;
422 if (!info->IsArg(output_tensor)) {
423 info->root_alloc.push_back(placeholder);
424 }
425 }
426
427 // Step 3. Collect Access Region
428 Array<BufferRegion> reads, writes;
429 for (const te::Tensor& tensor : extern_op->inputs) {
430 // We have ICHECK before so it is not needed here.
431 reads.push_back(BufferRegion::FullRegion(info->tensor2buffers[tensor]));
432 }
433 for (const Buffer& buffer : extern_op->output_placeholders) {
434 writes.push_back(BufferRegion::FullRegion(buffer));
435 }
436
437 BufferSubstituter substituter(var_map, input_buffer_map);
438 Stmt body = substituter(extern_op->body);
439
440 // Step 4. Generate opaque block as body.
441 return BlockRealize(/*iter_values=*/{},
442 /*predicate=*/Bool(true),
443 /*block=*/
444 Block(/*iter_vars=*/{},
445 /*reads=*/std::move(reads),
446 /*writes=*/std::move(writes),
447 /*name_hint=*/info->FreshName(extern_op->name),
448 /*body=*/std::move(body),
449 /*init=*/NullOpt,
450 /*alloc_buffers=*/{},
451 /*match_buffers=*/{},
452 /*annotations=*/extern_op->attrs));
453}
454
455Array<te::Operation> CollectOrderedOps(const Array<te::Tensor>& arg_list) {
456 Array<te::Operation> arg_ops;
457 for (const te::Tensor& arg : arg_list) {
458 arg_ops.push_back(arg->op);
459 }
460 te::ReadGraph g = te::CreateReadGraph(arg_ops);
461 Array<te::Operation> order = te::PostDFSOrder(arg_ops, g);
462
463 for (const te::Operation& op : order) {
464 if (!(op->IsInstance<te::PlaceholderOpNode>() || op->IsInstance<te::ComputeOpNode>() ||
465 op->IsInstance<te::ExternOpNode>()))
466 LOG(FATAL) << "TypeError: Unsupported Operation: " << op->GetTypeKey() << ". "
467 << "Only te.placeholder and te.compute are allowed for now.";
468 }
469 return order;
470}
471
472void InitializeBufferBinds(const Array<te::Operation>& ordered_ops, CreateFuncInfo* info) {
473 // Process any TE operations which contain user defined buffers
474 for (const auto& op : ordered_ops) {
475 // Initialize the tensor2buffer binds map with buffers defined by the te.extern
476 if (const auto* extern_op = op.as<te::ExternOpNode>()) {
477 ICHECK_EQ(extern_op->inputs.size(), extern_op->input_placeholders.size());
478 for (size_t i = 0; i < extern_op->inputs.size(); ++i) {
479 const te::Tensor& input = extern_op->inputs[i];
480 const Buffer& buffer = extern_op->input_placeholders[i];
481 info->tensor2buffers[input] = buffer;
482 }
483 }
484 }
485}
486
487void RewriteStageToBlock(const te::Operation& op, CreateFuncInfo* info, Array<Stmt>* root_stmts,
488 arith::Analyzer* analyzer) {
489 if (const auto* placeholder = op.as<te::PlaceholderOpNode>()) {
490 // Case 1. PlaceholderOp (te.placeholder)
491 ICHECK_EQ(op->num_outputs(), 1);
492 const te::Tensor& tensor = op.output(0);
493 // Check op is in op list
494 ICHECK(info->IsArg(tensor));
495 // Declare a buffer for any argument tensors without a pre-existing
496 // buffer declaration recorded in the tensor2buffer binds map
497 if (info->tensor2buffers.count(tensor) == 0) {
498 const Buffer& buffer =
499 decl_buffer(placeholder->shape, placeholder->dtype, placeholder->name, "global");
500 info->tensor2buffers[tensor] = buffer;
501 }
502 } else if (const auto* compute_op = op.as<te::ComputeOpNode>()) {
503 // Case 2. ComputeOp (te.compute)
504 root_stmts->push_back(
505 GenerateStmtFromCompute(GetRef<te::ComputeOp>(compute_op), info, analyzer));
506 } else if (const auto extern_op = op.as<te::ExternOpNode>()) {
507 // Case 3. ExternOp (te.extern)
508 root_stmts->push_back(GenerateStmtFromExternOp(GetRef<te::ExternOp>(extern_op), info));
509 } else {
510 ICHECK(false) << "TypeError: Unsupported Operation: " << op->GetTypeKey() << ". "
511 << "Only te.placeholder and te.compute are allowed for now.";
512 }
513}
514
515PrimFunc GenerateAndCompletePrimFunc(const Array<te::Tensor>& arg_list,
516 const Array<Stmt>& root_stmts, CreateFuncInfo* info) {
517 Array<Var> parameters;
518 Map<Var, Buffer> buffer_map;
519 for (const te::Tensor& tensor : arg_list) {
520 Var arg("var_" + tensor->GetNameHint(), PrimType(DataType::Handle()));
521 parameters.push_back(arg);
522 auto it = info->tensor2buffers.find(tensor);
523 ICHECK(it != info->tensor2buffers.end());
524 buffer_map.Set(arg, it->second);
525 }
526 PrimFunc func = WithAttrs(PrimFunc(/*params=*/std::move(parameters),
527 /*body=*/SeqStmt::Flatten(root_stmts),
528 /*ret_type=*/VoidType(),
529 /*buffer_map=*/std::move(buffer_map)),
530 {{"global_symbol", String("main")}, {"tir.noalias", Bool(true)}});
531 const auto* complete = runtime::Registry::Get("script.Complete");
532 ICHECK(complete);
533 func = (*complete)(std::move(func), info->root_alloc);
534 return func;
535}
536
537PrimFunc CreatePrimFuncWithConstants(const Array<te::Tensor>& arg_list,
538 const Array<runtime::NDArray>& constants,
539 std::optional<DataType> index_dtype_override) {
540 // Information used in CreatePrimFunc and its sub-functions.
541 CreateFuncInfo info(arg_list);
542 // Root body stmts.
543 Array<Stmt> root_stmts;
544 // Analyzer
545 arith::Analyzer analyzer;
546
547 // Step 1. Create ordered array of operations and validate they are supported.
548 Array<te::Operation> order = CollectOrderedOps(arg_list);
549
550 // Step 2. Initialize buffer binds map
551 InitializeBufferBinds(order, &info);
552
553 // Step 3. Rewrite compute stages into blocks.
554 for (const te::Operation& op : order) {
555 RewriteStageToBlock(op, &info, &root_stmts, &analyzer);
556 }
557
558 // Step 4. Create func and complete prim func.
559 auto func = GenerateAndCompletePrimFunc(arg_list, root_stmts, &info);
560 func = tir::BindParams(func, constants);
561 if (index_dtype_override.has_value()) {
562 func = IndexDataTypeNormalizer(index_dtype_override.value()).Rewrite(std::move(func));
563 }
564 auto result = LayoutFreePlaceholdersNormalizer().Process(std::move(func));
565 return result;
566}
567
568PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list,
569 std::optional<DataType> index_dtype_override) {
570 return CreatePrimFuncWithConstants(arg_list, {}, index_dtype_override);
571}
572
573TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body([](TVMArgs args, TVMRetValue* ret) {
574 Array<te::Tensor> arg_list = args[0];
575 std::optional<DataType> index_dtype_override{std::nullopt};
576 // Add conversion to make std::optional compatible with FFI.
577 ICHECK_EQ(args.size(), 2);
578 if (args[1].type_code() != kTVMNullptr) {
579 index_dtype_override = args[1].operator DataType();
580 }
581 *ret = CreatePrimFunc(arg_list, index_dtype_override);
582});
583
584} // namespace tir
585} // namespace tvm
586