1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/toco/tflite/operator.h" |
16 | |
17 | #include <map> |
18 | #include <memory> |
19 | #include <string> |
20 | #include <utility> |
21 | |
22 | #include "tensorflow/core/framework/attr_value.pb.h" |
23 | #include "tensorflow/core/framework/node_def.pb.h" |
24 | #include "tensorflow/core/framework/op.h" |
25 | #include "tensorflow/core/framework/op_def.pb.h" |
26 | #include "tensorflow/core/util/ptr_util.h" |
27 | |
28 | // TODO(ycling): Consider refactoring to extract the LSTM definition out of |
29 | // graph_transformation module. |
30 | #include "tensorflow/lite/builtin_op_data.h" |
31 | #include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h" |
32 | #include "tensorflow/lite/schema/schema_generated.h" |
33 | #include "tensorflow/lite/toco/graph_transformations/lstm_utils.h" |
34 | #include "tensorflow/lite/toco/model.h" |
35 | #include "tensorflow/lite/toco/tflite/builtin_operator.h" |
36 | #include "tensorflow/lite/toco/tflite/custom_operator.h" |
37 | #include "tensorflow/lite/toco/tflite/simple_operator.h" |
38 | #include "tensorflow/lite/toco/tflite/types.h" |
39 | #include "tensorflow/lite/tools/versioning/op_version.h" |
40 | |
41 | namespace toco { |
42 | |
43 | namespace tflite { |
44 | |
45 | // LINT.IfChange |
46 | |
47 | TfLiteType GetTensorType(const ArrayDataType type) { |
48 | const std::map<ArrayDataType, TfLiteType> tensor_type_map = { |
49 | {ArrayDataType::kBool, kTfLiteBool}, |
50 | {ArrayDataType::kFloat, kTfLiteFloat32}, |
51 | {ArrayDataType::kInt8, kTfLiteInt8}, |
52 | {ArrayDataType::kUint8, kTfLiteUInt8}, |
53 | {ArrayDataType::kInt16, kTfLiteInt16}, |
54 | {ArrayDataType::kUint16, kTfLiteUInt16}, |
55 | {ArrayDataType::kInt32, kTfLiteInt32}, |
56 | {ArrayDataType::kUint32, kTfLiteUInt32}, |
57 | {ArrayDataType::kInt64, kTfLiteInt64}, |
58 | {ArrayDataType::kUint64, kTfLiteUInt64}, |
59 | {ArrayDataType::kString, kTfLiteString}, |
60 | {ArrayDataType::kComplex64, kTfLiteComplex64}, |
61 | {ArrayDataType::kComplex128, kTfLiteComplex128}, |
62 | {ArrayDataType::kFloat16, kTfLiteFloat16}, |
63 | {ArrayDataType::kFloat64, kTfLiteFloat64}}; |
64 | |
65 | auto it = tensor_type_map.find(type); |
66 | if (it != tensor_type_map.end()) { |
67 | return it->second; |
68 | } |
69 | return kTfLiteNoType; |
70 | } |
71 | |
72 | ::tflite::OpSignature GetVersioningOpSig( |
73 | const ::tflite::BuiltinOperator op, const OperatorSignature& op_signature) { |
74 | std::vector<::tflite::OpSignatureTensorSpec> inputs, outputs; |
75 | for (const auto& input_name : op_signature.op->inputs) { |
76 | ::tflite::OpSignatureTensorSpec tensor = {kTfLiteNoType}; |
77 | if (op_signature.model->HasArray(input_name)) { |
78 | const Array& input_array = op_signature.model->GetArray(input_name); |
79 | tensor.type = GetTensorType(input_array.data_type); |
80 | if (input_array.has_shape()) { |
81 | tensor.dims = input_array.shape().dims(); |
82 | } |
83 | } |
84 | inputs.push_back(tensor); |
85 | } |
86 | for (const auto& output_name : op_signature.op->outputs) { |
87 | ::tflite::OpSignatureTensorSpec tensor = {kTfLiteNoType}; |
88 | if (op_signature.model->HasArray(output_name)) { |
89 | const Array& output_array = op_signature.model->GetArray(output_name); |
90 | tensor.type = GetTensorType(output_array.data_type); |
91 | if (output_array.has_shape()) { |
92 | tensor.dims = output_array.shape().dims(); |
93 | } |
94 | } |
95 | outputs.push_back(tensor); |
96 | } |
97 | return ::tflite::OpSignature{op, inputs, outputs}; |
98 | } |
99 | |
100 | class AveragePool |
101 | : public BuiltinOperator<AveragePoolOperator, ::tflite::Pool2DOptions, |
102 | ::tflite::BuiltinOptions_Pool2DOptions> { |
103 | public: |
104 | using BuiltinOperator::BuiltinOperator; |
105 | |
106 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
107 | const TocoOperator& op, |
108 | flatbuffers::FlatBufferBuilder* builder) const override { |
109 | auto padding = Padding::Serialize(op.padding.type); |
110 | auto activation_function = |
111 | ActivationFunction::Serialize(op.fused_activation_function); |
112 | return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width, |
113 | op.stride_height, op.kwidth, |
114 | op.kheight, activation_function); |
115 | } |
116 | |
117 | void ReadOptions(const TfLiteOptions& options, |
118 | TocoOperator* op) const override { |
119 | op->padding.type = Padding::Deserialize(options.padding()); |
120 | op->stride_width = options.stride_w(); |
121 | op->stride_height = options.stride_h(); |
122 | op->kwidth = options.filter_width(); |
123 | op->kheight = options.filter_height(); |
124 | op->fused_activation_function = |
125 | ActivationFunction::Deserialize(options.fused_activation_function()); |
126 | } |
127 | }; |
128 | |
129 | class Convolution |
130 | : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions, |
131 | ::tflite::BuiltinOptions_Conv2DOptions> { |
132 | public: |
133 | using BuiltinOperator::BuiltinOperator; |
134 | |
135 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
136 | const TocoOperator& op, |
137 | flatbuffers::FlatBufferBuilder* builder) const override { |
138 | auto padding = Padding::Serialize(op.padding.type); |
139 | auto activation_function = |
140 | ActivationFunction::Serialize(op.fused_activation_function); |
141 | return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width, |
142 | op.stride_height, activation_function, |
143 | op.dilation_width_factor, |
144 | op.dilation_height_factor); |
145 | } |
146 | |
147 | void ReadOptions(const TfLiteOptions& options, |
148 | TocoOperator* op) const override { |
149 | op->padding.type = Padding::Deserialize(options.padding()); |
150 | op->stride_width = options.stride_w(); |
151 | op->stride_height = options.stride_h(); |
152 | op->dilation_width_factor = options.dilation_w_factor(); |
153 | op->dilation_height_factor = options.dilation_h_factor(); |
154 | op->fused_activation_function = |
155 | ActivationFunction::Deserialize(options.fused_activation_function()); |
156 | } |
157 | }; |
158 | |
159 | class DepthwiseConvolution |
160 | : public BuiltinOperator<DepthwiseConvOperator, |
161 | ::tflite::DepthwiseConv2DOptions, |
162 | ::tflite::BuiltinOptions_DepthwiseConv2DOptions> { |
163 | public: |
164 | using BuiltinOperator::BuiltinOperator; |
165 | |
166 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
167 | const TocoOperator& op, |
168 | flatbuffers::FlatBufferBuilder* builder) const override { |
169 | auto padding = Padding::Serialize(op.padding.type); |
170 | auto activation_function = |
171 | ActivationFunction::Serialize(op.fused_activation_function); |
172 | return ::tflite::CreateDepthwiseConv2DOptions( |
173 | *builder, padding, op.stride_width, op.stride_height, |
174 | op.depth_multiplier, activation_function, op.dilation_width_factor, |
175 | op.dilation_height_factor); |
176 | } |
177 | |
178 | void ReadOptions(const TfLiteOptions& options, |
179 | TocoOperator* op) const override { |
180 | op->padding.type = Padding::Deserialize(options.padding()); |
181 | op->stride_width = options.stride_w(); |
182 | op->stride_height = options.stride_h(); |
183 | op->depth_multiplier = options.depth_multiplier(); |
184 | op->fused_activation_function = |
185 | ActivationFunction::Deserialize(options.fused_activation_function()); |
186 | op->dilation_width_factor = options.dilation_w_factor(); |
187 | op->dilation_height_factor = options.dilation_h_factor(); |
188 | } |
189 | |
190 | int GetVersion(const OperatorSignature& op_signature) const override { |
191 | const auto& conv_op = |
192 | static_cast<const DepthwiseConvOperator&>(*op_signature.op); |
193 | ::tflite::OpSignature op_sig = |
194 | GetVersioningOpSig(builtin_op(), op_signature); |
195 | TfLiteDepthwiseConvParams depthwise_conv_params = {}; |
196 | depthwise_conv_params.dilation_width_factor = conv_op.dilation_width_factor; |
197 | depthwise_conv_params.dilation_height_factor = |
198 | conv_op.dilation_height_factor; |
199 | op_sig.builtin_data = reinterpret_cast<void*>(&depthwise_conv_params); |
200 | return ::tflite::GetBuiltinOperatorVersion(op_sig); |
201 | } |
202 | }; |
203 | |
204 | class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions, |
205 | ::tflite::BuiltinOptions_AddOptions> { |
206 | public: |
207 | using BuiltinOperator::BuiltinOperator; |
208 | |
209 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
210 | const TocoOperator& op, |
211 | flatbuffers::FlatBufferBuilder* builder) const override { |
212 | auto activation_function = |
213 | ActivationFunction::Serialize(op.fused_activation_function); |
214 | return ::tflite::CreateAddOptions(*builder, activation_function); |
215 | } |
216 | |
217 | void ReadOptions(const TfLiteOptions& options, |
218 | TocoOperator* op) const override { |
219 | op->fused_activation_function = |
220 | ActivationFunction::Deserialize(options.fused_activation_function()); |
221 | } |
222 | }; |
223 | |
224 | class AddN : public BuiltinOperator<AddNOperator, ::tflite::AddNOptions, |
225 | ::tflite::BuiltinOptions_AddNOptions> { |
226 | public: |
227 | using BuiltinOperator::BuiltinOperator; |
228 | |
229 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
230 | const TocoOperator& op, |
231 | flatbuffers::FlatBufferBuilder* builder) const override { |
232 | return ::tflite::CreateAddNOptions(*builder); |
233 | } |
234 | |
235 | void ReadOptions(const TfLiteOptions& options, |
236 | TocoOperator* op) const override {} |
237 | }; |
238 | |
239 | class SpaceToBatchND |
240 | : public BuiltinOperator<SpaceToBatchNDOperator, |
241 | ::tflite::SpaceToBatchNDOptions, |
242 | ::tflite::BuiltinOptions_SpaceToBatchNDOptions> { |
243 | public: |
244 | using BuiltinOperator::BuiltinOperator; |
245 | |
246 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
247 | const TocoOperator& op, |
248 | flatbuffers::FlatBufferBuilder* builder) const override { |
249 | return ::tflite::CreateSpaceToBatchNDOptions(*builder); |
250 | } |
251 | |
252 | void ReadOptions(const TfLiteOptions& options, |
253 | TocoOperator* op) const override {} |
254 | |
255 | int GetVersion(const OperatorSignature& op_signature) const override { |
256 | ::tflite::OpSignature op_sig = |
257 | GetVersioningOpSig(builtin_op(), op_signature); |
258 | return ::tflite::GetBuiltinOperatorVersion(op_sig); |
259 | } |
260 | }; |
261 | |
262 | class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions, |
263 | ::tflite::BuiltinOptions_SubOptions> { |
264 | public: |
265 | using BuiltinOperator::BuiltinOperator; |
266 | |
267 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
268 | const TocoOperator& op, |
269 | flatbuffers::FlatBufferBuilder* builder) const override { |
270 | auto activation_function = |
271 | ActivationFunction::Serialize(op.fused_activation_function); |
272 | return ::tflite::CreateSubOptions(*builder, activation_function); |
273 | } |
274 | |
275 | void ReadOptions(const TfLiteOptions& options, |
276 | TocoOperator* op) const override { |
277 | op->fused_activation_function = |
278 | ActivationFunction::Deserialize(options.fused_activation_function()); |
279 | } |
280 | |
281 | int GetVersion(const OperatorSignature& op_signature) const override { |
282 | ::tflite::OpSignature op_sig = |
283 | GetVersioningOpSig(builtin_op(), op_signature); |
284 | return ::tflite::GetBuiltinOperatorVersion(op_sig); |
285 | } |
286 | }; |
287 | |
288 | class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions, |
289 | ::tflite::BuiltinOptions_DivOptions> { |
290 | public: |
291 | using BuiltinOperator::BuiltinOperator; |
292 | |
293 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
294 | const TocoOperator& op, |
295 | flatbuffers::FlatBufferBuilder* builder) const override { |
296 | auto activation_function = |
297 | ActivationFunction::Serialize(op.fused_activation_function); |
298 | return ::tflite::CreateDivOptions(*builder, activation_function); |
299 | } |
300 | |
301 | void ReadOptions(const TfLiteOptions& options, |
302 | TocoOperator* op) const override { |
303 | op->fused_activation_function = |
304 | ActivationFunction::Deserialize(options.fused_activation_function()); |
305 | } |
306 | |
307 | int GetVersion(const OperatorSignature& op_signature) const override { |
308 | ::tflite::OpSignature op_sig = |
309 | GetVersioningOpSig(builtin_op(), op_signature); |
310 | return ::tflite::GetBuiltinOperatorVersion(op_sig); |
311 | } |
312 | }; |
313 | |
314 | class BatchToSpaceND |
315 | : public BuiltinOperator<BatchToSpaceNDOperator, |
316 | ::tflite::BatchToSpaceNDOptions, |
317 | ::tflite::BuiltinOptions_BatchToSpaceNDOptions> { |
318 | public: |
319 | using BuiltinOperator::BuiltinOperator; |
320 | |
321 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
322 | const TocoOperator& op, |
323 | flatbuffers::FlatBufferBuilder* builder) const override { |
324 | return ::tflite::CreateBatchToSpaceNDOptions(*builder); |
325 | } |
326 | |
327 | void ReadOptions(const TfLiteOptions& options, |
328 | TocoOperator* op) const override {} |
329 | |
330 | int GetVersion(const OperatorSignature& op_signature) const override { |
331 | ::tflite::OpSignature op_sig = |
332 | GetVersioningOpSig(builtin_op(), op_signature); |
333 | return ::tflite::GetBuiltinOperatorVersion(op_sig); |
334 | } |
335 | }; |
336 | |
337 | class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions, |
338 | ::tflite::BuiltinOptions_CastOptions> { |
339 | public: |
340 | using BuiltinOperator::BuiltinOperator; |
341 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
342 | const TocoOperator& op, |
343 | flatbuffers::FlatBufferBuilder* builder) const override { |
344 | return ::tflite::CreateCastOptions(*builder, |
345 | DataType::Serialize(op.src_data_type), |
346 | DataType::Serialize(op.dst_data_type)); |
347 | } |
348 | |
349 | void ReadOptions(const TfLiteOptions& options, |
350 | TocoOperator* op) const override { |
351 | op->src_data_type = DataType::Deserialize(options.in_data_type()); |
352 | op->dst_data_type = DataType::Deserialize(options.out_data_type()); |
353 | } |
354 | }; |
355 | |
356 | class Concatenation |
357 | : public BuiltinOperator<ConcatenationOperator, |
358 | ::tflite::ConcatenationOptions, |
359 | ::tflite::BuiltinOptions_ConcatenationOptions> { |
360 | public: |
361 | using BuiltinOperator::BuiltinOperator; |
362 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
363 | const TocoOperator& op, |
364 | flatbuffers::FlatBufferBuilder* builder) const override { |
365 | return ::tflite::CreateConcatenationOptions(*builder, op.axis); |
366 | } |
367 | |
368 | void ReadOptions(const TfLiteOptions& options, |
369 | TocoOperator* op) const override { |
370 | op->axis = options.axis(); |
371 | } |
372 | }; |
373 | |
374 | class DepthToSpace |
375 | : public BuiltinOperator<DepthToSpaceOperator, |
376 | ::tflite::DepthToSpaceOptions, |
377 | ::tflite::BuiltinOptions_DepthToSpaceOptions> { |
378 | public: |
379 | using BuiltinOperator::BuiltinOperator; |
380 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
381 | const TocoOperator& op, |
382 | flatbuffers::FlatBufferBuilder* builder) const override { |
383 | return ::tflite::CreateDepthToSpaceOptions(*builder, op.block_size); |
384 | } |
385 | |
386 | void ReadOptions(const TfLiteOptions& options, |
387 | TocoOperator* op) const override { |
388 | op->block_size = options.block_size(); |
389 | } |
390 | }; |
391 | |
392 | class FakeQuant |
393 | : public BuiltinOperator<FakeQuantOperator, ::tflite::FakeQuantOptions, |
394 | ::tflite::BuiltinOptions_FakeQuantOptions> { |
395 | public: |
396 | using BuiltinOperator::BuiltinOperator; |
397 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
398 | const TocoOperator& op, |
399 | flatbuffers::FlatBufferBuilder* builder) const override { |
400 | return ::tflite::CreateFakeQuantOptions( |
401 | *builder, op.minmax->min, op.minmax->max, op.num_bits, op.narrow_range); |
402 | } |
403 | void ReadOptions(const TfLiteOptions& options, |
404 | TocoOperator* op) const override { |
405 | auto* minmax = new MinMax; |
406 | minmax->min = options.min(); |
407 | minmax->max = options.max(); |
408 | op->minmax.reset(minmax); |
409 | op->num_bits = options.num_bits(); |
410 | op->narrow_range = options.narrow_range(); |
411 | } |
412 | int GetVersion(const OperatorSignature& op_signature) const override { |
413 | const auto& fq_op = static_cast<const FakeQuantOperator&>(*op_signature.op); |
414 | ::tflite::OpSignature op_sig = |
415 | GetVersioningOpSig(builtin_op(), op_signature); |
416 | TfLiteFakeQuantParams fake_quant_params = {}; |
417 | fake_quant_params.narrow_range = fq_op.narrow_range; |
418 | op_sig.builtin_data = reinterpret_cast<void*>(&fake_quant_params); |
419 | return ::tflite::GetBuiltinOperatorVersion(op_sig); |
420 | } |
421 | }; |
422 | |
423 | class FullyConnected |
424 | : public BuiltinOperator<FullyConnectedOperator, |
425 | ::tflite::FullyConnectedOptions, |
426 | ::tflite::BuiltinOptions_FullyConnectedOptions> { |
427 | public: |
428 | using BuiltinOperator::BuiltinOperator; |
429 | |
430 | ::tflite::FullyConnectedOptionsWeightsFormat GetWeightFormat( |
431 | FullyConnectedWeightsFormat fmt) const { |
432 | switch (fmt) { |
433 | case FullyConnectedWeightsFormat::kDefault: |
434 | return ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT; |
435 | case FullyConnectedWeightsFormat::kShuffled4x16Int8: |
436 | return ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8; |
437 | default: |
438 | LOG(ERROR) << "Unhandled FC weights format" ; |
439 | return ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT; |
440 | } |
441 | } |
442 | |
443 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
444 | const TocoOperator& op, |
445 | flatbuffers::FlatBufferBuilder* builder) const override { |
446 | auto activation_function = |
447 | ActivationFunction::Serialize(op.fused_activation_function); |
448 | return ::tflite::CreateFullyConnectedOptions( |
449 | *builder, activation_function, GetWeightFormat(op.weights_format)); |
450 | } |
451 | |
452 | void ReadOptions(const TfLiteOptions& options, |
453 | TocoOperator* op) const override { |
454 | op->fused_activation_function = |
455 | ActivationFunction::Deserialize(options.fused_activation_function()); |
456 | switch (options.weights_format()) { |
457 | case ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT: |
458 | op->weights_format = FullyConnectedWeightsFormat::kDefault; |
459 | break; |
460 | case ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8: |
461 | op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8; |
462 | break; |
463 | default: |
464 | LOG(ERROR) << "Unhandled FC weights format" ; |
465 | op->weights_format = FullyConnectedWeightsFormat::kDefault; |
466 | } |
467 | } |
468 | |
469 | int GetVersion(const OperatorSignature& op_signature) const override { |
470 | const auto& fc_op = |
471 | static_cast<const FullyConnectedOperator&>(*op_signature.op); |
472 | ::tflite::OpSignature op_sig = |
473 | GetVersioningOpSig(builtin_op(), op_signature); |
474 | TfLiteFullyConnectedParams fully_connected_params = {}; |
475 | fully_connected_params.keep_num_dims = fc_op.keep_num_dims; |
476 | fully_connected_params.weights_format = |
477 | static_cast<TfLiteFullyConnectedWeightsFormat>( |
478 | GetWeightFormat(fc_op.weights_format)); |
479 | op_sig.builtin_data = reinterpret_cast<void*>(&fully_connected_params); |
480 | return ::tflite::GetBuiltinOperatorVersion(op_sig); |
481 | } |
482 | }; |
483 | |
484 | class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions, |
485 | ::tflite::BuiltinOptions_GatherOptions> { |
486 | public: |
487 | using BuiltinOperator::BuiltinOperator; |
488 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
489 | const TocoOperator& op, |
490 | flatbuffers::FlatBufferBuilder* builder) const override { |
491 | int axis = op.axis ? op.axis.value() : 0; |
492 | return ::tflite::CreateGatherOptions(*builder, axis); |
493 | } |
494 | |
495 | void ReadOptions(const TfLiteOptions& options, |
496 | TocoOperator* op) const override { |
497 | op->axis = {options.axis()}; |
498 | } |
499 | }; |
500 | |
501 | class GatherNd |
502 | : public BuiltinOperator<GatherNdOperator, ::tflite::GatherNdOptions, |
503 | ::tflite::BuiltinOptions_GatherNdOptions> { |
504 | public: |
505 | using BuiltinOperator::BuiltinOperator; |
506 | |
507 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
508 | const TocoOperator& op, |
509 | flatbuffers::FlatBufferBuilder* builder) const override { |
510 | return ::tflite::CreateGatherNdOptions(*builder); |
511 | } |
512 | |
513 | void ReadOptions(const TfLiteOptions& options, |
514 | TocoOperator* op) const override {} |
515 | }; |
516 | |
517 | class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions, |
518 | ::tflite::BuiltinOptions_SVDFOptions> { |
519 | public: |
520 | using BuiltinOperator::BuiltinOperator; |
521 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
522 | const TocoOperator& op, |
523 | flatbuffers::FlatBufferBuilder* builder) const override { |
524 | auto activation_function = |
525 | ActivationFunction::Serialize(op.fused_activation_function); |
526 | return ::tflite::CreateSVDFOptions(*builder, op.rank, activation_function); |
527 | } |
528 | |
529 | void ReadOptions(const TfLiteOptions& options, |
530 | TocoOperator* op) const override { |
531 | op->fused_activation_function = |
532 | ActivationFunction::Deserialize(options.fused_activation_function()); |
533 | op->rank = options.rank(); |
534 | } |
535 | }; |
536 | |
537 | class L2Normalization |
538 | : public BuiltinOperator<L2NormalizationOperator, ::tflite::L2NormOptions, |
539 | ::tflite::BuiltinOptions_L2NormOptions> { |
540 | public: |
541 | using BuiltinOperator::BuiltinOperator; |
542 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
543 | const TocoOperator& op, |
544 | flatbuffers::FlatBufferBuilder* builder) const override { |
545 | auto activation_function = |
546 | ActivationFunction::Serialize(op.fused_activation_function); |
547 | return ::tflite::CreateL2NormOptions(*builder, activation_function); |
548 | } |
549 | |
550 | void ReadOptions(const TfLiteOptions& options, |
551 | TocoOperator* op) const override { |
552 | op->fused_activation_function = |
553 | ActivationFunction::Deserialize(options.fused_activation_function()); |
554 | } |
555 | }; |
556 | |
557 | class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions, |
558 | ::tflite::BuiltinOptions_Pool2DOptions> { |
559 | public: |
560 | using BuiltinOperator::BuiltinOperator; |
561 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
562 | const TocoOperator& op, |
563 | flatbuffers::FlatBufferBuilder* builder) const override { |
564 | auto padding = Padding::Serialize(op.padding.type); |
565 | auto activation_function = |
566 | ActivationFunction::Serialize(op.fused_activation_function); |
567 | return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width, |
568 | op.stride_height, op.kwidth, |
569 | op.kheight, activation_function); |
570 | } |
571 | |
572 | void ReadOptions(const TfLiteOptions& options, |
573 | TocoOperator* op) const override { |
574 | op->padding.type = Padding::Deserialize(options.padding()); |
575 | op->stride_width = options.stride_w(); |
576 | op->stride_height = options.stride_h(); |
577 | op->kwidth = options.filter_width(); |
578 | op->kheight = options.filter_height(); |
579 | op->fused_activation_function = |
580 | ActivationFunction::Deserialize(options.fused_activation_function()); |
581 | } |
582 | }; |
583 | |
584 | class LocalResponseNormalization |
585 | : public BuiltinOperator< |
586 | LocalResponseNormalizationOperator, |
587 | ::tflite::LocalResponseNormalizationOptions, |
588 | ::tflite::BuiltinOptions_LocalResponseNormalizationOptions> { |
589 | public: |
590 | using BuiltinOperator::BuiltinOperator; |
591 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
592 | const TocoOperator& op, |
593 | flatbuffers::FlatBufferBuilder* builder) const override { |
594 | return ::tflite::CreateLocalResponseNormalizationOptions( |
595 | *builder, op.range, op.bias, op.alpha, op.beta); |
596 | } |
597 | |
598 | void ReadOptions(const TfLiteOptions& options, |
599 | TocoOperator* op) const override { |
600 | op->range = options.radius(); |
601 | op->bias = options.bias(); |
602 | op->alpha = options.alpha(); |
603 | op->beta = options.beta(); |
604 | } |
605 | }; |
606 | |
607 | class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions, |
608 | ::tflite::BuiltinOptions_Pool2DOptions> { |
609 | public: |
610 | using BuiltinOperator::BuiltinOperator; |
611 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
612 | const TocoOperator& op, |
613 | flatbuffers::FlatBufferBuilder* builder) const override { |
614 | auto padding = Padding::Serialize(op.padding.type); |
615 | auto activation_function = |
616 | ActivationFunction::Serialize(op.fused_activation_function); |
617 | return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width, |
618 | op.stride_height, op.kwidth, |
619 | op.kheight, activation_function); |
620 | } |
621 | |
622 | void ReadOptions(const TfLiteOptions& options, |
623 | TocoOperator* op) const override { |
624 | op->padding.type = Padding::Deserialize(options.padding()); |
625 | op->stride_width = options.stride_w(); |
626 | op->stride_height = options.stride_h(); |
627 | op->kwidth = options.filter_width(); |
628 | op->kheight = options.filter_height(); |
629 | op->fused_activation_function = |
630 | ActivationFunction::Deserialize(options.fused_activation_function()); |
631 | } |
632 | }; |
633 | |
634 | class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions, |
635 | ::tflite::BuiltinOptions_MulOptions> { |
636 | public: |
637 | using BuiltinOperator::BuiltinOperator; |
638 | |
639 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
640 | const TocoOperator& op, |
641 | flatbuffers::FlatBufferBuilder* builder) const override { |
642 | auto activation_function = |
643 | ActivationFunction::Serialize(op.fused_activation_function); |
644 | return ::tflite::CreateMulOptions(*builder, activation_function); |
645 | } |
646 | |
647 | void ReadOptions(const TfLiteOptions& options, |
648 | TocoOperator* op) const override { |
649 | op->fused_activation_function = |
650 | ActivationFunction::Deserialize(options.fused_activation_function()); |
651 | } |
652 | |
653 | int GetVersion(const OperatorSignature& op_signature) const override { |
654 | const std::string& input1_name = op_signature.op->inputs[0]; |
655 | const std::string& input2_name = op_signature.op->inputs[1]; |
656 | const std::string& output_name = op_signature.op->outputs[0]; |
657 | const Array& input1_array = op_signature.model->GetArray(input1_name); |
658 | const Array& input2_array = op_signature.model->GetArray(input2_name); |
659 | const Array& output_array = op_signature.model->GetArray(output_name); |
660 | const auto& input1_quant = input1_array.quantization_params; |
661 | const auto& input2_quant = input2_array.quantization_params; |
662 | const auto& output_quant = output_array.quantization_params; |
663 | const float input1_scale = input1_quant ? input1_quant->scale : 0.0f; |
664 | const float input2_scale = input2_quant ? input2_quant->scale : 0.0f; |
665 | const float output_scale = output_quant ? output_quant->scale : 0.0f; |
666 | ::tflite::OpSignature op_sig = |
667 | GetVersioningOpSig(builtin_op(), op_signature); |
668 | op_sig.ext_options.mul.input1_scale = input1_scale; |
669 | op_sig.ext_options.mul.input2_scale = input2_scale; |
670 | op_sig.ext_options.mul.output_scale = output_scale; |
671 | return ::tflite::GetBuiltinOperatorVersion(op_sig); |
672 | } |
673 | }; |
674 | |
675 | class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions, |
676 | ::tflite::BuiltinOptions_PadOptions> { |
677 | public: |
678 | using BuiltinOperator::BuiltinOperator; |
679 | |
680 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
681 | const TocoOperator& op, |
682 | flatbuffers::FlatBufferBuilder* builder) const override { |
683 | return ::tflite::CreatePadOptions(*builder); |
684 | } |
685 | |
686 | void ReadOptions(const TfLiteOptions& options, |
687 | TocoOperator* op) const override {} |
688 | }; |
689 | |
690 | class Tile |
691 | : public BuiltinOperator<TensorFlowTileOperator, ::tflite::TileOptions, |
692 | ::tflite::BuiltinOptions_TileOptions> { |
693 | using BuiltinOperator::BuiltinOperator; |
694 | |
695 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
696 | const TocoOperator& op, |
697 | flatbuffers::FlatBufferBuilder* builder) const override { |
698 | return ::tflite::CreateTileOptions(*builder); |
699 | } |
700 | |
701 | void ReadOptions(const TfLiteOptions& options, |
702 | TocoOperator* op) const override {} |
703 | }; |
704 | |
705 | class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options, |
706 | ::tflite::BuiltinOptions_PadV2Options> { |
707 | public: |
708 | using BuiltinOperator::BuiltinOperator; |
709 | |
710 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
711 | const TocoOperator& op, |
712 | flatbuffers::FlatBufferBuilder* builder) const override { |
713 | return ::tflite::CreatePadV2Options(*builder); |
714 | } |
715 | |
716 | void ReadOptions(const TfLiteOptions& options, |
717 | TocoOperator* op) const override {} |
718 | }; |
719 | |
720 | class Reshape |
721 | : public BuiltinOperator<TensorFlowReshapeOperator, |
722 | ::tflite::ReshapeOptions, |
723 | ::tflite::BuiltinOptions_ReshapeOptions> { |
724 | public: |
725 | using BuiltinOperator::BuiltinOperator; |
726 | |
727 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
728 | const TocoOperator& op, |
729 | flatbuffers::FlatBufferBuilder* builder) const override { |
730 | return ::tflite::CreateReshapeOptions(*builder, |
731 | builder->CreateVector(op.shape)); |
732 | } |
733 | |
734 | void ReadOptions(const TfLiteOptions& options, |
735 | TocoOperator* op) const override { |
736 | op->shape.insert(op->shape.end(), options.new_shape()->begin(), |
737 | options.new_shape()->end()); |
738 | } |
739 | }; |
740 | |
741 | class Softmax |
742 | : public BuiltinOperator<SoftmaxOperator, ::tflite::SoftmaxOptions, |
743 | ::tflite::BuiltinOptions_SoftmaxOptions> { |
744 | public: |
745 | using BuiltinOperator::BuiltinOperator; |
746 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
747 | const TocoOperator& op, |
748 | flatbuffers::FlatBufferBuilder* builder) const override { |
749 | return ::tflite::CreateSoftmaxOptions(*builder, op.beta); |
750 | } |
751 | |
752 | void ReadOptions(const TfLiteOptions& options, |
753 | TocoOperator* op) const override { |
754 | op->beta = options.beta(); |
755 | } |
756 | }; |
757 | |
758 | class SpaceToDepth |
759 | : public BuiltinOperator<SpaceToDepthOperator, |
760 | ::tflite::SpaceToDepthOptions, |
761 | ::tflite::BuiltinOptions_SpaceToDepthOptions> { |
762 | public: |
763 | using BuiltinOperator::BuiltinOperator; |
764 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
765 | const TocoOperator& op, |
766 | flatbuffers::FlatBufferBuilder* builder) const override { |
767 | return ::tflite::CreateSpaceToDepthOptions(*builder, op.block_size); |
768 | } |
769 | |
770 | void ReadOptions(const TfLiteOptions& options, |
771 | TocoOperator* op) const override { |
772 | op->block_size = options.block_size(); |
773 | } |
774 | }; |
775 | |
776 | class Transpose |
777 | : public BuiltinOperator<TransposeOperator, ::tflite::TransposeOptions, |
778 | ::tflite::BuiltinOptions_TransposeOptions> { |
779 | public: |
780 | using BuiltinOperator::BuiltinOperator; |
781 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
782 | const TocoOperator& op, |
783 | flatbuffers::FlatBufferBuilder* builder) const override { |
784 | return ::tflite::CreateTransposeOptions(*builder); |
785 | } |
786 | |
787 | void ReadOptions(const TfLiteOptions& options, |
788 | TocoOperator* op) const override {} |
789 | }; |
790 | |
791 | class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions, |
792 | ::tflite::BuiltinOptions_LSTMOptions> { |
793 | public: |
794 | using BuiltinOperator::BuiltinOperator; |
795 | |
796 | ::tflite::LSTMKernelType GetKernelType( |
797 | LstmCellOperator::KernelType type) const { |
798 | switch (type) { |
799 | case LstmCellOperator::KERNEL_BASIC: |
800 | return ::tflite::LSTMKernelType_BASIC; |
801 | break; |
802 | case LstmCellOperator::KERNEL_FULL: |
803 | return ::tflite::LSTMKernelType_FULL; |
804 | break; |
805 | default: |
806 | LOG(ERROR) << "Unhandled Kernel Type" ; |
807 | return static_cast<::tflite::LSTMKernelType>(-1); |
808 | } |
809 | } |
810 | |
811 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
812 | const TocoOperator& op, |
813 | flatbuffers::FlatBufferBuilder* builder) const override { |
814 | ::tflite::LSTMKernelType kernel_type = GetKernelType(op.kernel_type); |
815 | |
816 | // Current toco converter only supports tanh, no clip. |
817 | return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/ |
818 | ::tflite::ActivationFunctionType_TANH, |
819 | /*cell_clip=*/0.0, |
820 | /*proj_clip=*/0.0, kernel_type); |
821 | } |
822 | |
823 | void ReadOptions(const TfLiteOptions& options, |
824 | TocoOperator* op) const override { |
825 | // Only support tanh activation, so check that tflite type is tanh. |
826 | CHECK(options.fused_activation_function() == |
827 | ::tflite::ActivationFunctionType_TANH); |
828 | |
829 | switch (options.kernel_type()) { |
830 | case ::tflite::LSTMKernelType_BASIC: |
831 | op->kernel_type = LstmCellOperator::KERNEL_BASIC; |
832 | break; |
833 | case ::tflite::LSTMKernelType_FULL: |
834 | op->kernel_type = LstmCellOperator::KERNEL_FULL; |
835 | break; |
836 | } |
837 | } |
838 | |
839 | int GetVersion(const OperatorSignature& op_signature) const override { |
840 | const auto& lstm_op = |
841 | static_cast<const LstmCellOperator&>(*op_signature.op); |
842 | ::tflite::OpSignature op_sig = |
843 | GetVersioningOpSig(builtin_op(), op_signature); |
844 | TfLiteLSTMParams lstm_params = {}; |
845 | lstm_params.kernel_type = |
846 | static_cast<TfLiteLSTMKernelType>(GetKernelType(lstm_op.kernel_type)); |
847 | op_sig.builtin_data = reinterpret_cast<void*>(&lstm_params); |
848 | return ::tflite::GetBuiltinOperatorVersion(op_sig); |
849 | } |
850 | |
851 | std::vector<bool> GetMutatingInputVariables( |
852 | const Operator& op) const override { |
853 | const auto& lstm_op = static_cast<const LstmCellOperator&>(op); |
854 | |
855 | std::vector<bool> mutating_input_variables(op.inputs.size(), false); |
856 | switch (lstm_op.kernel_type) { |
857 | case LstmCellOperator::KERNEL_FULL: { |
858 | mutating_input_variables[kInputActivationStateTensor] = true; |
859 | mutating_input_variables[kInputCellStateTensor] = true; |
860 | break; |
861 | } |
862 | case LstmCellOperator::KERNEL_BASIC: { |
863 | mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true; |
864 | mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true; |
865 | break; |
866 | } |
867 | } |
868 | return mutating_input_variables; |
869 | } |
870 | }; |
871 | |
872 | class UnidirectionalSequenceLstm |
873 | : public BuiltinOperator< |
874 | UnidirectionalSequenceLstmOperator, |
875 | ::tflite::UnidirectionalSequenceLSTMOptions, |
876 | ::tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions> { |
877 | public: |
878 | using BuiltinOperator::BuiltinOperator; |
879 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
880 | const TocoOperator& op, |
881 | flatbuffers::FlatBufferBuilder* builder) const override { |
882 | // Current toco converter only supports tanh, no clip. |
883 | return ::tflite::CreateUnidirectionalSequenceLSTMOptions( |
884 | *builder, /*fused_activation_function=*/ |
885 | ::tflite::ActivationFunctionType_TANH, |
886 | /*cell_clip=*/0.0, |
887 | /*proj_clip=*/0.0, |
888 | /*time_major=*/true); |
889 | } |
890 | |
891 | void ReadOptions(const TfLiteOptions& options, |
892 | TocoOperator* op) const override { |
893 | // Only support tanh activation, so check that tflite type is tanh. |
894 | DCHECK(options.fused_activation_function() == |
895 | ::tflite::ActivationFunctionType_TANH); |
896 | } |
897 | |
898 | std::vector<bool> GetMutatingInputVariables( |
899 | const Operator& op) const override { |
900 | std::vector<bool> mutating_input_variables(op.inputs.size(), false); |
901 | mutating_input_variables[kInputActivationStateTensor] = true; |
902 | mutating_input_variables[kInputCellStateTensor] = true; |
903 | return mutating_input_variables; |
904 | } |
905 | }; |
906 | |
907 | class BidirectionalSequenceLstm |
908 | : public BuiltinOperator< |
909 | BidirectionalSequenceLstmOperator, |
910 | ::tflite::BidirectionalSequenceLSTMOptions, |
911 | ::tflite::BuiltinOptions_BidirectionalSequenceLSTMOptions> { |
912 | public: |
913 | using BuiltinOperator::BuiltinOperator; |
914 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
915 | const TocoOperator& op, |
916 | flatbuffers::FlatBufferBuilder* builder) const override { |
917 | // Current toco converter only supports tanh, no clip. |
918 | return ::tflite::CreateBidirectionalSequenceLSTMOptions( |
919 | *builder, /*fused_activation_function=*/ |
920 | ::tflite::ActivationFunctionType_TANH, |
921 | /*cell_clip=*/0.0, |
922 | /*proj_clip=*/0.0, |
923 | /*merge_outputs=*/op.merge_outputs, |
924 | /*time_major=*/true); |
925 | } |
926 | |
927 | void ReadOptions(const TfLiteOptions& options, |
928 | TocoOperator* op) const override { |
929 | // Only support tanh activation, so check that tflite type is tanh. |
930 | DCHECK(options.fused_activation_function() == |
931 | ::tflite::ActivationFunctionType_TANH); |
932 | op->merge_outputs = options.merge_outputs(); |
933 | } |
934 | |
935 | std::vector<bool> GetMutatingInputVariables( |
936 | const Operator& op) const override { |
937 | std::vector<bool> mutating_input_variables(op.inputs.size(), false); |
938 | // Forward input activation state. |
939 | mutating_input_variables[35] = true; |
940 | // Forward input cell state. |
941 | mutating_input_variables[36] = true; |
942 | // Backward input activation state. |
943 | mutating_input_variables[37] = true; |
944 | // Backward input cell state. |
945 | mutating_input_variables[38] = true; |
946 | return mutating_input_variables; |
947 | } |
948 | }; |
949 | |
950 | class BidirectionalSequenceRnn |
951 | : public BuiltinOperator< |
952 | BidirectionalSequenceRnnOperator, |
953 | ::tflite::BidirectionalSequenceRNNOptions, |
954 | ::tflite::BuiltinOptions_BidirectionalSequenceRNNOptions> { |
955 | public: |
956 | using BuiltinOperator::BuiltinOperator; |
957 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
958 | const TocoOperator& op, |
959 | flatbuffers::FlatBufferBuilder* builder) const override { |
960 | // Current toco converter only supports tanh, no clip. |
961 | return ::tflite::CreateBidirectionalSequenceRNNOptions( |
962 | *builder, /*time_major=*/true, |
963 | /*fused_activation_function=*/ |
964 | ::tflite::ActivationFunctionType_TANH, |
965 | /*merge_outputs=*/op.merge_outputs); |
966 | } |
967 | |
968 | void ReadOptions(const TfLiteOptions& options, |
969 | TocoOperator* op) const override { |
970 | // Only support tanh activation, so check that tflite type is tanh. |
971 | DCHECK(options.fused_activation_function() == |
972 | ::tflite::ActivationFunctionType_TANH); |
973 | op->merge_outputs = options.merge_outputs(); |
974 | } |
975 | |
976 | std::vector<bool> GetMutatingInputVariables( |
977 | const Operator& op) const override { |
978 | std::vector<bool> mutating_input_variables(op.inputs.size(), false); |
979 | // Forward hidden state. |
980 | mutating_input_variables[4] = true; |
981 | // Backward hidden state. |
982 | mutating_input_variables[8] = true; |
983 | return mutating_input_variables; |
984 | } |
985 | }; |
986 | |
987 | class Mean : public BuiltinOperator<MeanOperator, ::tflite::ReducerOptions, |
988 | ::tflite::BuiltinOptions_ReducerOptions> { |
989 | public: |
990 | using BuiltinOperator::BuiltinOperator; |
991 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
992 | const TocoOperator& op, |
993 | flatbuffers::FlatBufferBuilder* builder) const override { |
994 | return ::tflite::CreateReducerOptions(*builder, op.keep_dims); |
995 | } |
996 | |
997 | void ReadOptions(const TfLiteOptions& options, |
998 | TocoOperator* op) const override { |
999 | op->keep_dims = options.keep_dims(); |
1000 | } |
1001 | }; |
1002 | |
1003 | class Sum |
1004 | : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions, |
1005 | ::tflite::BuiltinOptions_ReducerOptions> { |
1006 | public: |
1007 | using BuiltinOperator::BuiltinOperator; |
1008 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1009 | const TocoOperator& op, |
1010 | flatbuffers::FlatBufferBuilder* builder) const override { |
1011 | return ::tflite::CreateReducerOptions(*builder, op.keep_dims); |
1012 | } |
1013 | |
1014 | void ReadOptions(const TfLiteOptions& options, |
1015 | TocoOperator* op) const override { |
1016 | op->keep_dims = options.keep_dims(); |
1017 | } |
1018 | }; |
1019 | |
1020 | class ReduceMax |
1021 | : public BuiltinOperator<TensorFlowMaxOperator, ::tflite::ReducerOptions, |
1022 | ::tflite::BuiltinOptions_ReducerOptions> { |
1023 | public: |
1024 | using BuiltinOperator::BuiltinOperator; |
1025 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1026 | const TocoOperator& op, |
1027 | flatbuffers::FlatBufferBuilder* builder) const override { |
1028 | return ::tflite::CreateReducerOptions(*builder, op.keep_dims); |
1029 | } |
1030 | |
1031 | void ReadOptions(const TfLiteOptions& options, |
1032 | TocoOperator* op) const override { |
1033 | op->keep_dims = options.keep_dims(); |
1034 | } |
1035 | }; |
1036 | |
1037 | class ReduceMin |
1038 | : public BuiltinOperator<TensorFlowMinOperator, ::tflite::ReducerOptions, |
1039 | ::tflite::BuiltinOptions_ReducerOptions> { |
1040 | public: |
1041 | using BuiltinOperator::BuiltinOperator; |
1042 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1043 | const TocoOperator& op, |
1044 | flatbuffers::FlatBufferBuilder* builder) const override { |
1045 | return ::tflite::CreateReducerOptions(*builder, op.keep_dims); |
1046 | } |
1047 | |
1048 | void ReadOptions(const TfLiteOptions& options, |
1049 | TocoOperator* op) const override { |
1050 | op->keep_dims = options.keep_dims(); |
1051 | } |
1052 | }; |
1053 | |
1054 | class ReduceProd |
1055 | : public BuiltinOperator<TensorFlowProdOperator, ::tflite::ReducerOptions, |
1056 | ::tflite::BuiltinOptions_ReducerOptions> { |
1057 | public: |
1058 | using BuiltinOperator::BuiltinOperator; |
1059 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1060 | const TocoOperator& op, |
1061 | flatbuffers::FlatBufferBuilder* builder) const override { |
1062 | return ::tflite::CreateReducerOptions(*builder, op.keep_dims); |
1063 | } |
1064 | |
1065 | void ReadOptions(const TfLiteOptions& options, |
1066 | TocoOperator* op) const override { |
1067 | op->keep_dims = options.keep_dims(); |
1068 | } |
1069 | }; |
1070 | |
1071 | class ReduceAny |
1072 | : public BuiltinOperator<TensorFlowAnyOperator, ::tflite::ReducerOptions, |
1073 | ::tflite::BuiltinOptions_ReducerOptions> { |
1074 | public: |
1075 | using BuiltinOperator::BuiltinOperator; |
1076 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1077 | const TocoOperator& op, |
1078 | flatbuffers::FlatBufferBuilder* builder) const override { |
1079 | return ::tflite::CreateReducerOptions(*builder, op.keep_dims); |
1080 | } |
1081 | |
1082 | void ReadOptions(const TfLiteOptions& options, |
1083 | TocoOperator* op) const override { |
1084 | op->keep_dims = options.keep_dims(); |
1085 | } |
1086 | }; |
1087 | |
1088 | class ResizeBilinear |
1089 | : public BuiltinOperator<ResizeBilinearOperator, |
1090 | ::tflite::ResizeBilinearOptions, |
1091 | ::tflite::BuiltinOptions_ResizeBilinearOptions> { |
1092 | public: |
1093 | using BuiltinOperator::BuiltinOperator; |
1094 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1095 | const TocoOperator& op, |
1096 | flatbuffers::FlatBufferBuilder* builder) const override { |
1097 | return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners, |
1098 | op.half_pixel_centers); |
1099 | } |
1100 | |
1101 | void ReadOptions(const TfLiteOptions& options, |
1102 | TocoOperator* op) const override { |
1103 | op->align_corners = options.align_corners(); |
1104 | op->half_pixel_centers = options.half_pixel_centers(); |
1105 | } |
1106 | |
1107 | int GetVersion(const OperatorSignature& op_signature) const override { |
1108 | const auto& resize_bilinear_op = |
1109 | static_cast<const ResizeBilinearOperator&>(*op_signature.op); |
1110 | ::tflite::OpSignature op_sig = |
1111 | GetVersioningOpSig(builtin_op(), op_signature); |
1112 | TfLiteResizeBilinearParams resize_bilinear_params = {}; |
1113 | resize_bilinear_params.half_pixel_centers = |
1114 | resize_bilinear_op.half_pixel_centers; |
1115 | resize_bilinear_params.align_corners = resize_bilinear_op.align_corners; |
1116 | op_sig.builtin_data = reinterpret_cast<void*>(&resize_bilinear_params); |
1117 | return ::tflite::GetBuiltinOperatorVersion(op_sig); |
1118 | } |
1119 | }; |
1120 | |
1121 | class ResizeNearestNeighbor |
1122 | : public BuiltinOperator< |
1123 | ResizeNearestNeighborOperator, ::tflite::ResizeNearestNeighborOptions, |
1124 | ::tflite::BuiltinOptions_ResizeNearestNeighborOptions> { |
1125 | public: |
1126 | using BuiltinOperator::BuiltinOperator; |
1127 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1128 | const TocoOperator& op, |
1129 | flatbuffers::FlatBufferBuilder* builder) const override { |
1130 | return ::tflite::CreateResizeNearestNeighborOptions( |
1131 | *builder, op.align_corners, op.half_pixel_centers); |
1132 | } |
1133 | |
1134 | void ReadOptions(const TfLiteOptions& options, |
1135 | TocoOperator* op) const override { |
1136 | op->align_corners = options.align_corners(); |
1137 | op->half_pixel_centers = options.half_pixel_centers(); |
1138 | } |
1139 | |
1140 | int GetVersion(const OperatorSignature& op_signature) const override { |
1141 | const auto& resize_nn_op = |
1142 | static_cast<const ResizeNearestNeighborOperator&>(*op_signature.op); |
1143 | ::tflite::OpSignature op_sig = |
1144 | GetVersioningOpSig(builtin_op(), op_signature); |
1145 | TfLiteResizeNearestNeighborParams resize_nearest_neighbor_params = {}; |
1146 | resize_nearest_neighbor_params.half_pixel_centers = |
1147 | resize_nn_op.half_pixel_centers; |
1148 | resize_nearest_neighbor_params.align_corners = resize_nn_op.align_corners; |
1149 | op_sig.builtin_data = |
1150 | reinterpret_cast<void*>(&resize_nearest_neighbor_params); |
1151 | return ::tflite::GetBuiltinOperatorVersion(op_sig); |
1152 | } |
1153 | }; |
1154 | |
1155 | class Squeeze |
1156 | : public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions, |
1157 | ::tflite::BuiltinOptions_SqueezeOptions> { |
1158 | public: |
1159 | using BuiltinOperator::BuiltinOperator; |
1160 | |
1161 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1162 | const TocoOperator& op, |
1163 | flatbuffers::FlatBufferBuilder* builder) const override { |
1164 | auto squeeze_dims = builder->CreateVector(op.squeeze_dims); |
1165 | return ::tflite::CreateSqueezeOptions(*builder, squeeze_dims); |
1166 | } |
1167 | |
1168 | void ReadOptions(const TfLiteOptions& options, |
1169 | TocoOperator* op) const override { |
1170 | op->squeeze_dims.insert(op->squeeze_dims.end(), |
1171 | options.squeeze_dims()->begin(), |
1172 | options.squeeze_dims()->end()); |
1173 | } |
1174 | }; |
1175 | |
1176 | class Split |
1177 | : public BuiltinOperator<TensorFlowSplitOperator, ::tflite::SplitOptions, |
1178 | ::tflite::BuiltinOptions_SplitOptions> { |
1179 | public: |
1180 | using BuiltinOperator::BuiltinOperator; |
1181 | |
1182 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1183 | const TocoOperator& op, |
1184 | flatbuffers::FlatBufferBuilder* builder) const override { |
1185 | return ::tflite::CreateSplitOptions(*builder, op.num_split); |
1186 | } |
1187 | |
1188 | void ReadOptions(const TfLiteOptions& options, |
1189 | TocoOperator* op) const override { |
1190 | op->num_split = options.num_splits(); |
1191 | } |
1192 | }; |
1193 | |
1194 | class SplitV |
1195 | : public BuiltinOperator<TensorFlowSplitVOperator, ::tflite::SplitVOptions, |
1196 | ::tflite::BuiltinOptions_SplitVOptions> { |
1197 | public: |
1198 | using BuiltinOperator::BuiltinOperator; |
1199 | |
1200 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1201 | const TocoOperator& op, |
1202 | flatbuffers::FlatBufferBuilder* builder) const override { |
1203 | return ::tflite::CreateSplitVOptions(*builder, op.num_split); |
1204 | } |
1205 | |
1206 | void ReadOptions(const TfLiteOptions& options, |
1207 | TocoOperator* op) const override { |
1208 | op->num_split = options.num_splits(); |
1209 | } |
1210 | }; |
1211 | |
1212 | class StridedSlice |
1213 | : public BuiltinOperator<StridedSliceOperator, |
1214 | ::tflite::StridedSliceOptions, |
1215 | ::tflite::BuiltinOptions_StridedSliceOptions> { |
1216 | public: |
1217 | using BuiltinOperator::BuiltinOperator; |
1218 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1219 | const TocoOperator& op, |
1220 | flatbuffers::FlatBufferBuilder* builder) const override { |
1221 | return ::tflite::CreateStridedSliceOptions( |
1222 | *builder, op.begin_mask, op.end_mask, op.ellipsis_mask, |
1223 | op.new_axis_mask, op.shrink_axis_mask); |
1224 | } |
1225 | |
1226 | void ReadOptions(const TfLiteOptions& options, |
1227 | TocoOperator* op) const override { |
1228 | op->begin_mask = options.begin_mask(); |
1229 | op->end_mask = options.end_mask(); |
1230 | op->ellipsis_mask = options.ellipsis_mask(); |
1231 | op->new_axis_mask = options.new_axis_mask(); |
1232 | op->shrink_axis_mask = options.shrink_axis_mask(); |
1233 | } |
1234 | |
1235 | int GetVersion(const OperatorSignature& op_signature) const override { |
1236 | const auto& ss_op = |
1237 | static_cast<const StridedSliceOperator&>(*op_signature.op); |
1238 | ::tflite::OpSignature op_sig = |
1239 | GetVersioningOpSig(builtin_op(), op_signature); |
1240 | op_sig.ext_options.strided_slice.num_dims = ss_op.start_indices.size(); |
1241 | TfLiteStridedSliceParams strided_slice_params = {}; |
1242 | strided_slice_params.ellipsis_mask = ss_op.ellipsis_mask; |
1243 | strided_slice_params.new_axis_mask = ss_op.new_axis_mask; |
1244 | op_sig.builtin_data = reinterpret_cast<void*>(&strided_slice_params); |
1245 | return ::tflite::GetBuiltinOperatorVersion(op_sig); |
1246 | } |
1247 | }; |
1248 | |
1249 | class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options, |
1250 | ::tflite::BuiltinOptions_TopKV2Options> { |
1251 | public: |
1252 | using BuiltinOperator::BuiltinOperator; |
1253 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1254 | const TocoOperator& op, |
1255 | flatbuffers::FlatBufferBuilder* builder) const override { |
1256 | return ::tflite::CreateTopKV2Options(*builder); |
1257 | } |
1258 | |
1259 | void ReadOptions(const TfLiteOptions& options, |
1260 | TocoOperator* op) const override {} |
1261 | }; |
1262 | |
1263 | class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions, |
1264 | ::tflite::BuiltinOptions_ArgMaxOptions> { |
1265 | public: |
1266 | using BuiltinOperator::BuiltinOperator; |
1267 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1268 | const TocoOperator& op, |
1269 | flatbuffers::FlatBufferBuilder* builder) const override { |
1270 | return ::tflite::CreateArgMaxOptions( |
1271 | *builder, DataType::Serialize(op.output_data_type)); |
1272 | } |
1273 | |
1274 | void ReadOptions(const TfLiteOptions& options, |
1275 | TocoOperator* op) const override { |
1276 | op->output_data_type = DataType::Deserialize(options.output_type()); |
1277 | } |
1278 | }; |
1279 | |
1280 | class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions, |
1281 | ::tflite::BuiltinOptions_ArgMinOptions> { |
1282 | public: |
1283 | using BuiltinOperator::BuiltinOperator; |
1284 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1285 | const TocoOperator& op, |
1286 | flatbuffers::FlatBufferBuilder* builder) const override { |
1287 | return ::tflite::CreateArgMinOptions( |
1288 | *builder, DataType::Serialize(op.output_data_type)); |
1289 | } |
1290 | |
1291 | void ReadOptions(const TfLiteOptions& options, |
1292 | TocoOperator* op) const override { |
1293 | op->output_data_type = DataType::Deserialize(options.output_type()); |
1294 | } |
1295 | }; |
1296 | |
1297 | class TransposeConv |
1298 | : public BuiltinOperator<TransposeConvOperator, |
1299 | ::tflite::TransposeConvOptions, |
1300 | ::tflite::BuiltinOptions_TransposeConvOptions> { |
1301 | public: |
1302 | using BuiltinOperator::BuiltinOperator; |
1303 | |
1304 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1305 | const TocoOperator& op, |
1306 | flatbuffers::FlatBufferBuilder* builder) const override { |
1307 | auto padding = Padding::Serialize(op.padding.type); |
1308 | return ::tflite::CreateTransposeConvOptions( |
1309 | *builder, padding, op.stride_width, op.stride_height); |
1310 | } |
1311 | |
1312 | void ReadOptions(const TfLiteOptions& options, |
1313 | TocoOperator* op) const override { |
1314 | op->padding.type = Padding::Deserialize(options.padding()); |
1315 | op->stride_width = options.stride_w(); |
1316 | op->stride_height = options.stride_h(); |
1317 | } |
1318 | }; |
1319 | |
1320 | class SparseToDense |
1321 | : public BuiltinOperator<SparseToDenseOperator, |
1322 | ::tflite::SparseToDenseOptions, |
1323 | ::tflite::BuiltinOptions_SparseToDenseOptions> { |
1324 | public: |
1325 | using BuiltinOperator::BuiltinOperator; |
1326 | |
1327 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1328 | const TocoOperator& op, |
1329 | flatbuffers::FlatBufferBuilder* builder) const override { |
1330 | return ::tflite::CreateSparseToDenseOptions(*builder, op.validate_indices); |
1331 | } |
1332 | |
1333 | void ReadOptions(const TfLiteOptions& options, |
1334 | TocoOperator* op) const override { |
1335 | op->validate_indices = options.validate_indices(); |
1336 | } |
1337 | }; |
1338 | |
1339 | class ExpandDims |
1340 | : public BuiltinOperator<ExpandDimsOperator, ::tflite::ExpandDimsOptions, |
1341 | ::tflite::BuiltinOptions_ExpandDimsOptions> { |
1342 | public: |
1343 | using BuiltinOperator::BuiltinOperator; |
1344 | |
1345 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1346 | const TocoOperator& op, |
1347 | flatbuffers::FlatBufferBuilder* builder) const override { |
1348 | return ::tflite::CreateExpandDimsOptions(*builder); |
1349 | } |
1350 | |
1351 | void ReadOptions(const TfLiteOptions& options, |
1352 | TocoOperator* op) const override {} |
1353 | }; |
1354 | |
1355 | class Pack : public BuiltinOperator<PackOperator, ::tflite::PackOptions, |
1356 | ::tflite::BuiltinOptions_PackOptions> { |
1357 | public: |
1358 | using BuiltinOperator::BuiltinOperator; |
1359 | |
1360 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1361 | const TocoOperator& op, |
1362 | flatbuffers::FlatBufferBuilder* builder) const override { |
1363 | return ::tflite::CreatePackOptions(*builder, op.values_count, op.axis); |
1364 | } |
1365 | |
1366 | void ReadOptions(const TfLiteOptions& options, |
1367 | TocoOperator* op) const override { |
1368 | op->values_count = options.values_count(); |
1369 | op->axis = options.axis(); |
1370 | } |
1371 | }; |
1372 | |
1373 | class Shape |
1374 | : public BuiltinOperator<TensorFlowShapeOperator, ::tflite::ShapeOptions, |
1375 | ::tflite::BuiltinOptions_ShapeOptions> { |
1376 | public: |
1377 | using BuiltinOperator::BuiltinOperator; |
1378 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1379 | const TocoOperator& op, |
1380 | flatbuffers::FlatBufferBuilder* builder) const override { |
1381 | return ::tflite::CreateShapeOptions( |
1382 | *builder, DataType::Serialize(op.output_data_type)); |
1383 | } |
1384 | |
1385 | void ReadOptions(const TfLiteOptions& options, |
1386 | TocoOperator* op) const override { |
1387 | op->output_data_type = DataType::Deserialize(options.out_type()); |
1388 | } |
1389 | }; |
1390 | |
1391 | class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions, |
1392 | ::tflite::BuiltinOptions_OneHotOptions> { |
1393 | public: |
1394 | using BuiltinOperator::BuiltinOperator; |
1395 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1396 | const TocoOperator& op, |
1397 | flatbuffers::FlatBufferBuilder* builder) const override { |
1398 | return ::tflite::CreateOneHotOptions(*builder, op.axis); |
1399 | } |
1400 | void ReadOptions(const TfLiteOptions& options, |
1401 | TocoOperator* op) const override { |
1402 | op->axis = options.axis(); |
1403 | } |
1404 | }; |
1405 | |
1406 | class CTCBeamSearchDecoder |
1407 | : public CustomOperator<CTCBeamSearchDecoderOperator> { |
1408 | public: |
1409 | using CustomOperator::CustomOperator; |
1410 | |
1411 | void WriteOptions(const TocoOperator& op, |
1412 | flexbuffers::Builder* fbb) const override { |
1413 | fbb->Int("beam_width" , op.beam_width); |
1414 | fbb->Int("top_paths" , op.top_paths); |
1415 | fbb->Bool("merge_repeated" , op.merge_repeated); |
1416 | } |
1417 | |
1418 | void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { |
1419 | op->beam_width = m["beam_width" ].AsInt32(); |
1420 | op->top_paths = m["top_paths" ].AsInt32(); |
1421 | op->merge_repeated = m["merge_repeated" ].AsBool(); |
1422 | } |
1423 | |
1424 | int GetVersion(const OperatorSignature& op_signature) const override { |
1425 | return 1; |
1426 | } |
1427 | }; |
1428 | |
1429 | class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions, |
1430 | ::tflite::BuiltinOptions_UnpackOptions> { |
1431 | public: |
1432 | using BuiltinOperator::BuiltinOperator; |
1433 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1434 | const TocoOperator& op, |
1435 | flatbuffers::FlatBufferBuilder* builder) const override { |
1436 | return ::tflite::CreateUnpackOptions(*builder, op.num, op.axis); |
1437 | } |
1438 | void ReadOptions(const TfLiteOptions& options, |
1439 | TocoOperator* op) const override { |
1440 | op->num = options.num(); |
1441 | op->axis = options.axis(); |
1442 | } |
1443 | |
1444 | int GetVersion(const OperatorSignature& op_signature) const override { |
1445 | const std::string& input_name = op_signature.op->inputs[0]; |
1446 | const Array& input_array = op_signature.model->GetArray(input_name); |
1447 | // If the op take int8/uint8 input, it is version 2. |
1448 | if (input_array.data_type == ArrayDataType::kInt8 || |
1449 | input_array.data_type == ArrayDataType::kUint8) { |
1450 | return 2; |
1451 | } |
1452 | // If the op take bool input, it is version 3. |
1453 | if (input_array.data_type == ArrayDataType::kBool) { |
1454 | return 3; |
1455 | } |
1456 | return 1; |
1457 | } |
1458 | }; |
1459 | |
1460 | class LeakyRelu |
1461 | : public BuiltinOperator<LeakyReluOperator, ::tflite::LeakyReluOptions, |
1462 | ::tflite::BuiltinOptions_LeakyReluOptions> { |
1463 | public: |
1464 | using BuiltinOperator::BuiltinOperator; |
1465 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1466 | const TocoOperator& op, |
1467 | flatbuffers::FlatBufferBuilder* builder) const override { |
1468 | return ::tflite::CreateLeakyReluOptions(*builder, op.alpha); |
1469 | } |
1470 | void ReadOptions(const TfLiteOptions& options, |
1471 | TocoOperator* op) const override { |
1472 | op->alpha = options.alpha(); |
1473 | } |
1474 | }; |
1475 | |
1476 | class SquaredDifference |
1477 | : public BuiltinOperator< |
1478 | SquaredDifferenceOperator, ::tflite::SquaredDifferenceOptions, |
1479 | ::tflite::BuiltinOptions_SquaredDifferenceOptions> { |
1480 | public: |
1481 | using BuiltinOperator::BuiltinOperator; |
1482 | |
1483 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1484 | const TocoOperator& op, |
1485 | flatbuffers::FlatBufferBuilder* builder) const override { |
1486 | return ::tflite::CreateSquaredDifferenceOptions(*builder); |
1487 | } |
1488 | |
1489 | void ReadOptions(const TfLiteOptions& options, |
1490 | TocoOperator* op) const override {} |
1491 | }; |
1492 | |
1493 | class MirrorPad |
1494 | : public BuiltinOperator<MirrorPadOperator, ::tflite::MirrorPadOptions, |
1495 | ::tflite::BuiltinOptions_MirrorPadOptions> { |
1496 | public: |
1497 | using BuiltinOperator::BuiltinOperator; |
1498 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1499 | const TocoOperator& op, |
1500 | flatbuffers::FlatBufferBuilder* builder) const override { |
1501 | return ::tflite::CreateMirrorPadOptions( |
1502 | *builder, op.mode == MirrorPadMode::kReflect |
1503 | ? ::tflite::MirrorPadMode::MirrorPadMode_REFLECT |
1504 | : ::tflite::MirrorPadMode::MirrorPadMode_SYMMETRIC); |
1505 | } |
1506 | void ReadOptions(const TfLiteOptions& options, |
1507 | TocoOperator* op) const override { |
1508 | op->mode = options.mode() == ::tflite::MirrorPadMode::MirrorPadMode_REFLECT |
1509 | ? MirrorPadMode::kReflect |
1510 | : MirrorPadMode::kSymmetric; |
1511 | } |
1512 | }; |
1513 | |
1514 | class Unique : public BuiltinOperator<UniqueOperator, ::tflite::UniqueOptions, |
1515 | ::tflite::BuiltinOptions_UniqueOptions> { |
1516 | public: |
1517 | using BuiltinOperator::BuiltinOperator; |
1518 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1519 | const TocoOperator& op, |
1520 | flatbuffers::FlatBufferBuilder* builder) const override { |
1521 | const UniqueOperator& unique_op = static_cast<const UniqueOperator&>(op); |
1522 | return ::tflite::CreateUniqueOptions( |
1523 | *builder, unique_op.idx_out_type == toco::ArrayDataType::kInt64 |
1524 | ? ::tflite::TensorType::TensorType_INT64 |
1525 | : ::tflite::TensorType_INT32); |
1526 | } |
1527 | void ReadOptions(const TfLiteOptions& options, |
1528 | TocoOperator* op) const override { |
1529 | UniqueOperator* unique_op = static_cast<UniqueOperator*>(op); |
1530 | unique_op->idx_out_type = |
1531 | options.idx_out_type() == ::tflite::TensorType_INT64 |
1532 | ? toco::ArrayDataType::kInt64 |
1533 | : toco::ArrayDataType::kInt32; |
1534 | } |
1535 | }; |
1536 | |
1537 | class UnidirectionalSequenceRnn |
1538 | : public BuiltinOperator<UnidirectionalSequenceRnnOperator, |
1539 | ::tflite::SequenceRNNOptions, |
1540 | ::tflite::BuiltinOptions_SequenceRNNOptions> { |
1541 | public: |
1542 | using BuiltinOperator::BuiltinOperator; |
1543 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1544 | const TocoOperator& op, |
1545 | flatbuffers::FlatBufferBuilder* builder) const override { |
1546 | return ::tflite::CreateSequenceRNNOptions( |
1547 | *builder, /*time_major=*/true, |
1548 | /*fused_activation_function=*/ |
1549 | ::tflite::ActivationFunctionType_TANH); |
1550 | } |
1551 | void ReadOptions(const TfLiteOptions& options, |
1552 | TocoOperator* op) const override { |
1553 | // Only support tanh activation, so check that tflite type is tanh. |
1554 | DCHECK(options.fused_activation_function() == |
1555 | ::tflite::ActivationFunctionType_TANH); |
1556 | } |
1557 | |
1558 | std::vector<bool> GetMutatingInputVariables( |
1559 | const Operator& op) const override { |
1560 | std::vector<bool> mutating_input_variables(op.inputs.size(), false); |
1561 | mutating_input_variables[4] = true; |
1562 | return mutating_input_variables; |
1563 | } |
1564 | }; |
1565 | |
1566 | class Where : public BuiltinOperator<WhereOperator, ::tflite::WhereOptions, |
1567 | ::tflite::BuiltinOptions_WhereOptions> { |
1568 | public: |
1569 | using BuiltinOperator::BuiltinOperator; |
1570 | |
1571 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1572 | const TocoOperator& op, |
1573 | flatbuffers::FlatBufferBuilder* builder) const override { |
1574 | return ::tflite::CreateWhereOptions(*builder); |
1575 | } |
1576 | |
1577 | void ReadOptions(const TfLiteOptions& options, |
1578 | TocoOperator* op) const override {} |
1579 | }; |
1580 | |
1581 | std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions( |
1582 | const std::string& tensorflow_node_def) { |
1583 | auto fbb = std::make_unique<flexbuffers::Builder>(); |
1584 | |
1585 | ::tensorflow::NodeDef node_def; |
1586 | if (!node_def.ParseFromString(tensorflow_node_def)) { |
1587 | LOG(ERROR) << "Failed to parse TensorFlow NodeDef" ; |
1588 | return {}; |
1589 | } |
1590 | |
1591 | fbb->Vector([&]() { |
1592 | fbb->String(node_def.op()); |
1593 | fbb->String(tensorflow_node_def); |
1594 | }); |
1595 | fbb->Finish(); |
1596 | LOG(INFO) << "Writing flex op: " << node_def.op(); |
1597 | return std::unique_ptr<flexbuffers::Builder>(fbb.release()); |
1598 | } |
1599 | |
1600 | class TensorFlowUnsupported : public BaseOperator { |
1601 | public: |
1602 | TensorFlowUnsupported(const std::string& name, OperatorType type, |
1603 | bool enable_select_tf_ops) |
1604 | : BaseOperator(name, type), enable_select_tf_ops_(enable_select_tf_ops) {} |
1605 | |
1606 | Options Serialize(const Operator& op, |
1607 | flatbuffers::FlatBufferBuilder* builder) const override { |
1608 | auto fbb = |
1609 | WriteOptions(static_cast<const TensorFlowUnsupportedOperator&>(op)); |
1610 | if (fbb) { |
1611 | return Options::Custom(builder->CreateVector(fbb->GetBuffer())); |
1612 | } else { |
1613 | return Options::Custom(0); |
1614 | } |
1615 | } |
1616 | |
1617 | std::unique_ptr<Operator> Deserialize( |
1618 | const BuiltinOptions* builtin_options, |
1619 | const CustomOptions* custom_options) const override { |
1620 | // Deserializing Flex ops doesn't work now. |
1621 | // TODO(ycling): Revisit and decide if we should fix the flow for importing |
1622 | // TFLite models with Flex ops. |
1623 | auto op = std::make_unique<TensorFlowUnsupportedOperator>(); |
1624 | if (custom_options) { |
1625 | auto flexbuffer_map = |
1626 | flexbuffers::GetRoot(custom_options->data(), custom_options->size()) |
1627 | .AsMap(); |
1628 | ReadOptions(flexbuffer_map, op.get()); |
1629 | } |
1630 | return std::unique_ptr<Operator>(op.release()); |
1631 | } |
1632 | |
1633 | std::unique_ptr<flexbuffers::Builder> WriteOptions( |
1634 | const TensorFlowUnsupportedOperator& op) const { |
1635 | if (enable_select_tf_ops_) { |
1636 | return WriteFlexOpOptions(op.tensorflow_node_def); |
1637 | } |
1638 | auto fbb = std::make_unique<flexbuffers::Builder>(); |
1639 | |
1640 | ::tensorflow::NodeDef node_def; |
1641 | if (!node_def.ParseFromString(op.tensorflow_node_def)) { |
1642 | LOG(ERROR) << "Failed to parse TensorFlow NodeDef" ; |
1643 | return std::unique_ptr<flexbuffers::Builder>(); |
1644 | } |
1645 | |
1646 | if (ShouldExportAsFlexOp(enable_select_tf_ops_, node_def.op())) { |
1647 | fbb->Vector([&]() { |
1648 | fbb->String(node_def.op()); |
1649 | fbb->String(op.tensorflow_node_def); |
1650 | }); |
1651 | fbb->Finish(); |
1652 | LOG(INFO) << "Writing flex op: " << node_def.op(); |
1653 | return std::unique_ptr<flexbuffers::Builder>(fbb.release()); |
1654 | } |
1655 | |
1656 | bool has_valid_attr = false; |
1657 | size_t map_start = fbb->StartMap(); |
1658 | for (const auto& pair : node_def.attr()) { |
1659 | const char* key = pair.first.c_str(); |
1660 | const auto& attr = pair.second; |
1661 | switch (attr.value_case()) { |
1662 | case ::tensorflow::AttrValue::kS: |
1663 | fbb->String(key, attr.s()); |
1664 | has_valid_attr = true; |
1665 | break; |
1666 | case ::tensorflow::AttrValue::kI: |
1667 | fbb->Int(key, attr.i()); |
1668 | has_valid_attr = true; |
1669 | break; |
1670 | case ::tensorflow::AttrValue::kF: |
1671 | fbb->Float(key, attr.f()); |
1672 | has_valid_attr = true; |
1673 | break; |
1674 | case ::tensorflow::AttrValue::kB: |
1675 | fbb->Bool(key, attr.b()); |
1676 | has_valid_attr = true; |
1677 | break; |
1678 | case tensorflow::AttrValue::kList: |
1679 | if (attr.list().s_size() > 0) { |
1680 | auto start = fbb->StartVector(key); |
1681 | for (const std::string& v : attr.list().s()) { |
1682 | fbb->Add(v); |
1683 | } |
1684 | fbb->EndVector(start, /*typed=*/true, /*fixed=*/false); |
1685 | has_valid_attr = true; |
1686 | } else if (attr.list().i_size() > 0) { |
1687 | auto start = fbb->StartVector(key); |
1688 | for (const int64_t v : attr.list().i()) { |
1689 | fbb->Add(v); |
1690 | } |
1691 | fbb->EndVector(start, /*typed=*/true, /*fixed=*/false); |
1692 | has_valid_attr = true; |
1693 | } else if (attr.list().f_size() > 0) { |
1694 | auto start = fbb->StartVector(key); |
1695 | for (const float v : attr.list().f()) { |
1696 | fbb->Add(v); |
1697 | } |
1698 | fbb->EndVector(start, /*typed=*/true, /*fixed=*/false); |
1699 | has_valid_attr = true; |
1700 | } else { |
1701 | LOG(WARNING) |
1702 | << "Ignoring unsupported type in list attribute with key '" |
1703 | << key << "'" ; |
1704 | } |
1705 | break; |
1706 | default: |
1707 | LOG(WARNING) << "Ignoring unsupported attribute type with key '" |
1708 | << key << "'" ; |
1709 | break; |
1710 | } |
1711 | } |
1712 | if (!has_valid_attr) { |
1713 | return std::unique_ptr<flexbuffers::Builder>(); |
1714 | } |
1715 | fbb->EndMap(map_start); |
1716 | fbb->Finish(); |
1717 | return std::unique_ptr<flexbuffers::Builder>(fbb.release()); |
1718 | } |
1719 | |
1720 | void ReadOptions(const flexbuffers::Map& m, |
1721 | TensorFlowUnsupportedOperator* op) const { |
1722 | ::tensorflow::NodeDef node_def; |
1723 | auto attr = node_def.mutable_attr(); |
1724 | |
1725 | const auto& keys = m.Keys(); |
1726 | for (size_t i = 0; i < keys.size(); ++i) { |
1727 | const auto key = keys[i].AsKey(); |
1728 | const auto& value = m[key]; |
1729 | switch (value.GetType()) { |
1730 | case flexbuffers::FBT_STRING: |
1731 | (*attr)[key].set_s(value.AsString().c_str()); |
1732 | break; |
1733 | case flexbuffers::FBT_INT: |
1734 | (*attr)[key].set_i(value.AsInt64()); |
1735 | break; |
1736 | case flexbuffers::FBT_FLOAT: |
1737 | (*attr)[key].set_f(value.AsFloat()); |
1738 | break; |
1739 | case flexbuffers::FBT_BOOL: |
1740 | (*attr)[key].set_b(value.AsBool()); |
1741 | if (std::string(key) == "_output_quantized" ) { |
1742 | op->quantized = value.AsBool(); |
1743 | } |
1744 | if (std::string(key) == |
1745 | "_support_output_type_float_in_quantized_op" ) { |
1746 | op->support_output_type_float_in_quantized_op = value.AsBool(); |
1747 | } |
1748 | break; |
1749 | case flexbuffers::FBT_VECTOR_INT: { |
1750 | auto* list = (*attr)[key].mutable_list(); |
1751 | const auto& vector = value.AsTypedVector(); |
1752 | for (size_t i = 0; i < vector.size(); i++) { |
1753 | list->add_i(vector[i].AsInt64()); |
1754 | } |
1755 | break; |
1756 | } |
1757 | case flexbuffers::FBT_VECTOR_FLOAT: { |
1758 | auto* list = (*attr)[key].mutable_list(); |
1759 | const auto& vector = value.AsTypedVector(); |
1760 | for (size_t i = 0; i < vector.size(); i++) { |
1761 | list->add_f(vector[i].AsFloat()); |
1762 | } |
1763 | break; |
1764 | } |
1765 | case 15 /* TO_DO(wvo): flexbuffers::FBT_VECTOR_STRING_DEPRECATED*/: { |
1766 | auto* list = (*attr)[key].mutable_list(); |
1767 | const auto& vector = value.AsTypedVector(); |
1768 | for (size_t i = 0; i < vector.size(); i++) { |
1769 | list->add_s(vector[i].AsString().str()); |
1770 | } |
1771 | break; |
1772 | } |
1773 | default: |
1774 | LOG(WARNING) << "Ignoring unsupported attribute type with key '" |
1775 | << key << "'" ; |
1776 | break; |
1777 | } |
1778 | } |
1779 | node_def.SerializeToString(&op->tensorflow_node_def); |
1780 | } |
1781 | |
1782 | int GetVersion(const OperatorSignature& op_signature) const override { |
1783 | // TODO(ycling): Design and implement a way to plumb the version of |
1784 | // custom ops. |
1785 | return 1; |
1786 | } |
1787 | |
1788 | private: |
1789 | const bool enable_select_tf_ops_; |
1790 | }; |
1791 | |
1792 | class Dequantize |
1793 | : public BuiltinOperator<DequantizeOperator, ::tflite::DequantizeOptions, |
1794 | ::tflite::BuiltinOptions_DequantizeOptions> { |
1795 | public: |
1796 | using BuiltinOperator::BuiltinOperator; |
1797 | |
1798 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1799 | const TocoOperator& op, |
1800 | flatbuffers::FlatBufferBuilder* builder) const override { |
1801 | return ::tflite::CreateDequantizeOptions(*builder); |
1802 | } |
1803 | |
1804 | void ReadOptions(const TfLiteOptions& options, |
1805 | TocoOperator* op) const override {} |
1806 | }; |
1807 | |
1808 | class ReverseSequence |
1809 | : public BuiltinOperator<ReverseSequenceOperator, |
1810 | ::tflite::ReverseSequenceOptions, |
1811 | ::tflite::BuiltinOptions_ReverseSequenceOptions> { |
1812 | public: |
1813 | using BuiltinOperator::BuiltinOperator; |
1814 | |
1815 | flatbuffers::Offset<TfLiteOptions> WriteOptions( |
1816 | const TocoOperator& op, |
1817 | flatbuffers::FlatBufferBuilder* builder) const override { |
1818 | return ::tflite::CreateReverseSequenceOptions(*builder, op.seq_dim, |
1819 | op.batch_dim); |
1820 | } |
1821 | |
1822 | void ReadOptions(const TfLiteOptions& options, |
1823 | TocoOperator* op) const override { |
1824 | op->seq_dim = options.seq_dim(); |
1825 | op->batch_dim = options.batch_dim(); |
1826 | } |
1827 | }; |
1828 | |
1829 | namespace { |
1830 | // Build a vector containing all the known operators. |
1831 | std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList( |
1832 | bool enable_select_tf_ops = false) { |
1833 | std::vector<std::unique_ptr<BaseOperator>> ops; |
1834 | using tensorflow::MakeUnique; |
1835 | // Builtin Operators. |
1836 | ops.push_back( |
1837 | MakeUnique<Add>(::tflite::BuiltinOperator_ADD, OperatorType::kAdd)); |
1838 | ops.push_back( |
1839 | MakeUnique<AddN>(::tflite::BuiltinOperator_ADD_N, OperatorType::kAddN)); |
1840 | ops.push_back( |
1841 | MakeUnique<Div>(::tflite::BuiltinOperator_DIV, OperatorType::kDiv)); |
1842 | ops.push_back( |
1843 | MakeUnique<Sub>(::tflite::BuiltinOperator_SUB, OperatorType::kSub)); |
1844 | ops.push_back(MakeUnique<AveragePool>( |
1845 | ::tflite::BuiltinOperator_AVERAGE_POOL_2D, OperatorType::kAveragePool)); |
1846 | ops.push_back( |
1847 | MakeUnique<SpaceToBatchND>(::tflite::BuiltinOperator_SPACE_TO_BATCH_ND, |
1848 | OperatorType::kSpaceToBatchND)); |
1849 | ops.push_back( |
1850 | MakeUnique<BatchToSpaceND>(::tflite::BuiltinOperator_BATCH_TO_SPACE_ND, |
1851 | OperatorType::kBatchToSpaceND)); |
1852 | ops.push_back(MakeUnique<Concatenation>( |
1853 | ::tflite::BuiltinOperator_CONCATENATION, OperatorType::kConcatenation)); |
1854 | ops.push_back(MakeUnique<Convolution>(::tflite::BuiltinOperator_CONV_2D, |
1855 | OperatorType::kConv)); |
1856 | ops.push_back(MakeUnique<DepthwiseConvolution>( |
1857 | ::tflite::BuiltinOperator_DEPTHWISE_CONV_2D, |
1858 | OperatorType::kDepthwiseConv)); |
1859 | ops.push_back(MakeUnique<Dequantize>(::tflite::BuiltinOperator_DEQUANTIZE, |
1860 | OperatorType::kDequantize)); |
1861 | ops.push_back( |
1862 | MakeUnique<FullyConnected>(::tflite::BuiltinOperator_FULLY_CONNECTED, |
1863 | OperatorType::kFullyConnected)); |
1864 | ops.push_back(MakeUnique<Gather>(::tflite::BuiltinOperator_GATHER, |
1865 | OperatorType::kGather)); |
1866 | ops.push_back(MakeUnique<GatherNd>(::tflite::BuiltinOperator_GATHER_ND, |
1867 | OperatorType::kGatherNd)); |
1868 | ops.push_back( |
1869 | MakeUnique<L2Normalization>(::tflite::BuiltinOperator_L2_NORMALIZATION, |
1870 | OperatorType::kL2Normalization)); |
1871 | ops.push_back(MakeUnique<L2Pool>(::tflite::BuiltinOperator_L2_POOL_2D, |
1872 | OperatorType::kL2Pool)); |
1873 | ops.push_back(MakeUnique<LocalResponseNormalization>( |
1874 | ::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, |
1875 | OperatorType::kLocalResponseNormalization)); |
1876 | ops.push_back(MakeUnique<MaxPool>(::tflite::BuiltinOperator_MAX_POOL_2D, |
1877 | OperatorType::kMaxPool)); |
1878 | ops.push_back( |
1879 | MakeUnique<Mul>(::tflite::BuiltinOperator_MUL, OperatorType::kMul)); |
1880 | |
1881 | ops.push_back( |
1882 | MakeUnique<Pad>(::tflite::BuiltinOperator_PAD, OperatorType::kPad)); |
1883 | ops.push_back( |
1884 | MakeUnique<PadV2>(::tflite::BuiltinOperator_PADV2, OperatorType::kPadV2)); |
1885 | ops.push_back(MakeUnique<Reshape>(::tflite::BuiltinOperator_RESHAPE, |
1886 | OperatorType::kReshape)); |
1887 | ops.push_back(MakeUnique<Softmax>(::tflite::BuiltinOperator_SOFTMAX, |
1888 | OperatorType::kSoftmax)); |
1889 | ops.push_back(MakeUnique<SpaceToDepth>( |
1890 | ::tflite::BuiltinOperator_SPACE_TO_DEPTH, OperatorType::kSpaceToDepth)); |
1891 | ops.push_back(MakeUnique<DepthToSpace>( |
1892 | ::tflite::BuiltinOperator_DEPTH_TO_SPACE, OperatorType::kDepthToSpace)); |
1893 | ops.push_back( |
1894 | MakeUnique<Svdf>(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf)); |
1895 | ops.push_back(MakeUnique<Transpose>(::tflite::BuiltinOperator_TRANSPOSE, |
1896 | OperatorType::kTranspose)); |
1897 | ops.push_back( |
1898 | MakeUnique<Mean>(::tflite::BuiltinOperator_MEAN, OperatorType::kMean)); |
1899 | ops.push_back( |
1900 | MakeUnique<Sum>(::tflite::BuiltinOperator_SUM, OperatorType::kSum)); |
1901 | ops.push_back(MakeUnique<ReduceProd>(::tflite::BuiltinOperator_REDUCE_PROD, |
1902 | OperatorType::kReduceProd)); |
1903 | ops.push_back(MakeUnique<ReduceMax>(::tflite::BuiltinOperator_REDUCE_MAX, |
1904 | OperatorType::kReduceMax)); |
1905 | ops.push_back(MakeUnique<ReduceMin>(::tflite::BuiltinOperator_REDUCE_MIN, |
1906 | OperatorType::kReduceMin)); |
1907 | ops.push_back(MakeUnique<ReduceAny>(::tflite::BuiltinOperator_REDUCE_ANY, |
1908 | OperatorType::kAny)); |
1909 | ops.push_back( |
1910 | MakeUnique<ResizeBilinear>(::tflite::BuiltinOperator_RESIZE_BILINEAR, |
1911 | OperatorType::kResizeBilinear)); |
1912 | ops.push_back(MakeUnique<ResizeNearestNeighbor>( |
1913 | ::tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, |
1914 | OperatorType::kResizeNearestNeighbor)); |
1915 | ops.push_back(MakeUnique<Squeeze>(::tflite::BuiltinOperator_SQUEEZE, |
1916 | OperatorType::kSqueeze)); |
1917 | ops.push_back( |
1918 | MakeUnique<Split>(::tflite::BuiltinOperator_SPLIT, OperatorType::kSplit)); |
1919 | ops.push_back(MakeUnique<SplitV>(::tflite::BuiltinOperator_SPLIT_V, |
1920 | OperatorType::kSplitV)); |
1921 | ops.push_back(MakeUnique<StridedSlice>( |
1922 | ::tflite::BuiltinOperator_STRIDED_SLICE, OperatorType::kStridedSlice)); |
1923 | ops.push_back(MakeUnique<TopK_V2>(::tflite::BuiltinOperator_TOPK_V2, |
1924 | OperatorType::kTopK_V2)); |
1925 | ops.push_back(MakeUnique<Lstm>(::tflite::BuiltinOperator_LSTM, |
1926 | OperatorType::kLstmCell)); |
1927 | ops.push_back( |
1928 | MakeUnique<Cast>(::tflite::BuiltinOperator_CAST, OperatorType::kCast)); |
1929 | ops.push_back(MakeUnique<ArgMax>(::tflite::BuiltinOperator_ARG_MAX, |
1930 | OperatorType::kArgMax)); |
1931 | ops.push_back(MakeUnique<ArgMin>(::tflite::BuiltinOperator_ARG_MIN, |
1932 | OperatorType::kArgMin)); |
1933 | ops.push_back( |
1934 | MakeUnique<Tile>(::tflite::BuiltinOperator_TILE, OperatorType::kTile)); |
1935 | ops.push_back(MakeUnique<ExpandDims>(::tflite::BuiltinOperator_EXPAND_DIMS, |
1936 | OperatorType::kExpandDims)); |
1937 | ops.push_back(MakeUnique<TransposeConv>( |
1938 | ::tflite::BuiltinOperator_TRANSPOSE_CONV, OperatorType::kTransposeConv)); |
1939 | ops.push_back(MakeUnique<SparseToDense>( |
1940 | ::tflite::BuiltinOperator_SPARSE_TO_DENSE, OperatorType::kSparseToDense)); |
1941 | ops.push_back( |
1942 | MakeUnique<Shape>(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape)); |
1943 | ops.push_back(MakeUnique<FakeQuant>(::tflite::BuiltinOperator_FAKE_QUANT, |
1944 | OperatorType::kFakeQuant)); |
1945 | ops.push_back( |
1946 | MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack)); |
1947 | ops.emplace_back(MakeUnique<UnidirectionalSequenceLstm>( |
1948 | ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, |
1949 | OperatorType::kUnidirectionalSequenceLstm)); |
1950 | ops.emplace_back(MakeUnique<BidirectionalSequenceLstm>( |
1951 | ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, |
1952 | OperatorType::kBidirectionalSequenceLstm)); |
1953 | ops.emplace_back(MakeUnique<BidirectionalSequenceRnn>( |
1954 | ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, |
1955 | OperatorType::kBidirectionalSequenceRnn)); |
1956 | ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT, |
1957 | OperatorType::kOneHot)); |
1958 | ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK, |
1959 | OperatorType::kUnpack)); |
1960 | ops.push_back(MakeUnique<LeakyRelu>(::tflite::BuiltinOperator_LEAKY_RELU, |
1961 | OperatorType::kLeakyRelu)); |
1962 | ops.push_back(MakeUnique<SquaredDifference>( |
1963 | ::tflite::BuiltinOperator_SQUARED_DIFFERENCE, |
1964 | OperatorType::kSquaredDifference)); |
1965 | ops.push_back(MakeUnique<MirrorPad>(::tflite::BuiltinOperator_MIRROR_PAD, |
1966 | OperatorType::kMirrorPad)); |
1967 | ops.push_back(MakeUnique<Unique>(::tflite::BuiltinOperator_UNIQUE, |
1968 | OperatorType::kUnique)); |
1969 | ops.push_back(MakeUnique<UnidirectionalSequenceRnn>( |
1970 | ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, |
1971 | OperatorType::kUnidirectionalSequenceRnn)); |
1972 | ops.push_back( |
1973 | MakeUnique<Where>(::tflite::BuiltinOperator_WHERE, OperatorType::kWhere)); |
1974 | ops.push_back( |
1975 | MakeUnique<ReverseSequence>(::tflite::BuiltinOperator_REVERSE_SEQUENCE, |
1976 | OperatorType::kReverseSequence)); |
1977 | ops.push_back(MakeUnique<SimpleOperator<MatrixDiagOperator>>( |
1978 | ::tflite::BuiltinOperator_MATRIX_DIAG, OperatorType::kMatrixDiag)); |
1979 | ops.push_back(MakeUnique<SimpleOperator<MatrixSetDiagOperator>>( |
1980 | ::tflite::BuiltinOperator_MATRIX_SET_DIAG, OperatorType::kMatrixSetDiag)); |
1981 | // Custom Operators. |
1982 | ops.push_back(MakeUnique<CTCBeamSearchDecoder>( |
1983 | "CTC_BEAM_SEARCH_DECODER" , OperatorType::kCTCBeamSearchDecoder)); |
1984 | ops.push_back(MakeUnique<TensorFlowUnsupported>("TENSORFLOW_UNSUPPORTED" , |
1985 | OperatorType::kUnsupported, |
1986 | enable_select_tf_ops)); |
1987 | |
1988 | // SimpleOperator was designed to export CUSTOM TF Lite ops, but has since |
1989 | // been modified to also export builtins. As TOCO evolved we added warnings |
1990 | // when custom ops are exported but SimpleOperator bypasses thoses. To |
1991 | // prevent user confusion we are settling on using SimpleOperator only for |
1992 | // builtins. |
1993 | ops.push_back(MakeUnique<SimpleOperator<FloorOperator>>( |
1994 | ::tflite::BuiltinOperator_FLOOR, OperatorType::kFloor)); |
1995 | ops.push_back(MakeUnique<SimpleOperator<CeilOperator>>( |
1996 | ::tflite::BuiltinOperator_CEIL, OperatorType::kCeil)); |
1997 | ops.push_back(MakeUnique<SimpleOperator<EluOperator>>( |
1998 | ::tflite::BuiltinOperator_ELU, OperatorType::kElu)); |
1999 | ops.push_back(MakeUnique<SimpleOperator<RoundOperator>>( |
2000 | ::tflite::BuiltinOperator_ROUND, OperatorType::kRound)); |
2001 | ops.push_back(MakeUnique<SimpleOperator<ReluOperator>>( |
2002 | ::tflite::BuiltinOperator_RELU, OperatorType::kRelu)); |
2003 | ops.push_back(MakeUnique<SimpleOperator<Relu1Operator>>( |
2004 | ::tflite::BuiltinOperator_RELU_N1_TO_1, OperatorType::kRelu1)); |
2005 | ops.push_back(MakeUnique<SimpleOperator<Relu6Operator>>( |
2006 | ::tflite::BuiltinOperator_RELU6, OperatorType::kRelu6)); |
2007 | ops.push_back(MakeUnique<SimpleOperator<PReluOperator>>( |
2008 | ::tflite::BuiltinOperator_PRELU, OperatorType::kPRelu)); |
2009 | ops.push_back(MakeUnique<SimpleOperator<LogisticOperator>>( |
2010 | ::tflite::BuiltinOperator_LOGISTIC, OperatorType::kLogistic)); |
2011 | ops.push_back(MakeUnique<SimpleOperator<TanhOperator>>( |
2012 | ::tflite::BuiltinOperator_TANH, OperatorType::kTanh)); |
2013 | ops.push_back(MakeUnique<SimpleOperator<ExpOperator>>( |
2014 | ::tflite::BuiltinOperator_EXP, OperatorType::kExp)); |
2015 | ops.push_back(MakeUnique<SimpleOperator<CosOperator>>( |
2016 | ::tflite::BuiltinOperator_COS, OperatorType::kCos)); |
2017 | ops.push_back(MakeUnique<SimpleOperator<LogSoftmaxOperator>>( |
2018 | ::tflite::BuiltinOperator_LOG_SOFTMAX, OperatorType::kLogSoftmax)); |
2019 | ops.push_back(MakeUnique<SimpleOperator<TensorFlowMaximumOperator>>( |
2020 | ::tflite::BuiltinOperator_MAXIMUM, OperatorType::kMaximum)); |
2021 | ops.push_back(MakeUnique<SimpleOperator<TensorFlowMinimumOperator>>( |
2022 | ::tflite::BuiltinOperator_MINIMUM, OperatorType::kMinimum)); |
2023 | ops.push_back(MakeUnique<SimpleOperator<TensorFlowGreaterOperator>>( |
2024 | ::tflite::BuiltinOperator_GREATER, OperatorType::kGreater)); |
2025 | ops.push_back(MakeUnique<SimpleOperator<TensorFlowGreaterEqualOperator>>( |
2026 | ::tflite::BuiltinOperator_GREATER_EQUAL, OperatorType::kGreaterEqual)); |
2027 | ops.push_back(MakeUnique<SimpleOperator<TensorFlowLessOperator>>( |
2028 | ::tflite::BuiltinOperator_LESS, OperatorType::kLess)); |
2029 | ops.push_back(MakeUnique<SimpleOperator<TensorFlowLessEqualOperator>>( |
2030 | ::tflite::BuiltinOperator_LESS_EQUAL, OperatorType::kLessEqual)); |
2031 | ops.push_back(MakeUnique<SimpleOperator<TensorFlowEqualOperator>>( |
2032 | ::tflite::BuiltinOperator_EQUAL, OperatorType::kEqual)); |
2033 | ops.push_back(MakeUnique<SimpleOperator<TensorFlowNotEqualOperator>>( |
2034 | ::tflite::BuiltinOperator_NOT_EQUAL, OperatorType::kNotEqual)); |
2035 | ops.push_back(MakeUnique<SimpleOperator<NegOperator>>( |
2036 | ::tflite::BuiltinOperator_NEG, OperatorType::kNeg)); |
2037 | ops.push_back(MakeUnique<SimpleOperator<SelectOperator>>( |
2038 | ::tflite::BuiltinOperator_SELECT, OperatorType::kSelect)); |
2039 | ops.push_back(MakeUnique<SimpleOperator<SliceOperator>>( |
2040 | ::tflite::BuiltinOperator_SLICE, OperatorType::kSlice)); |
2041 | ops.push_back(MakeUnique<SimpleOperator<PowOperator>>( |
2042 | ::tflite::BuiltinOperator_POW, OperatorType::kPow)); |
2043 | ops.push_back(MakeUnique<SimpleOperator<LogicalOrOperator>>( |
2044 | ::tflite::BuiltinOperator_LOGICAL_OR, OperatorType::kLogicalOr)); |
2045 | ops.emplace_back(new SimpleOperator<LogicalAndOperator>( |
2046 | ::tflite::BuiltinOperator_LOGICAL_AND, OperatorType::kLogicalAnd)); |
2047 | ops.emplace_back(new SimpleOperator<LogicalNotOperator>( |
2048 | ::tflite::BuiltinOperator_LOGICAL_NOT, OperatorType::kLogicalNot)); |
2049 | ops.emplace_back(new SimpleOperator<FloorDivOperator>( |
2050 | ::tflite::BuiltinOperator_FLOOR_DIV, OperatorType::kFloorDiv)); |
2051 | ops.emplace_back(new SimpleOperator<FloorModOperator>( |
2052 | ::tflite::BuiltinOperator_FLOOR_MOD, OperatorType::kFloorMod)); |
2053 | ops.emplace_back(new SimpleOperator<RangeOperator>( |
2054 | ::tflite::BuiltinOperator_RANGE, OperatorType::kRange)); |
2055 | // Element-wise operator |
2056 | ops.push_back(MakeUnique<SimpleOperator<SinOperator>>( |
2057 | ::tflite::BuiltinOperator_SIN, OperatorType::kSin)); |
2058 | ops.push_back(MakeUnique<SimpleOperator<LogOperator>>( |
2059 | ::tflite::BuiltinOperator_LOG, OperatorType::kLog)); |
2060 | ops.push_back(MakeUnique<SimpleOperator<TensorFlowSqrtOperator>>( |
2061 | ::tflite::BuiltinOperator_SQRT, OperatorType::kSqrt)); |
2062 | ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>( |
2063 | ::tflite::BuiltinOperator_RSQRT, OperatorType::kRsqrt)); |
2064 | ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>( |
2065 | ::tflite::BuiltinOperator_SQUARE, OperatorType::kSquare)); |
2066 | ops.push_back(MakeUnique<SimpleOperator<TensorFlowZerosLikeOperator>>( |
2067 | ::tflite::BuiltinOperator_ZEROS_LIKE, OperatorType::kZerosLike)); |
2068 | ops.push_back(MakeUnique<SimpleOperator<AbsOperator>>( |
2069 | ::tflite::BuiltinOperator_ABS, OperatorType::kAbs)); |
2070 | ops.push_back(MakeUnique<SimpleOperator<HardSwishOperator>>( |
2071 | ::tflite::BuiltinOperator_HARD_SWISH, OperatorType::kHardSwish)); |
2072 | ops.push_back(MakeUnique<SimpleOperator<FillOperator>>( |
2073 | ::tflite::BuiltinOperator_FILL, OperatorType::kFill)); |
2074 | ops.push_back(MakeUnique<SimpleOperator<ReverseV2Operator>>( |
2075 | ::tflite::BuiltinOperator_REVERSE_V2, OperatorType::kReverseV2)); |
2076 | ops.push_back(MakeUnique<SimpleOperator<TensorFlowRankOperator>>( |
2077 | ::tflite::BuiltinOperator_RANK, OperatorType::kRank)); |
2078 | ops.emplace_back(new SimpleOperator<SegmentSumOperator>( |
2079 | ::tflite::BuiltinOperator_SEGMENT_SUM, OperatorType::kSegmentSum)); |
2080 | ops.emplace_back(MakeUnique<SimpleOperator<ScatterNdOperator>>( |
2081 | ::tflite::BuiltinOperator_SCATTER_ND, OperatorType::kScatterNd)); |
2082 | return ops; |
2083 | } |
2084 | } // namespace |
2085 | |
2086 | // LINT.ThenChange(//tensorflow/lite/tools/versioning/op_version.cc) |
2087 | |
2088 | std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap( |
2089 | bool enable_select_tf_ops) { |
2090 | std::map<OperatorType, std::unique_ptr<BaseOperator>> result; |
2091 | |
2092 | std::vector<std::unique_ptr<BaseOperator>> ops = |
2093 | BuildOperatorList(enable_select_tf_ops); |
2094 | for (auto& op : ops) { |
2095 | result[op->type()] = std::move(op); |
2096 | } |
2097 | |
2098 | return result; |
2099 | } |
2100 | |
2101 | std::map<std::string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap( |
2102 | bool enable_select_tf_ops) { |
2103 | std::map<std::string, std::unique_ptr<BaseOperator>> result; |
2104 | |
2105 | std::vector<std::unique_ptr<BaseOperator>> ops = |
2106 | BuildOperatorList(enable_select_tf_ops); |
2107 | for (auto& op : ops) { |
2108 | result[op->name()] = std::move(op); |
2109 | } |
2110 | |
2111 | return result; |
2112 | } |
2113 | |
2114 | bool ShouldExportAsFlexOp(bool enable_select_tf_ops, |
2115 | const std::string& tensorflow_op_name) { |
2116 | // If Flex ops aren't allow at all, simply return false. |
2117 | if (!enable_select_tf_ops) { |
2118 | return false; |
2119 | } |
2120 | // Check if we can find the `OpDef` for the TensorFlow op. If we can find |
2121 | // it and it has been allowlisted, export the op as an Flex op. Otherwise, |
2122 | // export it as a regular custom op. |
2123 | const tensorflow::OpDef* op_def = nullptr; |
2124 | if (!tensorflow::OpRegistry::Global() |
2125 | ->LookUpOpDef(tensorflow_op_name, &op_def) |
2126 | .ok()) { |
2127 | return false; |
2128 | } |
2129 | |
2130 | if (!::tflite::flex::IsAllowlistedFlexOp(tensorflow_op_name)) { |
2131 | LOG(WARNING) << "Op " << tensorflow_op_name |
2132 | << " is a valid TensorFlow op but has not been allowlisted for" |
2133 | " the TensorFlow Lite flex op set." ; |
2134 | return false; |
2135 | } |
2136 | |
2137 | return true; |
2138 | } |
2139 | |
2140 | } // namespace tflite |
2141 | |
2142 | } // namespace toco |
2143 | |