1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file codegen_hybrid.h
22 * \brief Common utilities to generated C style code.
23 */
24#ifndef TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_
25#define TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_
26
27#include <tvm/ir/name_supply.h>
28#include <tvm/target/codegen.h>
29#include <tvm/te/operation.h>
30#include <tvm/te/schedule.h>
31#include <tvm/tir/expr.h>
32#include <tvm/tir/stmt_functor.h>
33
34#include <map>
35#include <string>
36#include <unordered_map>
37#include <utility>
38#include <vector>
39
40namespace tvm {
41namespace contrib {
42
43using namespace te;
44using namespace tir;
45/*!
46 * \brief A base class to generate Hybrid Script.
47 *
48 * **NOTE** CodeGenHybrid does not aim at generating Python scripts consumed by Python2/3.
49 * For runtime support, please refer the decorator in ``tvm/python/hybrid/api.py``.
50 */
51class CodeGenHybrid : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
52 public StmtFunctor<void(const Stmt&)> {
53 public:
54 /*!
55 * \brief Dump the given function body to hybrid script.
56 * \param stmt The function body to be dumped to hybrid script.
57 * \param inputs Input tensors of this schedule.
58 * \param outputs Output tensors of this schedule.
59 * \param name The name of the function.
60 */
61 void DumpStmt(const Stmt& stmt, const Array<ObjectRef>& inputs, const Array<Tensor>& outputs,
62 const std::string& name = "hybrid_func");
63 /*!
64 * \brief Finalize the compilation and return the code.
65 * \return The code.
66 */
67 std::string Finish();
68 /*! \brief Reserve keywords in avoid of name conflict. */
69 void ReserveKeywords();
70 /*!
71 * \brief Print the Stmt n to CodeGenHybrid->stream
72 * \param n The statement to be printed.
73 */
74 void PrintStmt(const Stmt& n) { this->VisitStmt(n); }
75 /*!
76 * \brief Print the expression n(or its ssa id if in ssa mode) into os
77 * \param n The expression to be printed.
78 * \param os The output stream
79 */
80 void PrintExpr(const PrimExpr& n, std::ostream& os) { this->VisitExpr(n, os); }
81 /*!
82 * \brief Same as PrintExpr, but simply returns result string
83 * \param n The expression to be printed.
84 */
85 std::string PrintExpr(const PrimExpr& n) {
86 std::ostringstream os;
87 PrintExpr(n, os);
88 return os.str();
89 }
90 // expression
91 void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*)
92 void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*)
93 void VisitExpr_(const BufferLoadNode* op, std::ostream& os) override; // NOLINT(*)
94 void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*)
95 void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*)
96 void VisitExpr_(const ProducerLoadNode* op, std::ostream& os) override; // NOLINT(*)
97 void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*)
98 void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*)
99 void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*)
100 void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*)
101 void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*)
102 void VisitExpr_(const FloorDivNode* op, std::ostream& os) override; // NOLINT(*)
103 void VisitExpr_(const FloorModNode* op, std::ostream& os) override; // NOLINT(*)
104 void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*)
105 void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*)
106 void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*)
107 void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*)
108 void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*)
109 void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*)
110 void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*)
111 void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*)
112 void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*)
113 void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*)
114 void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*)
115 void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*)
116 void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*)
117 void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*)
118 void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*)
119 void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*)
120 void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*)
121 void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*)
122 // statment
123 void VisitStmt_(const LetStmtNode* op) override;
124 void VisitStmt_(const StoreNode* op) override;
125 void VisitStmt_(const BufferStoreNode* op) override;
126 void VisitStmt_(const ProducerStoreNode* op) override;
127 void VisitStmt_(const ForNode* op) override;
128 void VisitStmt_(const IfThenElseNode* op) override;
129 void VisitStmt_(const AllocateNode* op) override;
130 void VisitStmt_(const ProducerRealizeNode* op) override;
131 void VisitStmt_(const AttrStmtNode* op) override;
132 void VisitStmt_(const AssertStmtNode* op) override;
133 void VisitStmt_(const EvaluateNode* op) override;
134 void VisitStmt_(const SeqStmtNode* op) override;
135 /*!
136 * \brief Print Type represetnation of type t.
137 * \param t The type representation.
138 * \param os The stream to print the ctype into
139 */
140 virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*)
141
142 private:
143 /*! \brief The current indent of the code dump. */
144 int indent_{0};
145 /*! \brief The tab size of code indent. */
146 const int tab_{4};
147 /*! \brief Print the current indent spaces. */
148 inline void PrintIndent();
149 /*! \brief NameSupply for allocated ids. */
150 NameSupply ids_allocated = NameSupply("");
151 /*!
152 * \brief Keys are either (tensors, value_index) or (variables, 0).
153 * Values are the corresponding IDs.*/
154 std::map<std::pair<const Object*, int>, std::string> id_map_;
155 /*! \brief Variables (keys) binded to the threads (values). */
156 std::map<const VarNode*, std::string> binds_;
157 /*! \brief The output code string builder. */
158 std::stringstream stream;
159 /*!
160 * \brief Get or allocate the ID for the given variable.
161 * \param v The given variable.
162 */
163 std::string GetVarID(const VarNode* v);
164 /*!
165 * \brief Get or allocate the ID for the given tensor.
166 * \param tensor The tensor to allocate a name.
167 */
168 std::string GetTensorID(const Tensor& tensor);
169};
170
171} // namespace contrib
172} // namespace tvm
173#endif // TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_
174