1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/gpu_impl_list.hpp"
18
19#include "gpu/jit/binary_format.hpp"
20
21#include "gpu/jit/conv/gen_convolution.hpp"
22#include "gpu/ocl/gen9_convolution.hpp"
23#include "gpu/ocl/gen9_wino_convolution.hpp"
24#include "gpu/ocl/ref_convolution.hpp"
25#include "gpu/ocl/xe_lp_x8s8x_1x1_convolution.hpp"
26#include "gpu/ocl/xe_lp_x8s8x_convolution.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace gpu {
31
32namespace {
33using namespace dnnl::impl::prop_kind;
34
35// clang-format off
36const std::map<pk_impl_key_t, std::vector<impl_list_item_t>>
37 impl_list_map REG_CONV_P({
38 {{forward}, {
39 INSTANCE(jit::gen_convolution_fwd_t)
40 INSTANCE(ocl::xe_lp_x8s8x_1x1_convolution_fwd_t)
41 INSTANCE(ocl::xe_lp_x8s8x_convolution_fwd_t)
42 INSTANCE(ocl::gen9_wino_convolution_fwd_t)
43 INSTANCE(ocl::gen9_convolution_fwd_t)
44 INSTANCE(ocl::ref_convolution_fwd_t)
45 nullptr,
46 }},
47 {{backward_data}, REG_BWD_D_PK({
48 INSTANCE(jit::gen_convolution_bwd_data_t)
49 INSTANCE(ocl::xe_lp_x8s8x_convolution_bwd_data_t)
50 INSTANCE(ocl::gen9_convolution_bwd_data_t)
51 INSTANCE(ocl::ref_convolution_bwd_data_t)
52 nullptr,
53 })},
54 {{backward_weights}, REG_BWD_PK({
55 INSTANCE(jit::gen_convolution_bwd_weights_t)
56 INSTANCE(ocl::gen9_convolution_bwd_weights_t)
57 INSTANCE(ocl::ref_convolution_bwd_weights_t)
58 nullptr,
59 })},
60});
61// clang-format on
62} // namespace
63
64const impl_list_item_t *get_convolution_impl_list(
65 const convolution_desc_t *desc) {
66 static const impl_list_item_t empty_list[] = {nullptr};
67
68 const bool is_fwd = utils::one_of(
69 desc->prop_kind, forward_training, forward_inference);
70 prop_kind_t prop_kind = is_fwd ? forward : desc->prop_kind;
71
72 const auto impl_list_it = impl_list_map.find({prop_kind});
73 return impl_list_it != impl_list_map.cend() ? impl_list_it->second.data()
74 : empty_list;
75}
76
77} // namespace gpu
78} // namespace impl
79} // namespace dnnl
80