1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file codegen_stack_vm.h |
22 | * \brief Codegen into Simple Stack VM. |
23 | */ |
24 | #ifndef TVM_TARGET_STACKVM_CODEGEN_STACKVM_H_ |
25 | #define TVM_TARGET_STACKVM_CODEGEN_STACKVM_H_ |
26 | |
27 | #include <tvm/target/codegen.h> |
28 | #include <tvm/tir/expr.h> |
29 | #include <tvm/tir/function.h> |
30 | #include <tvm/tir/op.h> |
31 | #include <tvm/tir/stmt_functor.h> |
32 | |
33 | #include <string> |
34 | #include <unordered_map> |
35 | #include <vector> |
36 | |
37 | #include "../../runtime/stackvm/stackvm.h" |
38 | |
39 | namespace tvm { |
40 | namespace codegen { |
41 | |
42 | using namespace tir; |
43 | using runtime::StackVM; |
44 | |
45 | /*! |
46 | * \brief A base class to generate a stack VM. |
47 | * This module is used to generate host wrapper |
48 | * into device function when only device JIT is available. |
49 | */ |
50 | class CodeGenStackVM : public ExprFunctor<void(const PrimExpr&)>, |
51 | public StmtFunctor<void(const Stmt&)> { |
52 | public: |
53 | /*! |
54 | * \brief Generate a stack VM representing |
55 | * \param f The function to be compiled |
56 | * \param device_funcs The extern device functions to be linked. |
57 | * \note Only call compile once, |
58 | * create a new codegen object each time. |
59 | */ |
60 | StackVM Compile(const PrimFunc& f); |
61 | /*! \brief Push stmt to generate new code */ |
62 | void Push(const Stmt& n); |
63 | /*! \brief Push expr to generate new code */ |
64 | void Push(const PrimExpr& n) { VisitExpr(n); } |
65 | /*! |
66 | * \brief Push the opcode to the code. |
67 | * \param opcode The code to be pushed. |
68 | */ |
69 | void PushOp(StackVM::OpCode opcode); |
70 | /*! |
71 | * \brief Push the opcode and operand to the code. |
72 | * \param opcode The opcode. |
73 | * \param operand The operand to be pushed. |
74 | * \return operand_index, indicating location of operand |
75 | */ |
76 | int64_t PushOp(StackVM::OpCode opcode, int operand); |
77 | /*! |
78 | * \brief Set the relative jump offset to be offset. |
79 | * \param operand_index The indexed returned by PushOp. |
80 | * \param operand The operand to be set. |
81 | */ |
82 | void SetOperand(int64_t operand_index, int64_t operand); |
83 | /*! \return The current program pointer */ |
84 | int64_t GetPC() const { return static_cast<int64_t>(vm_.code.size()); } |
85 | /*! |
86 | * \brief Get string id in vm |
87 | * \param key The string to get id. |
88 | * \return the id of the string. |
89 | */ |
90 | int GetStrID(const std::string& key); |
91 | /*! |
92 | * \brief Allocate a variable name for a newly defined var. |
93 | * \param v The variable. |
94 | * \return the heap index of the var. |
95 | */ |
96 | int AllocVarID(const VarNode* v); |
97 | /*! |
98 | * \brief Get a variable name. |
99 | * \param v The variable. |
100 | * \return the heap index of the var. |
101 | */ |
102 | int GetVarID(const VarNode* v) const; |
103 | // Push binary operator |
104 | void PushBinary(StackVM::OpCode op_int64, const PrimExpr& a, const PrimExpr& b); |
105 | // push cast; |
106 | void PushCast(DataType dst, DataType src); |
107 | // overloadable functions |
108 | // expression |
109 | void VisitExpr_(const VarNode* op) final; |
110 | void VisitExpr_(const LoadNode* op) final; |
111 | void VisitExpr_(const BufferLoadNode* op) final; |
112 | void VisitExpr_(const LetNode* op) final; |
113 | void VisitExpr_(const CallNode* op) final; |
114 | void VisitExpr_(const AddNode* op) final; |
115 | void VisitExpr_(const SubNode* op) final; |
116 | void VisitExpr_(const MulNode* op) final; |
117 | void VisitExpr_(const DivNode* op) final; |
118 | void VisitExpr_(const ModNode* op) final; |
119 | void VisitExpr_(const MinNode* op) final; |
120 | void VisitExpr_(const MaxNode* op) final; |
121 | void VisitExpr_(const EQNode* op) final; |
122 | void VisitExpr_(const NENode* op) final; |
123 | void VisitExpr_(const LTNode* op) final; |
124 | void VisitExpr_(const LENode* op) final; |
125 | void VisitExpr_(const GTNode* op) final; |
126 | void VisitExpr_(const GENode* op) final; |
127 | void VisitExpr_(const AndNode* op) final; |
128 | void VisitExpr_(const OrNode* op) final; |
129 | void VisitExpr_(const CastNode* op) final; |
130 | void VisitExpr_(const NotNode* op) final; |
131 | void VisitExpr_(const SelectNode* op) final; |
132 | void VisitExpr_(const RampNode* op) final; |
133 | void VisitExpr_(const BroadcastNode* op) final; |
134 | void VisitExpr_(const IntImmNode* op) final; |
135 | void VisitExpr_(const FloatImmNode* op) final; |
136 | void VisitExpr_(const StringImmNode* op) final; |
137 | // statment |
138 | void VisitStmt_(const LetStmtNode* op) final; |
139 | void VisitStmt_(const StoreNode* op) final; |
140 | void VisitStmt_(const BufferStoreNode* op) final; |
141 | void VisitStmt_(const ForNode* op) final; |
142 | void VisitStmt_(const IfThenElseNode* op) final; |
143 | void VisitStmt_(const AllocateNode* op) final; |
144 | void VisitStmt_(const AttrStmtNode* op) final; |
145 | void VisitStmt_(const AssertStmtNode* op) final; |
146 | void VisitStmt_(const EvaluateNode* op) final; |
147 | void VisitStmt_(const SeqStmtNode* op) final; |
148 | |
149 | private: |
150 | bool debug_{false}; |
151 | /*! \brief The vm to be generated */ |
152 | StackVM vm_; |
153 | /*! \brief id of each variable */ |
154 | std::unordered_map<const VarNode*, int> var_idmap_; |
155 | /*! \brief id of each string */ |
156 | std::unordered_map<std::string, int> str_idmap_; |
157 | /*! \brief id of each global function */ |
158 | std::unordered_map<std::string, int> extern_fun_idmap_; |
159 | |
160 | Op backend_alloc_workspace_op_ = Op::Get("tir.TVMBackendAllocWorkspace" ); |
161 | Op backend_free_workspace_op_ = Op::Get("tir.TVMBackendFreeWorkspace" ); |
162 | }; |
163 | |
164 | } // namespace codegen |
165 | } // namespace tvm |
166 | #endif // TVM_TARGET_STACKVM_CODEGEN_STACKVM_H_ |
167 | |