1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_COMPUTE_COMPUTE_ENGINE_HPP
18#define GPU_COMPUTE_COMPUTE_ENGINE_HPP
19
20#include <cassert>
21#include <memory>
22#include <vector>
23#include <initializer_list>
24
25#include "common/c_types_map.hpp"
26#include "common/engine.hpp"
27#include "common/primitive.hpp"
28#include "common/primitive_desc_iterator.hpp"
29#include "common/resource.hpp"
30#include "common/verbose.hpp"
31#include "gpu/compute/device_info.hpp"
32#include "gpu/compute/dispatch.hpp"
33#include "gpu/compute/kernel.hpp"
34#include "gpu/compute/kernel_ctx.hpp"
35#include "gpu/jit/jit_generator_base.hpp"
36
37namespace dnnl {
38namespace impl {
39namespace gpu {
40namespace compute {
41
42class compute_engine_t : public engine_t {
43public:
44 compute_engine_t(
45 engine_kind_t kind, runtime_kind_t runtime_kind, size_t index)
46 : engine_t(kind, runtime_kind, index) {}
47
48 virtual status_t init();
49 status_t init(const std::vector<uint8_t> &cache_blob);
50
51 const device_info_t *device_info() const { return device_info_.get(); }
52
53 virtual status_t create_kernel(compute::kernel_t *kernel,
54 jit::jit_generator_base *jitter, cache_blob_t cache_blob) const = 0;
55
56 virtual status_t create_kernels(std::vector<compute::kernel_t> *kernels,
57 const std::vector<const char *> &kernel_names,
58 const compute::kernel_ctx_t &kernel_ctx,
59 cache_blob_t cache_blob) const = 0;
60
61 virtual status_t create_kernels_from_ocl_source(
62 std::vector<compute::kernel_t> *kernels,
63 const std::vector<const char *> &kernel_names,
64 const char *source_string,
65 const compute::kernel_ctx_t &kernel_ctx) const {
66 assert(!"unexpected");
67 return status::success;
68 };
69
70 status_t get_zero_pad_primitive(
71 primitive_t *&result, const resource_mapper_t *&resources) {
72 std::call_once(zero_pad_init_, [&]() -> void {
73 zero_pad_desc_t desc;
74 desc.primitive_kind = primitive_kind::zero_pad;
75 primitive_desc_iterator_t it(
76 this, (op_desc_t *)&desc, nullptr, nullptr);
77 std::shared_ptr<primitive_desc_t> zero_pad_pd(*(++it));
78 if (zero_pad_pd == nullptr) return;
79
80 status_t status
81 = zero_pad_pd->create_primitive(zero_pad_primitive_, this);
82 if (status != status::success) { zero_pad_primitive_.reset(); }
83 });
84
85 result = zero_pad_primitive_.get();
86 resources = &zero_pad_resources_;
87 return result != nullptr ? status::success : status::unimplemented;
88 };
89
90 bool mayiuse_f16_accumulator_with_f16() const override {
91 // XeHPC+ must use f32 accumulation with f16 operations as documented.
92 switch (device_info_->gpu_arch()) {
93 case gpu_arch_t::gen9:
94 case gpu_arch_t::gen11:
95 case gpu_arch_t::xe_lp:
96 case gpu_arch_t::xe_hp:
97 case gpu_arch_t::xe_hpg: return true;
98 default: return false;
99 }
100 }
101
102 bool mayiuse(device_ext_t ext) const { return device_info_->has(ext); }
103
104 bool is_gen9() const {
105 return device_info_->gpu_arch() == gpu_arch_t::gen9;
106 }
107 bool is_gen11() const {
108 return device_info_->gpu_arch() == gpu_arch_t::gen11;
109 }
110 bool is_xe_lp() const {
111 return device_info_->gpu_arch() == gpu_arch_t::xe_lp;
112 }
113 bool is_xe_hp() const {
114 return device_info_->gpu_arch() == gpu_arch_t::xe_hp;
115 }
116 bool is_xe_hpg() const {
117 return device_info_->gpu_arch() == gpu_arch_t::xe_hpg;
118 }
119 bool is_xe_hpc() const {
120 return device_info_->gpu_arch() == gpu_arch_t::xe_hpc;
121 }
122 bool mayiuse_ngen_kernels() const {
123 return device_info_->mayiuse_ngen_kernels();
124 }
125 bool mayiuse_non_uniform_work_groups() const {
126 return device_info_->mayiuse_non_uniform_work_groups();
127 }
128 bool mayiuse_sub_group(int size) const {
129 return device_info_->mayiuse_sub_group(size);
130 }
131 bool mayiuse_sub_group(std::initializer_list<int> sizes) const {
132 for (int size : sizes)
133 if (!mayiuse_sub_group(size)) return false;
134 return true;
135 }
136 bool mayiuse_large_grf_mode() const {
137 // XXX: XeHPG 128EU A0 causes hangs with large GRF mode.
138 if (is_xe_hpg() && device_info()->eu_count() == 128
139 && device_info()->stepping_id() == 0)
140 return false;
141 return device_info_->gpu_arch() >= compute::gpu_arch_t::xe_hp;
142 }
143
144 dispatch_t create_dispatch(const memory_desc_t *md = nullptr) const {
145 return dispatch_t(this, md);
146 }
147
148 status_t get_service_stream(stream_t *&stream) override {
149 status_t status = status::success;
150 if (service_stream_ == nullptr) {
151 const std::lock_guard<std::mutex> lock(service_stream_mutex_);
152 if (service_stream_ == nullptr) {
153 stream_t *service_stream_ptr;
154 status = create_stream(
155 &service_stream_ptr, stream_flags::default_flags);
156 if (status == status::success)
157 service_stream_.reset(service_stream_ptr);
158 }
159 }
160 stream = service_stream_.get();
161 return status;
162 }
163
164 // non-blocking query to check if service stream is already created
165 bool is_service_stream_created() const { return (bool)service_stream_; }
166
167 virtual std::function<void(void *)> get_program_list_deleter() const = 0;
168
169protected:
170 virtual status_t init_device_info() = 0;
171 virtual status_t init_device_info(const std::vector<uint8_t> &cache_blob) {
172 assert(!"unexpected");
173 return status::runtime_error;
174 }
175
176 ~compute_engine_t() override = default;
177
178 std::shared_ptr<device_info_t> device_info_;
179
180private:
181 // Implement a zero_pad_primitive shared across the engine. The purpose is
182 // to prevent extra overhead associated with creating zero_pad_primitives
183 // for different inputs as ideally the zero_pad operations fast relative to
184 // the time to create the primitive.
185 std::shared_ptr<primitive_t> zero_pad_primitive_;
186 resource_mapper_t zero_pad_resources_;
187 std::once_flag zero_pad_init_;
188 std::unique_ptr<stream_t> service_stream_;
189 std::mutex service_stream_mutex_;
190};
191
192} // namespace compute
193} // namespace gpu
194} // namespace impl
195} // namespace dnnl
196
197// Exported for testing purposes only.
198extern "C" bool DNNL_API dnnl_impl_gpu_mayiuse_ngen_kernels(
199 dnnl::impl::engine_t *engine);
200
201#endif
202