1/*******************************************************************************
2* Copyright 2020 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_LRN_LRN_EXECUTOR_FACTORY_HPP
18#define CPU_X64_LRN_LRN_EXECUTOR_FACTORY_HPP
19
20#include <memory>
21#include "common/c_types_map.hpp"
22#include "common/utils.hpp"
23#include "cpu/x64/lrn/jit_avx512_common_lrn_utils.hpp"
24#include "cpu/x64/lrn/lrn_avx512_blocked_executor.hpp"
25#include "cpu/x64/lrn/lrn_avx512_nhwc_executor.hpp"
26#include "cpu/x64/lrn/lrn_executor.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31namespace x64 {
32namespace lrn {
33
34class lrn_executor_factory_t {
35public:
36 template <::dnnl::impl::data_type_t d_type, typename PD_T>
37 static std::unique_ptr<i_lrn_executor_t> create_executor(
38 const PD_T *pd, direction dir) {
39 const memory_desc_wrapper data_d(pd->src_md());
40
41 if (data_d.matches_tag(format_tag::nChw16c))
42 return create_jit_avx512_blocked_executor<d_type, PD_T>(pd, dir);
43
44 return create_jit_avx512_nhwc_executor<d_type, PD_T>(pd, dir);
45 }
46
47private:
48 template <::dnnl::impl::data_type_t d_type, typename PD_T>
49 static std::unique_ptr<i_lrn_executor_t> create_jit_avx512_nhwc_executor(
50 const PD_T *pd, direction dir) {
51
52 if (dir == direction::forward)
53 return utils::make_unique<
54 lrn_avx512_nhwc_executor_fwd_t<d_type, PD_T>>(pd);
55 return utils::make_unique<lrn_avx512_nhwc_executor_bwd_t<d_type, PD_T>>(
56 pd);
57 }
58
59 template <::dnnl::impl::data_type_t d_type, typename PD_T>
60 static std::unique_ptr<i_lrn_executor_t> create_jit_avx512_blocked_executor(
61 const PD_T *pd, direction dir) {
62
63 if (dir == direction::forward)
64 return utils::make_unique<
65 lrn_avx512_blocked_executor_fwd_t<d_type, PD_T>>(pd);
66 return utils::make_unique<
67 lrn_avx512_blocked_executor_bwd_t<d_type, PD_T>>(pd);
68 }
69};
70
71} // namespace lrn
72} // namespace x64
73} // namespace cpu
74} // namespace impl
75} // namespace dnnl
76
77#endif