1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_UNI_SOFTMAX_HPP |
18 | #define CPU_X64_JIT_UNI_SOFTMAX_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/memory_tracking.hpp" |
24 | #include "common/primitive.hpp" |
25 | #include "common/type_helpers.hpp" |
26 | #include "common/utils.hpp" |
27 | |
28 | #include "cpu/cpu_softmax_pd.hpp" |
29 | #include "cpu/x64/cpu_isa_traits.hpp" |
30 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
31 | |
32 | namespace dnnl { |
33 | namespace impl { |
34 | namespace cpu { |
35 | namespace x64 { |
36 | |
37 | namespace softmax_impl { |
38 | template <cpu_isa_t isa> |
39 | struct driver_t; |
40 | } |
41 | |
42 | template <cpu_isa_t isa> |
43 | struct jit_uni_softmax_fwd_t : public primitive_t { |
44 | struct pd_t : public cpu_softmax_fwd_pd_t { |
45 | using cpu_softmax_fwd_pd_t::cpu_softmax_fwd_pd_t; |
46 | |
47 | DECLARE_COMMON_PD_T( |
48 | JIT_IMPL_NAME_HELPER("jit:" , isa, "" ), jit_uni_softmax_fwd_t); |
49 | |
50 | status_t init(engine_t *engine) { |
51 | auto is_dense = [&]() { |
52 | const memory_desc_wrapper src_d(src_md()); |
53 | const auto &bd = src_d.blocking_desc(); |
54 | |
55 | if (!src_d.is_dense(true) || !src_d.only_padded_dim(axis())) |
56 | return false; |
57 | |
58 | if (src_d.is_plain()) return bd.strides[axis()] == 1; |
59 | |
60 | // It is fine to use float here as the kernel uses halfs of |
61 | // vector registers. |
62 | const auto blk_size = cpu_isa_traits<isa>::vlen / sizeof(float); |
63 | // 31 is a general limit, 2 is for unroll_regs_ = 4; |
64 | const size_t max_stride = (1LL << (31 - 2)) - 1; |
65 | const int last_blk = bd.inner_nblks - 1; |
66 | return bd.inner_blks[last_blk] == blk_size |
67 | && bd.inner_idxs[last_blk] == axis() |
68 | && sizeof(float) * bd.strides[axis()] < max_stride; |
69 | }; |
70 | |
71 | using namespace data_type; |
72 | using skip_mask_t = primitive_attr_t::skip_mask_t; |
73 | |
74 | const auto src_dt = src_md()->data_type; |
75 | const auto dst_dt = dst_md()->data_type; |
76 | bool ok = mayiuse(isa) && is_fwd() && !has_zero_dim_memory() |
77 | && utils::one_of(src_dt, f32, bf16, f16, s8, u8) |
78 | && utils::one_of(dst_dt, f32, bf16, f16, s8, u8) |
79 | // s8/u8 are temporary limitations due to priorities |
80 | && IMPLICATION( |
81 | (utils::one_of(s8, src_dt, dst_dt) |
82 | || utils::one_of(u8, src_dt, dst_dt)), |
83 | is_superset(isa, avx512_core)) |
84 | && IMPLICATION(utils::one_of(bf16, src_dt, dst_dt), |
85 | (is_superset(isa, avx512_core) |
86 | || (isa == avx2 && mayiuse(avx2_vnni_2)))) |
87 | // for f16 we reuse avx512_core/avx2 just to avoid |
88 | // additional instantiation. Possible because we do not |
89 | // currently support post-ops for this primitive |
90 | && IMPLICATION(utils::one_of(f16, src_dt, dst_dt), |
91 | (is_superset(isa, avx512_core) |
92 | && mayiuse(avx512_core_fp16)) |
93 | || (isa == avx2 && mayiuse(avx2_vnni_2))) |
94 | && attr()->has_default_values(skip_mask_t::scales_runtime) |
95 | && attr_scales_ok() |
96 | && set_default_formats() == status::success; |
97 | if (!ok) return status::unimplemented; |
98 | |
99 | ok = memory_desc_wrapper(src_md()).similar_to( |
100 | memory_desc_wrapper(dst_md()), true, false, 0) |
101 | && is_dense(); // not dense impl can be easily done |
102 | if (!ok) return status::unimplemented; |
103 | |
104 | // AVX2 only supports xf16 on plain layout now |
105 | ok = IMPLICATION(isa == avx2 && mayiuse(avx2_vnni_2) |
106 | && (utils::one_of(bf16, src_dt, dst_dt) |
107 | || utils::one_of(f16, src_dt, dst_dt)), |
108 | memory_desc_wrapper(src_md()).is_plain()); |
109 | if (!ok) return status::unimplemented; |
110 | |
111 | nthr_ = dnnl_get_max_threads(); |
112 | init_scratchpad(); |
113 | |
114 | return status::success; |
115 | }; |
116 | |
117 | int nthr_; // To not exceed the limit in execute used for set up. |
118 | |
119 | private: |
120 | void init_scratchpad() { |
121 | if (utils::one_of( |
122 | dst_md()->data_type, data_type::u8, data_type::s8)) { |
123 | auto scratchpad = scratchpad_registry().registrar(); |
124 | scratchpad.template book<char>( |
125 | memory_tracking::names::key_softmax_interim_store, |
126 | axis_size(true) * sizeof(float) * nthr_); |
127 | } |
128 | } |
129 | }; |
130 | |
131 | jit_uni_softmax_fwd_t(const pd_t *apd); |
132 | ~jit_uni_softmax_fwd_t(); |
133 | |
134 | status_t init(engine_t *engine) override; |
135 | |
136 | status_t execute(const exec_ctx_t &ctx) const override; |
137 | |
138 | private: |
139 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
140 | softmax_impl::driver_t<isa> *softmax_driver_; |
141 | }; |
142 | |
143 | template <cpu_isa_t isa> |
144 | struct jit_uni_softmax_bwd_t : public primitive_t { |
145 | struct pd_t : public cpu_softmax_bwd_pd_t { |
146 | using cpu_softmax_bwd_pd_t::cpu_softmax_bwd_pd_t; |
147 | |
148 | DECLARE_COMMON_PD_T( |
149 | JIT_IMPL_NAME_HELPER("jit:" , isa, "" ), jit_uni_softmax_bwd_t); |
150 | |
151 | status_t init(engine_t *engine) { |
152 | auto is_dense = [&]() { |
153 | const memory_desc_wrapper dst_d(dst_md()); |
154 | const auto &bd = dst_d.blocking_desc(); |
155 | |
156 | if (!dst_d.is_dense(true) || !dst_d.only_padded_dim(axis())) |
157 | return false; |
158 | |
159 | // It is fine to use float here as the kernel uses halfs of |
160 | // vector registers. |
161 | const auto blk_size = cpu_isa_traits<isa>::vlen / sizeof(float); |
162 | if (dst_d.is_plain()) |
163 | return bd.strides[axis()] == 1; |
164 | else { |
165 | // 31 is a general limit, 2 is for unroll_regs_ = 4; |
166 | const size_t max_stride = (1LL << (31 - 2)) - 1; |
167 | const int last_blk = bd.inner_nblks - 1; |
168 | return bd.inner_blks[last_blk] == blk_size |
169 | && bd.inner_idxs[last_blk] == axis() |
170 | && sizeof(float) * bd.strides[axis()] < max_stride; |
171 | } |
172 | }; |
173 | |
174 | using namespace data_type; |
175 | bool ok = mayiuse(isa) && !is_fwd() && !has_zero_dim_memory() |
176 | && utils::one_of(dst_md()->data_type, f32, bf16, f16) |
177 | && utils::one_of(diff_dst_md()->data_type, f32, bf16, f16) |
178 | && utils::one_of(diff_src_md()->data_type, f32, bf16, f16) |
179 | && IMPLICATION(utils::one_of(bf16, dst_md()->data_type, |
180 | diff_dst_md()->data_type, |
181 | diff_src_md()->data_type), |
182 | is_superset(isa, avx512_core)) |
183 | // for f16 we reuse avx512_core just to avoid additional |
184 | // instantiation. |
185 | && IMPLICATION(utils::one_of(f16, dst_md()->data_type, |
186 | diff_dst_md()->data_type, |
187 | diff_src_md()->data_type), |
188 | is_superset(isa, avx512_core) |
189 | && mayiuse(avx512_core_fp16)) |
190 | && attr()->has_default_values() |
191 | && set_default_formats() == status::success; |
192 | if (!ok) return status::unimplemented; |
193 | |
194 | ok = memory_desc_wrapper(diff_src_md()) |
195 | .similar_to(memory_desc_wrapper(diff_dst_md()), |
196 | true, false, 0) |
197 | && memory_desc_wrapper(diff_dst_md()) |
198 | == memory_desc_wrapper(dst_md()) |
199 | && is_dense(); // not dense impl can be easily done |
200 | if (!ok) return status::unimplemented; |
201 | |
202 | return status::success; |
203 | }; |
204 | }; |
205 | |
206 | jit_uni_softmax_bwd_t(const pd_t *apd); |
207 | ~jit_uni_softmax_bwd_t(); |
208 | |
209 | status_t init(engine_t *engine) override; |
210 | |
211 | status_t execute(const exec_ctx_t &ctx) const override; |
212 | |
213 | private: |
214 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
215 | |
216 | softmax_impl::driver_t<isa> *softmax_driver_; |
217 | }; |
218 | |
219 | } // namespace x64 |
220 | } // namespace cpu |
221 | } // namespace impl |
222 | } // namespace dnnl |
223 | |
224 | #endif |
225 | |
226 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
227 | |