1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file codegen_c.h
22 * \brief Common utilities to generated C style code.
23 */
24#ifndef TVM_TARGET_SOURCE_CODEGEN_C_H_
25#define TVM_TARGET_SOURCE_CODEGEN_C_H_
26
27#include <tvm/ir/op.h>
28#include <tvm/target/codegen.h>
29#include <tvm/tir/analysis.h>
30#include <tvm/tir/builtin.h>
31#include <tvm/tir/expr.h>
32#include <tvm/tir/function.h>
33#include <tvm/tir/op_attr_types.h>
34#include <tvm/tir/stmt.h>
35#include <tvm/tir/stmt_functor.h>
36
37#include <string>
38#include <unordered_map>
39#include <unordered_set>
40#include <vector>
41
42#include "../../tir/transforms/ir_utils.h"
43#include "codegen_source_base.h"
44
45namespace tvm {
46namespace codegen {
47
48using namespace tir;
49/*!
50 * \brief A base class to generate C code.
51 *
52 * CodeGenC have two modes: generate SSA formed C code or normal form.
53 *
54 * **NOTE** CodeGenC does not aim at generating C codes consumed by MSVC or GCC,
55 * Rather, it's providing infrastructural abstraction for C variants like CUDA
56 * and OpenCL-C. You might find some odd variant features, e.g., type `int3` for
57 * a vector of 3 `int`s. For native C code generator, see `CodeGenLLVM`.
58 */
59class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
60 public StmtFunctor<void(const Stmt&)>,
61 public CodeGenSourceBase {
62 public:
63 /*!
64 * \brief Initialize the code generator.
65 * \param output_ssa Whether output SSA.
66 */
67 void Init(bool output_ssa);
68 /*!
69 * \brief Add the function to the generated module.
70 * \param f The function to be compiled.
71 * \param whether to append return 0 in the end.
72 */
73 void AddFunction(const PrimFunc& f);
74 /*!
75 * \brief Finalize the compilation and return the code.
76 * \return The code.
77 */
78 virtual std::string Finish();
79 /*!
80 * \brief Print the Stmt n to CodeGenC->stream
81 * \param n The statement to be printed.
82 */
83 void PrintStmt(const Stmt& n) { VisitStmt(n); }
84 /*!
85 * \brief Print the expression n(or its ssa id if in ssa mode) into os
86 * \param n The expression to be printed.
87 * \param os The output stream
88 */
89 void PrintExpr(const PrimExpr& n, std::ostream& os);
90 /*!
91 * \brief Same as PrintExpr, but simply returns result string
92 * \param n The expression to be printed.
93 */
94 std::string PrintExpr(const PrimExpr& n) {
95 std::ostringstream os;
96 PrintExpr(n, os);
97 return os.str();
98 }
99 // The following parts are overloadable print operations.
100 /*!
101 * \brief Print the function header before the argument list
102 * \param os The output stream
103 *
104 * Example: stream << "void";
105 */
106 virtual void PrintFuncPrefix(std::ostream& os); // NOLINT(*)
107 /*!
108 * \brief Print extra function attributes
109 *
110 * Example: __launch_bounds__(256) for CUDA functions
111 */
112 virtual void PrintExtraAttrs(const PrimFunc& f);
113 /*!
114 * \brief Print the final return at the end the function.
115 */
116 virtual void PrintFinalReturn(); // NOLINT(*)
117 /*!
118 * \brief Insert statement before function body.
119 * \param f The function to be compiled.
120 */
121 virtual void PreFunctionBody(const PrimFunc& f) {}
122 /*!
123 * \brief Initialize codegen state for generating f.
124 * \param f The function to be compiled.
125 */
126 virtual void InitFuncState(const PrimFunc& f);
127 // expression
128 void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*)
129 void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*)
130 void VisitExpr_(const BufferLoadNode* op, std::ostream& os) override; // NOLINT(*)
131 void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*)
132 void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*)
133 void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*)
134 void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*)
135 void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*)
136 void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*)
137 void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*)
138 void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*)
139 void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*)
140 void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*)
141 void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*)
142 void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*)
143 void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*)
144 void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*)
145 void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*)
146 void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*)
147 void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*)
148 void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*)
149 void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*)
150 void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*)
151 void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*)
152 void VisitExpr_(const ShuffleNode* op, std::ostream& os) override; // NOLINT(*)
153 void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*)
154 void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*)
155 void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*)
156 void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*)
157 // statment
158 void VisitStmt_(const LetStmtNode* op) override;
159 void VisitStmt_(const StoreNode* op) override;
160 void VisitStmt_(const BufferStoreNode* op) override;
161 void VisitStmt_(const ForNode* op) override;
162 void VisitStmt_(const WhileNode* op) override;
163 void VisitStmt_(const IfThenElseNode* op) override;
164 void VisitStmt_(const AllocateNode* op) override;
165 void VisitStmt_(const AttrStmtNode* op) override;
166 void VisitStmt_(const AssertStmtNode* op) override;
167 void VisitStmt_(const EvaluateNode* op) override;
168 void VisitStmt_(const SeqStmtNode* op) override;
169 void VisitStmt_(const AllocateConstNode* op) override;
170 void VisitStmt_(const DeclBufferNode* op) override;
171
172 /*!
173 * \brief Print expr representing the thread tag
174 * \param IterVar iv The thread index to be binded;
175 */
176 virtual void BindThreadIndex(const IterVar& iv); // NOLINT(*)
177 virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*)
178 virtual void PrintStorageSync(const CallNode* op); // NOLINT(*)
179 // Binary vector op.
180 virtual void PrintVecBinaryOp(const std::string& op, DataType op_type, PrimExpr lhs, PrimExpr rhs,
181 std::ostream& os); // NOLINT(*)
182 // print vector load
183 virtual std::string GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base);
184 // print vector store
185 virtual void PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base,
186 const std::string& value); // NOLINT(*)
187 // print load of single element
188 virtual void PrintVecElemLoad(const std::string& vec, DataType t, int i,
189 std::ostream& os); // NOLINT(*)
190 // print store of single element.
191 virtual void PrintVecElemStore(const std::string& vec, DataType t, int i,
192 const std::string& value);
193 // Get a cast type from to
194 virtual std::string CastFromTo(std::string value, DataType from, DataType target);
195 // Get load of single element with expression
196 virtual void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os);
197 // Print restrict keyword for a given Var if applicable
198 virtual void PrintRestrict(const Var& v, std::ostream& os);
199
200 virtual void SetConstantsByteAlignment(Integer constants_byte_alignment) {
201 constants_byte_alignment_ = constants_byte_alignment;
202 }
203
204 protected:
205 // Print reference to struct location
206 std::string GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind);
207 // Print reference to a buffer as type t in index.
208 virtual std::string GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index);
209
210 /*!
211 * \brief Handle volatile loads.
212 *
213 * This is to workaround a bug in CUDA cuda_fp16.h. Volatile accesses
214 * to shared memory are required for reductions. However, __half class
215 * does not implement volatile member functions. CUDA codegen will cast
216 * away volatile qualifier from CUDA __half types.
217 */
218 virtual void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op,
219 std::ostream& os) {
220 // By default, do nothing but print the loaded value.
221 os << value;
222 }
223
224 /*!
225 * \brief Check if scope is part of type in the target language.
226 *
227 * **NOTE** In OpenCL, __local is part of type, so "__local int *"
228 * is legal. This is not the case for CUDA, where "__shared__"
229 * or "__constant__" is not part of type but a storage class (like
230 * C/C++ static).
231 */
232 virtual bool IsScopePartOfType() const { return true; }
233
234 /*!
235 * \brief Generate forward function declarations.
236 * \param global_symbol The symbolc of the target function.
237 * \param args The arguments to the function.
238 * \param os The output stream.
239 */
240 virtual void GenerateForwardFunctionDeclarations(String global_symbol,
241 const Array<PrimExpr>& args) {}
242 /*!
243 * \brief Print external function call.
244 * \param ret_type The return type.
245 * \param global_symbol The symbolc of the target function.
246 * \param args The arguments to the function.
247 * \param skip_first_arg Whether to skip the first arguments.
248 * \param os The output stream.
249 */
250 virtual void PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
251 bool skip_first_arg, std::ostream& os); // NOLINT(*)
252 /*!
253 * \brief If buffer is allocated as type t.
254 * \param buf_var The buffer variable.
255 * \param t The type to be checked.
256 */
257 bool HandleTypeMatch(const VarNode* buf_var, DataType t) const;
258 /*!
259 * \brief Register the data type of buf_var
260 * \param buf_var The buffer variable.
261 * \param t The type to be checked.
262 */
263 void RegisterHandleType(const VarNode* buf_var, DataType t);
264 // override
265 void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) final;
266 /*! \brief reserves common C keywords */
267 void ReserveKeywordsAsUnique();
268
269 /*! \brief Check if buf_var is volatile or not. */
270 bool IsVolatile(const VarNode* buf_var) const { return volatile_buf_.count(buf_var) != 0; }
271
272 /*! \brief restrict keyword */
273 std::string restrict_keyword_{""};
274 /*! \brief the storage scope of allocation */
275 std::unordered_map<const VarNode*, std::string> alloc_storage_scope_;
276 /*! \brief the data type of allocated buffers */
277 std::unordered_map<const VarNode*, DataType> handle_data_type_;
278 /*! \brief Record of ops that have pre-defined global symbol. */
279 OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
280 // cache commonly used ops
281 const Op& builtin_call_extern_ = builtin::call_extern();
282 const Op& builtin_call_pure_extern_ = builtin::call_pure_extern();
283 Integer constants_byte_alignment_ = 16;
284
285 private:
286 /*! \brief whether to print in SSA form */
287 bool print_ssa_form_{false};
288 /*! \brief set of volatile buf access */
289 std::unordered_set<const VarNode*> volatile_buf_;
290 // deep comparison of PrimExpr
291 ExprDeepEqual deep_equal_;
292 // binding of let variables. Enables duplicate var defs that map to same value
293 std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
294};
295
296} // namespace codegen
297} // namespace tvm
298#endif // TVM_TARGET_SOURCE_CODEGEN_C_H_
299