1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tir/ir/script/script_complete.cc
22 * \brief Used by TVM Script parser to expand incomplete TIR input
23 */
24
25#include "./script_complete.h"
26
27#include <tvm/arith/int_set.h>
28#include <tvm/tir/analysis.h>
29
30#include <utility>
31
32namespace tvm {
33namespace tir {
34
35/*! \brief Generate surrounding loops automatically */
36class ScriptCompleter : public StmtMutator {
37 public:
38 explicit ScriptCompleter(Map<Var, Buffer>* buffer_var_map) : buffer_var_map_(buffer_var_map) {}
39 /*! \brief Whether the stmt contains at least one block. */
40 bool contains_block = false;
41
42 private:
43 Map<Var, Buffer>* buffer_var_map_;
44 Stmt VisitStmt_(const BlockRealizeNode* op) override {
45 contains_block = true;
46 for (const PrimExpr& value : op->iter_values) {
47 CHECK(value.dtype().is_int())
48 << "BlockRealize iter_value expected a IntImm, but got " << value.dtype();
49 }
50 return StmtMutator::VisitStmt_(op);
51 }
52
53 Stmt VisitStmt_(const BlockNode* op) override {
54 // Buffers allocated in the block can be accessed by its body.
55 for (const auto& alloc_buffer : op->alloc_buffers) {
56 buffer_var_map_->Set(alloc_buffer->data, alloc_buffer);
57 }
58 for (const auto& match_buffer : op->match_buffers) {
59 const Buffer& target_buffer = match_buffer->buffer;
60 buffer_var_map_->Set(target_buffer->data, target_buffer);
61 }
62 Block block = Downcast<Block>(StmtMutator::VisitStmt_(op));
63 // Remove buffers allocated inside block to detect its access region
64 for (const auto& alloc_buffer : op->alloc_buffers) {
65 buffer_var_map_->erase(alloc_buffer->data);
66 }
67 for (const auto& match_buffer : op->match_buffers) {
68 const Buffer& target_buffer = match_buffer->buffer;
69 buffer_var_map_->erase(target_buffer->data);
70 }
71 // Get access detection mask
72 // 0 for provided region, 1 and 3 for need detect read, 2 and 3 for need detect write
73 int mask = 0;
74 auto it = op->annotations.find(attr::script_parsing_detect_access);
75 if (it != op->annotations.end()) {
76 mask = Downcast<IntImm>((*it).second)->value;
77 }
78 // ignore root block or blocks which already has reads/writes regions
79 if (mask != 0) {
80 auto access_region = GetBlockAccessRegion(block, *buffer_var_map_);
81 const Array<BufferRegion>& reads = access_region[0];
82 const Array<BufferRegion>& writes = access_region[1];
83 const Array<BufferRegion>& opaque = access_region[2];
84 CHECK(opaque.empty())
85 << "ValueError: Can not auto detect buffer access region from tir.Load, tir.Store or "
86 "direct access by buffer data. Please annotation the access region manually";
87 auto n = CopyOnWrite(block.operator->());
88 if (mask & 1) n->reads = reads;
89 if (mask & 2) n->writes = writes;
90 n->annotations = op->annotations;
91 n->annotations.erase(attr::script_parsing_detect_access);
92 return Block(n);
93 } else {
94 return std::move(block);
95 }
96 }
97};
98
99PrimFunc ScriptComplete(PrimFunc func, const Array<Buffer>& root_allocates) {
100 Map<Var, Buffer> buffer_var_map;
101 for (const auto& pair : func->buffer_map) {
102 const Buffer& buffer = pair.second;
103 buffer_var_map.Set(buffer->data, buffer);
104 }
105 for (const auto& alloc : root_allocates) {
106 buffer_var_map.Set(alloc->data, alloc);
107 }
108
109 Stmt res = func->body;
110
111 // Generate root block automatically. This is done before
112 // ScriptCompleter, in order to fill the root block's T.reads() and
113 // T.writes() annotations, as if it had been explicitly written.
114 bool should_insert_root = [&]() -> bool {
115 if (root_allocates.size()) {
116 return true;
117 }
118 auto* block_realize = func->body.as<BlockRealizeNode>();
119 if (block_realize && block_realize->block->iter_vars.size()) {
120 return true;
121 }
122 if (!block_realize && ContainsNode<BlockRealizeNode>(func->body)) {
123 return true;
124 }
125 return false;
126 }();
127
128 if (should_insert_root) {
129 Block root_block({}, {}, {}, "root", std::move(res), NullOpt, root_allocates);
130 res = BlockRealize({}, Bool(true), std::move(root_block));
131 }
132
133 // generate surrounding loops automatically
134 ScriptCompleter script_completer(&buffer_var_map);
135 res = script_completer(std::move(res));
136
137 if (func->body.same_as(res)) {
138 return func;
139 } else {
140 auto fptr = func.CopyOnWrite();
141 fptr->body = res;
142 return func;
143 }
144}
145
146TVM_REGISTER_GLOBAL("script.Complete").set_body_typed(ScriptComplete);
147
148} // namespace tir
149} // namespace tvm
150