1 | /******************************************************************************* |
2 | * Copyright 2020-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | /* |
18 | * Cell execution LSTM projection |
19 | */ |
20 | |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/math_utils.hpp" |
23 | |
24 | #include "cpu/simple_q10n.hpp" |
25 | |
26 | #include "cpu/rnn/postgemm_dispatcher.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | |
32 | using namespace dnnl::impl::utils; |
33 | using namespace dnnl::impl::math; |
34 | using namespace rnn_utils; |
35 | |
36 | namespace { |
37 | template <typename dst_layer_t, typename dst_iter_t> |
38 | void proj_dst_copy(const rnn_utils::rnn_conf_t &rnn, |
39 | rnn_utils::cell_position_t cell_position, dst_iter_t *dst_iter_, |
40 | const dst_layer_t *dst_layer_, int block_step) { |
41 | assert(rnn.dic == rnn.dlc); |
42 | static_assert(sizeof(dst_layer_t) == sizeof(dst_iter_t), |
43 | "memcpy requires the same data type size for src and dst" ); |
44 | const auto dst_layer_ld = rnn.dst_layer_ld(cell_position, true); |
45 | const auto dst_iter_ld = rnn.dst_iter_ld(cell_position); |
46 | |
47 | // If dst_iter is not nullptr, we need to copy the state to dst_iter |
48 | if (dst_iter_ != nullptr) { |
49 | if (rnn.is_brgemm && !rnn.unfused_post_gemm) { |
50 | for (int i = 0; i < rnn.m_block; i++) |
51 | std::memcpy(dst_iter_ + i * dst_iter_ld, |
52 | dst_layer_ + i * dst_layer_ld, block_step); |
53 | } else { |
54 | parallel_nd(rnn.mb, [&](dim_t i) { |
55 | std::memcpy(dst_iter_ + i * dst_iter_ld, |
56 | dst_layer_ + i * dst_layer_ld, block_step); |
57 | }); |
58 | } |
59 | } |
60 | } |
61 | } // namespace |
62 | |
63 | template <> |
64 | rnn_postgemm_sig(rnn_postgemm_fwd_f32_t::lstm_projection_postgemm) { |
65 | // nothing to do for f32, except copy to dst_iter if needed |
66 | proj_dst_copy(rnn, cell_position, dst_iter_, dst_layer_, block_step); |
67 | } |
68 | |
69 | template <> |
70 | rnn_postgemm_sig(rnn_postgemm_fwd_bf16_t::lstm_projection_postgemm) { |
71 | const auto dst_layer_ld = rnn.dst_layer_ld(cell_position, true); |
72 | |
73 | // Currently, scratch_gates_ contains the output of the projection |
74 | const int n_elem = block_step / (int)sizeof(dst_layer_t); |
75 | |
76 | const int m_block |
77 | = (rnn.is_brgemm && !rnn.unfused_post_gemm) ? rnn.m_block : rnn.mb; |
78 | |
79 | for (int i = 0; i < m_block; i++) |
80 | cvt_float_to_bfloat16((bfloat16_t *)dst_layer_ + i * dst_layer_ld, |
81 | (float *)scratch_gates_ + i * rnn.scratch_gates_ld, n_elem); |
82 | |
83 | // we copy to dst_iter if necessary |
84 | proj_dst_copy(rnn, cell_position, dst_iter_, dst_layer_, block_step); |
85 | } |
86 | |
87 | template <> |
88 | rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::lstm_projection_postgemm) { |
89 | // Here, we use |
90 | // - scratch_gates to pass the s32 output of the projection |
91 | // - src_iter_c to pass the projection compensation |
92 | |
93 | const auto dst_layer_ld = rnn.dst_layer_ld(cell_position, true); |
94 | const auto w_proj_comp = static_cast<const float *>(src_iter_c_); |
95 | |
96 | const float data_shift = pd_->attr()->rnn_data_qparams_.shift_; |
97 | const float data_scale = pd_->attr()->rnn_data_qparams_.scale_; |
98 | |
99 | const auto quantize_f32_u8 = [&](float f) { |
100 | float qf = f * data_scale + data_shift; |
101 | qf = nstl::min(qf, 255.0f); |
102 | qf = nstl::max(qf, 0.0f); |
103 | return qz_a1b0<float, dst_layer_t>()(qf); |
104 | }; |
105 | |
106 | const auto dequantize_s32_f32 = [&](gemm_acc_t s, int j) { |
107 | const float wscale |
108 | = pd_->attr()->rnn_weights_projection_qparams_.mask_ == 0 |
109 | ? weights_scales_[0] |
110 | : weights_scales_[j]; |
111 | const float wcomp = w_proj_comp[j] * data_shift; |
112 | |
113 | return (saturate<float>(s) - wcomp) / (wscale * data_scale); |
114 | }; |
115 | |
116 | auto postgemm_call = [&](int i) { |
117 | const int n_elem = block_step / (int)sizeof(dst_layer_t); |
118 | PRAGMA_OMP_SIMD() |
119 | for (int j = 0; j < n_elem; j++) { |
120 | const int scratch_off = i * rnn.scratch_gates_ld + j; |
121 | const int dst_off = i * dst_layer_ld + j; |
122 | const float tmp |
123 | = dequantize_s32_f32(scratch_gates_[scratch_off], j); |
124 | dst_layer_[dst_off] = quantize_f32_u8(tmp); |
125 | } |
126 | }; |
127 | if (rnn.is_brgemm && !rnn.unfused_post_gemm) { |
128 | for (int i = 0; i < rnn.m_block; i++) |
129 | postgemm_call(i); |
130 | } else { |
131 | parallel_nd(rnn.mb, [&](dim_t i) { postgemm_call(i); }); |
132 | } |
133 | proj_dst_copy(rnn, cell_position, dst_iter_, dst_layer_, block_step); |
134 | } |
135 | |
136 | template <> |
137 | rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::lstm_projection_postgemm) { |
138 | // Here, we use |
139 | // - scratch_gates to pass the s32 output of the projection |
140 | // - no need to pass the projection compensation for s8s8 amx |
141 | const auto dst_layer_ld = rnn.dst_layer_ld(cell_position, true); |
142 | |
143 | const float data_shift = pd_->attr()->rnn_data_qparams_.shift_; |
144 | const float data_scale = pd_->attr()->rnn_data_qparams_.scale_; |
145 | |
146 | const auto quantize_f32_s8 = [&](float f) { |
147 | const float qf = f * data_scale + data_shift; |
148 | return qz_a1b0<float, dst_layer_t>()(qf); |
149 | }; |
150 | |
151 | const auto dequantize_s32_f32 = [&](gemm_acc_t s, int j) { |
152 | const float wscale |
153 | = pd_->attr()->rnn_weights_projection_qparams_.mask_ == 0 |
154 | ? weights_scales_[0] |
155 | : weights_scales_[j]; |
156 | |
157 | return (saturate<float>(s)) / (wscale * data_scale); |
158 | }; |
159 | |
160 | const auto postgemm_call = [&](dim_t i) { |
161 | const int n_elem = block_step / (int)sizeof(dst_layer_t); |
162 | PRAGMA_OMP_SIMD() |
163 | for (int j = 0; j < n_elem; j++) { |
164 | const int scratch_off = i * rnn.scratch_gates_ld + j; |
165 | const int dst_off = i * dst_layer_ld + j; |
166 | const float tmp |
167 | = dequantize_s32_f32(scratch_gates_[scratch_off], j); |
168 | dst_layer_[dst_off] = quantize_f32_s8(tmp); |
169 | } |
170 | }; |
171 | if (rnn.is_brgemm && !rnn.unfused_post_gemm) { |
172 | for (int i = 0; i < rnn.m_block; i++) |
173 | postgemm_call(i); |
174 | } else { |
175 | parallel_nd(rnn.mb, [&](dim_t i) { postgemm_call(i); }); |
176 | } |
177 | proj_dst_copy(rnn, cell_position, dst_iter_, dst_layer_, block_step); |
178 | } |
179 | |
180 | template <> |
181 | rnn_postgemm_sig(rnn_postgemm_bwd_f32_t::lstm_projection_postgemm) { |
182 | assert(!"unsupported" ); |
183 | } |
184 | |
185 | template <> |
186 | rnn_postgemm_sig(rnn_postgemm_bwd_bf16_t::lstm_projection_postgemm) { |
187 | assert(!"unsupported" ); |
188 | } |
189 | |
190 | } // namespace cpu |
191 | } // namespace impl |
192 | } // namespace dnnl |
193 | |