1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/script/ir_builder/tir/frame.h>
20#include <tvm/tir/function.h>
21
22#include "../../../tir/ir/script/script_complete.h"
23#include "./utils.h"
24
25namespace tvm {
26namespace script {
27namespace ir_builder {
28namespace tir {
29
30void PrimFuncFrameNode::ExitWithScope() {
31 TIRFrameNode::ExitWithScope();
32 tvm::tir::PrimFunc func(
33 /*params=*/args,
34 /*body=*/AsStmt(stmts),
35 /*ret_type=*/ret_type.value_or(TupleType::Empty()),
36 /*buffer_map=*/buffer_map,
37 /*attrs=*/attrs.defined() ? DictAttrs(attrs.value()) : NullValue<DictAttrs>());
38 func = tvm::tir::ScriptComplete(func, root_alloc_buffers);
39 IRBuilder builder = IRBuilder::Current();
40 if (builder->frames.empty()) {
41 ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
42 builder->result = func;
43 } else if (Optional<ir::IRModuleFrame> opt_frame = builder->FindFrame<ir::IRModuleFrame>()) {
44 ir::IRModuleFrame frame = opt_frame.value();
45 frame->global_vars.push_back(GlobalVar(name.value_or("")));
46 frame->functions.push_back(func);
47 } else {
48 LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc";
49 }
50}
51
52void BlockFrameNode::ExitWithScope() {
53 TIRFrameNode::ExitWithScope();
54 Array<tvm::tir::Buffer> tir_alloc_buffers;
55 for (const tvm::tir::Buffer& buffer : alloc_buffers) {
56 tir_alloc_buffers.push_back(buffer);
57 }
58 Map<String, ObjectRef> attrs = annotations.value_or({});
59 if (int detect_access = (!reads.defined()) | (!writes.defined() << 1)) {
60 attrs.Set("tir.script_parsing_detect_access", tvm::IntImm(DataType::Int(64), detect_access));
61 }
62 tvm::tir::Block block(iter_vars, reads.value_or(Array<tvm::tir::BufferRegion>()),
63 writes.value_or(Array<tvm::tir::BufferRegion>()), name, AsStmt(stmts), init,
64 tir_alloc_buffers, match_buffers, attrs);
65 if (no_realize) {
66 CHECK(iter_values.empty())
67 << "ValueError: Block bindings are not allowed when `no_realize=True`";
68 CHECK(!predicate.defined()) << "ValueError: `T.where` is not allowed when `no_realize=True`";
69 AddToParent(block);
70 } else {
71 AddToParent(tvm::tir::BlockRealize(iter_values, predicate.value_or(Bool(true)), block));
72 }
73}
74
75void BlockInitFrameNode::EnterWithScope() {
76 BlockFrame frame = FindBlockFrame("T.init");
77 if (frame->init.defined()) {
78 LOG(FATAL) << "ValueError: Duplicate block init declaration";
79 }
80 TIRFrameNode::EnterWithScope();
81}
82
83void BlockInitFrameNode::ExitWithScope() {
84 TIRFrameNode::ExitWithScope();
85 BlockFrame frame = FindBlockFrame("T.init");
86 frame->init = AsStmt(stmts);
87}
88
89void ForFrameNode::ExitWithScope() {
90 TIRFrameNode::ExitWithScope();
91 AddToParent(this->f_make_for_loop(vars, doms, AsStmt(stmts)));
92}
93
94void AssertFrameNode::ExitWithScope() {
95 TIRFrameNode::ExitWithScope();
96 AddToParent(tvm::tir::AssertStmt(condition, message, AsStmt(stmts)));
97}
98
99void LetFrameNode::ExitWithScope() {
100 TIRFrameNode::ExitWithScope();
101 AddToParent(tvm::tir::LetStmt(var, value, AsStmt(stmts)));
102}
103
104void RealizeFrameNode::ExitWithScope() {
105 TIRFrameNode::ExitWithScope();
106 AddToParent(tvm::tir::AttrStmt(buffer_slice->buffer, "realize_scope",
107 tvm::tir::StringImm(storage_scope),
108 tvm::tir::BufferRealize(buffer_slice->buffer, buffer_slice->region,
109 condition, AsStmt(stmts))));
110}
111
112void LaunchThreadFrameNode::ExitWithScope() {
113 TIRFrameNode::ExitWithScope();
114 AddToParent(tvm::tir::AttrStmt(iter_var, attr_key, extent, AsStmt(stmts)));
115}
116
117void AllocateFrameNode::ExitWithScope() {
118 TIRFrameNode::ExitWithScope();
119 AddToParent(
120 tvm::tir::Allocate(buffer_var, dtype, extents, condition, AsStmt(stmts), annotations));
121}
122
123void AllocateConstFrameNode::ExitWithScope() {
124 TIRFrameNode::ExitWithScope();
125 AddToParent(
126 tvm::tir::AllocateConst(buffer_var, dtype, extents, data, AsStmt(stmts), annotations));
127}
128void AttrFrameNode::ExitWithScope() {
129 TIRFrameNode::ExitWithScope();
130 AddToParent(tvm::tir::AttrStmt(node, attr_key, value, AsStmt(stmts)));
131}
132
133void WhileFrameNode::ExitWithScope() {
134 TIRFrameNode::ExitWithScope();
135 AddToParent(tvm::tir::While(condition, AsStmt(stmts)));
136}
137
138void IfFrameNode::ExitWithScope() {
139 TIRFrameNode::ExitWithScope();
140 if (!stmts.empty()) {
141 LOG(FATAL) << "stmt within IfThenElse frame should be either in ThenFrame or ElseFrame";
142 }
143 if (!then_stmts.defined()) {
144 LOG(FATAL) << "IfThenElse frame should have at least one then branch";
145 }
146 AddToParent(tvm::tir::IfThenElse(
147 condition, AsStmt(then_stmts.value()),
148 else_stmts.defined() ? AsStmt(else_stmts.value()) : tvm::tir::Stmt(nullptr)));
149}
150
151void ThenFrameNode::EnterWithScope() {
152 IfFrame frame = FindIfFrame("T.then_");
153 if (frame->then_stmts.defined()) {
154 LOG(FATAL) << "ValueError: Duplicate then branch declaration, previous one is "
155 << frame->then_stmts.value();
156 }
157 TIRFrameNode::EnterWithScope();
158}
159
160void ThenFrameNode::ExitWithScope() {
161 TIRFrameNode::ExitWithScope();
162 FindIfFrame("T.then_")->then_stmts = stmts;
163}
164
165void ElseFrameNode::EnterWithScope() {
166 IfFrame frame = FindIfFrame("T.else_");
167 if (!frame->then_stmts.defined()) {
168 LOG(FATAL) << "The else branch should follow then branch";
169 }
170 if (frame->else_stmts.defined()) {
171 LOG(FATAL) << "ValueError: Duplicate else branch declaration, previous one is "
172 << frame->else_stmts.value();
173 }
174 TIRFrameNode::EnterWithScope();
175}
176
177void ElseFrameNode::ExitWithScope() {
178 TIRFrameNode::ExitWithScope();
179 FindIfFrame("T.else_")->else_stmts = stmts;
180}
181
182void DeclBufferFrameNode::ExitWithScope() {
183 TIRFrameNode::ExitWithScope();
184 if (allocated) {
185 AddToParent(tvm::tir::DeclBuffer(buffer, AsStmt(stmts)));
186 } else {
187 AddToParent(tvm::tir::Allocate(buffer->data, buffer->dtype, buffer->shape,
188 tvm::IntImm(DataType::Bool(), 1),
189 tvm::tir::DeclBuffer(buffer, AsStmt(stmts))));
190 }
191}
192
193TVM_REGISTER_NODE_TYPE(TIRFrameNode);
194TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode);
195TVM_REGISTER_NODE_TYPE(BlockFrameNode);
196TVM_REGISTER_NODE_TYPE(BlockInitFrameNode);
197TVM_REGISTER_NODE_TYPE(ForFrameNode);
198TVM_REGISTER_NODE_TYPE(AssertFrameNode);
199TVM_REGISTER_NODE_TYPE(LetFrameNode);
200TVM_REGISTER_NODE_TYPE(RealizeFrameNode);
201TVM_REGISTER_NODE_TYPE(LaunchThreadFrameNode);
202TVM_REGISTER_NODE_TYPE(AllocateFrameNode);
203TVM_REGISTER_NODE_TYPE(AllocateConstFrameNode);
204TVM_REGISTER_NODE_TYPE(AttrFrameNode);
205TVM_REGISTER_NODE_TYPE(WhileFrameNode);
206TVM_REGISTER_NODE_TYPE(IfFrameNode);
207TVM_REGISTER_NODE_TYPE(ThenFrameNode);
208TVM_REGISTER_NODE_TYPE(ElseFrameNode);
209TVM_REGISTER_NODE_TYPE(DeclBufferFrameNode);
210
211} // namespace tir
212} // namespace ir_builder
213} // namespace script
214} // namespace tvm
215