1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file codegen_opencl.cc |
22 | */ |
23 | #include "codegen_opencl.h" |
24 | |
25 | #include <cmath> |
26 | #include <string> |
27 | #include <vector> |
28 | |
29 | #include "../../runtime/opencl/opencl_module.h" |
30 | #include "../../runtime/texture.h" |
31 | #include "../../runtime/thread_storage_scope.h" |
32 | #include "../build_common.h" |
33 | |
34 | namespace tvm { |
35 | namespace codegen { |
36 | |
37 | class InferTextureAccess : public StmtExprVisitor { |
38 | public: |
39 | static constexpr const uint8_t kReadAccess = 1; |
40 | static constexpr const uint8_t kWriteAccess = 2; |
41 | |
42 | InferTextureAccess() {} |
43 | std::unordered_map<const VarNode*, std::string> Infer(const Stmt& n) { |
44 | StmtExprVisitor::VisitStmt(n); |
45 | std::unordered_map<const VarNode*, std::string> storage_scope_qualifiers; |
46 | for (auto& texture : var_access_map_) { |
47 | if (texture.second == kReadAccess) { |
48 | storage_scope_qualifiers.insert({texture.first, "texture_read" }); |
49 | } else if (texture.second == kWriteAccess) { |
50 | storage_scope_qualifiers.insert({texture.first, "texture_write" }); |
51 | } else if (texture.second == (kReadAccess | kWriteAccess)) { |
52 | storage_scope_qualifiers.insert({texture.first, "" }); |
53 | } |
54 | } |
55 | return storage_scope_qualifiers; |
56 | } |
57 | void VisitExpr_(const CallNode* op) { |
58 | if (op->op.same_as(builtin::texture2d_load())) { |
59 | var_access_map_[op->args[0].as<VarNode>()] |= kReadAccess; |
60 | } else if (op->op.same_as(builtin::texture2d_store())) { |
61 | var_access_map_[op->args[0].as<VarNode>()] |= kWriteAccess; |
62 | } |
63 | StmtExprVisitor::VisitExpr_(op); |
64 | } |
65 | |
66 | private: |
67 | std::unordered_map<const VarNode*, uint8_t> var_access_map_; |
68 | }; |
69 | |
70 | CodeGenOpenCL::CodeGenOpenCL() { |
71 | // Set OpenCL specific restrict keyword |
72 | restrict_keyword_ = "restrict" ; |
73 | } |
74 | |
75 | void CodeGenOpenCL::InitFuncState(const PrimFunc& f) { |
76 | CodeGenC::InitFuncState(f); |
77 | this->SetTextureScope(InferTextureAccess().Infer(f->body)); |
78 | for (Var arg : f->params) { |
79 | auto ptr_type = arg->type_annotation.as<PointerTypeNode>(); |
80 | if (ptr_type && runtime::IsTextureStorage(std::string(ptr_type->storage_scope))) { |
81 | // Storage scope qualifiers for textures are inferred |
82 | // and set prior to function codegen. |
83 | continue; |
84 | } else if (arg.dtype().is_handle()) { |
85 | alloc_storage_scope_[arg.get()] = "global" ; |
86 | } |
87 | } |
88 | } |
89 | |
90 | void CodeGenOpenCL::PrintFuncPrefix(std::ostream& os) { os << "__kernel void" ; } |
91 | |
92 | void CodeGenOpenCL::PreFunctionBody(const PrimFunc& f) { |
93 | for (Var arg : f->params) { |
94 | auto ptr_type = arg->type_annotation.as<PointerTypeNode>(); |
95 | if (ptr_type && runtime::IsTextureStorage(std::string(ptr_type->storage_scope))) { |
96 | this->stream << " const sampler_t image_sampler = " |
97 | "CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" ; |
98 | return; |
99 | } |
100 | } |
101 | } |
102 | |
103 | std::string CodeGenOpenCL::Finish() { |
104 | // inject extension enable pragma for fp16 and fp64 |
105 | if (enable_fp16_) { |
106 | decl_stream << "#ifdef cl_khr_fp16\n" |
107 | "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" |
108 | "#elif defined(cl_amd_fp16)\n" |
109 | "#pragma OPENCL EXTENSION cl_amd_fp16 : enable\n" |
110 | "#else\n" |
111 | "#error \"Half precision floating point not supported" |
112 | " by OpenCL implementation on your device.\" \n" |
113 | "#endif\n\n" ; |
114 | } |
115 | |
116 | if (enable_fp64_) { |
117 | decl_stream << "#ifdef cl_khr_fp64\n" |
118 | "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n" |
119 | "#elif defined(cl_amd_fp64)\n" |
120 | "#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n" |
121 | "#else\n" |
122 | "#error \"Double precision floating point not supported" |
123 | " by OpenCL implementation on your device.\" \n" |
124 | "#endif\n\n" ; |
125 | } |
126 | |
127 | // Enable atomic_add used by get_valid_counts. Only needed for OpenCL < 1.1. |
128 | if (enable_atomics_) { |
129 | decl_stream << "#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable\n" |
130 | "#pragma OPENCL EXTENSION cl_khr_global_int32_extended_atomics : enable\n\n" ; |
131 | } |
132 | |
133 | // Enable OpenCL 1.2 sampler-less texture reads, but utilize |
134 | // provided sampler in OpenCL 2.0. |
135 | if (enable_compliant_texture_reads_) { |
136 | // TODO(csullivan, lunderberg): Extend device attribute querying to support remote devices |
137 | // generically through the device API such that a target can be created from a specific device's |
138 | // attributes and utilized during codegen. Potential generlization of #8127 (c02cafb) for remote |
139 | // devices. |
140 | // |
141 | // E.g. Only provide an image sampler when the local or remote device supports OpenCL 2.0, |
142 | // see below for context. |
143 | // |
144 | // For backwards compatibility with OpenCL 1.2, sampler-less read_image calls are used. |
145 | // By default in sampler-less read_image calls OpenCL defaults to |
146 | // sampler_ = "CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST"; |
147 | // See section 6.12.14.3 Built-in Image Sampler-less Read Functions in the OpenCL 1.2 |
148 | // specification. For OpenCL 2.0 it can be preferable to use, |
149 | // sampler_ = "CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST"; |
150 | // For now we rely on OpenCL preprocessor directives to utilize the correct behavior |
151 | // depending on the OpenCL version detected at OpenCL compile time. |
152 | decl_stream << "#ifdef __OPENCL_VERSION__\n" |
153 | << "#if __OPENCL_VERSION__ == CL_VERSION_2_0" |
154 | << " || __OPENCL_VERSION__ == CL_VERSION_3_0 \n" |
155 | << "#define READ_IMAGEH(image, sampler, coord) " |
156 | << "read_imageh(image, sampler, coord)\n" |
157 | << "#define READ_IMAGEF(image, sampler, coord) " |
158 | << "read_imagef(image, sampler, coord)\n" |
159 | << "#else\n" |
160 | << "#define READ_IMAGEH(image, sampler, coord) " |
161 | << "read_imageh(image, coord)\n" |
162 | << "#define READ_IMAGEF(image, sampler, coord) " |
163 | << "read_imagef(image, coord)\n" |
164 | << "#endif\n" |
165 | << "#endif\n\n" ; |
166 | } |
167 | return CodeGenC::Finish(); |
168 | } |
169 | |
170 | void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) { |
171 | ICHECK(!var_idmap_.count(iv->var.get())); |
172 | runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); |
173 | std::ostringstream os; |
174 | if (ts.rank == 1) { |
175 | os << "get_local_id(" << ts.dim_index << ")" ; |
176 | } else { |
177 | os << "get_group_id(" << ts.dim_index << ")" ; |
178 | } |
179 | var_idmap_[iv->var.get()] = CastFromTo(os.str(), DataType::UInt(64), iv->var.dtype()); |
180 | } |
181 | |
182 | void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) |
183 | int lanes = t.lanes(); |
184 | if (t.is_handle()) { |
185 | ICHECK_EQ(lanes, 1) << "do not yet support vector types" ; |
186 | os << "void*" ; |
187 | return; |
188 | } |
189 | if (t.is_void()) { |
190 | os << "void" ; |
191 | return; |
192 | } |
193 | if (t == DataType::Bool()) { |
194 | os << "bool" ; |
195 | return; |
196 | } |
197 | bool fail = false; |
198 | if (t.is_float()) { |
199 | switch (t.bits()) { |
200 | case 16: |
201 | os << "half" ; |
202 | enable_fp16_ = true; |
203 | break; |
204 | case 32: |
205 | os << "float" ; |
206 | break; |
207 | case 64: |
208 | os << "double" ; |
209 | enable_fp64_ = true; |
210 | break; |
211 | default: |
212 | fail = true; |
213 | break; |
214 | } |
215 | if (!fail && lanes == 1) return; |
216 | if (!fail && ((lanes >= 2 && lanes <= 4) || lanes == 8 || lanes == 16)) { |
217 | os << lanes; |
218 | return; |
219 | } |
220 | } else if (t.is_uint() || t.is_int()) { |
221 | if (t.is_uint()) { |
222 | os << 'u'; |
223 | } |
224 | if (t.bits() == 8 && t.lanes() == 4) { |
225 | // directly 4 8 bit int in integer. |
226 | os << "int" ; |
227 | return; |
228 | } |
229 | switch (t.bits()) { |
230 | case 8: |
231 | os << "char" ; |
232 | break; |
233 | case 16: |
234 | os << "short" ; |
235 | break; |
236 | case 32: |
237 | os << "int" ; |
238 | break; |
239 | case 64: |
240 | os << "long" ; |
241 | break; |
242 | case 1: |
243 | os << "int" ; |
244 | break; |
245 | default: |
246 | fail = true; |
247 | break; |
248 | } |
249 | if (!fail && lanes == 1) return; |
250 | if (!fail && ((lanes >= 2 && lanes <= 4) || lanes == 8 || lanes == 16)) { |
251 | os << lanes; |
252 | return; |
253 | } |
254 | } |
255 | LOG(FATAL) << "Cannot convert type " << t << " to OpenCL type" ; |
256 | } |
257 | |
258 | void CodeGenOpenCL::PrintType(const Type& type, std::ostream& os) { // NOLINT(*) |
259 | if (auto* ptr = type.as<PrimTypeNode>()) { |
260 | return PrintType(ptr->dtype, os); |
261 | } else if (auto* ptr = type.as<PointerTypeNode>()) { |
262 | if (runtime::IsTextureStorage(std::string(ptr->storage_scope))) { |
263 | os << "image2d_t" ; |
264 | } else { |
265 | PrintType(ptr->element_type, os); |
266 | os << '*'; |
267 | } |
268 | } else if (IsVoidType(type)) { |
269 | os << "void" ; |
270 | } else { |
271 | LOG(FATAL) << "Type " << type << " does not have a corresponding C Type" ; |
272 | } |
273 | } |
274 | |
275 | void CodeGenOpenCL::PrintVecAddr(const BufferNode* buffer, DataType t, PrimExpr base, |
276 | std::ostream& os) { // NOLINT(*) |
277 | const VarNode* buffer_var = buffer->data.get(); |
278 | if (!HandleTypeMatch(buffer_var, t.element_of())) { |
279 | os << '('; |
280 | auto it = alloc_storage_scope_.find(buffer_var); |
281 | if (it != alloc_storage_scope_.end()) { |
282 | PrintStorageScope(it->second, os); |
283 | } |
284 | PrintType(t.element_of(), os); |
285 | os << "*)" ; |
286 | } |
287 | os << GetVarID(buffer_var) << " + " ; |
288 | PrintExpr(base, os); |
289 | } |
290 | std::string CodeGenOpenCL::GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base) { |
291 | std::ostringstream os; |
292 | os << "vload" << t.lanes() << "(0, " ; |
293 | PrintVecAddr(buffer, t, base, os); |
294 | os << ")" ; |
295 | return os.str(); |
296 | } |
297 | |
298 | void CodeGenOpenCL::PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base, |
299 | const std::string& value) { |
300 | this->PrintIndent(); |
301 | stream << "vstore" << t.lanes() << "(" << value << ", 0, " ; |
302 | PrintVecAddr(buffer, t, base, stream); |
303 | stream << ");\n" ; |
304 | } |
305 | |
306 | void CodeGenOpenCL::PrintStorageSync(const CallNode* op) { |
307 | const std::string& sync = op->args[0].as<StringImmNode>()->value; |
308 | if (sync == "warp" ) { |
309 | this->PrintIndent(); |
310 | this->stream << "barrier(CLK_LOCAL_MEM_FENCE);\n" ; |
311 | } else if (sync == "shared" ) { |
312 | this->PrintIndent(); |
313 | this->stream << "barrier(CLK_LOCAL_MEM_FENCE);\n" ; |
314 | } else if (sync == "global" ) { |
315 | LOG(FATAL) << "not supported" ; |
316 | } |
317 | } |
318 | |
319 | void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) |
320 | if (scope == "global" ) { |
321 | os << "__global " ; |
322 | } else if (scope == "shared" ) { |
323 | os << "__local " ; |
324 | } else if (scope == "texture_read" ) { |
325 | os << "__read_only " ; |
326 | } else if (scope == "texture_write" ) { |
327 | os << "__write_only " ; |
328 | } |
329 | } |
330 | |
331 | void CodeGenOpenCL::PrintRestrict(const Var& v, std::ostream& os) { |
332 | // Apply restrict qualifer for non-texture types only |
333 | if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) { |
334 | if (!runtime::IsTextureStorage(std::string(ptr->storage_scope))) { |
335 | os << ' ' << restrict_keyword_; |
336 | } |
337 | } |
338 | } |
339 | |
340 | std::string CodeGenOpenCL::CastFromTo(std::string value, DataType from, DataType target) { |
341 | if (from == target) return value; |
342 | return CastTo(value, target); |
343 | } |
344 | |
345 | std::string CodeGenOpenCL::CastTo(std::string value, DataType target) { |
346 | std::ostringstream os; |
347 | if (target.lanes() == 1) { |
348 | os << "((" ; |
349 | this->PrintType(target, os); |
350 | os << ")" << value << ")" ; |
351 | } else { // convert vector type |
352 | os << "(" ; |
353 | os << "convert_" ; |
354 | this->PrintType(target, os); |
355 | os << "(" << value << "))" ; |
356 | } |
357 | return os.str(); |
358 | } |
359 | |
360 | void CodeGenOpenCL::VisitStmt_(const StoreNode* op) { |
361 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
362 | } |
363 | |
364 | void CodeGenOpenCL::VisitStmt_(const BufferStoreNode* op) { |
365 | if (auto call = op->value.as<CallNode>()) { |
366 | if (call->op.same_as(builtin::texture2d_load())) { |
367 | need_texture_ssa_ = false; |
368 | // If storing a texture load into a buffer, don't use an |
369 | // intermediate local unless the buffer allocation is a |
370 | // single element selected from the texture read. |
371 | auto it = allocation_size_.find(op->buffer->data.get()); |
372 | if (it != allocation_size_.end() && it->second == 1) { |
373 | need_texture_ssa_ = true; |
374 | } |
375 | } |
376 | } |
377 | CodeGenC::VisitStmt_(op); |
378 | need_texture_ssa_ = true; |
379 | } |
380 | |
381 | void CodeGenOpenCL::VisitExpr_(const CastNode* op, std::ostream& os) { |
382 | if (auto call = op->value.as<CallNode>()) { |
383 | if (call->op.same_as(builtin::texture2d_load())) { |
384 | need_texture_ssa_ = false; |
385 | } |
386 | } |
387 | CodeGenC::VisitExpr_(op, os); |
388 | need_texture_ssa_ = true; |
389 | } |
390 | |
391 | void CodeGenOpenCL::VisitStmt_(const AllocateNode* op) { |
392 | allocation_size_.insert({op->buffer_var.get(), op->ConstantAllocationSize() * op->dtype.lanes()}); |
393 | CodeGenC::VisitStmt_(op); |
394 | } |
395 | |
396 | void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { |
397 | if (op->op.same_as(builtin::address_of())) { |
398 | // Overload tvm_address_of to add storage scope (e.g. __global). |
399 | const BufferLoadNode* load = op->args[0].as<BufferLoadNode>(); |
400 | ICHECK(op->args.size() == 1 && load); |
401 | ICHECK_EQ(load->indices.size(), 1) << "CodeGenOpenCL only supports flat memory allocations." ; |
402 | os << "((" ; |
403 | auto it = alloc_storage_scope_.find(load->buffer->data.get()); |
404 | if (it != alloc_storage_scope_.end()) { |
405 | PrintStorageScope(it->second, os); |
406 | } |
407 | this->PrintType(load->dtype.element_of(), os); |
408 | os << " *)" << this->GetVarID(load->buffer->data.get()) << " + " ; |
409 | this->PrintExpr(load->indices[0], os); |
410 | os << ')'; |
411 | } else if (op->op.same_as(builtin::texture2d_store())) { |
412 | auto* ptr_type = op->args[0].as<VarNode>()->type_annotation.as<PointerTypeNode>(); |
413 | ICHECK(ptr_type != nullptr) << "Texture Var's must be of PointerType" ; |
414 | ICHECK(runtime::IsTextureStorage(std::string(ptr_type->storage_scope))) |
415 | << "builtin::texture2d_store() only supports storing to texture buffers" ; |
416 | DataType buffer_type = ptr_type->element_type.as<PrimTypeNode>()->dtype; |
417 | if (buffer_type.is_float16()) { |
418 | os << "write_imageh(" ; |
419 | } else if (buffer_type.is_float()) { |
420 | os << "write_imagef(" ; |
421 | } else { |
422 | LOG(FATAL) << "Unsupported type: " << buffer_type |
423 | << ", currently only float and half are supported for image2d OpenCL codegen." ; |
424 | } |
425 | this->PrintExpr(op->args[0], os); |
426 | os << ", " ; |
427 | os << "(int2)(" ; |
428 | this->PrintExpr(op->args[1], os); |
429 | os << ", " ; |
430 | this->PrintExpr(op->args[2], os); |
431 | os << "), " ; |
432 | this->PrintExpr(op->args[3], os); |
433 | os << ")" ; |
434 | } else if (op->op.same_as(builtin::texture2d_load())) { |
435 | enable_compliant_texture_reads_ = true; |
436 | std::stringstream ss; |
437 | if (op->dtype.is_float16()) { |
438 | ss << "READ_IMAGEH(" ; |
439 | } else if (op->dtype.is_float()) { |
440 | ss << "READ_IMAGEF(" ; |
441 | } else { |
442 | LOG(FATAL) << "Unsupported type: " << op->dtype |
443 | << ", currently only float and half are supported for image2d OpenCL codegen." ; |
444 | } |
445 | this->PrintExpr(op->args[0], ss); |
446 | ss << ", " ; |
447 | ss << "image_sampler, " ; |
448 | ss << "((int2)(" ; |
449 | this->PrintExpr(op->args[1], ss); |
450 | ss << ", " ; |
451 | this->PrintExpr(op->args[2], ss); |
452 | ss << ")))" ; |
453 | |
454 | // Only use local SSA if texture is not already being stored |
455 | if (need_texture_ssa_) { |
456 | std::string rhs = SSAGetID(ss.str(), op->dtype.with_lanes(4)); |
457 | if (op->args.back().as<RampNode>()) { |
458 | os << rhs; |
459 | } else { |
460 | os << "((" ; |
461 | this->PrintType(op->dtype.with_lanes(1), os); |
462 | os << "*)&" << rhs << ")[" ; |
463 | this->PrintExpr(op->args.back(), os); |
464 | os << "]" ; |
465 | } |
466 | } else { |
467 | os << ss.str(); |
468 | } |
469 | } else if (op->op.same_as(builtin_call_extern_)) { |
470 | auto func = Downcast<StringImm>(op->args[0]); |
471 | // Enable atomics extension if used. |
472 | if (func->value == "atomic_add" ) { |
473 | enable_atomics_ = true; |
474 | } |
475 | CodeGenC::VisitExpr_(op, os); |
476 | } else { |
477 | CodeGenC::VisitExpr_(op, os); |
478 | } |
479 | } |
480 | |
481 | void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) |
482 | std::string v = PrintExpr(op->value); |
483 | os << "((" ; |
484 | PrintType(op->dtype, os); |
485 | os << ")(" ; |
486 | for (int i = 0; i < op->lanes; ++i) { |
487 | if (i != 0) os << ", " ; |
488 | os << v; |
489 | } |
490 | os << "))" ; |
491 | } |
492 | |
493 | void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) |
494 | if (std::isinf(op->value)) { |
495 | if (op->value < 0) { |
496 | os << "-" ; |
497 | } |
498 | os << "INFINITY" ; |
499 | } else if (std::isnan(op->value)) { |
500 | os << "NAN" ; |
501 | } else { |
502 | CodeGenC::VisitExpr_(op, os); |
503 | } |
504 | } |
505 | |
506 | template <typename T> |
507 | inline void PrintBinaryExpr(const T* op, const char* opstr, std::ostream& os, CodeGenOpenCL* p) { |
508 | if (op->dtype.lanes() == 1) { |
509 | os << opstr << "((" ; |
510 | p->PrintType(op->a->dtype, os); |
511 | os << ")" ; |
512 | p->PrintExpr(op->a, os); |
513 | os << ", (" ; |
514 | p->PrintType(op->b->dtype, os); |
515 | os << ")" ; |
516 | p->PrintExpr(op->b, os); |
517 | os << ')'; |
518 | } else { |
519 | p->PrintVecBinaryOp(opstr, op->dtype, op->a, op->b, os); |
520 | } |
521 | } |
522 | |
523 | void CodeGenOpenCL::VisitExpr_(const MinNode* op, std::ostream& os) { |
524 | PrintBinaryExpr(op, "min" , os, this); |
525 | } |
526 | |
527 | void CodeGenOpenCL::VisitExpr_(const MaxNode* op, std::ostream& os) { |
528 | PrintBinaryExpr(op, "max" , os, this); |
529 | } |
530 | |
531 | void CodeGenOpenCL::VisitExpr_(const AndNode* op, std::ostream& os) { |
532 | std::ostringstream oss; |
533 | os << "(" ; |
534 | this->PrintExpr(op->a, oss); |
535 | os << CastTo(oss.str(), op->dtype); |
536 | oss.str("" ); |
537 | os << " && " ; |
538 | this->PrintExpr(op->b, oss); |
539 | os << CastTo(oss.str(), op->dtype); |
540 | os << ")" ; |
541 | } |
542 | |
543 | void CodeGenOpenCL::VisitExpr_(const OrNode* op, std::ostream& os) { |
544 | std::ostringstream oss; |
545 | os << "(" ; |
546 | this->PrintExpr(op->a, oss); |
547 | os << CastTo(oss.str(), op->dtype); |
548 | oss.str("" ); |
549 | os << " || " ; |
550 | this->PrintExpr(op->b, oss); |
551 | os << CastTo(oss.str(), op->dtype); |
552 | os << ")" ; |
553 | } |
554 | |
555 | void CodeGenOpenCL::VisitExpr_(const SelectNode* op, std::ostream& os) { |
556 | std::ostringstream oss; |
557 | os << "select(" ; |
558 | PrintExpr(op->false_value, oss); |
559 | os << CastFromTo(oss.str(), op->false_value.dtype(), op->dtype); |
560 | oss.str("" ); |
561 | os << ", " ; |
562 | PrintExpr(op->true_value, oss); |
563 | os << CastFromTo(oss.str(), op->true_value.dtype(), op->dtype); |
564 | oss.str("" ); |
565 | os << ", " ; |
566 | PrintExpr(op->condition, oss); |
567 | if (op->dtype.is_float()) { |
568 | if (op->condition.dtype().is_uint() || op->condition.dtype().is_int()) { |
569 | os << oss.str(); |
570 | } else { |
571 | os << CastTo(oss.str(), DataType::Int(op->dtype.bits(), op->dtype.lanes())); |
572 | } |
573 | } else { |
574 | os << CastFromTo(oss.str(), op->condition.dtype(), op->dtype); |
575 | } |
576 | os << ")" ; |
577 | } |
578 | |
579 | void CodeGenOpenCL::SetTextureScope( |
580 | const std::unordered_map<const VarNode*, std::string>& scope) { // NOLINT(*) |
581 | for (auto& texture : scope) { |
582 | alloc_storage_scope_.insert(texture); |
583 | } |
584 | } |
585 | |
586 | runtime::Module BuildOpenCL(IRModule mod, Target target) { |
587 | using tvm::runtime::Registry; |
588 | bool output_ssa = false; |
589 | |
590 | std::stringstream code; |
591 | const auto* fpostproc = Registry::Get("tvm_callback_opencl_postproc" ); |
592 | for (auto kv : mod->functions) { |
593 | ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only take PrimFunc" ; |
594 | code << "// Function: " << kv.first->name_hint << std::endl; |
595 | CodeGenOpenCL cg; |
596 | cg.Init(output_ssa); |
597 | auto f = Downcast<PrimFunc>(kv.second); |
598 | auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); |
599 | ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) |
600 | << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch" ; |
601 | cg.AddFunction(f); |
602 | std::string fsource = cg.Finish(); |
603 | if (fpostproc) { |
604 | fsource = (*fpostproc)(fsource).operator std::string(); |
605 | } |
606 | code << fsource; |
607 | } |
608 | |
609 | return OpenCLModuleCreate(code.str(), "cl" , ExtractFuncInfo(mod), code.str()); |
610 | } |
611 | |
612 | TVM_REGISTER_GLOBAL("target.build.opencl" ).set_body_typed(BuildOpenCL); |
613 | } // namespace codegen |
614 | } // namespace tvm |
615 | |