1/*******************************************************************************
2* Copyright 2020-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17/*
18 * Cell execution LSTM projection
19 */
20
21#include "common/dnnl_thread.hpp"
22#include "common/math_utils.hpp"
23
24#include "cpu/simple_q10n.hpp"
25
26#include "cpu/rnn/postgemm_dispatcher.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31
32using namespace dnnl::impl::utils;
33using namespace dnnl::impl::math;
34using namespace rnn_utils;
35
36namespace {
37template <typename dst_layer_t, typename dst_iter_t>
38void proj_dst_copy(const rnn_utils::rnn_conf_t &rnn,
39 rnn_utils::cell_position_t cell_position, dst_iter_t *dst_iter_,
40 const dst_layer_t *dst_layer_, int block_step) {
41 assert(rnn.dic == rnn.dlc);
42 static_assert(sizeof(dst_layer_t) == sizeof(dst_iter_t),
43 "memcpy requires the same data type size for src and dst");
44 const auto dst_layer_ld = rnn.dst_layer_ld(cell_position, true);
45 const auto dst_iter_ld = rnn.dst_iter_ld(cell_position);
46
47 // If dst_iter is not nullptr, we need to copy the state to dst_iter
48 if (dst_iter_ != nullptr) {
49 if (rnn.is_brgemm && !rnn.unfused_post_gemm) {
50 for (int i = 0; i < rnn.m_block; i++)
51 std::memcpy(dst_iter_ + i * dst_iter_ld,
52 dst_layer_ + i * dst_layer_ld, block_step);
53 } else {
54 parallel_nd(rnn.mb, [&](dim_t i) {
55 std::memcpy(dst_iter_ + i * dst_iter_ld,
56 dst_layer_ + i * dst_layer_ld, block_step);
57 });
58 }
59 }
60}
61} // namespace
62
63template <>
64rnn_postgemm_sig(rnn_postgemm_fwd_f32_t::lstm_projection_postgemm) {
65 // nothing to do for f32, except copy to dst_iter if needed
66 proj_dst_copy(rnn, cell_position, dst_iter_, dst_layer_, block_step);
67}
68
69template <>
70rnn_postgemm_sig(rnn_postgemm_fwd_bf16_t::lstm_projection_postgemm) {
71 const auto dst_layer_ld = rnn.dst_layer_ld(cell_position, true);
72
73 // Currently, scratch_gates_ contains the output of the projection
74 const int n_elem = block_step / (int)sizeof(dst_layer_t);
75
76 const int m_block
77 = (rnn.is_brgemm && !rnn.unfused_post_gemm) ? rnn.m_block : rnn.mb;
78
79 for (int i = 0; i < m_block; i++)
80 cvt_float_to_bfloat16((bfloat16_t *)dst_layer_ + i * dst_layer_ld,
81 (float *)scratch_gates_ + i * rnn.scratch_gates_ld, n_elem);
82
83 // we copy to dst_iter if necessary
84 proj_dst_copy(rnn, cell_position, dst_iter_, dst_layer_, block_step);
85}
86
87template <>
88rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::lstm_projection_postgemm) {
89 // Here, we use
90 // - scratch_gates to pass the s32 output of the projection
91 // - src_iter_c to pass the projection compensation
92
93 const auto dst_layer_ld = rnn.dst_layer_ld(cell_position, true);
94 const auto w_proj_comp = static_cast<const float *>(src_iter_c_);
95
96 const float data_shift = pd_->attr()->rnn_data_qparams_.shift_;
97 const float data_scale = pd_->attr()->rnn_data_qparams_.scale_;
98
99 const auto quantize_f32_u8 = [&](float f) {
100 float qf = f * data_scale + data_shift;
101 qf = nstl::min(qf, 255.0f);
102 qf = nstl::max(qf, 0.0f);
103 return qz_a1b0<float, dst_layer_t>()(qf);
104 };
105
106 const auto dequantize_s32_f32 = [&](gemm_acc_t s, int j) {
107 const float wscale
108 = pd_->attr()->rnn_weights_projection_qparams_.mask_ == 0
109 ? weights_scales_[0]
110 : weights_scales_[j];
111 const float wcomp = w_proj_comp[j] * data_shift;
112
113 return (saturate<float>(s) - wcomp) / (wscale * data_scale);
114 };
115
116 auto postgemm_call = [&](int i) {
117 const int n_elem = block_step / (int)sizeof(dst_layer_t);
118 PRAGMA_OMP_SIMD()
119 for (int j = 0; j < n_elem; j++) {
120 const int scratch_off = i * rnn.scratch_gates_ld + j;
121 const int dst_off = i * dst_layer_ld + j;
122 const float tmp
123 = dequantize_s32_f32(scratch_gates_[scratch_off], j);
124 dst_layer_[dst_off] = quantize_f32_u8(tmp);
125 }
126 };
127 if (rnn.is_brgemm && !rnn.unfused_post_gemm) {
128 for (int i = 0; i < rnn.m_block; i++)
129 postgemm_call(i);
130 } else {
131 parallel_nd(rnn.mb, [&](dim_t i) { postgemm_call(i); });
132 }
133 proj_dst_copy(rnn, cell_position, dst_iter_, dst_layer_, block_step);
134}
135
136template <>
137rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::lstm_projection_postgemm) {
138 // Here, we use
139 // - scratch_gates to pass the s32 output of the projection
140 // - no need to pass the projection compensation for s8s8 amx
141 const auto dst_layer_ld = rnn.dst_layer_ld(cell_position, true);
142
143 const float data_shift = pd_->attr()->rnn_data_qparams_.shift_;
144 const float data_scale = pd_->attr()->rnn_data_qparams_.scale_;
145
146 const auto quantize_f32_s8 = [&](float f) {
147 const float qf = f * data_scale + data_shift;
148 return qz_a1b0<float, dst_layer_t>()(qf);
149 };
150
151 const auto dequantize_s32_f32 = [&](gemm_acc_t s, int j) {
152 const float wscale
153 = pd_->attr()->rnn_weights_projection_qparams_.mask_ == 0
154 ? weights_scales_[0]
155 : weights_scales_[j];
156
157 return (saturate<float>(s)) / (wscale * data_scale);
158 };
159
160 const auto postgemm_call = [&](dim_t i) {
161 const int n_elem = block_step / (int)sizeof(dst_layer_t);
162 PRAGMA_OMP_SIMD()
163 for (int j = 0; j < n_elem; j++) {
164 const int scratch_off = i * rnn.scratch_gates_ld + j;
165 const int dst_off = i * dst_layer_ld + j;
166 const float tmp
167 = dequantize_s32_f32(scratch_gates_[scratch_off], j);
168 dst_layer_[dst_off] = quantize_f32_s8(tmp);
169 }
170 };
171 if (rnn.is_brgemm && !rnn.unfused_post_gemm) {
172 for (int i = 0; i < rnn.m_block; i++)
173 postgemm_call(i);
174 } else {
175 parallel_nd(rnn.mb, [&](dim_t i) { postgemm_call(i); });
176 }
177 proj_dst_copy(rnn, cell_position, dst_iter_, dst_layer_, block_step);
178}
179
180template <>
181rnn_postgemm_sig(rnn_postgemm_bwd_f32_t::lstm_projection_postgemm) {
182 assert(!"unsupported");
183}
184
185template <>
186rnn_postgemm_sig(rnn_postgemm_bwd_bf16_t::lstm_projection_postgemm) {
187 assert(!"unsupported");
188}
189
190} // namespace cpu
191} // namespace impl
192} // namespace dnnl
193