1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/gpu_impl_list.hpp"
18
19#include "gpu/ocl/convolution_inner_product.hpp"
20#include "gpu/ocl/gemm_inner_product.hpp"
21#include "gpu/ocl/gemm_post_ops_inner_product.hpp"
22#include "gpu/ocl/ref_inner_product.hpp"
23
24namespace dnnl {
25namespace impl {
26namespace gpu {
27
28namespace {
29using namespace dnnl::impl::prop_kind;
30
31// clang-format off
32const std::map<pk_impl_key_t, std::vector<impl_list_item_t>>
33 impl_list_map REG_IP_P({
34 {{forward}, {
35 INSTANCE(ocl::gemm_inner_product_fwd_t)
36 INSTANCE(ocl::convolution_inner_product_fwd_t)
37 INSTANCE(ocl::ref_inner_product_fwd_t)
38 nullptr,
39 }},
40 {{backward}, REG_BWD_PK({
41 INSTANCE(ocl::gemm_inner_product_bwd_data_t)
42 INSTANCE(ocl::gemm_inner_product_bwd_weights_t)
43 INSTANCE(ocl::ref_inner_product_bwd_data_t)
44 INSTANCE(ocl::ref_inner_product_bwd_weights_t)
45 nullptr,
46 })},
47});
48// clang-format on
49} // namespace
50
51const impl_list_item_t *get_inner_product_impl_list(
52 const inner_product_desc_t *desc) {
53 static const impl_list_item_t empty_list[] = {nullptr};
54
55 const bool is_fwd = utils::one_of(
56 desc->prop_kind, forward_training, forward_inference);
57 prop_kind_t prop_kind = is_fwd ? forward : backward;
58
59 const auto impl_list_it = impl_list_map.find({prop_kind});
60 return impl_list_it != impl_list_map.cend() ? impl_list_it->second.data()
61 : empty_list;
62}
63
64} // namespace gpu
65} // namespace impl
66} // namespace dnnl
67