1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file codegen_stack_vm.h
22 * \brief Codegen into Simple Stack VM.
23 */
24#ifndef TVM_TARGET_STACKVM_CODEGEN_STACKVM_H_
25#define TVM_TARGET_STACKVM_CODEGEN_STACKVM_H_
26
27#include <tvm/target/codegen.h>
28#include <tvm/tir/expr.h>
29#include <tvm/tir/function.h>
30#include <tvm/tir/op.h>
31#include <tvm/tir/stmt_functor.h>
32
33#include <string>
34#include <unordered_map>
35#include <vector>
36
37#include "../../runtime/stackvm/stackvm.h"
38
39namespace tvm {
40namespace codegen {
41
42using namespace tir;
43using runtime::StackVM;
44
45/*!
46 * \brief A base class to generate a stack VM.
47 * This module is used to generate host wrapper
48 * into device function when only device JIT is available.
49 */
50class CodeGenStackVM : public ExprFunctor<void(const PrimExpr&)>,
51 public StmtFunctor<void(const Stmt&)> {
52 public:
53 /*!
54 * \brief Generate a stack VM representing
55 * \param f The function to be compiled
56 * \param device_funcs The extern device functions to be linked.
57 * \note Only call compile once,
58 * create a new codegen object each time.
59 */
60 StackVM Compile(const PrimFunc& f);
61 /*! \brief Push stmt to generate new code */
62 void Push(const Stmt& n);
63 /*! \brief Push expr to generate new code */
64 void Push(const PrimExpr& n) { VisitExpr(n); }
65 /*!
66 * \brief Push the opcode to the code.
67 * \param opcode The code to be pushed.
68 */
69 void PushOp(StackVM::OpCode opcode);
70 /*!
71 * \brief Push the opcode and operand to the code.
72 * \param opcode The opcode.
73 * \param operand The operand to be pushed.
74 * \return operand_index, indicating location of operand
75 */
76 int64_t PushOp(StackVM::OpCode opcode, int operand);
77 /*!
78 * \brief Set the relative jump offset to be offset.
79 * \param operand_index The indexed returned by PushOp.
80 * \param operand The operand to be set.
81 */
82 void SetOperand(int64_t operand_index, int64_t operand);
83 /*! \return The current program pointer */
84 int64_t GetPC() const { return static_cast<int64_t>(vm_.code.size()); }
85 /*!
86 * \brief Get string id in vm
87 * \param key The string to get id.
88 * \return the id of the string.
89 */
90 int GetStrID(const std::string& key);
91 /*!
92 * \brief Allocate a variable name for a newly defined var.
93 * \param v The variable.
94 * \return the heap index of the var.
95 */
96 int AllocVarID(const VarNode* v);
97 /*!
98 * \brief Get a variable name.
99 * \param v The variable.
100 * \return the heap index of the var.
101 */
102 int GetVarID(const VarNode* v) const;
103 // Push binary operator
104 void PushBinary(StackVM::OpCode op_int64, const PrimExpr& a, const PrimExpr& b);
105 // push cast;
106 void PushCast(DataType dst, DataType src);
107 // overloadable functions
108 // expression
109 void VisitExpr_(const VarNode* op) final;
110 void VisitExpr_(const LoadNode* op) final;
111 void VisitExpr_(const BufferLoadNode* op) final;
112 void VisitExpr_(const LetNode* op) final;
113 void VisitExpr_(const CallNode* op) final;
114 void VisitExpr_(const AddNode* op) final;
115 void VisitExpr_(const SubNode* op) final;
116 void VisitExpr_(const MulNode* op) final;
117 void VisitExpr_(const DivNode* op) final;
118 void VisitExpr_(const ModNode* op) final;
119 void VisitExpr_(const MinNode* op) final;
120 void VisitExpr_(const MaxNode* op) final;
121 void VisitExpr_(const EQNode* op) final;
122 void VisitExpr_(const NENode* op) final;
123 void VisitExpr_(const LTNode* op) final;
124 void VisitExpr_(const LENode* op) final;
125 void VisitExpr_(const GTNode* op) final;
126 void VisitExpr_(const GENode* op) final;
127 void VisitExpr_(const AndNode* op) final;
128 void VisitExpr_(const OrNode* op) final;
129 void VisitExpr_(const CastNode* op) final;
130 void VisitExpr_(const NotNode* op) final;
131 void VisitExpr_(const SelectNode* op) final;
132 void VisitExpr_(const RampNode* op) final;
133 void VisitExpr_(const BroadcastNode* op) final;
134 void VisitExpr_(const IntImmNode* op) final;
135 void VisitExpr_(const FloatImmNode* op) final;
136 void VisitExpr_(const StringImmNode* op) final;
137 // statment
138 void VisitStmt_(const LetStmtNode* op) final;
139 void VisitStmt_(const StoreNode* op) final;
140 void VisitStmt_(const BufferStoreNode* op) final;
141 void VisitStmt_(const ForNode* op) final;
142 void VisitStmt_(const IfThenElseNode* op) final;
143 void VisitStmt_(const AllocateNode* op) final;
144 void VisitStmt_(const AttrStmtNode* op) final;
145 void VisitStmt_(const AssertStmtNode* op) final;
146 void VisitStmt_(const EvaluateNode* op) final;
147 void VisitStmt_(const SeqStmtNode* op) final;
148
149 private:
150 bool debug_{false};
151 /*! \brief The vm to be generated */
152 StackVM vm_;
153 /*! \brief id of each variable */
154 std::unordered_map<const VarNode*, int> var_idmap_;
155 /*! \brief id of each string */
156 std::unordered_map<std::string, int> str_idmap_;
157 /*! \brief id of each global function */
158 std::unordered_map<std::string, int> extern_fun_idmap_;
159
160 Op backend_alloc_workspace_op_ = Op::Get("tir.TVMBackendAllocWorkspace");
161 Op backend_free_workspace_op_ = Op::Get("tir.TVMBackendFreeWorkspace");
162};
163
164} // namespace codegen
165} // namespace tvm
166#endif // TVM_TARGET_STACKVM_CODEGEN_STACKVM_H_
167