1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3* Copyright 2020-2022 Arm Ltd. and affiliates
4*
5* Licensed under the Apache License, Version 2.0 (the "License");
6* you may not use this file except in compliance with the License.
7* You may obtain a copy of the License at
8*
9* http://www.apache.org/licenses/LICENSE-2.0
10*
11* Unless required by applicable law or agreed to in writing, software
12* distributed under the License is distributed on an "AS IS" BASIS,
13* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14* See the License for the specific language governing permissions and
15* limitations under the License.
16*******************************************************************************/
17
18#ifndef CPU_CPU_ENGINE_HPP
19#define CPU_CPU_ENGINE_HPP
20
21#include <assert.h>
22
23#include "oneapi/dnnl/dnnl.h"
24
25#include "common/c_types_map.hpp"
26#include "common/engine.hpp"
27#include "common/engine_id.hpp"
28#include "common/impl_list_item.hpp"
29
30#include "cpu/platform.hpp"
31
32#if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL
33#include "cpu/aarch64/acl_thread.hpp"
34#endif
35
36#define CPU_INSTANCE(...) \
37 impl_list_item_t( \
38 impl_list_item_t::type_deduction_helper_t<__VA_ARGS__::pd_t>()),
39#define CPU_INSTANCE_X64(...) DNNL_X64_ONLY(CPU_INSTANCE(__VA_ARGS__))
40#define CPU_INSTANCE_SSE41(...) REG_SSE41_ISA(CPU_INSTANCE(__VA_ARGS__))
41#define CPU_INSTANCE_AVX2(...) REG_AVX2_ISA(CPU_INSTANCE(__VA_ARGS__))
42#define CPU_INSTANCE_AVX512(...) REG_AVX512_ISA(CPU_INSTANCE(__VA_ARGS__))
43#define CPU_INSTANCE_AMX(...) REG_AMX_ISA(CPU_INSTANCE(__VA_ARGS__))
44#define CPU_INSTANCE_AARCH64(...) DNNL_AARCH64_ONLY(CPU_INSTANCE(__VA_ARGS__))
45#define CPU_INSTANCE_AARCH64_ACL(...) \
46 DNNL_AARCH64_ACL_ONLY(CPU_INSTANCE(__VA_ARGS__))
47
48namespace dnnl {
49namespace impl {
50namespace cpu {
51
52#define DECLARE_IMPL_LIST(kind) \
53 const impl_list_item_t *get_##kind##_impl_list(const kind##_desc_t *desc);
54
55DECLARE_IMPL_LIST(batch_normalization);
56DECLARE_IMPL_LIST(binary);
57DECLARE_IMPL_LIST(convolution);
58DECLARE_IMPL_LIST(deconvolution);
59DECLARE_IMPL_LIST(eltwise);
60DECLARE_IMPL_LIST(inner_product);
61DECLARE_IMPL_LIST(layer_normalization);
62DECLARE_IMPL_LIST(lrn);
63DECLARE_IMPL_LIST(matmul);
64DECLARE_IMPL_LIST(pooling);
65DECLARE_IMPL_LIST(prelu);
66DECLARE_IMPL_LIST(reduction);
67DECLARE_IMPL_LIST(resampling);
68DECLARE_IMPL_LIST(rnn);
69DECLARE_IMPL_LIST(shuffle);
70DECLARE_IMPL_LIST(softmax);
71
72#undef DECLARE_IMPL_LIST
73
74class cpu_engine_impl_list_t {
75public:
76 static const impl_list_item_t *get_concat_implementation_list();
77 static const impl_list_item_t *get_reorder_implementation_list(
78 const memory_desc_t *src_md, const memory_desc_t *dst_md);
79 static const impl_list_item_t *get_sum_implementation_list();
80
81 static const impl_list_item_t *get_implementation_list(
82 const op_desc_t *desc) {
83 static const impl_list_item_t empty_list[] = {nullptr};
84
85// clang-format off
86#define CASE(kind) \
87 case primitive_kind::kind: \
88 return get_##kind##_impl_list((const kind##_desc_t *)desc);
89 switch (desc->kind) {
90 CASE(batch_normalization);
91 CASE(binary);
92 CASE(convolution);
93 CASE(deconvolution);
94 CASE(eltwise);
95 CASE(inner_product);
96 CASE(layer_normalization);
97 CASE(lrn);
98 CASE(matmul);
99 CASE(pooling);
100 CASE(prelu);
101 CASE(reduction);
102 CASE(resampling);
103 CASE(rnn);
104 CASE(shuffle);
105 CASE(softmax);
106 default: assert(!"unknown primitive kind"); return empty_list;
107 }
108#undef CASE
109 }
110 // clang-format on
111};
112
113class cpu_engine_t : public engine_t {
114public:
115 cpu_engine_t() : engine_t(engine_kind::cpu, get_cpu_native_runtime(), 0) {}
116
117 /* implementation part */
118
119 status_t create_memory_storage(memory_storage_t **storage, unsigned flags,
120 size_t size, void *handle) override;
121
122 status_t create_stream(stream_t **stream, unsigned flags) override;
123
124#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
125 status_t create_stream(stream_t **stream,
126 dnnl::threadpool_interop::threadpool_iface *threadpool) override;
127#endif
128
129 const impl_list_item_t *get_concat_implementation_list() const override {
130 return cpu_engine_impl_list_t::get_concat_implementation_list();
131 }
132
133 const impl_list_item_t *get_reorder_implementation_list(
134 const memory_desc_t *src_md,
135 const memory_desc_t *dst_md) const override {
136 return cpu_engine_impl_list_t::get_reorder_implementation_list(
137 src_md, dst_md);
138 }
139 const impl_list_item_t *get_sum_implementation_list() const override {
140 return cpu_engine_impl_list_t::get_sum_implementation_list();
141 }
142 const impl_list_item_t *get_implementation_list(
143 const op_desc_t *desc) const override {
144 return cpu_engine_impl_list_t::get_implementation_list(desc);
145 }
146
147 device_id_t device_id() const override { return std::make_tuple(0, 0, 0); }
148
149 engine_id_t engine_id() const override {
150 // Non-sycl CPU engine doesn't have device and context.
151 return {};
152 }
153
154protected:
155 ~cpu_engine_t() override = default;
156};
157
158class cpu_engine_factory_t : public engine_factory_t {
159public:
160 size_t count() const override { return 1; }
161 status_t engine_create(engine_t **engine, size_t index) const override {
162 assert(index == 0);
163 *engine = new cpu_engine_t();
164
165#if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL
166#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP
167 // Number of threads in Compute Library is set by OMP_NUM_THREADS
168 // dnnl_get_max_threads() == OMP_NUM_THREADS
169 dnnl::impl::cpu::aarch64::acl_thread_utils::acl_thread_bind();
170#endif
171
172#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
173 // Set ACL scheduler for threadpool runtime
174 dnnl::impl::cpu::aarch64::acl_thread_utils::acl_set_custom_scheduler();
175#endif
176#endif
177 return status::success;
178 };
179};
180
181engine_t *get_service_engine();
182
183} // namespace cpu
184} // namespace impl
185} // namespace dnnl
186
187#endif
188
189// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
190