1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file codegen_metal.cc
22 */
23#include "codegen_metal.h"
24
25#include <algorithm>
26#include <string>
27#include <vector>
28
29#include "../../runtime/metal/metal_module.h"
30#include "../../runtime/thread_storage_scope.h"
31#include "../build_common.h"
32
33namespace tvm {
34namespace codegen {
35
36void CodeGenMetal::InitFuncState(const PrimFunc& f) {
37 CodeGenC::InitFuncState(f);
38 // analyze the data;
39 for (Var arg : f->params) {
40 if (arg.dtype().is_handle()) {
41 alloc_storage_scope_[arg.get()] = "global";
42 }
43 }
44}
45
46CodeGenMetal::CodeGenMetal(Target target) : target_(target) {
47 decl_stream << "#include <metal_stdlib>\n";
48 decl_stream << "using namespace metal;\n\n";
49 decl_stream << "union __TVMArgUnion {\n"
50 << " int v_int[2];\n"
51 << "};\n\n";
52}
53
54void CodeGenMetal::AddFunction(const PrimFunc& f) {
55 // clear previous generated state.
56 this->InitFuncState(f);
57 // skip the first underscore, so SSA variable starts from _1
58 name_supply_->FreshName("_");
59
60 // add to alloc buffer type.
61 auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
62 ICHECK(global_symbol.defined())
63 << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
64
65 // Function header.
66 this->stream << "kernel void " << static_cast<std::string>(global_symbol.value()) << "(";
67
68 // Buffer arguments
69 size_t num_buffer = 0;
70 int limit = target_->GetAttr<Integer>("max_function_args").value().IntValue();
71 if (static_cast<int>(f->params.size()) > limit) {
72 LOG(WARNING) << "Probably you won't be able to execute your kernel due to high number of "
73 "buffers in the kernel";
74 }
75 for (size_t i = 0; i < f->params.size(); ++i, ++num_buffer) {
76 Var v = f->params[i];
77 if (!v.dtype().is_handle()) break;
78 stream << " ";
79 std::string vid = AllocVarID(v.get());
80 auto it = alloc_storage_scope_.find(v.get());
81 if (it != alloc_storage_scope_.end()) {
82 PrintStorageScope(it->second, stream);
83 }
84 PrintType(GetType(v), stream);
85 // Register handle data type
86 // TODO(tvm-team): consider simply keep type info in the
87 // type annotation(via a normalizing rewriting).
88 if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
89 if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
90 RegisterHandleType(v.get(), prim->dtype);
91 }
92 }
93 stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n";
94 }
95 // Setup normal arguments.
96 size_t nargs = f->params.size() - num_buffer;
97 std::string varg = name_supply_->FreshName("arg");
98 if (nargs != 0) {
99 std::string arg_buf_type = static_cast<std::string>(global_symbol.value()) + "_args_t";
100 stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer
101 << ") ]],\n";
102 // declare the struct
103 decl_stream << "struct " << arg_buf_type << " {\n";
104 for (size_t i = num_buffer; i < f->params.size(); ++i) {
105 Var v = f->params[i];
106 ICHECK(!v.dtype().is_handle());
107 std::string vid = AllocVarID(v.get());
108 std::ostringstream vref;
109 if (v.dtype().bits() == 32) {
110 decl_stream << " ";
111 PrintType(v.dtype(), decl_stream);
112 decl_stream << " " << vid << "[2];\n";
113 vref << varg << "." << vid << "[0]";
114 } else if (v.dtype().bits() == 64) {
115 decl_stream << " ";
116 PrintType(v.dtype(), decl_stream);
117 decl_stream << " " << vid << ";\n";
118 vref << varg << "." << vid;
119 } else {
120 // For non 32bit type, ref through arg union.
121 decl_stream << " __TVMArgUnion " << vid << ";\n";
122 vref << varg << "." << vid << ".v_";
123 PrintType(v.dtype(), vref);
124 }
125 var_idmap_[v.get()] = vref.str();
126 }
127 decl_stream << "};\n\n";
128 }
129 // Setup the thread group info.
130 ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
131 ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
132 int work_dim = 0;
133 auto thread_axis = f->GetAttr<Array<tir::IterVar>>(tir::attr::kDeviceThreadAxis).value();
134
135 for (IterVar iv : thread_axis) {
136 runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag);
137 work_dim = std::max(work_dim, scope.dim_index + 1);
138 }
139 if (work_dim != 0) {
140 // use ushort by default for now
141 stream << " ";
142 PrintType(DataType::UInt(thread_index_bits_, work_dim), stream);
143 stream << " blockIdx [[threadgroup_position_in_grid]],\n";
144 stream << " ";
145 PrintType(DataType::UInt(thread_index_bits_, work_dim), stream);
146 stream << " threadIdx [[thread_position_in_threadgroup]]\n";
147 }
148 // bind thread axis
149 for (IterVar iv : thread_axis) {
150 ICHECK(!var_idmap_.count(iv->var.get()));
151 std::string vname = iv->thread_tag;
152 if (work_dim <= 1) {
153 vname = vname.substr(0, iv->thread_tag.length() - 2);
154 }
155 var_idmap_[iv->var.get()] =
156 CastFromTo(vname, DataType::UInt(thread_index_bits_), iv->var.dtype());
157 }
158 // the function scope.
159 stream << ") {\n";
160 int func_scope = this->BeginScope();
161 this->PrintStmt(f->body);
162 this->EndScope(func_scope);
163 this->PrintIndent();
164 this->stream << "}\n\n";
165}
166
167void CodeGenMetal::BindThreadIndex(const IterVar& iv) {
168 ICHECK(!var_idmap_.count(iv->var.get()));
169 var_idmap_[iv->var.get()] =
170 CastFromTo(iv->thread_tag, DataType::UInt(thread_index_bits_), iv->var.dtype());
171}
172
173void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
174 int lanes = t.lanes();
175 if (t.is_handle()) {
176 ICHECK_EQ(lanes, 1) << "do not yet support vector types";
177 os << "void*";
178 return;
179 }
180
181 if (t.is_void()) {
182 os << "void";
183 return;
184 }
185 if (t == DataType::Bool()) {
186 os << "bool";
187 return;
188 }
189 bool fail = false;
190 if (t.is_float()) {
191 // Need to care about sizes and alignment of half3/float3 because tir representation might not
192 // be aware of Metal half3/float3 details and can treat them as just three elements,
193 // while sizes and alignmnents of half3/float3 are one element more (half3-8 bytes/
194 // float13 - 16bytes).
195 // Example of problematic pattern: filling of threadgroup packed array using float3 elements
196 // by threads concurrently can lead to datarace and wrong data in threadgroup shared array.
197 // packed_(half3/float3) are exactly datatypes dealing with 3 elements and per-element
198 // alignment
199 if (lanes == 3) {
200 os << "packed_";
201 }
202 switch (t.bits()) {
203 case 16:
204 os << "half";
205 break;
206 case 32:
207 os << "float";
208 break;
209 default:
210 fail = true;
211 break;
212 }
213 if (!fail && lanes == 1) return;
214 if (!fail && (lanes >= 2 && lanes <= 4)) {
215 os << lanes;
216 return;
217 }
218 } else if (t.is_uint() || t.is_int()) {
219 if (t.is_uint()) {
220 os << 'u';
221 }
222 if (t.bits() == 8 && t.lanes() == 4) {
223 // directly 4 8 bit int in integer.
224 os << "int";
225 return;
226 }
227 switch (t.bits()) {
228 case 8:
229 os << "char";
230 break;
231 case 16:
232 os << "short";
233 break;
234 case 32:
235 os << "int";
236 break;
237 case 64:
238 os << "long";
239 break;
240 case 1:
241 os << "bool";
242 break;
243 default:
244 fail = true;
245 break;
246 }
247 if (!fail && lanes == 1) return;
248 if (!fail && (lanes >= 2 && lanes <= 4)) {
249 os << lanes;
250 return;
251 }
252 }
253 LOG(FATAL) << "Cannot convert type " << t << " to Metal type";
254}
255
256void CodeGenMetal::PrintStorageSync(const CallNode* op) {
257 const std::string& sync = op->args[0].as<StringImmNode>()->value;
258 if (sync == "warp") {
259 this->PrintIndent();
260 this->stream << "simdgroup_barrier(mem_flags::mem_threadgroup);\n";
261 } else if (sync == "shared") {
262 this->PrintIndent();
263 this->stream << "threadgroup_barrier(mem_flags::mem_threadgroup);\n";
264 } else if (sync == "global") {
265 LOG(FATAL) << "global barrier not supported";
266 }
267}
268
269void CodeGenMetal::PrintVecElemLoad(const std::string& vec, DataType t, int i,
270 std::ostream& os) { // NOLINT(*)
271 os << vec << "[" << i << "]";
272}
273
274void CodeGenMetal::PrintVecElemStore(const std::string& vec, DataType t, int i,
275 const std::string& value) {
276 this->PrintIndent();
277 stream << vec << "[" << i << "]"
278 << " = " << value << ";\n";
279}
280
281void CodeGenMetal::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
282 if (scope == "global") {
283 os << "device ";
284 } else if (scope == "shared") {
285 os << "threadgroup ";
286 } else {
287 os << "thread ";
288 }
289}
290
291void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
292 std::string v = PrintExpr(op->value);
293 PrintType(op->dtype, os);
294 os << "(";
295 for (int i = 0; i < op->lanes; ++i) {
296 if (i != 0) os << ", ";
297 os << v;
298 }
299 os << ')';
300}
301
302void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
303 if (op->op.same_as(builtin::reinterpret())) {
304 // generate as_type<TYPE>(ARG)
305 os << "(as_type<";
306 this->PrintType(op->dtype, os);
307 os << ">(";
308 this->PrintExpr(op->args[0], os);
309 os << "))";
310 } else {
311 CodeGenC::VisitExpr_(op, os);
312 }
313}
314
315void CodeGenMetal::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
316 std::ostringstream temp;
317 if (std::isinf(op->value)) {
318 if (op->value < 0) {
319 temp << "-";
320 }
321 temp << "INFINITY";
322 } else if (std::isnan(op->value)) {
323 temp << "NAN";
324 } else {
325 temp << std::scientific << op->value;
326 if (op->dtype.bits() == 32)
327 temp << 'f';
328 else if (op->dtype.bits() == 16)
329 temp << 'h';
330 }
331 MarkConst(temp.str());
332 os << temp.str();
333}
334
335runtime::Module BuildMetal(IRModule mod, Target target) {
336 using tvm::runtime::Registry;
337 bool output_ssa = false;
338
339 std::stringstream code;
340 std::stringstream source;
341 std::string fmt = "metal";
342 for (auto kv : mod->functions) {
343 ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only take PrimFunc";
344 code << "// Function: " << kv.first->name_hint << std::endl;
345 CodeGenMetal cg(target);
346 cg.Init(output_ssa);
347 auto f = Downcast<PrimFunc>(kv.second);
348 auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
349 ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
350 << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
351 cg.AddFunction(f);
352 std::string fsource = cg.Finish();
353 if (const auto* f = Registry::Get("tvm_callback_metal_compile")) {
354 source << fsource;
355 fsource = (*f)(fsource).operator std::string();
356 fmt = "metallib";
357 }
358 code << fsource;
359 }
360
361 return MetalModuleCreate(code.str(), fmt, ExtractFuncInfo(mod), source.str());
362}
363
364TVM_REGISTER_GLOBAL("target.build.metal").set_body_typed(BuildMetal);
365} // namespace codegen
366} // namespace tvm
367