1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #ifndef TVM_SCRIPT_IR_BUILDER_TIR_UTILS_H_ |
20 | #define TVM_SCRIPT_IR_BUILDER_TIR_UTILS_H_ |
21 | |
22 | #include <tvm/script/ir_builder/tir/frame.h> |
23 | #include <tvm/script/ir_builder/tir/ir.h> |
24 | #include <tvm/tir/op.h> |
25 | #include <tvm/tir/stmt.h> |
26 | |
27 | namespace tvm { |
28 | namespace script { |
29 | namespace ir_builder { |
30 | namespace tir { |
31 | |
32 | /*! |
33 | * \brief Add tir Stmt to the top frame in IRBuilder frame stack. |
34 | * \param stmt The Stmt. |
35 | */ |
36 | inline void AddToParent(tvm::tir::Stmt stmt) { |
37 | IRBuilder builder = IRBuilder::Current(); |
38 | if (builder->frames.empty()) { |
39 | ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set" ; |
40 | builder->result = stmt; |
41 | } else if (const auto* tir_frame = builder->frames.back().as<TIRFrameNode>()) { |
42 | GetRef<TIRFrame>(tir_frame)->stmts.push_back(stmt); |
43 | } else { |
44 | LOG(FATAL) << "TypeError: Unsupported frame type: " << builder->frames.back(); |
45 | } |
46 | } |
47 | |
48 | /*! |
49 | * \brief Convert array of tir Stmt to single Stmt. |
50 | * \param stmt The array of Stmt. |
51 | * \return The SeqStmt. |
52 | */ |
53 | inline tvm::tir::Stmt AsStmt(const Array<tvm::tir::Stmt>& stmt) { |
54 | using namespace tvm::tir; |
55 | if (stmt.empty()) { |
56 | return tvm::tir::Evaluate(0); |
57 | } else if (stmt.size() == 1) { |
58 | return stmt[0]; |
59 | } else { |
60 | return SeqStmt(stmt); |
61 | } |
62 | } |
63 | |
64 | /*! |
65 | * \brief Check whether the top frame in IRBuilder frame stack is PrimFuncFrame. |
66 | * \param method The method name to be printed when throwing exception. |
67 | * \return The top frame of PrimFuncFrame. |
68 | */ |
69 | inline PrimFuncFrame FindPrimFuncFrame(const String& method) { |
70 | if (Optional<PrimFuncFrame> frame = IRBuilder::Current()->GetLastFrame<PrimFuncFrame>()) { |
71 | return frame.value(); |
72 | } |
73 | LOG(FATAL) << "ValueError: PrimFunc frame not find. Please ensure '" << method |
74 | << "' is called under T.prim_func()" ; |
75 | throw; |
76 | } |
77 | |
78 | /*! |
79 | * \brief Check whether the top frame in IRBuilder frame stack is BlockFrame. |
80 | * \param method The method name to be printed when throwing exception. |
81 | * \return The top frame of BlockFrame. |
82 | */ |
83 | inline BlockFrame FindBlockFrame(const String& method) { |
84 | if (Optional<BlockFrame> frame = IRBuilder::Current()->GetLastFrame<BlockFrame>()) { |
85 | return frame.value(); |
86 | } |
87 | LOG(FATAL) << "ValueError: Block frame not find. Please ensure '" << method |
88 | << "' is called under T.block()" ; |
89 | throw; |
90 | } |
91 | |
92 | /*! |
93 | * \brief Check whether the top frame in IRBuilder frame stack is IfFrame. |
94 | * \param method The method name to be printed when throwing exception. |
95 | * \return The top frame of IfFrame. |
96 | */ |
97 | inline IfFrame FindIfFrame(const String& method) { |
98 | if (Optional<IfFrame> frame = IRBuilder::Current()->GetLastFrame<IfFrame>()) { |
99 | return frame.value(); |
100 | } else { |
101 | LOG(FATAL) << "ValueError: IfThenElse frame not find. Please ensure '" << method |
102 | << "' is called under T.if_()" ; |
103 | } |
104 | throw; |
105 | } |
106 | |
107 | /*! |
108 | * \brief Convert BufferLoad to BufferRegion. |
109 | * \param buffer_load The BufferLoad. |
110 | * \return The converted BufferRegion. |
111 | */ |
112 | inline tvm::tir::BufferRegion BufferRegionFromLoad(tvm::tir::BufferLoad buffer_load) { |
113 | Array<Range> ranges; |
114 | for (const PrimExpr& index : buffer_load->indices) { |
115 | ranges.push_back(Range::FromMinExtent(index, IntImm(index->dtype, 1))); |
116 | } |
117 | return tvm::tir::BufferRegion(buffer_load->buffer, ranges); |
118 | } |
119 | |
120 | } // namespace tir |
121 | } // namespace ir_builder |
122 | } // namespace script |
123 | } // namespace tvm |
124 | |
125 | #endif // TVM_SCRIPT_IR_BUILDER_TIR_UTILS_H_ |
126 | |