1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_SCRIPT_IR_BUILDER_TIR_UTILS_H_
20#define TVM_SCRIPT_IR_BUILDER_TIR_UTILS_H_
21
22#include <tvm/script/ir_builder/tir/frame.h>
23#include <tvm/script/ir_builder/tir/ir.h>
24#include <tvm/tir/op.h>
25#include <tvm/tir/stmt.h>
26
27namespace tvm {
28namespace script {
29namespace ir_builder {
30namespace tir {
31
32/*!
33 * \brief Add tir Stmt to the top frame in IRBuilder frame stack.
34 * \param stmt The Stmt.
35 */
36inline void AddToParent(tvm::tir::Stmt stmt) {
37 IRBuilder builder = IRBuilder::Current();
38 if (builder->frames.empty()) {
39 ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
40 builder->result = stmt;
41 } else if (const auto* tir_frame = builder->frames.back().as<TIRFrameNode>()) {
42 GetRef<TIRFrame>(tir_frame)->stmts.push_back(stmt);
43 } else {
44 LOG(FATAL) << "TypeError: Unsupported frame type: " << builder->frames.back();
45 }
46}
47
48/*!
49 * \brief Convert array of tir Stmt to single Stmt.
50 * \param stmt The array of Stmt.
51 * \return The SeqStmt.
52 */
53inline tvm::tir::Stmt AsStmt(const Array<tvm::tir::Stmt>& stmt) {
54 using namespace tvm::tir;
55 if (stmt.empty()) {
56 return tvm::tir::Evaluate(0);
57 } else if (stmt.size() == 1) {
58 return stmt[0];
59 } else {
60 return SeqStmt(stmt);
61 }
62}
63
64/*!
65 * \brief Check whether the top frame in IRBuilder frame stack is PrimFuncFrame.
66 * \param method The method name to be printed when throwing exception.
67 * \return The top frame of PrimFuncFrame.
68 */
69inline PrimFuncFrame FindPrimFuncFrame(const String& method) {
70 if (Optional<PrimFuncFrame> frame = IRBuilder::Current()->GetLastFrame<PrimFuncFrame>()) {
71 return frame.value();
72 }
73 LOG(FATAL) << "ValueError: PrimFunc frame not find. Please ensure '" << method
74 << "' is called under T.prim_func()";
75 throw;
76}
77
78/*!
79 * \brief Check whether the top frame in IRBuilder frame stack is BlockFrame.
80 * \param method The method name to be printed when throwing exception.
81 * \return The top frame of BlockFrame.
82 */
83inline BlockFrame FindBlockFrame(const String& method) {
84 if (Optional<BlockFrame> frame = IRBuilder::Current()->GetLastFrame<BlockFrame>()) {
85 return frame.value();
86 }
87 LOG(FATAL) << "ValueError: Block frame not find. Please ensure '" << method
88 << "' is called under T.block()";
89 throw;
90}
91
92/*!
93 * \brief Check whether the top frame in IRBuilder frame stack is IfFrame.
94 * \param method The method name to be printed when throwing exception.
95 * \return The top frame of IfFrame.
96 */
97inline IfFrame FindIfFrame(const String& method) {
98 if (Optional<IfFrame> frame = IRBuilder::Current()->GetLastFrame<IfFrame>()) {
99 return frame.value();
100 } else {
101 LOG(FATAL) << "ValueError: IfThenElse frame not find. Please ensure '" << method
102 << "' is called under T.if_()";
103 }
104 throw;
105}
106
107/*!
108 * \brief Convert BufferLoad to BufferRegion.
109 * \param buffer_load The BufferLoad.
110 * \return The converted BufferRegion.
111 */
112inline tvm::tir::BufferRegion BufferRegionFromLoad(tvm::tir::BufferLoad buffer_load) {
113 Array<Range> ranges;
114 for (const PrimExpr& index : buffer_load->indices) {
115 ranges.push_back(Range::FromMinExtent(index, IntImm(index->dtype, 1)));
116 }
117 return tvm::tir::BufferRegion(buffer_load->buffer, ranges);
118}
119
120} // namespace tir
121} // namespace ir_builder
122} // namespace script
123} // namespace tvm
124
125#endif // TVM_SCRIPT_IR_BUILDER_TIR_UTILS_H_
126