1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file codegen_metal.cc |
22 | */ |
23 | #include "codegen_metal.h" |
24 | |
25 | #include <algorithm> |
26 | #include <string> |
27 | #include <vector> |
28 | |
29 | #include "../../runtime/metal/metal_module.h" |
30 | #include "../../runtime/thread_storage_scope.h" |
31 | #include "../build_common.h" |
32 | |
33 | namespace tvm { |
34 | namespace codegen { |
35 | |
36 | void CodeGenMetal::InitFuncState(const PrimFunc& f) { |
37 | CodeGenC::InitFuncState(f); |
38 | // analyze the data; |
39 | for (Var arg : f->params) { |
40 | if (arg.dtype().is_handle()) { |
41 | alloc_storage_scope_[arg.get()] = "global" ; |
42 | } |
43 | } |
44 | } |
45 | |
46 | CodeGenMetal::CodeGenMetal(Target target) : target_(target) { |
47 | decl_stream << "#include <metal_stdlib>\n" ; |
48 | decl_stream << "using namespace metal;\n\n" ; |
49 | decl_stream << "union __TVMArgUnion {\n" |
50 | << " int v_int[2];\n" |
51 | << "};\n\n" ; |
52 | } |
53 | |
54 | void CodeGenMetal::AddFunction(const PrimFunc& f) { |
55 | // clear previous generated state. |
56 | this->InitFuncState(f); |
57 | // skip the first underscore, so SSA variable starts from _1 |
58 | name_supply_->FreshName("_" ); |
59 | |
60 | // add to alloc buffer type. |
61 | auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); |
62 | ICHECK(global_symbol.defined()) |
63 | << "CodeGenC: Expect PrimFunc to have the global_symbol attribute" ; |
64 | |
65 | // Function header. |
66 | this->stream << "kernel void " << static_cast<std::string>(global_symbol.value()) << "(" ; |
67 | |
68 | // Buffer arguments |
69 | size_t num_buffer = 0; |
70 | int limit = target_->GetAttr<Integer>("max_function_args" ).value().IntValue(); |
71 | if (static_cast<int>(f->params.size()) > limit) { |
72 | LOG(WARNING) << "Probably you won't be able to execute your kernel due to high number of " |
73 | "buffers in the kernel" ; |
74 | } |
75 | for (size_t i = 0; i < f->params.size(); ++i, ++num_buffer) { |
76 | Var v = f->params[i]; |
77 | if (!v.dtype().is_handle()) break; |
78 | stream << " " ; |
79 | std::string vid = AllocVarID(v.get()); |
80 | auto it = alloc_storage_scope_.find(v.get()); |
81 | if (it != alloc_storage_scope_.end()) { |
82 | PrintStorageScope(it->second, stream); |
83 | } |
84 | PrintType(GetType(v), stream); |
85 | // Register handle data type |
86 | // TODO(tvm-team): consider simply keep type info in the |
87 | // type annotation(via a normalizing rewriting). |
88 | if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) { |
89 | if (auto* prim = ptr->element_type.as<PrimTypeNode>()) { |
90 | RegisterHandleType(v.get(), prim->dtype); |
91 | } |
92 | } |
93 | stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n" ; |
94 | } |
95 | // Setup normal arguments. |
96 | size_t nargs = f->params.size() - num_buffer; |
97 | std::string varg = name_supply_->FreshName("arg" ); |
98 | if (nargs != 0) { |
99 | std::string arg_buf_type = static_cast<std::string>(global_symbol.value()) + "_args_t" ; |
100 | stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer |
101 | << ") ]],\n" ; |
102 | // declare the struct |
103 | decl_stream << "struct " << arg_buf_type << " {\n" ; |
104 | for (size_t i = num_buffer; i < f->params.size(); ++i) { |
105 | Var v = f->params[i]; |
106 | ICHECK(!v.dtype().is_handle()); |
107 | std::string vid = AllocVarID(v.get()); |
108 | std::ostringstream vref; |
109 | if (v.dtype().bits() == 32) { |
110 | decl_stream << " " ; |
111 | PrintType(v.dtype(), decl_stream); |
112 | decl_stream << " " << vid << "[2];\n" ; |
113 | vref << varg << "." << vid << "[0]" ; |
114 | } else if (v.dtype().bits() == 64) { |
115 | decl_stream << " " ; |
116 | PrintType(v.dtype(), decl_stream); |
117 | decl_stream << " " << vid << ";\n" ; |
118 | vref << varg << "." << vid; |
119 | } else { |
120 | // For non 32bit type, ref through arg union. |
121 | decl_stream << " __TVMArgUnion " << vid << ";\n" ; |
122 | vref << varg << "." << vid << ".v_" ; |
123 | PrintType(v.dtype(), vref); |
124 | } |
125 | var_idmap_[v.get()] = vref.str(); |
126 | } |
127 | decl_stream << "};\n\n" ; |
128 | } |
129 | // Setup the thread group info. |
130 | ICHECK_EQ(name_supply_->FreshName("threadIdx" ), "threadIdx" ); |
131 | ICHECK_EQ(name_supply_->FreshName("blockIdx" ), "blockIdx" ); |
132 | int work_dim = 0; |
133 | auto thread_axis = f->GetAttr<Array<tir::IterVar>>(tir::attr::kDeviceThreadAxis).value(); |
134 | |
135 | for (IterVar iv : thread_axis) { |
136 | runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag); |
137 | work_dim = std::max(work_dim, scope.dim_index + 1); |
138 | } |
139 | if (work_dim != 0) { |
140 | // use ushort by default for now |
141 | stream << " " ; |
142 | PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); |
143 | stream << " blockIdx [[threadgroup_position_in_grid]],\n" ; |
144 | stream << " " ; |
145 | PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); |
146 | stream << " threadIdx [[thread_position_in_threadgroup]]\n" ; |
147 | } |
148 | // bind thread axis |
149 | for (IterVar iv : thread_axis) { |
150 | ICHECK(!var_idmap_.count(iv->var.get())); |
151 | std::string vname = iv->thread_tag; |
152 | if (work_dim <= 1) { |
153 | vname = vname.substr(0, iv->thread_tag.length() - 2); |
154 | } |
155 | var_idmap_[iv->var.get()] = |
156 | CastFromTo(vname, DataType::UInt(thread_index_bits_), iv->var.dtype()); |
157 | } |
158 | // the function scope. |
159 | stream << ") {\n" ; |
160 | int func_scope = this->BeginScope(); |
161 | this->PrintStmt(f->body); |
162 | this->EndScope(func_scope); |
163 | this->PrintIndent(); |
164 | this->stream << "}\n\n" ; |
165 | } |
166 | |
167 | void CodeGenMetal::BindThreadIndex(const IterVar& iv) { |
168 | ICHECK(!var_idmap_.count(iv->var.get())); |
169 | var_idmap_[iv->var.get()] = |
170 | CastFromTo(iv->thread_tag, DataType::UInt(thread_index_bits_), iv->var.dtype()); |
171 | } |
172 | |
173 | void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) |
174 | int lanes = t.lanes(); |
175 | if (t.is_handle()) { |
176 | ICHECK_EQ(lanes, 1) << "do not yet support vector types" ; |
177 | os << "void*" ; |
178 | return; |
179 | } |
180 | |
181 | if (t.is_void()) { |
182 | os << "void" ; |
183 | return; |
184 | } |
185 | if (t == DataType::Bool()) { |
186 | os << "bool" ; |
187 | return; |
188 | } |
189 | bool fail = false; |
190 | if (t.is_float()) { |
191 | // Need to care about sizes and alignment of half3/float3 because tir representation might not |
192 | // be aware of Metal half3/float3 details and can treat them as just three elements, |
193 | // while sizes and alignmnents of half3/float3 are one element more (half3-8 bytes/ |
194 | // float13 - 16bytes). |
195 | // Example of problematic pattern: filling of threadgroup packed array using float3 elements |
196 | // by threads concurrently can lead to datarace and wrong data in threadgroup shared array. |
197 | // packed_(half3/float3) are exactly datatypes dealing with 3 elements and per-element |
198 | // alignment |
199 | if (lanes == 3) { |
200 | os << "packed_" ; |
201 | } |
202 | switch (t.bits()) { |
203 | case 16: |
204 | os << "half" ; |
205 | break; |
206 | case 32: |
207 | os << "float" ; |
208 | break; |
209 | default: |
210 | fail = true; |
211 | break; |
212 | } |
213 | if (!fail && lanes == 1) return; |
214 | if (!fail && (lanes >= 2 && lanes <= 4)) { |
215 | os << lanes; |
216 | return; |
217 | } |
218 | } else if (t.is_uint() || t.is_int()) { |
219 | if (t.is_uint()) { |
220 | os << 'u'; |
221 | } |
222 | if (t.bits() == 8 && t.lanes() == 4) { |
223 | // directly 4 8 bit int in integer. |
224 | os << "int" ; |
225 | return; |
226 | } |
227 | switch (t.bits()) { |
228 | case 8: |
229 | os << "char" ; |
230 | break; |
231 | case 16: |
232 | os << "short" ; |
233 | break; |
234 | case 32: |
235 | os << "int" ; |
236 | break; |
237 | case 64: |
238 | os << "long" ; |
239 | break; |
240 | case 1: |
241 | os << "bool" ; |
242 | break; |
243 | default: |
244 | fail = true; |
245 | break; |
246 | } |
247 | if (!fail && lanes == 1) return; |
248 | if (!fail && (lanes >= 2 && lanes <= 4)) { |
249 | os << lanes; |
250 | return; |
251 | } |
252 | } |
253 | LOG(FATAL) << "Cannot convert type " << t << " to Metal type" ; |
254 | } |
255 | |
256 | void CodeGenMetal::PrintStorageSync(const CallNode* op) { |
257 | const std::string& sync = op->args[0].as<StringImmNode>()->value; |
258 | if (sync == "warp" ) { |
259 | this->PrintIndent(); |
260 | this->stream << "simdgroup_barrier(mem_flags::mem_threadgroup);\n" ; |
261 | } else if (sync == "shared" ) { |
262 | this->PrintIndent(); |
263 | this->stream << "threadgroup_barrier(mem_flags::mem_threadgroup);\n" ; |
264 | } else if (sync == "global" ) { |
265 | LOG(FATAL) << "global barrier not supported" ; |
266 | } |
267 | } |
268 | |
269 | void CodeGenMetal::PrintVecElemLoad(const std::string& vec, DataType t, int i, |
270 | std::ostream& os) { // NOLINT(*) |
271 | os << vec << "[" << i << "]" ; |
272 | } |
273 | |
274 | void CodeGenMetal::PrintVecElemStore(const std::string& vec, DataType t, int i, |
275 | const std::string& value) { |
276 | this->PrintIndent(); |
277 | stream << vec << "[" << i << "]" |
278 | << " = " << value << ";\n" ; |
279 | } |
280 | |
281 | void CodeGenMetal::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) |
282 | if (scope == "global" ) { |
283 | os << "device " ; |
284 | } else if (scope == "shared" ) { |
285 | os << "threadgroup " ; |
286 | } else { |
287 | os << "thread " ; |
288 | } |
289 | } |
290 | |
291 | void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) |
292 | std::string v = PrintExpr(op->value); |
293 | PrintType(op->dtype, os); |
294 | os << "(" ; |
295 | for (int i = 0; i < op->lanes; ++i) { |
296 | if (i != 0) os << ", " ; |
297 | os << v; |
298 | } |
299 | os << ')'; |
300 | } |
301 | |
302 | void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) |
303 | if (op->op.same_as(builtin::reinterpret())) { |
304 | // generate as_type<TYPE>(ARG) |
305 | os << "(as_type<" ; |
306 | this->PrintType(op->dtype, os); |
307 | os << ">(" ; |
308 | this->PrintExpr(op->args[0], os); |
309 | os << "))" ; |
310 | } else { |
311 | CodeGenC::VisitExpr_(op, os); |
312 | } |
313 | } |
314 | |
315 | void CodeGenMetal::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) |
316 | std::ostringstream temp; |
317 | if (std::isinf(op->value)) { |
318 | if (op->value < 0) { |
319 | temp << "-" ; |
320 | } |
321 | temp << "INFINITY" ; |
322 | } else if (std::isnan(op->value)) { |
323 | temp << "NAN" ; |
324 | } else { |
325 | temp << std::scientific << op->value; |
326 | if (op->dtype.bits() == 32) |
327 | temp << 'f'; |
328 | else if (op->dtype.bits() == 16) |
329 | temp << 'h'; |
330 | } |
331 | MarkConst(temp.str()); |
332 | os << temp.str(); |
333 | } |
334 | |
335 | runtime::Module BuildMetal(IRModule mod, Target target) { |
336 | using tvm::runtime::Registry; |
337 | bool output_ssa = false; |
338 | |
339 | std::stringstream code; |
340 | std::stringstream source; |
341 | std::string fmt = "metal" ; |
342 | for (auto kv : mod->functions) { |
343 | ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only take PrimFunc" ; |
344 | code << "// Function: " << kv.first->name_hint << std::endl; |
345 | CodeGenMetal cg(target); |
346 | cg.Init(output_ssa); |
347 | auto f = Downcast<PrimFunc>(kv.second); |
348 | auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); |
349 | ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) |
350 | << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch" ; |
351 | cg.AddFunction(f); |
352 | std::string fsource = cg.Finish(); |
353 | if (const auto* f = Registry::Get("tvm_callback_metal_compile" )) { |
354 | source << fsource; |
355 | fsource = (*f)(fsource).operator std::string(); |
356 | fmt = "metallib" ; |
357 | } |
358 | code << fsource; |
359 | } |
360 | |
361 | return MetalModuleCreate(code.str(), fmt, ExtractFuncInfo(mod), source.str()); |
362 | } |
363 | |
364 | TVM_REGISTER_GLOBAL("target.build.metal" ).set_body_typed(BuildMetal); |
365 | } // namespace codegen |
366 | } // namespace tvm |
367 | |