1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file codegen_hybrid.h |
22 | * \brief Common utilities to generated C style code. |
23 | */ |
24 | #ifndef TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_ |
25 | #define TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_ |
26 | |
27 | #include <tvm/ir/name_supply.h> |
28 | #include <tvm/target/codegen.h> |
29 | #include <tvm/te/operation.h> |
30 | #include <tvm/te/schedule.h> |
31 | #include <tvm/tir/expr.h> |
32 | #include <tvm/tir/stmt_functor.h> |
33 | |
34 | #include <map> |
35 | #include <string> |
36 | #include <unordered_map> |
37 | #include <utility> |
38 | #include <vector> |
39 | |
40 | namespace tvm { |
41 | namespace contrib { |
42 | |
43 | using namespace te; |
44 | using namespace tir; |
45 | /*! |
46 | * \brief A base class to generate Hybrid Script. |
47 | * |
48 | * **NOTE** CodeGenHybrid does not aim at generating Python scripts consumed by Python2/3. |
49 | * For runtime support, please refer the decorator in ``tvm/python/hybrid/api.py``. |
50 | */ |
51 | class CodeGenHybrid : public ExprFunctor<void(const PrimExpr&, std::ostream&)>, |
52 | public StmtFunctor<void(const Stmt&)> { |
53 | public: |
54 | /*! |
55 | * \brief Dump the given function body to hybrid script. |
56 | * \param stmt The function body to be dumped to hybrid script. |
57 | * \param inputs Input tensors of this schedule. |
58 | * \param outputs Output tensors of this schedule. |
59 | * \param name The name of the function. |
60 | */ |
61 | void DumpStmt(const Stmt& stmt, const Array<ObjectRef>& inputs, const Array<Tensor>& outputs, |
62 | const std::string& name = "hybrid_func" ); |
63 | /*! |
64 | * \brief Finalize the compilation and return the code. |
65 | * \return The code. |
66 | */ |
67 | std::string Finish(); |
68 | /*! \brief Reserve keywords in avoid of name conflict. */ |
69 | void ReserveKeywords(); |
70 | /*! |
71 | * \brief Print the Stmt n to CodeGenHybrid->stream |
72 | * \param n The statement to be printed. |
73 | */ |
74 | void PrintStmt(const Stmt& n) { this->VisitStmt(n); } |
75 | /*! |
76 | * \brief Print the expression n(or its ssa id if in ssa mode) into os |
77 | * \param n The expression to be printed. |
78 | * \param os The output stream |
79 | */ |
80 | void PrintExpr(const PrimExpr& n, std::ostream& os) { this->VisitExpr(n, os); } |
81 | /*! |
82 | * \brief Same as PrintExpr, but simply returns result string |
83 | * \param n The expression to be printed. |
84 | */ |
85 | std::string PrintExpr(const PrimExpr& n) { |
86 | std::ostringstream os; |
87 | PrintExpr(n, os); |
88 | return os.str(); |
89 | } |
90 | // expression |
91 | void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*) |
92 | void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*) |
93 | void VisitExpr_(const BufferLoadNode* op, std::ostream& os) override; // NOLINT(*) |
94 | void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*) |
95 | void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*) |
96 | void VisitExpr_(const ProducerLoadNode* op, std::ostream& os) override; // NOLINT(*) |
97 | void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*) |
98 | void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*) |
99 | void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*) |
100 | void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*) |
101 | void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*) |
102 | void VisitExpr_(const FloorDivNode* op, std::ostream& os) override; // NOLINT(*) |
103 | void VisitExpr_(const FloorModNode* op, std::ostream& os) override; // NOLINT(*) |
104 | void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*) |
105 | void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*) |
106 | void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*) |
107 | void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*) |
108 | void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*) |
109 | void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*) |
110 | void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*) |
111 | void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*) |
112 | void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*) |
113 | void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*) |
114 | void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*) |
115 | void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*) |
116 | void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*) |
117 | void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*) |
118 | void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*) |
119 | void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) |
120 | void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) |
121 | void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*) |
122 | // statment |
123 | void VisitStmt_(const LetStmtNode* op) override; |
124 | void VisitStmt_(const StoreNode* op) override; |
125 | void VisitStmt_(const BufferStoreNode* op) override; |
126 | void VisitStmt_(const ProducerStoreNode* op) override; |
127 | void VisitStmt_(const ForNode* op) override; |
128 | void VisitStmt_(const IfThenElseNode* op) override; |
129 | void VisitStmt_(const AllocateNode* op) override; |
130 | void VisitStmt_(const ProducerRealizeNode* op) override; |
131 | void VisitStmt_(const AttrStmtNode* op) override; |
132 | void VisitStmt_(const AssertStmtNode* op) override; |
133 | void VisitStmt_(const EvaluateNode* op) override; |
134 | void VisitStmt_(const SeqStmtNode* op) override; |
135 | /*! |
136 | * \brief Print Type represetnation of type t. |
137 | * \param t The type representation. |
138 | * \param os The stream to print the ctype into |
139 | */ |
140 | virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*) |
141 | |
142 | private: |
143 | /*! \brief The current indent of the code dump. */ |
144 | int indent_{0}; |
145 | /*! \brief The tab size of code indent. */ |
146 | const int tab_{4}; |
147 | /*! \brief Print the current indent spaces. */ |
148 | inline void PrintIndent(); |
149 | /*! \brief NameSupply for allocated ids. */ |
150 | NameSupply ids_allocated = NameSupply("" ); |
151 | /*! |
152 | * \brief Keys are either (tensors, value_index) or (variables, 0). |
153 | * Values are the corresponding IDs.*/ |
154 | std::map<std::pair<const Object*, int>, std::string> id_map_; |
155 | /*! \brief Variables (keys) binded to the threads (values). */ |
156 | std::map<const VarNode*, std::string> binds_; |
157 | /*! \brief The output code string builder. */ |
158 | std::stringstream stream; |
159 | /*! |
160 | * \brief Get or allocate the ID for the given variable. |
161 | * \param v The given variable. |
162 | */ |
163 | std::string GetVarID(const VarNode* v); |
164 | /*! |
165 | * \brief Get or allocate the ID for the given tensor. |
166 | * \param tensor The tensor to allocate a name. |
167 | */ |
168 | std::string GetTensorID(const Tensor& tensor); |
169 | }; |
170 | |
171 | } // namespace contrib |
172 | } // namespace tvm |
173 | #endif // TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_ |
174 | |