1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_UNI_TBB_BATCH_NORMALIZATION_HPP
18#define CPU_X64_JIT_UNI_TBB_BATCH_NORMALIZATION_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/primitive.hpp"
22#include "common/type_helpers.hpp"
23#include "common/utils.hpp"
24
25#include "cpu/cpu_batch_normalization_pd.hpp"
26#include "cpu/x64/cpu_isa_traits.hpp"
27#include "jit_primitive_conf.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace cpu {
32namespace x64 {
33
34namespace bnorm_tbb_impl {
35template <cpu_isa_t isa>
36struct driver_t;
37}
38
39template <cpu_isa_t isa>
40struct jit_uni_tbb_batch_normalization_fwd_t : public primitive_t {
41 struct pd_t : public cpu_batch_normalization_fwd_pd_t {
42 pd_t(const batch_normalization_desc_t *adesc,
43 const primitive_attr_t *attr,
44 const batch_normalization_fwd_pd_t *hint_fwd_pd)
45 : cpu_batch_normalization_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
46
47 DECLARE_COMMON_PD_T(
48 JIT_IMPL_NAME_HELPER("bnorm_tbb_jit:",
49 (src_md()->data_type == data_type::bf16)
50 ? (mayiuse(avx512_core_bf16)
51 ? avx512_core_bf16
52 : mayiuse(avx512_core)
53 ? bf16_emulation_t::
54 get_isa()
55 : avx2_vnni_2)
56 : (src_md()->data_type == data_type::f16)
57 ? (mayiuse(avx512_core_fp16)
58 ? avx512_core_fp16
59 : avx2_vnni_2)
60 : isa,
61 ""),
62 jit_uni_tbb_batch_normalization_fwd_t);
63
64 status_t init(engine_t *engine);
65
66 jit_memory_tag_kind_t tag_kind_;
67 };
68
69 jit_uni_tbb_batch_normalization_fwd_t(const pd_t *apd);
70 ~jit_uni_tbb_batch_normalization_fwd_t();
71
72 status_t init(engine_t *engine) override;
73
74 status_t execute(const exec_ctx_t &ctx) const override;
75
76private:
77 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
78
79 std::unique_ptr<bnorm_tbb_impl::driver_t<isa>> bnorm_driver_;
80};
81
82template <cpu_isa_t isa>
83struct jit_uni_tbb_batch_normalization_bwd_t : public primitive_t {
84 struct pd_t : public cpu_batch_normalization_bwd_pd_t {
85 pd_t(const batch_normalization_desc_t *adesc,
86 const primitive_attr_t *attr,
87 const batch_normalization_fwd_pd_t *hint_fwd_pd)
88 : cpu_batch_normalization_bwd_pd_t(adesc, attr, hint_fwd_pd) {}
89
90 DECLARE_COMMON_PD_T(
91 JIT_IMPL_NAME_HELPER("bnorm_tbb_jit:",
92 (src_md()->data_type == data_type::bf16)
93 ? (mayiuse(avx512_core_bf16)
94 ? avx512_core_bf16
95 : bf16_emulation_t::get_isa())
96 : (src_md()->data_type == data_type::f16)
97 ? avx512_core_fp16
98 : isa,
99 ""),
100 jit_uni_tbb_batch_normalization_bwd_t);
101
102 status_t init(engine_t *engine);
103
104 jit_memory_tag_kind_t tag_kind_;
105 };
106
107 jit_uni_tbb_batch_normalization_bwd_t(const pd_t *apd);
108 ~jit_uni_tbb_batch_normalization_bwd_t();
109
110 status_t init(engine_t *engine) override;
111
112 status_t execute(const exec_ctx_t &ctx) const override;
113
114private:
115 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
116
117 std::unique_ptr<bnorm_tbb_impl::driver_t<isa>> bnorm_driver_;
118};
119
120} // namespace x64
121} // namespace cpu
122} // namespace impl
123} // namespace dnnl
124
125#endif
126
127// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
128