1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file codegen_llvm_cpu.h |
22 | * \brief Common base class for generating into LLVM IR on CPU host. |
23 | */ |
24 | #ifndef TVM_TARGET_LLVM_CODEGEN_CPU_H_ |
25 | #define TVM_TARGET_LLVM_CODEGEN_CPU_H_ |
26 | |
27 | #ifdef TVM_LLVM_VERSION |
28 | |
29 | #include <memory> |
30 | #include <string> |
31 | #include <unordered_map> |
32 | #include <utility> |
33 | #include <vector> |
34 | |
35 | #include "codegen_llvm.h" |
36 | |
37 | namespace llvm { |
38 | class BasicBlock; |
39 | class Constant; |
40 | class DIBuilder; |
41 | class DIType; |
42 | class Function; |
43 | class FunctionType; |
44 | class GlobalVariable; |
45 | class LLVMContext; |
46 | class MDNode; |
47 | class StructType; |
48 | class TargetMachine; |
49 | class Type; |
50 | class Value; |
51 | |
52 | // Used in std::unique_ptr |
53 | class Module; |
54 | } // namespace llvm |
55 | |
56 | namespace tvm { |
57 | namespace codegen { |
58 | |
59 | class LLVMTarget; |
60 | |
61 | // CPU host code generation |
62 | class CodeGenCPU : public CodeGenLLVM { |
63 | public: |
64 | CodeGenCPU(); |
65 | virtual ~CodeGenCPU(); |
66 | |
67 | void Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib, |
68 | bool dynamic_lookup, bool target_c_runtime) override; |
69 | void AddFunction(const PrimFunc& f) override; |
70 | void AddMainFunction(const std::string& entry_func_name) override; |
71 | std::unique_ptr<llvm::Module> Finish() override; |
72 | void VisitStmt_(const AssertStmtNode* op) override; |
73 | void VisitStmt_(const AttrStmtNode* op) override; |
74 | void VisitStmt_(const ForNode* op) override; |
75 | llvm::Value* CreateIntrinsic(const CallNode* op) override; |
76 | llvm::Value* CreateCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args, |
77 | bool skip_first_arg) override; |
78 | |
79 | /*! |
80 | * \brief A CPU-specific function to create the FuncRegistry. |
81 | * \param func_names List of functions to be included, in order. |
82 | */ |
83 | void DefineFunctionRegistry(Array<String> func_names); |
84 | |
85 | /*! |
86 | * \brief Serialize the metadata object as data, and implement get_c_metadata function. |
87 | * \param metadata The metadata which should be serialized. |
88 | */ |
89 | void DefineMetadata(runtime::metadata::Metadata metadata); |
90 | |
91 | protected: |
92 | void AddStartupFunction() final; |
93 | // meta data |
94 | llvm::MDNode* md_tbaa_ctx_ptr_{nullptr}; |
95 | // TVM related data types |
96 | llvm::Type* t_tvm_shape_index_{nullptr}; |
97 | llvm::Type* t_tvm_func_handle_{nullptr}; |
98 | llvm::StructType* t_tvm_device_{nullptr}; |
99 | llvm::StructType* t_tvm_type_{nullptr}; |
100 | llvm::StructType* t_tvm_array_{nullptr}; |
101 | llvm::StructType* t_tvm_value_{nullptr}; |
102 | llvm::StructType* t_tvm_parallel_group_env_{nullptr}; |
103 | |
104 | llvm::FunctionType* ftype_tvm_backend_packed_c_func_{nullptr}; |
105 | llvm::StructType* t_tvm_crt_func_registry_{nullptr}; |
106 | llvm::StructType* t_tvm_crt_module_{nullptr}; |
107 | |
108 | llvm::FunctionType* ftype_tvm_parallel_lambda_{nullptr}; |
109 | llvm::FunctionType* ftype_tvm_func_call_{nullptr}; |
110 | llvm::FunctionType* ftype_tvm_get_func_from_env_{nullptr}; |
111 | llvm::FunctionType* ftype_tvm_api_set_last_error_{nullptr}; |
112 | llvm::FunctionType* ftype_tvm_parallel_launch_{nullptr}; |
113 | llvm::FunctionType* ftype_tvm_parallel_barrier_{nullptr}; |
114 | llvm::FunctionType* ftype_tvm_register_system_symbol_{nullptr}; |
115 | // Lazy entry for function call. |
116 | llvm::FunctionType* ftype_tvm_static_init_callback_{nullptr}; |
117 | llvm::FunctionType* ftype_tvm_static_init_{nullptr}; |
118 | |
119 | private: |
120 | // the parallel group information |
121 | struct ParallelEnv { |
122 | Var task_id; |
123 | Var num_task; |
124 | bool stride_pattern{false}; |
125 | bool in_parallel_loop{false}; |
126 | int parallel_loop_count{0}; |
127 | llvm::Value* penv{nullptr}; |
128 | }; |
129 | // Get runtime functions |
130 | void InitGlobalContext(bool dynamic_lookup); |
131 | llvm::GlobalVariable* InitContextPtr(llvm::Type* type, std::string name); |
132 | llvm::Value* GetContextPtr(llvm::GlobalVariable* gv); |
133 | llvm::Value* RuntimeTVMFuncCall(); |
134 | llvm::Value* RuntimeTVMGetFuncFromEnv(); |
135 | llvm::Value* RuntimeTVMAPISetLastError(); |
136 | llvm::Value* RuntimeTVMParallelLaunch(); |
137 | llvm::Value* RuntimeTVMParallelBarrier(); |
138 | llvm::Value* CreateStaticHandle(); |
139 | llvm::Value* GetPackedFuncHandle(const std::string& str); |
140 | TypedPointer PackClosureData(const Array<Var>& fields, uint64_t* num_bytes, |
141 | std::string struct_name = "" ); |
142 | TypedPointer CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind); |
143 | void UnpackClosureData(TypedPointer cdata, const Array<Var>& fields, |
144 | std::unordered_map<const VarNode*, llvm::Value*>* vmap); |
145 | // Make packed call. |
146 | struct PackedCall { |
147 | llvm::Value* ret_value; |
148 | llvm::Value* ret_tcode; |
149 | llvm::BasicBlock* end_block; |
150 | }; |
151 | PackedCall MakeCallPackedLowered(const Array<PrimExpr>& args, const DataType& r_type, |
152 | const int64_t begin, const int64_t end, bool use_string_lookup); |
153 | // create call into tvm packed function. |
154 | llvm::Value* CreateCallPacked(const CallNode* op, bool use_string_lookup); |
155 | // Create trace call into tvm packed function. |
156 | llvm::Value* CreateCallTracePacked(const CallNode* op); |
157 | // Create static initialization |
158 | void CreateStaticInit(const std::string& init_fname, const Stmt& body); |
159 | // Create parallel launch |
160 | void CreateParallelLaunch(const Stmt& body, int num_task, std::string name = "" ); |
161 | // Create a new compute scope. |
162 | void CreateComputeScope(const AttrStmtNode* op); |
163 | // Check if the call to packed function is successful |
164 | // if not directly finalize function and pass on return code. |
165 | // return the end block after the check |
166 | llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode); |
167 | llvm::DISubprogram* CreateDebugFunction(const PrimFunc& f); |
168 | // Context for injection lookup |
169 | llvm::GlobalVariable* gv_mod_ctx_{nullptr}; |
170 | llvm::GlobalVariable* gv_tvm_func_call_{nullptr}; |
171 | llvm::GlobalVariable* gv_tvm_get_func_from_env_{nullptr}; |
172 | llvm::GlobalVariable* gv_tvm_api_set_last_error_{nullptr}; |
173 | llvm::GlobalVariable* gv_tvm_parallel_launch_{nullptr}; |
174 | llvm::GlobalVariable* gv_tvm_parallel_barrier_{nullptr}; |
175 | std::unordered_map<String, llvm::GlobalVariable*> gv_func_map_; |
176 | // context for direct dynamic lookup |
177 | llvm::Function* f_tvm_func_call_{nullptr}; |
178 | llvm::Function* f_tvm_get_func_from_env_{nullptr}; |
179 | llvm::Function* f_tvm_api_set_last_error_{nullptr}; |
180 | llvm::Function* f_tvm_parallel_launch_{nullptr}; |
181 | llvm::Function* f_tvm_parallel_barrier_{nullptr}; |
182 | llvm::Function* f_tvm_register_system_symbol_{nullptr}; |
183 | // Current parallel environment scope. |
184 | ParallelEnv parallel_env_; |
185 | // global to packed function handle |
186 | std::unordered_map<std::string, llvm::GlobalVariable*> func_handle_map_; |
187 | // List of symbols to be exported to TVM system lib. |
188 | std::vector<std::pair<std::string, llvm::Constant*>> export_system_symbols_; |
189 | // List of functions to be registered in the FuncRegistry, if generated. |
190 | std::vector<std::pair<std::string, llvm::Function*>> registry_functions_; |
191 | // internal debug information, to be populated by |
192 | std::unique_ptr<DebugInfo> dbg_info_; |
193 | bool target_c_runtime_; |
194 | bool is_system_lib_; |
195 | |
196 | // Get the DWARF type corresponding to the LLVM type |ty|. The current API in practice only |
197 | // generates |int32|, and |int8*|. |
198 | llvm::DIType* GetDebugType(const Type& ty_tir); |
199 | llvm::DIType* GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm); |
200 | // Adds the DWARF debug information for |function| to |dbg_info_|. |
201 | void AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm); |
202 | }; |
203 | |
204 | } // namespace codegen |
205 | } // namespace tvm |
206 | |
207 | #endif // TVM_LLVM_VERSION |
208 | #endif // TVM_TARGET_LLVM_CODEGEN_CPU_H_ |
209 | |