1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_LITE_KERNELS_LSTM_EVAL_H_
16#define TENSORFLOW_LITE_KERNELS_LSTM_EVAL_H_
17
18#include <cstdint>
19#include <memory>
20
21#include "tensorflow/lite/c/builtin_op_data.h"
22#include "tensorflow/lite/c/common.h"
23#include "tensorflow/lite/kernels/cpu_backend_context.h"
24
25namespace tflite {
26namespace ops {
27namespace builtin {
28namespace lstm_eval {
29
30// Pamameters for integer LSTM.
31// Consider split this into two Integer Parameters if more fields are added.
32struct IntegerLstmParameter {
33 int32_t effective_input_to_input_scale_a;
34 int32_t effective_input_to_input_scale_b;
35 int32_t effective_recurrent_to_input_scale_a;
36 int32_t effective_recurrent_to_input_scale_b;
37 int32_t effective_cell_to_input_scale_a;
38 int32_t effective_cell_to_input_scale_b;
39 int32_t effective_input_to_forget_scale_a;
40 int32_t effective_input_to_forget_scale_b;
41 int32_t effective_recurrent_to_forget_scale_a;
42 int32_t effective_recurrent_to_forget_scale_b;
43 int32_t effective_cell_to_forget_scale_a;
44 int32_t effective_cell_to_forget_scale_b;
45 int32_t effective_input_to_cell_scale_a;
46 int32_t effective_input_to_cell_scale_b;
47 int32_t effective_recurrent_to_cell_scale_a;
48 int32_t effective_recurrent_to_cell_scale_b;
49 int32_t effective_input_to_output_scale_a;
50 int32_t effective_input_to_output_scale_b;
51 int32_t effective_recurrent_to_output_scale_a;
52 int32_t effective_recurrent_to_output_scale_b;
53 int32_t effective_cell_to_output_scale_a;
54 int32_t effective_cell_to_output_scale_b;
55 int32_t effective_proj_scale_a;
56 int32_t effective_proj_scale_b;
57 int32_t effective_hidden_scale_a;
58 int32_t effective_hidden_scale_b;
59 int32_t layer_norm_input_scale_a;
60 int32_t layer_norm_input_scale_b;
61 int32_t layer_norm_forget_scale_a;
62 int32_t layer_norm_forget_scale_b;
63 int32_t layer_norm_cell_scale_a;
64 int32_t layer_norm_cell_scale_b;
65 int32_t layer_norm_output_scale_a;
66 int32_t layer_norm_output_scale_b;
67 // Quantized clip value for cell and projection. Zero value means no clipping.
68 int16_t quantized_cell_clip;
69 int8_t quantized_proj_clip;
70 int32_t hidden_zp;
71 int32_t cell_scale;
72
73 int32_t input_variance_guard;
74 int32_t forget_variance_guard;
75 int32_t cell_variance_guard;
76 int32_t output_variance_guard;
77
78 // Pre-calculate bias + zero_point * weight.
79 // Unabled to use temporary tensors since those are used in Prepare() and
80 // scratch buffer is only allocated after Preapre().
81 std::unique_ptr<int32_t[]> input_to_forget_effective_bias;
82 std::unique_ptr<int32_t[]> recurrent_to_forget_effective_bias;
83 std::unique_ptr<int32_t[]> input_to_cell_effective_bias;
84 std::unique_ptr<int32_t[]> recurrent_to_cell_effective_bias;
85 std::unique_ptr<int32_t[]> input_to_output_effective_bias;
86 std::unique_ptr<int32_t[]> recurrent_to_output_effective_bias;
87 std::unique_ptr<int32_t[]> input_to_input_effective_bias;
88 std::unique_ptr<int32_t[]> recurrent_to_input_effective_bias;
89 std::unique_ptr<int32_t[]> projection_effective_bias;
90
91 // Scale and zero point for intermediate tensors.
92 // Used only in the 8x8_8 case.
93 int32_t intermediate_scale_a[8];
94 int32_t intermediate_scale_b[8];
95 int32_t intermediate_zp[12];
96};
97
98TfLiteStatus EvalFloat(
99 const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
100 const TfLiteTensor* input_to_forget_weights,
101 const TfLiteTensor* input_to_cell_weights,
102 const TfLiteTensor* input_to_output_weights,
103 const TfLiteTensor* recurrent_to_input_weights,
104 const TfLiteTensor* recurrent_to_forget_weights,
105 const TfLiteTensor* recurrent_to_cell_weights,
106 const TfLiteTensor* recurrent_to_output_weights,
107 const TfLiteTensor* cell_to_input_weights,
108 const TfLiteTensor* cell_to_forget_weights,
109 const TfLiteTensor* cell_to_output_weights,
110 const TfLiteTensor* input_layer_norm_coefficients,
111 const TfLiteTensor* forget_layer_norm_coefficients,
112 const TfLiteTensor* cell_layer_norm_coefficients,
113 const TfLiteTensor* output_layer_norm_coefficients,
114 const TfLiteTensor* aux_input,
115 const TfLiteTensor* aux_input_to_input_weights,
116 const TfLiteTensor* aux_input_to_forget_weights,
117 const TfLiteTensor* aux_input_to_cell_weights,
118 const TfLiteTensor* aux_input_to_output_weights,
119 const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
120 const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
121 const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
122 const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
123 int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* output_state,
124 TfLiteTensor* cell_state, TfLiteTensor* output, CpuBackendContext* context);
125
126TfLiteStatus EvalHybrid(
127 const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
128 const TfLiteTensor* input_to_input_weights_ledger,
129 const TfLiteTensor* input_to_forget_weights,
130 const TfLiteTensor* input_to_forget_weights_ledger,
131 const TfLiteTensor* input_to_cell_weights,
132 const TfLiteTensor* input_to_cell_weights_ledger,
133 const TfLiteTensor* input_to_output_weights,
134 const TfLiteTensor* input_to_output_weights_ledger,
135 const TfLiteTensor* recurrent_to_input_weights,
136 const TfLiteTensor* recurrent_to_input_weights_ledger,
137 const TfLiteTensor* recurrent_to_forget_weights,
138 const TfLiteTensor* recurrent_to_forget_weights_ledger,
139 const TfLiteTensor* recurrent_to_cell_weights,
140 const TfLiteTensor* recurrent_to_cell_weights_ledger,
141 const TfLiteTensor* recurrent_to_output_weights,
142 const TfLiteTensor* recurrent_to_output_weights_ledger,
143 const TfLiteTensor* cell_to_input_weights,
144 const TfLiteTensor* cell_to_forget_weights,
145 const TfLiteTensor* cell_to_output_weights,
146 const TfLiteTensor* input_layer_norm_coefficients,
147 const TfLiteTensor* forget_layer_norm_coefficients,
148 const TfLiteTensor* cell_layer_norm_coefficients,
149 const TfLiteTensor* output_layer_norm_coefficients,
150 const TfLiteTensor* aux_input,
151 const TfLiteTensor* aux_input_to_input_weights,
152 const TfLiteTensor* aux_input_to_forget_weights,
153 const TfLiteTensor* aux_input_to_cell_weights,
154 const TfLiteTensor* aux_input_to_output_weights,
155 const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
156 const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
157 const TfLiteTensor* projection_weights,
158 const TfLiteTensor* projection_weights_ledger,
159 const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params,
160 bool forward_sequence, bool time_major, int output_offset,
161 TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
162 TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf,
163 TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
164 TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
165 TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
166 TfLiteTensor* output_state, TfLiteTensor* cell_state,
167 TfLiteTensor* output_scratch_buffer, TfLiteTensor* output,
168 TfLiteTensor* input_zp, TfLiteTensor* aux_input_zp,
169 TfLiteTensor* output_state_zp, TfLiteTensor* row_sums, int row_sums_size,
170 bool* compute_row_sums, CpuBackendContext* context);
171
172TfLiteStatus EvalInteger8x8_16(
173 const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
174 const TfLiteTensor* input_to_forget_weights,
175 const TfLiteTensor* input_to_cell_weights,
176 const TfLiteTensor* input_to_output_weights,
177 const TfLiteTensor* recurrent_to_input_weights,
178 const TfLiteTensor* recurrent_to_forget_weights,
179 const TfLiteTensor* recurrent_to_cell_weights,
180 const TfLiteTensor* recurrent_to_output_weights,
181 const TfLiteTensor* cell_to_input_weights,
182 const TfLiteTensor* cell_to_forget_weights,
183 const TfLiteTensor* cell_to_output_weights,
184 const TfLiteTensor* input_layer_norm_coefficients,
185 const TfLiteTensor* forget_layer_norm_coefficients,
186 const TfLiteTensor* cell_layer_norm_coefficients,
187 const TfLiteTensor* output_layer_norm_coefficients,
188 const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
189 const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
190 const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
191 const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
192 const lstm_eval::IntegerLstmParameter* integer_lstm_param,
193 TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output,
194 TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
195 TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5,
196 CpuBackendContext* context);
197
198TfLiteStatus EvalInteger8x8_8(
199 const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
200 const TfLiteTensor* input_to_forget_weights,
201 const TfLiteTensor* input_to_cell_weights,
202 const TfLiteTensor* input_to_output_weights,
203 const TfLiteTensor* recurrent_to_input_weights,
204 const TfLiteTensor* recurrent_to_forget_weights,
205 const TfLiteTensor* recurrent_to_cell_weights,
206 const TfLiteTensor* recurrent_to_output_weights,
207 const TfLiteTensor* cell_to_input_weights,
208 const TfLiteTensor* cell_to_forget_weights,
209 const TfLiteTensor* cell_to_output_weights,
210 const TfLiteTensor* input_layer_norm_coefficients,
211 const TfLiteTensor* forget_layer_norm_coefficients,
212 const TfLiteTensor* cell_layer_norm_coefficients,
213 const TfLiteTensor* output_layer_norm_coefficients,
214 const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
215 const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
216 const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
217 const TfLiteLSTMParams* params, TfLiteTensor* output_state,
218 TfLiteTensor* cell_state, TfLiteTensor* output,
219 const lstm_eval::IntegerLstmParameter* integer_lstm_param,
220 TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
221 TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5,
222 TfLiteTensor* scratch6, TfLiteTensor* scratch7);
223
224} // namespace lstm_eval
225} // namespace builtin
226} // namespace ops
227} // namespace tflite
228#endif // TENSORFLOW_LITE_KERNELS_LSTM_EVAL_H_
229