1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "cpu/cpu_engine.hpp"
18
19#include "cpu/rnn/ref_rnn.hpp"
20
21namespace dnnl {
22namespace impl {
23namespace cpu {
24
25namespace {
26using namespace dnnl::impl::data_type;
27using namespace dnnl::impl::prop_kind;
28
29const std::map<pk_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map() {
30 // clang-format off
31 static std::map<pk_impl_key_t, std::vector<impl_list_item_t>> the_map = REG_RNN_P({
32 {{forward}, {
33 CPU_INSTANCE(ref_rnn_fwd_bf16_t)
34 CPU_INSTANCE(ref_rnn_fwd_f32_t)
35 CPU_INSTANCE(ref_rnn_fwd_s8s8_t)
36 CPU_INSTANCE(ref_rnn_fwd_u8s8_t)
37 nullptr,
38 }},
39 {{backward}, REG_BWD_PK({
40 CPU_INSTANCE(ref_rnn_bwd_f32_t)
41 CPU_INSTANCE(ref_rnn_bwd_bf16_t)
42 nullptr,
43 })},
44 });
45 // clang-format on
46 return the_map;
47}
48} // namespace
49
50const impl_list_item_t *get_rnn_impl_list(const rnn_desc_t *desc) {
51 static const impl_list_item_t empty_list[] = {nullptr};
52
53 const bool is_fwd = utils::one_of(
54 desc->prop_kind, forward_training, forward_inference);
55 prop_kind_t prop_kind = is_fwd ? forward : backward;
56
57 pk_impl_key_t key {prop_kind};
58
59 const auto impl_list_it = impl_list_map().find(key);
60 return impl_list_it != impl_list_map().cend() ? impl_list_it->second.data()
61 : empty_list;
62}
63
64} // namespace cpu
65} // namespace impl
66} // namespace dnnl
67