1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file codegen_opencl.h
22 * \brief Generate OpenCL device code.
23 */
24#ifndef TVM_TARGET_SOURCE_CODEGEN_OPENCL_H_
25#define TVM_TARGET_SOURCE_CODEGEN_OPENCL_H_
26
27#include <tvm/target/codegen.h>
28
29#include <string>
30#include <unordered_map>
31
32#include "codegen_c.h"
33
34namespace tvm {
35namespace codegen {
36
37class CodeGenOpenCL final : public CodeGenC {
38 public:
39 CodeGenOpenCL();
40 std::string Finish();
41
42 // override print thread tag.
43 void InitFuncState(const PrimFunc& f) final;
44 void PrintFuncPrefix(std::ostream& os) final; // NOLINT(*)
45 void PreFunctionBody(const PrimFunc& f) final; // NOLINT(*)
46 void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
47 void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
48 void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
49 void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
50 void PrintType(const Type& type, std::ostream& os) final; // NOLINT(*)
51 std::string GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base) final;
52 void PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base,
53 const std::string& value) final; // NOLINT(*)
54 // the address of load/store
55 void PrintVecAddr(const BufferNode* buffer, DataType t, PrimExpr base,
56 std::ostream& os); // NOLINT(*)
57 void PrintRestrict(const Var& v, std::ostream& os) final; // NOLINT(*)
58 std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*)
59 std::string CastTo(std::string value, DataType target); // NOLINT(*)
60 void SetTextureScope(const std::unordered_map<const VarNode*, std::string>&); // NOLINT(*)
61
62 // overload visitor
63 void VisitStmt_(const AllocateNode* op) final; // NOLINT(*)
64 void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
65 void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
66 void VisitExpr_(const CastNode* op, std::ostream& os) final; // NOLINT(*)
67 void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*)
68 void VisitStmt_(const StoreNode* op) final; // NOLINT(*)
69 void VisitStmt_(const BufferStoreNode* op) final; // NOLINT(*)
70
71 // overload min and max to avoid ambiguous call errors
72 void VisitExpr_(const MinNode* op, std::ostream& os) final;
73 void VisitExpr_(const MaxNode* op, std::ostream& os) final;
74 void VisitExpr_(const AndNode* op, std::ostream& os) final;
75 void VisitExpr_(const OrNode* op, std::ostream& os) final;
76 void VisitExpr_(const SelectNode* op, std::ostream& os) final;
77
78 private:
79 // whether enable fp16 and fp64 extension
80 bool enable_fp16_{false};
81 bool enable_fp64_{false};
82 // Whether to enable atomics extension.
83 bool enable_atomics_{false};
84 // Whether to enable sampler or sampler-less texture reads,
85 // where the choice depends on the OpenCL version used.
86 bool enable_compliant_texture_reads_{false};
87 // Key to disable use of texture SSA in certain scenarios. For example,
88 // when loaded value is stored directly to a user declared l-value buffer
89 bool need_texture_ssa_{true};
90 // Mapping from buffer to allocation size.
91 // Useful to track when a scalar store of a vectorized texture load is required.
92 std::unordered_map<const Object*, size_t> allocation_size_;
93};
94
95} // namespace codegen
96} // namespace tvm
97
98#endif // TVM_TARGET_SOURCE_CODEGEN_OPENCL_H_
99