1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/script/ir_builder/tir/frame.h> |
20 | #include <tvm/tir/function.h> |
21 | |
22 | #include "../../../tir/ir/script/script_complete.h" |
23 | #include "./utils.h" |
24 | |
25 | namespace tvm { |
26 | namespace script { |
27 | namespace ir_builder { |
28 | namespace tir { |
29 | |
30 | void PrimFuncFrameNode::ExitWithScope() { |
31 | TIRFrameNode::ExitWithScope(); |
32 | tvm::tir::PrimFunc func( |
33 | /*params=*/args, |
34 | /*body=*/AsStmt(stmts), |
35 | /*ret_type=*/ret_type.value_or(TupleType::Empty()), |
36 | /*buffer_map=*/buffer_map, |
37 | /*attrs=*/attrs.defined() ? DictAttrs(attrs.value()) : NullValue<DictAttrs>()); |
38 | func = tvm::tir::ScriptComplete(func, root_alloc_buffers); |
39 | IRBuilder builder = IRBuilder::Current(); |
40 | if (builder->frames.empty()) { |
41 | ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set" ; |
42 | builder->result = func; |
43 | } else if (Optional<ir::IRModuleFrame> opt_frame = builder->FindFrame<ir::IRModuleFrame>()) { |
44 | ir::IRModuleFrame frame = opt_frame.value(); |
45 | frame->global_vars.push_back(GlobalVar(name.value_or("" ))); |
46 | frame->functions.push_back(func); |
47 | } else { |
48 | LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc" ; |
49 | } |
50 | } |
51 | |
52 | void BlockFrameNode::ExitWithScope() { |
53 | TIRFrameNode::ExitWithScope(); |
54 | Array<tvm::tir::Buffer> tir_alloc_buffers; |
55 | for (const tvm::tir::Buffer& buffer : alloc_buffers) { |
56 | tir_alloc_buffers.push_back(buffer); |
57 | } |
58 | Map<String, ObjectRef> attrs = annotations.value_or({}); |
59 | if (int detect_access = (!reads.defined()) | (!writes.defined() << 1)) { |
60 | attrs.Set("tir.script_parsing_detect_access" , tvm::IntImm(DataType::Int(64), detect_access)); |
61 | } |
62 | tvm::tir::Block block(iter_vars, reads.value_or(Array<tvm::tir::BufferRegion>()), |
63 | writes.value_or(Array<tvm::tir::BufferRegion>()), name, AsStmt(stmts), init, |
64 | tir_alloc_buffers, match_buffers, attrs); |
65 | if (no_realize) { |
66 | CHECK(iter_values.empty()) |
67 | << "ValueError: Block bindings are not allowed when `no_realize=True`" ; |
68 | CHECK(!predicate.defined()) << "ValueError: `T.where` is not allowed when `no_realize=True`" ; |
69 | AddToParent(block); |
70 | } else { |
71 | AddToParent(tvm::tir::BlockRealize(iter_values, predicate.value_or(Bool(true)), block)); |
72 | } |
73 | } |
74 | |
75 | void BlockInitFrameNode::EnterWithScope() { |
76 | BlockFrame frame = FindBlockFrame("T.init" ); |
77 | if (frame->init.defined()) { |
78 | LOG(FATAL) << "ValueError: Duplicate block init declaration" ; |
79 | } |
80 | TIRFrameNode::EnterWithScope(); |
81 | } |
82 | |
83 | void BlockInitFrameNode::ExitWithScope() { |
84 | TIRFrameNode::ExitWithScope(); |
85 | BlockFrame frame = FindBlockFrame("T.init" ); |
86 | frame->init = AsStmt(stmts); |
87 | } |
88 | |
89 | void ForFrameNode::ExitWithScope() { |
90 | TIRFrameNode::ExitWithScope(); |
91 | AddToParent(this->f_make_for_loop(vars, doms, AsStmt(stmts))); |
92 | } |
93 | |
94 | void AssertFrameNode::ExitWithScope() { |
95 | TIRFrameNode::ExitWithScope(); |
96 | AddToParent(tvm::tir::AssertStmt(condition, message, AsStmt(stmts))); |
97 | } |
98 | |
99 | void LetFrameNode::ExitWithScope() { |
100 | TIRFrameNode::ExitWithScope(); |
101 | AddToParent(tvm::tir::LetStmt(var, value, AsStmt(stmts))); |
102 | } |
103 | |
104 | void RealizeFrameNode::ExitWithScope() { |
105 | TIRFrameNode::ExitWithScope(); |
106 | AddToParent(tvm::tir::AttrStmt(buffer_slice->buffer, "realize_scope" , |
107 | tvm::tir::StringImm(storage_scope), |
108 | tvm::tir::BufferRealize(buffer_slice->buffer, buffer_slice->region, |
109 | condition, AsStmt(stmts)))); |
110 | } |
111 | |
112 | void LaunchThreadFrameNode::ExitWithScope() { |
113 | TIRFrameNode::ExitWithScope(); |
114 | AddToParent(tvm::tir::AttrStmt(iter_var, attr_key, extent, AsStmt(stmts))); |
115 | } |
116 | |
117 | void AllocateFrameNode::ExitWithScope() { |
118 | TIRFrameNode::ExitWithScope(); |
119 | AddToParent( |
120 | tvm::tir::Allocate(buffer_var, dtype, extents, condition, AsStmt(stmts), annotations)); |
121 | } |
122 | |
123 | void AllocateConstFrameNode::ExitWithScope() { |
124 | TIRFrameNode::ExitWithScope(); |
125 | AddToParent( |
126 | tvm::tir::AllocateConst(buffer_var, dtype, extents, data, AsStmt(stmts), annotations)); |
127 | } |
128 | void AttrFrameNode::ExitWithScope() { |
129 | TIRFrameNode::ExitWithScope(); |
130 | AddToParent(tvm::tir::AttrStmt(node, attr_key, value, AsStmt(stmts))); |
131 | } |
132 | |
133 | void WhileFrameNode::ExitWithScope() { |
134 | TIRFrameNode::ExitWithScope(); |
135 | AddToParent(tvm::tir::While(condition, AsStmt(stmts))); |
136 | } |
137 | |
138 | void IfFrameNode::ExitWithScope() { |
139 | TIRFrameNode::ExitWithScope(); |
140 | if (!stmts.empty()) { |
141 | LOG(FATAL) << "stmt within IfThenElse frame should be either in ThenFrame or ElseFrame" ; |
142 | } |
143 | if (!then_stmts.defined()) { |
144 | LOG(FATAL) << "IfThenElse frame should have at least one then branch" ; |
145 | } |
146 | AddToParent(tvm::tir::IfThenElse( |
147 | condition, AsStmt(then_stmts.value()), |
148 | else_stmts.defined() ? AsStmt(else_stmts.value()) : tvm::tir::Stmt(nullptr))); |
149 | } |
150 | |
151 | void ThenFrameNode::EnterWithScope() { |
152 | IfFrame frame = FindIfFrame("T.then_" ); |
153 | if (frame->then_stmts.defined()) { |
154 | LOG(FATAL) << "ValueError: Duplicate then branch declaration, previous one is " |
155 | << frame->then_stmts.value(); |
156 | } |
157 | TIRFrameNode::EnterWithScope(); |
158 | } |
159 | |
160 | void ThenFrameNode::ExitWithScope() { |
161 | TIRFrameNode::ExitWithScope(); |
162 | FindIfFrame("T.then_" )->then_stmts = stmts; |
163 | } |
164 | |
165 | void ElseFrameNode::EnterWithScope() { |
166 | IfFrame frame = FindIfFrame("T.else_" ); |
167 | if (!frame->then_stmts.defined()) { |
168 | LOG(FATAL) << "The else branch should follow then branch" ; |
169 | } |
170 | if (frame->else_stmts.defined()) { |
171 | LOG(FATAL) << "ValueError: Duplicate else branch declaration, previous one is " |
172 | << frame->else_stmts.value(); |
173 | } |
174 | TIRFrameNode::EnterWithScope(); |
175 | } |
176 | |
177 | void ElseFrameNode::ExitWithScope() { |
178 | TIRFrameNode::ExitWithScope(); |
179 | FindIfFrame("T.else_" )->else_stmts = stmts; |
180 | } |
181 | |
182 | void DeclBufferFrameNode::ExitWithScope() { |
183 | TIRFrameNode::ExitWithScope(); |
184 | if (allocated) { |
185 | AddToParent(tvm::tir::DeclBuffer(buffer, AsStmt(stmts))); |
186 | } else { |
187 | AddToParent(tvm::tir::Allocate(buffer->data, buffer->dtype, buffer->shape, |
188 | tvm::IntImm(DataType::Bool(), 1), |
189 | tvm::tir::DeclBuffer(buffer, AsStmt(stmts)))); |
190 | } |
191 | } |
192 | |
193 | TVM_REGISTER_NODE_TYPE(TIRFrameNode); |
194 | TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode); |
195 | TVM_REGISTER_NODE_TYPE(BlockFrameNode); |
196 | TVM_REGISTER_NODE_TYPE(BlockInitFrameNode); |
197 | TVM_REGISTER_NODE_TYPE(ForFrameNode); |
198 | TVM_REGISTER_NODE_TYPE(AssertFrameNode); |
199 | TVM_REGISTER_NODE_TYPE(LetFrameNode); |
200 | TVM_REGISTER_NODE_TYPE(RealizeFrameNode); |
201 | TVM_REGISTER_NODE_TYPE(LaunchThreadFrameNode); |
202 | TVM_REGISTER_NODE_TYPE(AllocateFrameNode); |
203 | TVM_REGISTER_NODE_TYPE(AllocateConstFrameNode); |
204 | TVM_REGISTER_NODE_TYPE(AttrFrameNode); |
205 | TVM_REGISTER_NODE_TYPE(WhileFrameNode); |
206 | TVM_REGISTER_NODE_TYPE(IfFrameNode); |
207 | TVM_REGISTER_NODE_TYPE(ThenFrameNode); |
208 | TVM_REGISTER_NODE_TYPE(ElseFrameNode); |
209 | TVM_REGISTER_NODE_TYPE(DeclBufferFrameNode); |
210 | |
211 | } // namespace tir |
212 | } // namespace ir_builder |
213 | } // namespace script |
214 | } // namespace tvm |
215 | |