1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_COMPUTE_COMPUTE_ENGINE_HPP |
18 | #define GPU_COMPUTE_COMPUTE_ENGINE_HPP |
19 | |
20 | #include <cassert> |
21 | #include <memory> |
22 | #include <vector> |
23 | #include <initializer_list> |
24 | |
25 | #include "common/c_types_map.hpp" |
26 | #include "common/engine.hpp" |
27 | #include "common/primitive.hpp" |
28 | #include "common/primitive_desc_iterator.hpp" |
29 | #include "common/resource.hpp" |
30 | #include "common/verbose.hpp" |
31 | #include "gpu/compute/device_info.hpp" |
32 | #include "gpu/compute/dispatch.hpp" |
33 | #include "gpu/compute/kernel.hpp" |
34 | #include "gpu/compute/kernel_ctx.hpp" |
35 | #include "gpu/jit/jit_generator_base.hpp" |
36 | |
37 | namespace dnnl { |
38 | namespace impl { |
39 | namespace gpu { |
40 | namespace compute { |
41 | |
42 | class compute_engine_t : public engine_t { |
43 | public: |
44 | compute_engine_t( |
45 | engine_kind_t kind, runtime_kind_t runtime_kind, size_t index) |
46 | : engine_t(kind, runtime_kind, index) {} |
47 | |
48 | virtual status_t init(); |
49 | status_t init(const std::vector<uint8_t> &cache_blob); |
50 | |
51 | const device_info_t *device_info() const { return device_info_.get(); } |
52 | |
53 | virtual status_t create_kernel(compute::kernel_t *kernel, |
54 | jit::jit_generator_base *jitter, cache_blob_t cache_blob) const = 0; |
55 | |
56 | virtual status_t create_kernels(std::vector<compute::kernel_t> *kernels, |
57 | const std::vector<const char *> &kernel_names, |
58 | const compute::kernel_ctx_t &kernel_ctx, |
59 | cache_blob_t cache_blob) const = 0; |
60 | |
61 | virtual status_t create_kernels_from_ocl_source( |
62 | std::vector<compute::kernel_t> *kernels, |
63 | const std::vector<const char *> &kernel_names, |
64 | const char *source_string, |
65 | const compute::kernel_ctx_t &kernel_ctx) const { |
66 | assert(!"unexpected" ); |
67 | return status::success; |
68 | }; |
69 | |
70 | status_t get_zero_pad_primitive( |
71 | primitive_t *&result, const resource_mapper_t *&resources) { |
72 | std::call_once(zero_pad_init_, [&]() -> void { |
73 | zero_pad_desc_t desc; |
74 | desc.primitive_kind = primitive_kind::zero_pad; |
75 | primitive_desc_iterator_t it( |
76 | this, (op_desc_t *)&desc, nullptr, nullptr); |
77 | std::shared_ptr<primitive_desc_t> zero_pad_pd(*(++it)); |
78 | if (zero_pad_pd == nullptr) return; |
79 | |
80 | status_t status |
81 | = zero_pad_pd->create_primitive(zero_pad_primitive_, this); |
82 | if (status != status::success) { zero_pad_primitive_.reset(); } |
83 | }); |
84 | |
85 | result = zero_pad_primitive_.get(); |
86 | resources = &zero_pad_resources_; |
87 | return result != nullptr ? status::success : status::unimplemented; |
88 | }; |
89 | |
90 | bool mayiuse_f16_accumulator_with_f16() const override { |
91 | // XeHPC+ must use f32 accumulation with f16 operations as documented. |
92 | switch (device_info_->gpu_arch()) { |
93 | case gpu_arch_t::gen9: |
94 | case gpu_arch_t::gen11: |
95 | case gpu_arch_t::xe_lp: |
96 | case gpu_arch_t::xe_hp: |
97 | case gpu_arch_t::xe_hpg: return true; |
98 | default: return false; |
99 | } |
100 | } |
101 | |
102 | bool mayiuse(device_ext_t ext) const { return device_info_->has(ext); } |
103 | |
104 | bool is_gen9() const { |
105 | return device_info_->gpu_arch() == gpu_arch_t::gen9; |
106 | } |
107 | bool is_gen11() const { |
108 | return device_info_->gpu_arch() == gpu_arch_t::gen11; |
109 | } |
110 | bool is_xe_lp() const { |
111 | return device_info_->gpu_arch() == gpu_arch_t::xe_lp; |
112 | } |
113 | bool is_xe_hp() const { |
114 | return device_info_->gpu_arch() == gpu_arch_t::xe_hp; |
115 | } |
116 | bool is_xe_hpg() const { |
117 | return device_info_->gpu_arch() == gpu_arch_t::xe_hpg; |
118 | } |
119 | bool is_xe_hpc() const { |
120 | return device_info_->gpu_arch() == gpu_arch_t::xe_hpc; |
121 | } |
122 | bool mayiuse_ngen_kernels() const { |
123 | return device_info_->mayiuse_ngen_kernels(); |
124 | } |
125 | bool mayiuse_non_uniform_work_groups() const { |
126 | return device_info_->mayiuse_non_uniform_work_groups(); |
127 | } |
128 | bool mayiuse_sub_group(int size) const { |
129 | return device_info_->mayiuse_sub_group(size); |
130 | } |
131 | bool mayiuse_sub_group(std::initializer_list<int> sizes) const { |
132 | for (int size : sizes) |
133 | if (!mayiuse_sub_group(size)) return false; |
134 | return true; |
135 | } |
136 | bool mayiuse_large_grf_mode() const { |
137 | // XXX: XeHPG 128EU A0 causes hangs with large GRF mode. |
138 | if (is_xe_hpg() && device_info()->eu_count() == 128 |
139 | && device_info()->stepping_id() == 0) |
140 | return false; |
141 | return device_info_->gpu_arch() >= compute::gpu_arch_t::xe_hp; |
142 | } |
143 | |
144 | dispatch_t create_dispatch(const memory_desc_t *md = nullptr) const { |
145 | return dispatch_t(this, md); |
146 | } |
147 | |
148 | status_t get_service_stream(stream_t *&stream) override { |
149 | status_t status = status::success; |
150 | if (service_stream_ == nullptr) { |
151 | const std::lock_guard<std::mutex> lock(service_stream_mutex_); |
152 | if (service_stream_ == nullptr) { |
153 | stream_t *service_stream_ptr; |
154 | status = create_stream( |
155 | &service_stream_ptr, stream_flags::default_flags); |
156 | if (status == status::success) |
157 | service_stream_.reset(service_stream_ptr); |
158 | } |
159 | } |
160 | stream = service_stream_.get(); |
161 | return status; |
162 | } |
163 | |
164 | // non-blocking query to check if service stream is already created |
165 | bool is_service_stream_created() const { return (bool)service_stream_; } |
166 | |
167 | virtual std::function<void(void *)> get_program_list_deleter() const = 0; |
168 | |
169 | protected: |
170 | virtual status_t init_device_info() = 0; |
171 | virtual status_t init_device_info(const std::vector<uint8_t> &cache_blob) { |
172 | assert(!"unexpected" ); |
173 | return status::runtime_error; |
174 | } |
175 | |
176 | ~compute_engine_t() override = default; |
177 | |
178 | std::shared_ptr<device_info_t> device_info_; |
179 | |
180 | private: |
181 | // Implement a zero_pad_primitive shared across the engine. The purpose is |
182 | // to prevent extra overhead associated with creating zero_pad_primitives |
183 | // for different inputs as ideally the zero_pad operations fast relative to |
184 | // the time to create the primitive. |
185 | std::shared_ptr<primitive_t> zero_pad_primitive_; |
186 | resource_mapper_t zero_pad_resources_; |
187 | std::once_flag zero_pad_init_; |
188 | std::unique_ptr<stream_t> service_stream_; |
189 | std::mutex service_stream_mutex_; |
190 | }; |
191 | |
192 | } // namespace compute |
193 | } // namespace gpu |
194 | } // namespace impl |
195 | } // namespace dnnl |
196 | |
197 | // Exported for testing purposes only. |
198 | extern "C" bool DNNL_API dnnl_impl_gpu_mayiuse_ngen_kernels( |
199 | dnnl::impl::engine_t *engine); |
200 | |
201 | #endif |
202 | |