1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_REF_SOFTMAX_HPP |
18 | #define CPU_REF_SOFTMAX_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/dnnl_thread.hpp" |
24 | #include "common/memory_tracking.hpp" |
25 | #include "common/primitive.hpp" |
26 | #include "common/type_helpers.hpp" |
27 | #include "common/utils.hpp" |
28 | |
29 | #include "cpu/cpu_softmax_pd.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace cpu { |
34 | |
35 | struct ref_softmax_fwd_t : public primitive_t { |
36 | struct pd_t : public cpu_softmax_fwd_pd_t { |
37 | using cpu_softmax_fwd_pd_t::cpu_softmax_fwd_pd_t; |
38 | |
39 | DECLARE_COMMON_PD_T("ref:any" , ref_softmax_fwd_t); |
40 | |
41 | status_t init(engine_t *engine) { |
42 | using namespace data_type; |
43 | using skip_mask_t = primitive_attr_t::skip_mask_t; |
44 | |
45 | bool ok = is_fwd() |
46 | && utils::one_of( |
47 | src_md()->data_type, f32, bf16, f16, s8, u8) |
48 | && utils::one_of( |
49 | dst_md()->data_type, f32, bf16, f16, s8, u8) |
50 | && platform::has_data_type_support(src_md()->data_type) |
51 | && platform::has_data_type_support(dst_md()->data_type) |
52 | && attr()->has_default_values(skip_mask_t::scales_runtime) |
53 | && attr_scales_ok() |
54 | && set_default_formats() == status::success; |
55 | if (!ok) return status::unimplemented; |
56 | |
57 | nthr_ = 0; |
58 | init_scratchpad(); |
59 | |
60 | return status::success; |
61 | } |
62 | |
63 | int nthr_; // To not exceed the limit in execute used for set up. |
64 | |
65 | bool need_int8_scratchpad() const { |
66 | return utils::one_of( |
67 | dst_md()->data_type, data_type::u8, data_type::s8); |
68 | } |
69 | |
70 | private: |
71 | void init_scratchpad() { |
72 | auto scratchpad = scratchpad_registry().registrar(); |
73 | const dim_t in_s = inner_size(); |
74 | |
75 | if (in_s > 1) { |
76 | const dim_t ou_s = outer_size(); |
77 | scratchpad.template book<float>( |
78 | memory_tracking::names::key_softmax_reduction, |
79 | 2 * in_s * ou_s); |
80 | } |
81 | |
82 | if (need_int8_scratchpad()) { |
83 | nthr_ = dnnl_get_max_threads(); |
84 | scratchpad.template book<char>( |
85 | memory_tracking::names::key_softmax_interim_store, |
86 | axis_size(true) * sizeof(float) * nthr_); |
87 | } |
88 | } |
89 | }; |
90 | |
91 | ref_softmax_fwd_t(const pd_t *apd) : primitive_t(apd) {} |
92 | |
93 | status_t init(engine_t *engine) override { |
94 | outer_size_ = pd()->outer_size(); |
95 | channels_ = pd()->axis_size(); |
96 | inner_size_ = pd()->inner_size(); |
97 | |
98 | const memory_desc_wrapper src_d(pd()->src_md()); |
99 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
100 | const auto &bd = src_d.blocking_desc(); |
101 | |
102 | auto axis = pd()->axis(); |
103 | dim_t axis_blk_size = 1; |
104 | for (int iblk = 0; iblk < bd.inner_nblks; ++iblk) |
105 | if (bd.inner_idxs[iblk] == axis) |
106 | axis_blk_size *= bd.inner_blks[iblk]; |
107 | |
108 | use_dense_ = inner_size_ == 1 && src_d == dst_d && src_d.is_dense(true) |
109 | && src_d.only_padded_dim(axis) |
110 | && bd.strides[axis] == axis_blk_size; |
111 | return status::success; |
112 | } |
113 | |
114 | status_t execute(const exec_ctx_t &ctx) const override { |
115 | if (use_dense_) |
116 | return execute_forward_dense(ctx); |
117 | else |
118 | return execute_forward_generic(ctx); |
119 | } |
120 | |
121 | private: |
122 | status_t execute_forward_dense(const exec_ctx_t &ctx) const; |
123 | status_t execute_forward_generic(const exec_ctx_t &ctx) const; |
124 | |
125 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
126 | |
127 | bool use_dense_; |
128 | int outer_size_, channels_, inner_size_; |
129 | }; |
130 | |
131 | struct ref_softmax_bwd_t : public primitive_t { |
132 | struct pd_t : public cpu_softmax_bwd_pd_t { |
133 | using cpu_softmax_bwd_pd_t::cpu_softmax_bwd_pd_t; |
134 | |
135 | DECLARE_COMMON_PD_T("ref:any" , ref_softmax_bwd_t); |
136 | |
137 | status_t init(engine_t *engine) { |
138 | using namespace data_type; |
139 | bool ok = !is_fwd() |
140 | && utils::one_of(dst_md()->data_type, f32, bf16, f16) |
141 | && platform::has_data_type_support(dst_md()->data_type) |
142 | && platform::has_data_type_support(diff_dst_md()->data_type) |
143 | && platform::has_data_type_support(diff_src_md()->data_type) |
144 | && dst_md()->data_type == diff_dst_md()->data_type |
145 | && attr()->has_default_values() |
146 | && set_default_formats() == status::success; |
147 | if (!ok) return status::unimplemented; |
148 | |
149 | return status::success; |
150 | } |
151 | }; |
152 | |
153 | ref_softmax_bwd_t(const pd_t *apd) : primitive_t(apd) {} |
154 | |
155 | status_t init(engine_t *engine) override { |
156 | outer_size_ = pd()->outer_size(); |
157 | channels_ = pd()->axis_size(); |
158 | inner_size_ = pd()->inner_size(); |
159 | |
160 | const memory_desc_wrapper data_d(pd()->dst_md()); |
161 | const memory_desc_wrapper diff_d(pd()->diff_dst_md()); |
162 | const auto &bd = diff_d.blocking_desc(); |
163 | |
164 | auto axis = pd()->axis(); |
165 | dim_t axis_blk_size = 1; |
166 | for (int iblk = 0; iblk < bd.inner_nblks; ++iblk) |
167 | if (bd.inner_idxs[iblk] == axis) |
168 | axis_blk_size *= bd.inner_blks[iblk]; |
169 | |
170 | use_dense_ = inner_size_ == 1 && diff_d == data_d && diff_d.is_dense() |
171 | && bd.strides[axis] == axis_blk_size; |
172 | return status::success; |
173 | } |
174 | |
175 | status_t execute(const exec_ctx_t &ctx) const override { |
176 | if (use_dense_) |
177 | return execute_backward_dense(ctx); |
178 | else |
179 | return execute_backward_generic(ctx); |
180 | } |
181 | |
182 | private: |
183 | status_t execute_backward_dense(const exec_ctx_t &ctx) const; |
184 | status_t execute_backward_generic(const exec_ctx_t &ctx) const; |
185 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
186 | |
187 | bool use_dense_; |
188 | int outer_size_, channels_, inner_size_; |
189 | }; |
190 | |
191 | } // namespace cpu |
192 | } // namespace impl |
193 | } // namespace dnnl |
194 | |
195 | #endif |
196 | |
197 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
198 | |