1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_LITE_KERNELS_LSTM_EVAL_H_ |
16 | #define TENSORFLOW_LITE_KERNELS_LSTM_EVAL_H_ |
17 | |
18 | #include <cstdint> |
19 | #include <memory> |
20 | |
21 | #include "tensorflow/lite/c/builtin_op_data.h" |
22 | #include "tensorflow/lite/c/common.h" |
23 | #include "tensorflow/lite/kernels/cpu_backend_context.h" |
24 | |
25 | namespace tflite { |
26 | namespace ops { |
27 | namespace builtin { |
28 | namespace lstm_eval { |
29 | |
30 | // Pamameters for integer LSTM. |
31 | // Consider split this into two Integer Parameters if more fields are added. |
32 | struct IntegerLstmParameter { |
33 | int32_t effective_input_to_input_scale_a; |
34 | int32_t effective_input_to_input_scale_b; |
35 | int32_t effective_recurrent_to_input_scale_a; |
36 | int32_t effective_recurrent_to_input_scale_b; |
37 | int32_t effective_cell_to_input_scale_a; |
38 | int32_t effective_cell_to_input_scale_b; |
39 | int32_t effective_input_to_forget_scale_a; |
40 | int32_t effective_input_to_forget_scale_b; |
41 | int32_t effective_recurrent_to_forget_scale_a; |
42 | int32_t effective_recurrent_to_forget_scale_b; |
43 | int32_t effective_cell_to_forget_scale_a; |
44 | int32_t effective_cell_to_forget_scale_b; |
45 | int32_t effective_input_to_cell_scale_a; |
46 | int32_t effective_input_to_cell_scale_b; |
47 | int32_t effective_recurrent_to_cell_scale_a; |
48 | int32_t effective_recurrent_to_cell_scale_b; |
49 | int32_t effective_input_to_output_scale_a; |
50 | int32_t effective_input_to_output_scale_b; |
51 | int32_t effective_recurrent_to_output_scale_a; |
52 | int32_t effective_recurrent_to_output_scale_b; |
53 | int32_t effective_cell_to_output_scale_a; |
54 | int32_t effective_cell_to_output_scale_b; |
55 | int32_t effective_proj_scale_a; |
56 | int32_t effective_proj_scale_b; |
57 | int32_t effective_hidden_scale_a; |
58 | int32_t effective_hidden_scale_b; |
59 | int32_t layer_norm_input_scale_a; |
60 | int32_t layer_norm_input_scale_b; |
61 | int32_t layer_norm_forget_scale_a; |
62 | int32_t layer_norm_forget_scale_b; |
63 | int32_t layer_norm_cell_scale_a; |
64 | int32_t layer_norm_cell_scale_b; |
65 | int32_t layer_norm_output_scale_a; |
66 | int32_t layer_norm_output_scale_b; |
67 | // Quantized clip value for cell and projection. Zero value means no clipping. |
68 | int16_t quantized_cell_clip; |
69 | int8_t quantized_proj_clip; |
70 | int32_t hidden_zp; |
71 | int32_t cell_scale; |
72 | |
73 | int32_t input_variance_guard; |
74 | int32_t forget_variance_guard; |
75 | int32_t cell_variance_guard; |
76 | int32_t output_variance_guard; |
77 | |
78 | // Pre-calculate bias + zero_point * weight. |
79 | // Unabled to use temporary tensors since those are used in Prepare() and |
80 | // scratch buffer is only allocated after Preapre(). |
81 | std::unique_ptr<int32_t[]> input_to_forget_effective_bias; |
82 | std::unique_ptr<int32_t[]> recurrent_to_forget_effective_bias; |
83 | std::unique_ptr<int32_t[]> input_to_cell_effective_bias; |
84 | std::unique_ptr<int32_t[]> recurrent_to_cell_effective_bias; |
85 | std::unique_ptr<int32_t[]> input_to_output_effective_bias; |
86 | std::unique_ptr<int32_t[]> recurrent_to_output_effective_bias; |
87 | std::unique_ptr<int32_t[]> input_to_input_effective_bias; |
88 | std::unique_ptr<int32_t[]> recurrent_to_input_effective_bias; |
89 | std::unique_ptr<int32_t[]> projection_effective_bias; |
90 | |
91 | // Scale and zero point for intermediate tensors. |
92 | // Used only in the 8x8_8 case. |
93 | int32_t intermediate_scale_a[8]; |
94 | int32_t intermediate_scale_b[8]; |
95 | int32_t intermediate_zp[12]; |
96 | }; |
97 | |
98 | TfLiteStatus EvalFloat( |
99 | const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, |
100 | const TfLiteTensor* input_to_forget_weights, |
101 | const TfLiteTensor* input_to_cell_weights, |
102 | const TfLiteTensor* input_to_output_weights, |
103 | const TfLiteTensor* recurrent_to_input_weights, |
104 | const TfLiteTensor* recurrent_to_forget_weights, |
105 | const TfLiteTensor* recurrent_to_cell_weights, |
106 | const TfLiteTensor* recurrent_to_output_weights, |
107 | const TfLiteTensor* cell_to_input_weights, |
108 | const TfLiteTensor* cell_to_forget_weights, |
109 | const TfLiteTensor* cell_to_output_weights, |
110 | const TfLiteTensor* input_layer_norm_coefficients, |
111 | const TfLiteTensor* forget_layer_norm_coefficients, |
112 | const TfLiteTensor* cell_layer_norm_coefficients, |
113 | const TfLiteTensor* output_layer_norm_coefficients, |
114 | const TfLiteTensor* aux_input, |
115 | const TfLiteTensor* aux_input_to_input_weights, |
116 | const TfLiteTensor* aux_input_to_forget_weights, |
117 | const TfLiteTensor* aux_input_to_cell_weights, |
118 | const TfLiteTensor* aux_input_to_output_weights, |
119 | const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, |
120 | const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias, |
121 | const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, |
122 | const TfLiteLSTMParams* params, bool forward_sequence, bool time_major, |
123 | int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* output_state, |
124 | TfLiteTensor* cell_state, TfLiteTensor* output, CpuBackendContext* context); |
125 | |
126 | TfLiteStatus EvalHybrid( |
127 | const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, |
128 | const TfLiteTensor* input_to_input_weights_ledger, |
129 | const TfLiteTensor* input_to_forget_weights, |
130 | const TfLiteTensor* input_to_forget_weights_ledger, |
131 | const TfLiteTensor* input_to_cell_weights, |
132 | const TfLiteTensor* input_to_cell_weights_ledger, |
133 | const TfLiteTensor* input_to_output_weights, |
134 | const TfLiteTensor* input_to_output_weights_ledger, |
135 | const TfLiteTensor* recurrent_to_input_weights, |
136 | const TfLiteTensor* recurrent_to_input_weights_ledger, |
137 | const TfLiteTensor* recurrent_to_forget_weights, |
138 | const TfLiteTensor* recurrent_to_forget_weights_ledger, |
139 | const TfLiteTensor* recurrent_to_cell_weights, |
140 | const TfLiteTensor* recurrent_to_cell_weights_ledger, |
141 | const TfLiteTensor* recurrent_to_output_weights, |
142 | const TfLiteTensor* recurrent_to_output_weights_ledger, |
143 | const TfLiteTensor* cell_to_input_weights, |
144 | const TfLiteTensor* cell_to_forget_weights, |
145 | const TfLiteTensor* cell_to_output_weights, |
146 | const TfLiteTensor* input_layer_norm_coefficients, |
147 | const TfLiteTensor* forget_layer_norm_coefficients, |
148 | const TfLiteTensor* cell_layer_norm_coefficients, |
149 | const TfLiteTensor* output_layer_norm_coefficients, |
150 | const TfLiteTensor* aux_input, |
151 | const TfLiteTensor* aux_input_to_input_weights, |
152 | const TfLiteTensor* aux_input_to_forget_weights, |
153 | const TfLiteTensor* aux_input_to_cell_weights, |
154 | const TfLiteTensor* aux_input_to_output_weights, |
155 | const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, |
156 | const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias, |
157 | const TfLiteTensor* projection_weights, |
158 | const TfLiteTensor* projection_weights_ledger, |
159 | const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params, |
160 | bool forward_sequence, bool time_major, int output_offset, |
161 | TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf, |
162 | TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf, |
163 | TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, |
164 | TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized, |
165 | TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, |
166 | TfLiteTensor* output_state, TfLiteTensor* cell_state, |
167 | TfLiteTensor* output_scratch_buffer, TfLiteTensor* output, |
168 | TfLiteTensor* input_zp, TfLiteTensor* aux_input_zp, |
169 | TfLiteTensor* output_state_zp, TfLiteTensor* row_sums, int row_sums_size, |
170 | bool* compute_row_sums, CpuBackendContext* context); |
171 | |
172 | TfLiteStatus EvalInteger8x8_16( |
173 | const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, |
174 | const TfLiteTensor* input_to_forget_weights, |
175 | const TfLiteTensor* input_to_cell_weights, |
176 | const TfLiteTensor* input_to_output_weights, |
177 | const TfLiteTensor* recurrent_to_input_weights, |
178 | const TfLiteTensor* recurrent_to_forget_weights, |
179 | const TfLiteTensor* recurrent_to_cell_weights, |
180 | const TfLiteTensor* recurrent_to_output_weights, |
181 | const TfLiteTensor* cell_to_input_weights, |
182 | const TfLiteTensor* cell_to_forget_weights, |
183 | const TfLiteTensor* cell_to_output_weights, |
184 | const TfLiteTensor* input_layer_norm_coefficients, |
185 | const TfLiteTensor* forget_layer_norm_coefficients, |
186 | const TfLiteTensor* cell_layer_norm_coefficients, |
187 | const TfLiteTensor* output_layer_norm_coefficients, |
188 | const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, |
189 | const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias, |
190 | const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, |
191 | const TfLiteLSTMParams* params, bool forward_sequence, bool time_major, |
192 | const lstm_eval::IntegerLstmParameter* integer_lstm_param, |
193 | TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output, |
194 | TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2, |
195 | TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5, |
196 | CpuBackendContext* context); |
197 | |
198 | TfLiteStatus EvalInteger8x8_8( |
199 | const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, |
200 | const TfLiteTensor* input_to_forget_weights, |
201 | const TfLiteTensor* input_to_cell_weights, |
202 | const TfLiteTensor* input_to_output_weights, |
203 | const TfLiteTensor* recurrent_to_input_weights, |
204 | const TfLiteTensor* recurrent_to_forget_weights, |
205 | const TfLiteTensor* recurrent_to_cell_weights, |
206 | const TfLiteTensor* recurrent_to_output_weights, |
207 | const TfLiteTensor* cell_to_input_weights, |
208 | const TfLiteTensor* cell_to_forget_weights, |
209 | const TfLiteTensor* cell_to_output_weights, |
210 | const TfLiteTensor* input_layer_norm_coefficients, |
211 | const TfLiteTensor* forget_layer_norm_coefficients, |
212 | const TfLiteTensor* cell_layer_norm_coefficients, |
213 | const TfLiteTensor* output_layer_norm_coefficients, |
214 | const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, |
215 | const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias, |
216 | const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, |
217 | const TfLiteLSTMParams* params, TfLiteTensor* output_state, |
218 | TfLiteTensor* cell_state, TfLiteTensor* output, |
219 | const lstm_eval::IntegerLstmParameter* integer_lstm_param, |
220 | TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2, |
221 | TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5, |
222 | TfLiteTensor* scratch6, TfLiteTensor* scratch7); |
223 | |
224 | } // namespace lstm_eval |
225 | } // namespace builtin |
226 | } // namespace ops |
227 | } // namespace tflite |
228 | #endif // TENSORFLOW_LITE_KERNELS_LSTM_EVAL_H_ |
229 | |