1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_UNI_TBB_BATCH_NORMALIZATION_HPP |
18 | #define CPU_X64_JIT_UNI_TBB_BATCH_NORMALIZATION_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/primitive.hpp" |
22 | #include "common/type_helpers.hpp" |
23 | #include "common/utils.hpp" |
24 | |
25 | #include "cpu/cpu_batch_normalization_pd.hpp" |
26 | #include "cpu/x64/cpu_isa_traits.hpp" |
27 | #include "jit_primitive_conf.hpp" |
28 | |
29 | namespace dnnl { |
30 | namespace impl { |
31 | namespace cpu { |
32 | namespace x64 { |
33 | |
34 | namespace bnorm_tbb_impl { |
35 | template <cpu_isa_t isa> |
36 | struct driver_t; |
37 | } |
38 | |
39 | template <cpu_isa_t isa> |
40 | struct jit_uni_tbb_batch_normalization_fwd_t : public primitive_t { |
41 | struct pd_t : public cpu_batch_normalization_fwd_pd_t { |
42 | pd_t(const batch_normalization_desc_t *adesc, |
43 | const primitive_attr_t *attr, |
44 | const batch_normalization_fwd_pd_t *hint_fwd_pd) |
45 | : cpu_batch_normalization_fwd_pd_t(adesc, attr, hint_fwd_pd) {} |
46 | |
47 | DECLARE_COMMON_PD_T( |
48 | JIT_IMPL_NAME_HELPER("bnorm_tbb_jit:" , |
49 | (src_md()->data_type == data_type::bf16) |
50 | ? (mayiuse(avx512_core_bf16) |
51 | ? avx512_core_bf16 |
52 | : mayiuse(avx512_core) |
53 | ? bf16_emulation_t:: |
54 | get_isa() |
55 | : avx2_vnni_2) |
56 | : (src_md()->data_type == data_type::f16) |
57 | ? (mayiuse(avx512_core_fp16) |
58 | ? avx512_core_fp16 |
59 | : avx2_vnni_2) |
60 | : isa, |
61 | "" ), |
62 | jit_uni_tbb_batch_normalization_fwd_t); |
63 | |
64 | status_t init(engine_t *engine); |
65 | |
66 | jit_memory_tag_kind_t tag_kind_; |
67 | }; |
68 | |
69 | jit_uni_tbb_batch_normalization_fwd_t(const pd_t *apd); |
70 | ~jit_uni_tbb_batch_normalization_fwd_t(); |
71 | |
72 | status_t init(engine_t *engine) override; |
73 | |
74 | status_t execute(const exec_ctx_t &ctx) const override; |
75 | |
76 | private: |
77 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
78 | |
79 | std::unique_ptr<bnorm_tbb_impl::driver_t<isa>> bnorm_driver_; |
80 | }; |
81 | |
82 | template <cpu_isa_t isa> |
83 | struct jit_uni_tbb_batch_normalization_bwd_t : public primitive_t { |
84 | struct pd_t : public cpu_batch_normalization_bwd_pd_t { |
85 | pd_t(const batch_normalization_desc_t *adesc, |
86 | const primitive_attr_t *attr, |
87 | const batch_normalization_fwd_pd_t *hint_fwd_pd) |
88 | : cpu_batch_normalization_bwd_pd_t(adesc, attr, hint_fwd_pd) {} |
89 | |
90 | DECLARE_COMMON_PD_T( |
91 | JIT_IMPL_NAME_HELPER("bnorm_tbb_jit:" , |
92 | (src_md()->data_type == data_type::bf16) |
93 | ? (mayiuse(avx512_core_bf16) |
94 | ? avx512_core_bf16 |
95 | : bf16_emulation_t::get_isa()) |
96 | : (src_md()->data_type == data_type::f16) |
97 | ? avx512_core_fp16 |
98 | : isa, |
99 | "" ), |
100 | jit_uni_tbb_batch_normalization_bwd_t); |
101 | |
102 | status_t init(engine_t *engine); |
103 | |
104 | jit_memory_tag_kind_t tag_kind_; |
105 | }; |
106 | |
107 | jit_uni_tbb_batch_normalization_bwd_t(const pd_t *apd); |
108 | ~jit_uni_tbb_batch_normalization_bwd_t(); |
109 | |
110 | status_t init(engine_t *engine) override; |
111 | |
112 | status_t execute(const exec_ctx_t &ctx) const override; |
113 | |
114 | private: |
115 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
116 | |
117 | std::unique_ptr<bnorm_tbb_impl::driver_t<isa>> bnorm_driver_; |
118 | }; |
119 | |
120 | } // namespace x64 |
121 | } // namespace cpu |
122 | } // namespace impl |
123 | } // namespace dnnl |
124 | |
125 | #endif |
126 | |
127 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
128 | |