1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/toco/tflite/operator.h"
16
17#include <map>
18#include <memory>
19#include <string>
20#include <utility>
21
22#include "tensorflow/core/framework/attr_value.pb.h"
23#include "tensorflow/core/framework/node_def.pb.h"
24#include "tensorflow/core/framework/op.h"
25#include "tensorflow/core/framework/op_def.pb.h"
26#include "tensorflow/core/util/ptr_util.h"
27
28// TODO(ycling): Consider refactoring to extract the LSTM definition out of
29// graph_transformation module.
30#include "tensorflow/lite/builtin_op_data.h"
31#include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
32#include "tensorflow/lite/schema/schema_generated.h"
33#include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
34#include "tensorflow/lite/toco/model.h"
35#include "tensorflow/lite/toco/tflite/builtin_operator.h"
36#include "tensorflow/lite/toco/tflite/custom_operator.h"
37#include "tensorflow/lite/toco/tflite/simple_operator.h"
38#include "tensorflow/lite/toco/tflite/types.h"
39#include "tensorflow/lite/tools/versioning/op_version.h"
40
41namespace toco {
42
43namespace tflite {
44
45// LINT.IfChange
46
47TfLiteType GetTensorType(const ArrayDataType type) {
48 const std::map<ArrayDataType, TfLiteType> tensor_type_map = {
49 {ArrayDataType::kBool, kTfLiteBool},
50 {ArrayDataType::kFloat, kTfLiteFloat32},
51 {ArrayDataType::kInt8, kTfLiteInt8},
52 {ArrayDataType::kUint8, kTfLiteUInt8},
53 {ArrayDataType::kInt16, kTfLiteInt16},
54 {ArrayDataType::kUint16, kTfLiteUInt16},
55 {ArrayDataType::kInt32, kTfLiteInt32},
56 {ArrayDataType::kUint32, kTfLiteUInt32},
57 {ArrayDataType::kInt64, kTfLiteInt64},
58 {ArrayDataType::kUint64, kTfLiteUInt64},
59 {ArrayDataType::kString, kTfLiteString},
60 {ArrayDataType::kComplex64, kTfLiteComplex64},
61 {ArrayDataType::kComplex128, kTfLiteComplex128},
62 {ArrayDataType::kFloat16, kTfLiteFloat16},
63 {ArrayDataType::kFloat64, kTfLiteFloat64}};
64
65 auto it = tensor_type_map.find(type);
66 if (it != tensor_type_map.end()) {
67 return it->second;
68 }
69 return kTfLiteNoType;
70}
71
72::tflite::OpSignature GetVersioningOpSig(
73 const ::tflite::BuiltinOperator op, const OperatorSignature& op_signature) {
74 std::vector<::tflite::OpSignatureTensorSpec> inputs, outputs;
75 for (const auto& input_name : op_signature.op->inputs) {
76 ::tflite::OpSignatureTensorSpec tensor = {kTfLiteNoType};
77 if (op_signature.model->HasArray(input_name)) {
78 const Array& input_array = op_signature.model->GetArray(input_name);
79 tensor.type = GetTensorType(input_array.data_type);
80 if (input_array.has_shape()) {
81 tensor.dims = input_array.shape().dims();
82 }
83 }
84 inputs.push_back(tensor);
85 }
86 for (const auto& output_name : op_signature.op->outputs) {
87 ::tflite::OpSignatureTensorSpec tensor = {kTfLiteNoType};
88 if (op_signature.model->HasArray(output_name)) {
89 const Array& output_array = op_signature.model->GetArray(output_name);
90 tensor.type = GetTensorType(output_array.data_type);
91 if (output_array.has_shape()) {
92 tensor.dims = output_array.shape().dims();
93 }
94 }
95 outputs.push_back(tensor);
96 }
97 return ::tflite::OpSignature{op, inputs, outputs};
98}
99
100class AveragePool
101 : public BuiltinOperator<AveragePoolOperator, ::tflite::Pool2DOptions,
102 ::tflite::BuiltinOptions_Pool2DOptions> {
103 public:
104 using BuiltinOperator::BuiltinOperator;
105
106 flatbuffers::Offset<TfLiteOptions> WriteOptions(
107 const TocoOperator& op,
108 flatbuffers::FlatBufferBuilder* builder) const override {
109 auto padding = Padding::Serialize(op.padding.type);
110 auto activation_function =
111 ActivationFunction::Serialize(op.fused_activation_function);
112 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
113 op.stride_height, op.kwidth,
114 op.kheight, activation_function);
115 }
116
117 void ReadOptions(const TfLiteOptions& options,
118 TocoOperator* op) const override {
119 op->padding.type = Padding::Deserialize(options.padding());
120 op->stride_width = options.stride_w();
121 op->stride_height = options.stride_h();
122 op->kwidth = options.filter_width();
123 op->kheight = options.filter_height();
124 op->fused_activation_function =
125 ActivationFunction::Deserialize(options.fused_activation_function());
126 }
127};
128
129class Convolution
130 : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
131 ::tflite::BuiltinOptions_Conv2DOptions> {
132 public:
133 using BuiltinOperator::BuiltinOperator;
134
135 flatbuffers::Offset<TfLiteOptions> WriteOptions(
136 const TocoOperator& op,
137 flatbuffers::FlatBufferBuilder* builder) const override {
138 auto padding = Padding::Serialize(op.padding.type);
139 auto activation_function =
140 ActivationFunction::Serialize(op.fused_activation_function);
141 return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
142 op.stride_height, activation_function,
143 op.dilation_width_factor,
144 op.dilation_height_factor);
145 }
146
147 void ReadOptions(const TfLiteOptions& options,
148 TocoOperator* op) const override {
149 op->padding.type = Padding::Deserialize(options.padding());
150 op->stride_width = options.stride_w();
151 op->stride_height = options.stride_h();
152 op->dilation_width_factor = options.dilation_w_factor();
153 op->dilation_height_factor = options.dilation_h_factor();
154 op->fused_activation_function =
155 ActivationFunction::Deserialize(options.fused_activation_function());
156 }
157};
158
159class DepthwiseConvolution
160 : public BuiltinOperator<DepthwiseConvOperator,
161 ::tflite::DepthwiseConv2DOptions,
162 ::tflite::BuiltinOptions_DepthwiseConv2DOptions> {
163 public:
164 using BuiltinOperator::BuiltinOperator;
165
166 flatbuffers::Offset<TfLiteOptions> WriteOptions(
167 const TocoOperator& op,
168 flatbuffers::FlatBufferBuilder* builder) const override {
169 auto padding = Padding::Serialize(op.padding.type);
170 auto activation_function =
171 ActivationFunction::Serialize(op.fused_activation_function);
172 return ::tflite::CreateDepthwiseConv2DOptions(
173 *builder, padding, op.stride_width, op.stride_height,
174 op.depth_multiplier, activation_function, op.dilation_width_factor,
175 op.dilation_height_factor);
176 }
177
178 void ReadOptions(const TfLiteOptions& options,
179 TocoOperator* op) const override {
180 op->padding.type = Padding::Deserialize(options.padding());
181 op->stride_width = options.stride_w();
182 op->stride_height = options.stride_h();
183 op->depth_multiplier = options.depth_multiplier();
184 op->fused_activation_function =
185 ActivationFunction::Deserialize(options.fused_activation_function());
186 op->dilation_width_factor = options.dilation_w_factor();
187 op->dilation_height_factor = options.dilation_h_factor();
188 }
189
190 int GetVersion(const OperatorSignature& op_signature) const override {
191 const auto& conv_op =
192 static_cast<const DepthwiseConvOperator&>(*op_signature.op);
193 ::tflite::OpSignature op_sig =
194 GetVersioningOpSig(builtin_op(), op_signature);
195 TfLiteDepthwiseConvParams depthwise_conv_params = {};
196 depthwise_conv_params.dilation_width_factor = conv_op.dilation_width_factor;
197 depthwise_conv_params.dilation_height_factor =
198 conv_op.dilation_height_factor;
199 op_sig.builtin_data = reinterpret_cast<void*>(&depthwise_conv_params);
200 return ::tflite::GetBuiltinOperatorVersion(op_sig);
201 }
202};
203
204class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
205 ::tflite::BuiltinOptions_AddOptions> {
206 public:
207 using BuiltinOperator::BuiltinOperator;
208
209 flatbuffers::Offset<TfLiteOptions> WriteOptions(
210 const TocoOperator& op,
211 flatbuffers::FlatBufferBuilder* builder) const override {
212 auto activation_function =
213 ActivationFunction::Serialize(op.fused_activation_function);
214 return ::tflite::CreateAddOptions(*builder, activation_function);
215 }
216
217 void ReadOptions(const TfLiteOptions& options,
218 TocoOperator* op) const override {
219 op->fused_activation_function =
220 ActivationFunction::Deserialize(options.fused_activation_function());
221 }
222};
223
224class AddN : public BuiltinOperator<AddNOperator, ::tflite::AddNOptions,
225 ::tflite::BuiltinOptions_AddNOptions> {
226 public:
227 using BuiltinOperator::BuiltinOperator;
228
229 flatbuffers::Offset<TfLiteOptions> WriteOptions(
230 const TocoOperator& op,
231 flatbuffers::FlatBufferBuilder* builder) const override {
232 return ::tflite::CreateAddNOptions(*builder);
233 }
234
235 void ReadOptions(const TfLiteOptions& options,
236 TocoOperator* op) const override {}
237};
238
239class SpaceToBatchND
240 : public BuiltinOperator<SpaceToBatchNDOperator,
241 ::tflite::SpaceToBatchNDOptions,
242 ::tflite::BuiltinOptions_SpaceToBatchNDOptions> {
243 public:
244 using BuiltinOperator::BuiltinOperator;
245
246 flatbuffers::Offset<TfLiteOptions> WriteOptions(
247 const TocoOperator& op,
248 flatbuffers::FlatBufferBuilder* builder) const override {
249 return ::tflite::CreateSpaceToBatchNDOptions(*builder);
250 }
251
252 void ReadOptions(const TfLiteOptions& options,
253 TocoOperator* op) const override {}
254
255 int GetVersion(const OperatorSignature& op_signature) const override {
256 ::tflite::OpSignature op_sig =
257 GetVersioningOpSig(builtin_op(), op_signature);
258 return ::tflite::GetBuiltinOperatorVersion(op_sig);
259 }
260};
261
262class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
263 ::tflite::BuiltinOptions_SubOptions> {
264 public:
265 using BuiltinOperator::BuiltinOperator;
266
267 flatbuffers::Offset<TfLiteOptions> WriteOptions(
268 const TocoOperator& op,
269 flatbuffers::FlatBufferBuilder* builder) const override {
270 auto activation_function =
271 ActivationFunction::Serialize(op.fused_activation_function);
272 return ::tflite::CreateSubOptions(*builder, activation_function);
273 }
274
275 void ReadOptions(const TfLiteOptions& options,
276 TocoOperator* op) const override {
277 op->fused_activation_function =
278 ActivationFunction::Deserialize(options.fused_activation_function());
279 }
280
281 int GetVersion(const OperatorSignature& op_signature) const override {
282 ::tflite::OpSignature op_sig =
283 GetVersioningOpSig(builtin_op(), op_signature);
284 return ::tflite::GetBuiltinOperatorVersion(op_sig);
285 }
286};
287
288class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
289 ::tflite::BuiltinOptions_DivOptions> {
290 public:
291 using BuiltinOperator::BuiltinOperator;
292
293 flatbuffers::Offset<TfLiteOptions> WriteOptions(
294 const TocoOperator& op,
295 flatbuffers::FlatBufferBuilder* builder) const override {
296 auto activation_function =
297 ActivationFunction::Serialize(op.fused_activation_function);
298 return ::tflite::CreateDivOptions(*builder, activation_function);
299 }
300
301 void ReadOptions(const TfLiteOptions& options,
302 TocoOperator* op) const override {
303 op->fused_activation_function =
304 ActivationFunction::Deserialize(options.fused_activation_function());
305 }
306
307 int GetVersion(const OperatorSignature& op_signature) const override {
308 ::tflite::OpSignature op_sig =
309 GetVersioningOpSig(builtin_op(), op_signature);
310 return ::tflite::GetBuiltinOperatorVersion(op_sig);
311 }
312};
313
314class BatchToSpaceND
315 : public BuiltinOperator<BatchToSpaceNDOperator,
316 ::tflite::BatchToSpaceNDOptions,
317 ::tflite::BuiltinOptions_BatchToSpaceNDOptions> {
318 public:
319 using BuiltinOperator::BuiltinOperator;
320
321 flatbuffers::Offset<TfLiteOptions> WriteOptions(
322 const TocoOperator& op,
323 flatbuffers::FlatBufferBuilder* builder) const override {
324 return ::tflite::CreateBatchToSpaceNDOptions(*builder);
325 }
326
327 void ReadOptions(const TfLiteOptions& options,
328 TocoOperator* op) const override {}
329
330 int GetVersion(const OperatorSignature& op_signature) const override {
331 ::tflite::OpSignature op_sig =
332 GetVersioningOpSig(builtin_op(), op_signature);
333 return ::tflite::GetBuiltinOperatorVersion(op_sig);
334 }
335};
336
337class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions,
338 ::tflite::BuiltinOptions_CastOptions> {
339 public:
340 using BuiltinOperator::BuiltinOperator;
341 flatbuffers::Offset<TfLiteOptions> WriteOptions(
342 const TocoOperator& op,
343 flatbuffers::FlatBufferBuilder* builder) const override {
344 return ::tflite::CreateCastOptions(*builder,
345 DataType::Serialize(op.src_data_type),
346 DataType::Serialize(op.dst_data_type));
347 }
348
349 void ReadOptions(const TfLiteOptions& options,
350 TocoOperator* op) const override {
351 op->src_data_type = DataType::Deserialize(options.in_data_type());
352 op->dst_data_type = DataType::Deserialize(options.out_data_type());
353 }
354};
355
356class Concatenation
357 : public BuiltinOperator<ConcatenationOperator,
358 ::tflite::ConcatenationOptions,
359 ::tflite::BuiltinOptions_ConcatenationOptions> {
360 public:
361 using BuiltinOperator::BuiltinOperator;
362 flatbuffers::Offset<TfLiteOptions> WriteOptions(
363 const TocoOperator& op,
364 flatbuffers::FlatBufferBuilder* builder) const override {
365 return ::tflite::CreateConcatenationOptions(*builder, op.axis);
366 }
367
368 void ReadOptions(const TfLiteOptions& options,
369 TocoOperator* op) const override {
370 op->axis = options.axis();
371 }
372};
373
374class DepthToSpace
375 : public BuiltinOperator<DepthToSpaceOperator,
376 ::tflite::DepthToSpaceOptions,
377 ::tflite::BuiltinOptions_DepthToSpaceOptions> {
378 public:
379 using BuiltinOperator::BuiltinOperator;
380 flatbuffers::Offset<TfLiteOptions> WriteOptions(
381 const TocoOperator& op,
382 flatbuffers::FlatBufferBuilder* builder) const override {
383 return ::tflite::CreateDepthToSpaceOptions(*builder, op.block_size);
384 }
385
386 void ReadOptions(const TfLiteOptions& options,
387 TocoOperator* op) const override {
388 op->block_size = options.block_size();
389 }
390};
391
392class FakeQuant
393 : public BuiltinOperator<FakeQuantOperator, ::tflite::FakeQuantOptions,
394 ::tflite::BuiltinOptions_FakeQuantOptions> {
395 public:
396 using BuiltinOperator::BuiltinOperator;
397 flatbuffers::Offset<TfLiteOptions> WriteOptions(
398 const TocoOperator& op,
399 flatbuffers::FlatBufferBuilder* builder) const override {
400 return ::tflite::CreateFakeQuantOptions(
401 *builder, op.minmax->min, op.minmax->max, op.num_bits, op.narrow_range);
402 }
403 void ReadOptions(const TfLiteOptions& options,
404 TocoOperator* op) const override {
405 auto* minmax = new MinMax;
406 minmax->min = options.min();
407 minmax->max = options.max();
408 op->minmax.reset(minmax);
409 op->num_bits = options.num_bits();
410 op->narrow_range = options.narrow_range();
411 }
412 int GetVersion(const OperatorSignature& op_signature) const override {
413 const auto& fq_op = static_cast<const FakeQuantOperator&>(*op_signature.op);
414 ::tflite::OpSignature op_sig =
415 GetVersioningOpSig(builtin_op(), op_signature);
416 TfLiteFakeQuantParams fake_quant_params = {};
417 fake_quant_params.narrow_range = fq_op.narrow_range;
418 op_sig.builtin_data = reinterpret_cast<void*>(&fake_quant_params);
419 return ::tflite::GetBuiltinOperatorVersion(op_sig);
420 }
421};
422
423class FullyConnected
424 : public BuiltinOperator<FullyConnectedOperator,
425 ::tflite::FullyConnectedOptions,
426 ::tflite::BuiltinOptions_FullyConnectedOptions> {
427 public:
428 using BuiltinOperator::BuiltinOperator;
429
430 ::tflite::FullyConnectedOptionsWeightsFormat GetWeightFormat(
431 FullyConnectedWeightsFormat fmt) const {
432 switch (fmt) {
433 case FullyConnectedWeightsFormat::kDefault:
434 return ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
435 case FullyConnectedWeightsFormat::kShuffled4x16Int8:
436 return ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
437 default:
438 LOG(ERROR) << "Unhandled FC weights format";
439 return ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
440 }
441 }
442
443 flatbuffers::Offset<TfLiteOptions> WriteOptions(
444 const TocoOperator& op,
445 flatbuffers::FlatBufferBuilder* builder) const override {
446 auto activation_function =
447 ActivationFunction::Serialize(op.fused_activation_function);
448 return ::tflite::CreateFullyConnectedOptions(
449 *builder, activation_function, GetWeightFormat(op.weights_format));
450 }
451
452 void ReadOptions(const TfLiteOptions& options,
453 TocoOperator* op) const override {
454 op->fused_activation_function =
455 ActivationFunction::Deserialize(options.fused_activation_function());
456 switch (options.weights_format()) {
457 case ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT:
458 op->weights_format = FullyConnectedWeightsFormat::kDefault;
459 break;
460 case ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
461 op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8;
462 break;
463 default:
464 LOG(ERROR) << "Unhandled FC weights format";
465 op->weights_format = FullyConnectedWeightsFormat::kDefault;
466 }
467 }
468
469 int GetVersion(const OperatorSignature& op_signature) const override {
470 const auto& fc_op =
471 static_cast<const FullyConnectedOperator&>(*op_signature.op);
472 ::tflite::OpSignature op_sig =
473 GetVersioningOpSig(builtin_op(), op_signature);
474 TfLiteFullyConnectedParams fully_connected_params = {};
475 fully_connected_params.keep_num_dims = fc_op.keep_num_dims;
476 fully_connected_params.weights_format =
477 static_cast<TfLiteFullyConnectedWeightsFormat>(
478 GetWeightFormat(fc_op.weights_format));
479 op_sig.builtin_data = reinterpret_cast<void*>(&fully_connected_params);
480 return ::tflite::GetBuiltinOperatorVersion(op_sig);
481 }
482};
483
484class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
485 ::tflite::BuiltinOptions_GatherOptions> {
486 public:
487 using BuiltinOperator::BuiltinOperator;
488 flatbuffers::Offset<TfLiteOptions> WriteOptions(
489 const TocoOperator& op,
490 flatbuffers::FlatBufferBuilder* builder) const override {
491 int axis = op.axis ? op.axis.value() : 0;
492 return ::tflite::CreateGatherOptions(*builder, axis);
493 }
494
495 void ReadOptions(const TfLiteOptions& options,
496 TocoOperator* op) const override {
497 op->axis = {options.axis()};
498 }
499};
500
501class GatherNd
502 : public BuiltinOperator<GatherNdOperator, ::tflite::GatherNdOptions,
503 ::tflite::BuiltinOptions_GatherNdOptions> {
504 public:
505 using BuiltinOperator::BuiltinOperator;
506
507 flatbuffers::Offset<TfLiteOptions> WriteOptions(
508 const TocoOperator& op,
509 flatbuffers::FlatBufferBuilder* builder) const override {
510 return ::tflite::CreateGatherNdOptions(*builder);
511 }
512
513 void ReadOptions(const TfLiteOptions& options,
514 TocoOperator* op) const override {}
515};
516
517class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
518 ::tflite::BuiltinOptions_SVDFOptions> {
519 public:
520 using BuiltinOperator::BuiltinOperator;
521 flatbuffers::Offset<TfLiteOptions> WriteOptions(
522 const TocoOperator& op,
523 flatbuffers::FlatBufferBuilder* builder) const override {
524 auto activation_function =
525 ActivationFunction::Serialize(op.fused_activation_function);
526 return ::tflite::CreateSVDFOptions(*builder, op.rank, activation_function);
527 }
528
529 void ReadOptions(const TfLiteOptions& options,
530 TocoOperator* op) const override {
531 op->fused_activation_function =
532 ActivationFunction::Deserialize(options.fused_activation_function());
533 op->rank = options.rank();
534 }
535};
536
537class L2Normalization
538 : public BuiltinOperator<L2NormalizationOperator, ::tflite::L2NormOptions,
539 ::tflite::BuiltinOptions_L2NormOptions> {
540 public:
541 using BuiltinOperator::BuiltinOperator;
542 flatbuffers::Offset<TfLiteOptions> WriteOptions(
543 const TocoOperator& op,
544 flatbuffers::FlatBufferBuilder* builder) const override {
545 auto activation_function =
546 ActivationFunction::Serialize(op.fused_activation_function);
547 return ::tflite::CreateL2NormOptions(*builder, activation_function);
548 }
549
550 void ReadOptions(const TfLiteOptions& options,
551 TocoOperator* op) const override {
552 op->fused_activation_function =
553 ActivationFunction::Deserialize(options.fused_activation_function());
554 }
555};
556
557class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
558 ::tflite::BuiltinOptions_Pool2DOptions> {
559 public:
560 using BuiltinOperator::BuiltinOperator;
561 flatbuffers::Offset<TfLiteOptions> WriteOptions(
562 const TocoOperator& op,
563 flatbuffers::FlatBufferBuilder* builder) const override {
564 auto padding = Padding::Serialize(op.padding.type);
565 auto activation_function =
566 ActivationFunction::Serialize(op.fused_activation_function);
567 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
568 op.stride_height, op.kwidth,
569 op.kheight, activation_function);
570 }
571
572 void ReadOptions(const TfLiteOptions& options,
573 TocoOperator* op) const override {
574 op->padding.type = Padding::Deserialize(options.padding());
575 op->stride_width = options.stride_w();
576 op->stride_height = options.stride_h();
577 op->kwidth = options.filter_width();
578 op->kheight = options.filter_height();
579 op->fused_activation_function =
580 ActivationFunction::Deserialize(options.fused_activation_function());
581 }
582};
583
584class LocalResponseNormalization
585 : public BuiltinOperator<
586 LocalResponseNormalizationOperator,
587 ::tflite::LocalResponseNormalizationOptions,
588 ::tflite::BuiltinOptions_LocalResponseNormalizationOptions> {
589 public:
590 using BuiltinOperator::BuiltinOperator;
591 flatbuffers::Offset<TfLiteOptions> WriteOptions(
592 const TocoOperator& op,
593 flatbuffers::FlatBufferBuilder* builder) const override {
594 return ::tflite::CreateLocalResponseNormalizationOptions(
595 *builder, op.range, op.bias, op.alpha, op.beta);
596 }
597
598 void ReadOptions(const TfLiteOptions& options,
599 TocoOperator* op) const override {
600 op->range = options.radius();
601 op->bias = options.bias();
602 op->alpha = options.alpha();
603 op->beta = options.beta();
604 }
605};
606
607class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
608 ::tflite::BuiltinOptions_Pool2DOptions> {
609 public:
610 using BuiltinOperator::BuiltinOperator;
611 flatbuffers::Offset<TfLiteOptions> WriteOptions(
612 const TocoOperator& op,
613 flatbuffers::FlatBufferBuilder* builder) const override {
614 auto padding = Padding::Serialize(op.padding.type);
615 auto activation_function =
616 ActivationFunction::Serialize(op.fused_activation_function);
617 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
618 op.stride_height, op.kwidth,
619 op.kheight, activation_function);
620 }
621
622 void ReadOptions(const TfLiteOptions& options,
623 TocoOperator* op) const override {
624 op->padding.type = Padding::Deserialize(options.padding());
625 op->stride_width = options.stride_w();
626 op->stride_height = options.stride_h();
627 op->kwidth = options.filter_width();
628 op->kheight = options.filter_height();
629 op->fused_activation_function =
630 ActivationFunction::Deserialize(options.fused_activation_function());
631 }
632};
633
634class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
635 ::tflite::BuiltinOptions_MulOptions> {
636 public:
637 using BuiltinOperator::BuiltinOperator;
638
639 flatbuffers::Offset<TfLiteOptions> WriteOptions(
640 const TocoOperator& op,
641 flatbuffers::FlatBufferBuilder* builder) const override {
642 auto activation_function =
643 ActivationFunction::Serialize(op.fused_activation_function);
644 return ::tflite::CreateMulOptions(*builder, activation_function);
645 }
646
647 void ReadOptions(const TfLiteOptions& options,
648 TocoOperator* op) const override {
649 op->fused_activation_function =
650 ActivationFunction::Deserialize(options.fused_activation_function());
651 }
652
653 int GetVersion(const OperatorSignature& op_signature) const override {
654 const std::string& input1_name = op_signature.op->inputs[0];
655 const std::string& input2_name = op_signature.op->inputs[1];
656 const std::string& output_name = op_signature.op->outputs[0];
657 const Array& input1_array = op_signature.model->GetArray(input1_name);
658 const Array& input2_array = op_signature.model->GetArray(input2_name);
659 const Array& output_array = op_signature.model->GetArray(output_name);
660 const auto& input1_quant = input1_array.quantization_params;
661 const auto& input2_quant = input2_array.quantization_params;
662 const auto& output_quant = output_array.quantization_params;
663 const float input1_scale = input1_quant ? input1_quant->scale : 0.0f;
664 const float input2_scale = input2_quant ? input2_quant->scale : 0.0f;
665 const float output_scale = output_quant ? output_quant->scale : 0.0f;
666 ::tflite::OpSignature op_sig =
667 GetVersioningOpSig(builtin_op(), op_signature);
668 op_sig.ext_options.mul.input1_scale = input1_scale;
669 op_sig.ext_options.mul.input2_scale = input2_scale;
670 op_sig.ext_options.mul.output_scale = output_scale;
671 return ::tflite::GetBuiltinOperatorVersion(op_sig);
672 }
673};
674
675class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
676 ::tflite::BuiltinOptions_PadOptions> {
677 public:
678 using BuiltinOperator::BuiltinOperator;
679
680 flatbuffers::Offset<TfLiteOptions> WriteOptions(
681 const TocoOperator& op,
682 flatbuffers::FlatBufferBuilder* builder) const override {
683 return ::tflite::CreatePadOptions(*builder);
684 }
685
686 void ReadOptions(const TfLiteOptions& options,
687 TocoOperator* op) const override {}
688};
689
690class Tile
691 : public BuiltinOperator<TensorFlowTileOperator, ::tflite::TileOptions,
692 ::tflite::BuiltinOptions_TileOptions> {
693 using BuiltinOperator::BuiltinOperator;
694
695 flatbuffers::Offset<TfLiteOptions> WriteOptions(
696 const TocoOperator& op,
697 flatbuffers::FlatBufferBuilder* builder) const override {
698 return ::tflite::CreateTileOptions(*builder);
699 }
700
701 void ReadOptions(const TfLiteOptions& options,
702 TocoOperator* op) const override {}
703};
704
705class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options,
706 ::tflite::BuiltinOptions_PadV2Options> {
707 public:
708 using BuiltinOperator::BuiltinOperator;
709
710 flatbuffers::Offset<TfLiteOptions> WriteOptions(
711 const TocoOperator& op,
712 flatbuffers::FlatBufferBuilder* builder) const override {
713 return ::tflite::CreatePadV2Options(*builder);
714 }
715
716 void ReadOptions(const TfLiteOptions& options,
717 TocoOperator* op) const override {}
718};
719
720class Reshape
721 : public BuiltinOperator<TensorFlowReshapeOperator,
722 ::tflite::ReshapeOptions,
723 ::tflite::BuiltinOptions_ReshapeOptions> {
724 public:
725 using BuiltinOperator::BuiltinOperator;
726
727 flatbuffers::Offset<TfLiteOptions> WriteOptions(
728 const TocoOperator& op,
729 flatbuffers::FlatBufferBuilder* builder) const override {
730 return ::tflite::CreateReshapeOptions(*builder,
731 builder->CreateVector(op.shape));
732 }
733
734 void ReadOptions(const TfLiteOptions& options,
735 TocoOperator* op) const override {
736 op->shape.insert(op->shape.end(), options.new_shape()->begin(),
737 options.new_shape()->end());
738 }
739};
740
741class Softmax
742 : public BuiltinOperator<SoftmaxOperator, ::tflite::SoftmaxOptions,
743 ::tflite::BuiltinOptions_SoftmaxOptions> {
744 public:
745 using BuiltinOperator::BuiltinOperator;
746 flatbuffers::Offset<TfLiteOptions> WriteOptions(
747 const TocoOperator& op,
748 flatbuffers::FlatBufferBuilder* builder) const override {
749 return ::tflite::CreateSoftmaxOptions(*builder, op.beta);
750 }
751
752 void ReadOptions(const TfLiteOptions& options,
753 TocoOperator* op) const override {
754 op->beta = options.beta();
755 }
756};
757
758class SpaceToDepth
759 : public BuiltinOperator<SpaceToDepthOperator,
760 ::tflite::SpaceToDepthOptions,
761 ::tflite::BuiltinOptions_SpaceToDepthOptions> {
762 public:
763 using BuiltinOperator::BuiltinOperator;
764 flatbuffers::Offset<TfLiteOptions> WriteOptions(
765 const TocoOperator& op,
766 flatbuffers::FlatBufferBuilder* builder) const override {
767 return ::tflite::CreateSpaceToDepthOptions(*builder, op.block_size);
768 }
769
770 void ReadOptions(const TfLiteOptions& options,
771 TocoOperator* op) const override {
772 op->block_size = options.block_size();
773 }
774};
775
776class Transpose
777 : public BuiltinOperator<TransposeOperator, ::tflite::TransposeOptions,
778 ::tflite::BuiltinOptions_TransposeOptions> {
779 public:
780 using BuiltinOperator::BuiltinOperator;
781 flatbuffers::Offset<TfLiteOptions> WriteOptions(
782 const TocoOperator& op,
783 flatbuffers::FlatBufferBuilder* builder) const override {
784 return ::tflite::CreateTransposeOptions(*builder);
785 }
786
787 void ReadOptions(const TfLiteOptions& options,
788 TocoOperator* op) const override {}
789};
790
791class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
792 ::tflite::BuiltinOptions_LSTMOptions> {
793 public:
794 using BuiltinOperator::BuiltinOperator;
795
796 ::tflite::LSTMKernelType GetKernelType(
797 LstmCellOperator::KernelType type) const {
798 switch (type) {
799 case LstmCellOperator::KERNEL_BASIC:
800 return ::tflite::LSTMKernelType_BASIC;
801 break;
802 case LstmCellOperator::KERNEL_FULL:
803 return ::tflite::LSTMKernelType_FULL;
804 break;
805 default:
806 LOG(ERROR) << "Unhandled Kernel Type";
807 return static_cast<::tflite::LSTMKernelType>(-1);
808 }
809 }
810
811 flatbuffers::Offset<TfLiteOptions> WriteOptions(
812 const TocoOperator& op,
813 flatbuffers::FlatBufferBuilder* builder) const override {
814 ::tflite::LSTMKernelType kernel_type = GetKernelType(op.kernel_type);
815
816 // Current toco converter only supports tanh, no clip.
817 return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/
818 ::tflite::ActivationFunctionType_TANH,
819 /*cell_clip=*/0.0,
820 /*proj_clip=*/0.0, kernel_type);
821 }
822
823 void ReadOptions(const TfLiteOptions& options,
824 TocoOperator* op) const override {
825 // Only support tanh activation, so check that tflite type is tanh.
826 CHECK(options.fused_activation_function() ==
827 ::tflite::ActivationFunctionType_TANH);
828
829 switch (options.kernel_type()) {
830 case ::tflite::LSTMKernelType_BASIC:
831 op->kernel_type = LstmCellOperator::KERNEL_BASIC;
832 break;
833 case ::tflite::LSTMKernelType_FULL:
834 op->kernel_type = LstmCellOperator::KERNEL_FULL;
835 break;
836 }
837 }
838
839 int GetVersion(const OperatorSignature& op_signature) const override {
840 const auto& lstm_op =
841 static_cast<const LstmCellOperator&>(*op_signature.op);
842 ::tflite::OpSignature op_sig =
843 GetVersioningOpSig(builtin_op(), op_signature);
844 TfLiteLSTMParams lstm_params = {};
845 lstm_params.kernel_type =
846 static_cast<TfLiteLSTMKernelType>(GetKernelType(lstm_op.kernel_type));
847 op_sig.builtin_data = reinterpret_cast<void*>(&lstm_params);
848 return ::tflite::GetBuiltinOperatorVersion(op_sig);
849 }
850
851 std::vector<bool> GetMutatingInputVariables(
852 const Operator& op) const override {
853 const auto& lstm_op = static_cast<const LstmCellOperator&>(op);
854
855 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
856 switch (lstm_op.kernel_type) {
857 case LstmCellOperator::KERNEL_FULL: {
858 mutating_input_variables[kInputActivationStateTensor] = true;
859 mutating_input_variables[kInputCellStateTensor] = true;
860 break;
861 }
862 case LstmCellOperator::KERNEL_BASIC: {
863 mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true;
864 mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true;
865 break;
866 }
867 }
868 return mutating_input_variables;
869 }
870};
871
872class UnidirectionalSequenceLstm
873 : public BuiltinOperator<
874 UnidirectionalSequenceLstmOperator,
875 ::tflite::UnidirectionalSequenceLSTMOptions,
876 ::tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions> {
877 public:
878 using BuiltinOperator::BuiltinOperator;
879 flatbuffers::Offset<TfLiteOptions> WriteOptions(
880 const TocoOperator& op,
881 flatbuffers::FlatBufferBuilder* builder) const override {
882 // Current toco converter only supports tanh, no clip.
883 return ::tflite::CreateUnidirectionalSequenceLSTMOptions(
884 *builder, /*fused_activation_function=*/
885 ::tflite::ActivationFunctionType_TANH,
886 /*cell_clip=*/0.0,
887 /*proj_clip=*/0.0,
888 /*time_major=*/true);
889 }
890
891 void ReadOptions(const TfLiteOptions& options,
892 TocoOperator* op) const override {
893 // Only support tanh activation, so check that tflite type is tanh.
894 DCHECK(options.fused_activation_function() ==
895 ::tflite::ActivationFunctionType_TANH);
896 }
897
898 std::vector<bool> GetMutatingInputVariables(
899 const Operator& op) const override {
900 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
901 mutating_input_variables[kInputActivationStateTensor] = true;
902 mutating_input_variables[kInputCellStateTensor] = true;
903 return mutating_input_variables;
904 }
905};
906
907class BidirectionalSequenceLstm
908 : public BuiltinOperator<
909 BidirectionalSequenceLstmOperator,
910 ::tflite::BidirectionalSequenceLSTMOptions,
911 ::tflite::BuiltinOptions_BidirectionalSequenceLSTMOptions> {
912 public:
913 using BuiltinOperator::BuiltinOperator;
914 flatbuffers::Offset<TfLiteOptions> WriteOptions(
915 const TocoOperator& op,
916 flatbuffers::FlatBufferBuilder* builder) const override {
917 // Current toco converter only supports tanh, no clip.
918 return ::tflite::CreateBidirectionalSequenceLSTMOptions(
919 *builder, /*fused_activation_function=*/
920 ::tflite::ActivationFunctionType_TANH,
921 /*cell_clip=*/0.0,
922 /*proj_clip=*/0.0,
923 /*merge_outputs=*/op.merge_outputs,
924 /*time_major=*/true);
925 }
926
927 void ReadOptions(const TfLiteOptions& options,
928 TocoOperator* op) const override {
929 // Only support tanh activation, so check that tflite type is tanh.
930 DCHECK(options.fused_activation_function() ==
931 ::tflite::ActivationFunctionType_TANH);
932 op->merge_outputs = options.merge_outputs();
933 }
934
935 std::vector<bool> GetMutatingInputVariables(
936 const Operator& op) const override {
937 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
938 // Forward input activation state.
939 mutating_input_variables[35] = true;
940 // Forward input cell state.
941 mutating_input_variables[36] = true;
942 // Backward input activation state.
943 mutating_input_variables[37] = true;
944 // Backward input cell state.
945 mutating_input_variables[38] = true;
946 return mutating_input_variables;
947 }
948};
949
950class BidirectionalSequenceRnn
951 : public BuiltinOperator<
952 BidirectionalSequenceRnnOperator,
953 ::tflite::BidirectionalSequenceRNNOptions,
954 ::tflite::BuiltinOptions_BidirectionalSequenceRNNOptions> {
955 public:
956 using BuiltinOperator::BuiltinOperator;
957 flatbuffers::Offset<TfLiteOptions> WriteOptions(
958 const TocoOperator& op,
959 flatbuffers::FlatBufferBuilder* builder) const override {
960 // Current toco converter only supports tanh, no clip.
961 return ::tflite::CreateBidirectionalSequenceRNNOptions(
962 *builder, /*time_major=*/true,
963 /*fused_activation_function=*/
964 ::tflite::ActivationFunctionType_TANH,
965 /*merge_outputs=*/op.merge_outputs);
966 }
967
968 void ReadOptions(const TfLiteOptions& options,
969 TocoOperator* op) const override {
970 // Only support tanh activation, so check that tflite type is tanh.
971 DCHECK(options.fused_activation_function() ==
972 ::tflite::ActivationFunctionType_TANH);
973 op->merge_outputs = options.merge_outputs();
974 }
975
976 std::vector<bool> GetMutatingInputVariables(
977 const Operator& op) const override {
978 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
979 // Forward hidden state.
980 mutating_input_variables[4] = true;
981 // Backward hidden state.
982 mutating_input_variables[8] = true;
983 return mutating_input_variables;
984 }
985};
986
987class Mean : public BuiltinOperator<MeanOperator, ::tflite::ReducerOptions,
988 ::tflite::BuiltinOptions_ReducerOptions> {
989 public:
990 using BuiltinOperator::BuiltinOperator;
991 flatbuffers::Offset<TfLiteOptions> WriteOptions(
992 const TocoOperator& op,
993 flatbuffers::FlatBufferBuilder* builder) const override {
994 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
995 }
996
997 void ReadOptions(const TfLiteOptions& options,
998 TocoOperator* op) const override {
999 op->keep_dims = options.keep_dims();
1000 }
1001};
1002
1003class Sum
1004 : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions,
1005 ::tflite::BuiltinOptions_ReducerOptions> {
1006 public:
1007 using BuiltinOperator::BuiltinOperator;
1008 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1009 const TocoOperator& op,
1010 flatbuffers::FlatBufferBuilder* builder) const override {
1011 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1012 }
1013
1014 void ReadOptions(const TfLiteOptions& options,
1015 TocoOperator* op) const override {
1016 op->keep_dims = options.keep_dims();
1017 }
1018};
1019
1020class ReduceMax
1021 : public BuiltinOperator<TensorFlowMaxOperator, ::tflite::ReducerOptions,
1022 ::tflite::BuiltinOptions_ReducerOptions> {
1023 public:
1024 using BuiltinOperator::BuiltinOperator;
1025 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1026 const TocoOperator& op,
1027 flatbuffers::FlatBufferBuilder* builder) const override {
1028 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1029 }
1030
1031 void ReadOptions(const TfLiteOptions& options,
1032 TocoOperator* op) const override {
1033 op->keep_dims = options.keep_dims();
1034 }
1035};
1036
1037class ReduceMin
1038 : public BuiltinOperator<TensorFlowMinOperator, ::tflite::ReducerOptions,
1039 ::tflite::BuiltinOptions_ReducerOptions> {
1040 public:
1041 using BuiltinOperator::BuiltinOperator;
1042 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1043 const TocoOperator& op,
1044 flatbuffers::FlatBufferBuilder* builder) const override {
1045 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1046 }
1047
1048 void ReadOptions(const TfLiteOptions& options,
1049 TocoOperator* op) const override {
1050 op->keep_dims = options.keep_dims();
1051 }
1052};
1053
1054class ReduceProd
1055 : public BuiltinOperator<TensorFlowProdOperator, ::tflite::ReducerOptions,
1056 ::tflite::BuiltinOptions_ReducerOptions> {
1057 public:
1058 using BuiltinOperator::BuiltinOperator;
1059 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1060 const TocoOperator& op,
1061 flatbuffers::FlatBufferBuilder* builder) const override {
1062 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1063 }
1064
1065 void ReadOptions(const TfLiteOptions& options,
1066 TocoOperator* op) const override {
1067 op->keep_dims = options.keep_dims();
1068 }
1069};
1070
1071class ReduceAny
1072 : public BuiltinOperator<TensorFlowAnyOperator, ::tflite::ReducerOptions,
1073 ::tflite::BuiltinOptions_ReducerOptions> {
1074 public:
1075 using BuiltinOperator::BuiltinOperator;
1076 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1077 const TocoOperator& op,
1078 flatbuffers::FlatBufferBuilder* builder) const override {
1079 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1080 }
1081
1082 void ReadOptions(const TfLiteOptions& options,
1083 TocoOperator* op) const override {
1084 op->keep_dims = options.keep_dims();
1085 }
1086};
1087
1088class ResizeBilinear
1089 : public BuiltinOperator<ResizeBilinearOperator,
1090 ::tflite::ResizeBilinearOptions,
1091 ::tflite::BuiltinOptions_ResizeBilinearOptions> {
1092 public:
1093 using BuiltinOperator::BuiltinOperator;
1094 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1095 const TocoOperator& op,
1096 flatbuffers::FlatBufferBuilder* builder) const override {
1097 return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners,
1098 op.half_pixel_centers);
1099 }
1100
1101 void ReadOptions(const TfLiteOptions& options,
1102 TocoOperator* op) const override {
1103 op->align_corners = options.align_corners();
1104 op->half_pixel_centers = options.half_pixel_centers();
1105 }
1106
1107 int GetVersion(const OperatorSignature& op_signature) const override {
1108 const auto& resize_bilinear_op =
1109 static_cast<const ResizeBilinearOperator&>(*op_signature.op);
1110 ::tflite::OpSignature op_sig =
1111 GetVersioningOpSig(builtin_op(), op_signature);
1112 TfLiteResizeBilinearParams resize_bilinear_params = {};
1113 resize_bilinear_params.half_pixel_centers =
1114 resize_bilinear_op.half_pixel_centers;
1115 resize_bilinear_params.align_corners = resize_bilinear_op.align_corners;
1116 op_sig.builtin_data = reinterpret_cast<void*>(&resize_bilinear_params);
1117 return ::tflite::GetBuiltinOperatorVersion(op_sig);
1118 }
1119};
1120
1121class ResizeNearestNeighbor
1122 : public BuiltinOperator<
1123 ResizeNearestNeighborOperator, ::tflite::ResizeNearestNeighborOptions,
1124 ::tflite::BuiltinOptions_ResizeNearestNeighborOptions> {
1125 public:
1126 using BuiltinOperator::BuiltinOperator;
1127 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1128 const TocoOperator& op,
1129 flatbuffers::FlatBufferBuilder* builder) const override {
1130 return ::tflite::CreateResizeNearestNeighborOptions(
1131 *builder, op.align_corners, op.half_pixel_centers);
1132 }
1133
1134 void ReadOptions(const TfLiteOptions& options,
1135 TocoOperator* op) const override {
1136 op->align_corners = options.align_corners();
1137 op->half_pixel_centers = options.half_pixel_centers();
1138 }
1139
1140 int GetVersion(const OperatorSignature& op_signature) const override {
1141 const auto& resize_nn_op =
1142 static_cast<const ResizeNearestNeighborOperator&>(*op_signature.op);
1143 ::tflite::OpSignature op_sig =
1144 GetVersioningOpSig(builtin_op(), op_signature);
1145 TfLiteResizeNearestNeighborParams resize_nearest_neighbor_params = {};
1146 resize_nearest_neighbor_params.half_pixel_centers =
1147 resize_nn_op.half_pixel_centers;
1148 resize_nearest_neighbor_params.align_corners = resize_nn_op.align_corners;
1149 op_sig.builtin_data =
1150 reinterpret_cast<void*>(&resize_nearest_neighbor_params);
1151 return ::tflite::GetBuiltinOperatorVersion(op_sig);
1152 }
1153};
1154
1155class Squeeze
1156 : public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions,
1157 ::tflite::BuiltinOptions_SqueezeOptions> {
1158 public:
1159 using BuiltinOperator::BuiltinOperator;
1160
1161 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1162 const TocoOperator& op,
1163 flatbuffers::FlatBufferBuilder* builder) const override {
1164 auto squeeze_dims = builder->CreateVector(op.squeeze_dims);
1165 return ::tflite::CreateSqueezeOptions(*builder, squeeze_dims);
1166 }
1167
1168 void ReadOptions(const TfLiteOptions& options,
1169 TocoOperator* op) const override {
1170 op->squeeze_dims.insert(op->squeeze_dims.end(),
1171 options.squeeze_dims()->begin(),
1172 options.squeeze_dims()->end());
1173 }
1174};
1175
1176class Split
1177 : public BuiltinOperator<TensorFlowSplitOperator, ::tflite::SplitOptions,
1178 ::tflite::BuiltinOptions_SplitOptions> {
1179 public:
1180 using BuiltinOperator::BuiltinOperator;
1181
1182 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1183 const TocoOperator& op,
1184 flatbuffers::FlatBufferBuilder* builder) const override {
1185 return ::tflite::CreateSplitOptions(*builder, op.num_split);
1186 }
1187
1188 void ReadOptions(const TfLiteOptions& options,
1189 TocoOperator* op) const override {
1190 op->num_split = options.num_splits();
1191 }
1192};
1193
1194class SplitV
1195 : public BuiltinOperator<TensorFlowSplitVOperator, ::tflite::SplitVOptions,
1196 ::tflite::BuiltinOptions_SplitVOptions> {
1197 public:
1198 using BuiltinOperator::BuiltinOperator;
1199
1200 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1201 const TocoOperator& op,
1202 flatbuffers::FlatBufferBuilder* builder) const override {
1203 return ::tflite::CreateSplitVOptions(*builder, op.num_split);
1204 }
1205
1206 void ReadOptions(const TfLiteOptions& options,
1207 TocoOperator* op) const override {
1208 op->num_split = options.num_splits();
1209 }
1210};
1211
1212class StridedSlice
1213 : public BuiltinOperator<StridedSliceOperator,
1214 ::tflite::StridedSliceOptions,
1215 ::tflite::BuiltinOptions_StridedSliceOptions> {
1216 public:
1217 using BuiltinOperator::BuiltinOperator;
1218 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1219 const TocoOperator& op,
1220 flatbuffers::FlatBufferBuilder* builder) const override {
1221 return ::tflite::CreateStridedSliceOptions(
1222 *builder, op.begin_mask, op.end_mask, op.ellipsis_mask,
1223 op.new_axis_mask, op.shrink_axis_mask);
1224 }
1225
1226 void ReadOptions(const TfLiteOptions& options,
1227 TocoOperator* op) const override {
1228 op->begin_mask = options.begin_mask();
1229 op->end_mask = options.end_mask();
1230 op->ellipsis_mask = options.ellipsis_mask();
1231 op->new_axis_mask = options.new_axis_mask();
1232 op->shrink_axis_mask = options.shrink_axis_mask();
1233 }
1234
1235 int GetVersion(const OperatorSignature& op_signature) const override {
1236 const auto& ss_op =
1237 static_cast<const StridedSliceOperator&>(*op_signature.op);
1238 ::tflite::OpSignature op_sig =
1239 GetVersioningOpSig(builtin_op(), op_signature);
1240 op_sig.ext_options.strided_slice.num_dims = ss_op.start_indices.size();
1241 TfLiteStridedSliceParams strided_slice_params = {};
1242 strided_slice_params.ellipsis_mask = ss_op.ellipsis_mask;
1243 strided_slice_params.new_axis_mask = ss_op.new_axis_mask;
1244 op_sig.builtin_data = reinterpret_cast<void*>(&strided_slice_params);
1245 return ::tflite::GetBuiltinOperatorVersion(op_sig);
1246 }
1247};
1248
1249class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
1250 ::tflite::BuiltinOptions_TopKV2Options> {
1251 public:
1252 using BuiltinOperator::BuiltinOperator;
1253 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1254 const TocoOperator& op,
1255 flatbuffers::FlatBufferBuilder* builder) const override {
1256 return ::tflite::CreateTopKV2Options(*builder);
1257 }
1258
1259 void ReadOptions(const TfLiteOptions& options,
1260 TocoOperator* op) const override {}
1261};
1262
1263class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
1264 ::tflite::BuiltinOptions_ArgMaxOptions> {
1265 public:
1266 using BuiltinOperator::BuiltinOperator;
1267 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1268 const TocoOperator& op,
1269 flatbuffers::FlatBufferBuilder* builder) const override {
1270 return ::tflite::CreateArgMaxOptions(
1271 *builder, DataType::Serialize(op.output_data_type));
1272 }
1273
1274 void ReadOptions(const TfLiteOptions& options,
1275 TocoOperator* op) const override {
1276 op->output_data_type = DataType::Deserialize(options.output_type());
1277 }
1278};
1279
1280class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions,
1281 ::tflite::BuiltinOptions_ArgMinOptions> {
1282 public:
1283 using BuiltinOperator::BuiltinOperator;
1284 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1285 const TocoOperator& op,
1286 flatbuffers::FlatBufferBuilder* builder) const override {
1287 return ::tflite::CreateArgMinOptions(
1288 *builder, DataType::Serialize(op.output_data_type));
1289 }
1290
1291 void ReadOptions(const TfLiteOptions& options,
1292 TocoOperator* op) const override {
1293 op->output_data_type = DataType::Deserialize(options.output_type());
1294 }
1295};
1296
1297class TransposeConv
1298 : public BuiltinOperator<TransposeConvOperator,
1299 ::tflite::TransposeConvOptions,
1300 ::tflite::BuiltinOptions_TransposeConvOptions> {
1301 public:
1302 using BuiltinOperator::BuiltinOperator;
1303
1304 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1305 const TocoOperator& op,
1306 flatbuffers::FlatBufferBuilder* builder) const override {
1307 auto padding = Padding::Serialize(op.padding.type);
1308 return ::tflite::CreateTransposeConvOptions(
1309 *builder, padding, op.stride_width, op.stride_height);
1310 }
1311
1312 void ReadOptions(const TfLiteOptions& options,
1313 TocoOperator* op) const override {
1314 op->padding.type = Padding::Deserialize(options.padding());
1315 op->stride_width = options.stride_w();
1316 op->stride_height = options.stride_h();
1317 }
1318};
1319
1320class SparseToDense
1321 : public BuiltinOperator<SparseToDenseOperator,
1322 ::tflite::SparseToDenseOptions,
1323 ::tflite::BuiltinOptions_SparseToDenseOptions> {
1324 public:
1325 using BuiltinOperator::BuiltinOperator;
1326
1327 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1328 const TocoOperator& op,
1329 flatbuffers::FlatBufferBuilder* builder) const override {
1330 return ::tflite::CreateSparseToDenseOptions(*builder, op.validate_indices);
1331 }
1332
1333 void ReadOptions(const TfLiteOptions& options,
1334 TocoOperator* op) const override {
1335 op->validate_indices = options.validate_indices();
1336 }
1337};
1338
1339class ExpandDims
1340 : public BuiltinOperator<ExpandDimsOperator, ::tflite::ExpandDimsOptions,
1341 ::tflite::BuiltinOptions_ExpandDimsOptions> {
1342 public:
1343 using BuiltinOperator::BuiltinOperator;
1344
1345 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1346 const TocoOperator& op,
1347 flatbuffers::FlatBufferBuilder* builder) const override {
1348 return ::tflite::CreateExpandDimsOptions(*builder);
1349 }
1350
1351 void ReadOptions(const TfLiteOptions& options,
1352 TocoOperator* op) const override {}
1353};
1354
1355class Pack : public BuiltinOperator<PackOperator, ::tflite::PackOptions,
1356 ::tflite::BuiltinOptions_PackOptions> {
1357 public:
1358 using BuiltinOperator::BuiltinOperator;
1359
1360 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1361 const TocoOperator& op,
1362 flatbuffers::FlatBufferBuilder* builder) const override {
1363 return ::tflite::CreatePackOptions(*builder, op.values_count, op.axis);
1364 }
1365
1366 void ReadOptions(const TfLiteOptions& options,
1367 TocoOperator* op) const override {
1368 op->values_count = options.values_count();
1369 op->axis = options.axis();
1370 }
1371};
1372
1373class Shape
1374 : public BuiltinOperator<TensorFlowShapeOperator, ::tflite::ShapeOptions,
1375 ::tflite::BuiltinOptions_ShapeOptions> {
1376 public:
1377 using BuiltinOperator::BuiltinOperator;
1378 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1379 const TocoOperator& op,
1380 flatbuffers::FlatBufferBuilder* builder) const override {
1381 return ::tflite::CreateShapeOptions(
1382 *builder, DataType::Serialize(op.output_data_type));
1383 }
1384
1385 void ReadOptions(const TfLiteOptions& options,
1386 TocoOperator* op) const override {
1387 op->output_data_type = DataType::Deserialize(options.out_type());
1388 }
1389};
1390
1391class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
1392 ::tflite::BuiltinOptions_OneHotOptions> {
1393 public:
1394 using BuiltinOperator::BuiltinOperator;
1395 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1396 const TocoOperator& op,
1397 flatbuffers::FlatBufferBuilder* builder) const override {
1398 return ::tflite::CreateOneHotOptions(*builder, op.axis);
1399 }
1400 void ReadOptions(const TfLiteOptions& options,
1401 TocoOperator* op) const override {
1402 op->axis = options.axis();
1403 }
1404};
1405
1406class CTCBeamSearchDecoder
1407 : public CustomOperator<CTCBeamSearchDecoderOperator> {
1408 public:
1409 using CustomOperator::CustomOperator;
1410
1411 void WriteOptions(const TocoOperator& op,
1412 flexbuffers::Builder* fbb) const override {
1413 fbb->Int("beam_width", op.beam_width);
1414 fbb->Int("top_paths", op.top_paths);
1415 fbb->Bool("merge_repeated", op.merge_repeated);
1416 }
1417
1418 void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
1419 op->beam_width = m["beam_width"].AsInt32();
1420 op->top_paths = m["top_paths"].AsInt32();
1421 op->merge_repeated = m["merge_repeated"].AsBool();
1422 }
1423
1424 int GetVersion(const OperatorSignature& op_signature) const override {
1425 return 1;
1426 }
1427};
1428
1429class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
1430 ::tflite::BuiltinOptions_UnpackOptions> {
1431 public:
1432 using BuiltinOperator::BuiltinOperator;
1433 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1434 const TocoOperator& op,
1435 flatbuffers::FlatBufferBuilder* builder) const override {
1436 return ::tflite::CreateUnpackOptions(*builder, op.num, op.axis);
1437 }
1438 void ReadOptions(const TfLiteOptions& options,
1439 TocoOperator* op) const override {
1440 op->num = options.num();
1441 op->axis = options.axis();
1442 }
1443
1444 int GetVersion(const OperatorSignature& op_signature) const override {
1445 const std::string& input_name = op_signature.op->inputs[0];
1446 const Array& input_array = op_signature.model->GetArray(input_name);
1447 // If the op take int8/uint8 input, it is version 2.
1448 if (input_array.data_type == ArrayDataType::kInt8 ||
1449 input_array.data_type == ArrayDataType::kUint8) {
1450 return 2;
1451 }
1452 // If the op take bool input, it is version 3.
1453 if (input_array.data_type == ArrayDataType::kBool) {
1454 return 3;
1455 }
1456 return 1;
1457 }
1458};
1459
1460class LeakyRelu
1461 : public BuiltinOperator<LeakyReluOperator, ::tflite::LeakyReluOptions,
1462 ::tflite::BuiltinOptions_LeakyReluOptions> {
1463 public:
1464 using BuiltinOperator::BuiltinOperator;
1465 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1466 const TocoOperator& op,
1467 flatbuffers::FlatBufferBuilder* builder) const override {
1468 return ::tflite::CreateLeakyReluOptions(*builder, op.alpha);
1469 }
1470 void ReadOptions(const TfLiteOptions& options,
1471 TocoOperator* op) const override {
1472 op->alpha = options.alpha();
1473 }
1474};
1475
1476class SquaredDifference
1477 : public BuiltinOperator<
1478 SquaredDifferenceOperator, ::tflite::SquaredDifferenceOptions,
1479 ::tflite::BuiltinOptions_SquaredDifferenceOptions> {
1480 public:
1481 using BuiltinOperator::BuiltinOperator;
1482
1483 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1484 const TocoOperator& op,
1485 flatbuffers::FlatBufferBuilder* builder) const override {
1486 return ::tflite::CreateSquaredDifferenceOptions(*builder);
1487 }
1488
1489 void ReadOptions(const TfLiteOptions& options,
1490 TocoOperator* op) const override {}
1491};
1492
1493class MirrorPad
1494 : public BuiltinOperator<MirrorPadOperator, ::tflite::MirrorPadOptions,
1495 ::tflite::BuiltinOptions_MirrorPadOptions> {
1496 public:
1497 using BuiltinOperator::BuiltinOperator;
1498 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1499 const TocoOperator& op,
1500 flatbuffers::FlatBufferBuilder* builder) const override {
1501 return ::tflite::CreateMirrorPadOptions(
1502 *builder, op.mode == MirrorPadMode::kReflect
1503 ? ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
1504 : ::tflite::MirrorPadMode::MirrorPadMode_SYMMETRIC);
1505 }
1506 void ReadOptions(const TfLiteOptions& options,
1507 TocoOperator* op) const override {
1508 op->mode = options.mode() == ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
1509 ? MirrorPadMode::kReflect
1510 : MirrorPadMode::kSymmetric;
1511 }
1512};
1513
1514class Unique : public BuiltinOperator<UniqueOperator, ::tflite::UniqueOptions,
1515 ::tflite::BuiltinOptions_UniqueOptions> {
1516 public:
1517 using BuiltinOperator::BuiltinOperator;
1518 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1519 const TocoOperator& op,
1520 flatbuffers::FlatBufferBuilder* builder) const override {
1521 const UniqueOperator& unique_op = static_cast<const UniqueOperator&>(op);
1522 return ::tflite::CreateUniqueOptions(
1523 *builder, unique_op.idx_out_type == toco::ArrayDataType::kInt64
1524 ? ::tflite::TensorType::TensorType_INT64
1525 : ::tflite::TensorType_INT32);
1526 }
1527 void ReadOptions(const TfLiteOptions& options,
1528 TocoOperator* op) const override {
1529 UniqueOperator* unique_op = static_cast<UniqueOperator*>(op);
1530 unique_op->idx_out_type =
1531 options.idx_out_type() == ::tflite::TensorType_INT64
1532 ? toco::ArrayDataType::kInt64
1533 : toco::ArrayDataType::kInt32;
1534 }
1535};
1536
1537class UnidirectionalSequenceRnn
1538 : public BuiltinOperator<UnidirectionalSequenceRnnOperator,
1539 ::tflite::SequenceRNNOptions,
1540 ::tflite::BuiltinOptions_SequenceRNNOptions> {
1541 public:
1542 using BuiltinOperator::BuiltinOperator;
1543 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1544 const TocoOperator& op,
1545 flatbuffers::FlatBufferBuilder* builder) const override {
1546 return ::tflite::CreateSequenceRNNOptions(
1547 *builder, /*time_major=*/true,
1548 /*fused_activation_function=*/
1549 ::tflite::ActivationFunctionType_TANH);
1550 }
1551 void ReadOptions(const TfLiteOptions& options,
1552 TocoOperator* op) const override {
1553 // Only support tanh activation, so check that tflite type is tanh.
1554 DCHECK(options.fused_activation_function() ==
1555 ::tflite::ActivationFunctionType_TANH);
1556 }
1557
1558 std::vector<bool> GetMutatingInputVariables(
1559 const Operator& op) const override {
1560 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
1561 mutating_input_variables[4] = true;
1562 return mutating_input_variables;
1563 }
1564};
1565
1566class Where : public BuiltinOperator<WhereOperator, ::tflite::WhereOptions,
1567 ::tflite::BuiltinOptions_WhereOptions> {
1568 public:
1569 using BuiltinOperator::BuiltinOperator;
1570
1571 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1572 const TocoOperator& op,
1573 flatbuffers::FlatBufferBuilder* builder) const override {
1574 return ::tflite::CreateWhereOptions(*builder);
1575 }
1576
1577 void ReadOptions(const TfLiteOptions& options,
1578 TocoOperator* op) const override {}
1579};
1580
1581std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
1582 const std::string& tensorflow_node_def) {
1583 auto fbb = std::make_unique<flexbuffers::Builder>();
1584
1585 ::tensorflow::NodeDef node_def;
1586 if (!node_def.ParseFromString(tensorflow_node_def)) {
1587 LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
1588 return {};
1589 }
1590
1591 fbb->Vector([&]() {
1592 fbb->String(node_def.op());
1593 fbb->String(tensorflow_node_def);
1594 });
1595 fbb->Finish();
1596 LOG(INFO) << "Writing flex op: " << node_def.op();
1597 return std::unique_ptr<flexbuffers::Builder>(fbb.release());
1598}
1599
1600class TensorFlowUnsupported : public BaseOperator {
1601 public:
1602 TensorFlowUnsupported(const std::string& name, OperatorType type,
1603 bool enable_select_tf_ops)
1604 : BaseOperator(name, type), enable_select_tf_ops_(enable_select_tf_ops) {}
1605
1606 Options Serialize(const Operator& op,
1607 flatbuffers::FlatBufferBuilder* builder) const override {
1608 auto fbb =
1609 WriteOptions(static_cast<const TensorFlowUnsupportedOperator&>(op));
1610 if (fbb) {
1611 return Options::Custom(builder->CreateVector(fbb->GetBuffer()));
1612 } else {
1613 return Options::Custom(0);
1614 }
1615 }
1616
1617 std::unique_ptr<Operator> Deserialize(
1618 const BuiltinOptions* builtin_options,
1619 const CustomOptions* custom_options) const override {
1620 // Deserializing Flex ops doesn't work now.
1621 // TODO(ycling): Revisit and decide if we should fix the flow for importing
1622 // TFLite models with Flex ops.
1623 auto op = std::make_unique<TensorFlowUnsupportedOperator>();
1624 if (custom_options) {
1625 auto flexbuffer_map =
1626 flexbuffers::GetRoot(custom_options->data(), custom_options->size())
1627 .AsMap();
1628 ReadOptions(flexbuffer_map, op.get());
1629 }
1630 return std::unique_ptr<Operator>(op.release());
1631 }
1632
1633 std::unique_ptr<flexbuffers::Builder> WriteOptions(
1634 const TensorFlowUnsupportedOperator& op) const {
1635 if (enable_select_tf_ops_) {
1636 return WriteFlexOpOptions(op.tensorflow_node_def);
1637 }
1638 auto fbb = std::make_unique<flexbuffers::Builder>();
1639
1640 ::tensorflow::NodeDef node_def;
1641 if (!node_def.ParseFromString(op.tensorflow_node_def)) {
1642 LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
1643 return std::unique_ptr<flexbuffers::Builder>();
1644 }
1645
1646 if (ShouldExportAsFlexOp(enable_select_tf_ops_, node_def.op())) {
1647 fbb->Vector([&]() {
1648 fbb->String(node_def.op());
1649 fbb->String(op.tensorflow_node_def);
1650 });
1651 fbb->Finish();
1652 LOG(INFO) << "Writing flex op: " << node_def.op();
1653 return std::unique_ptr<flexbuffers::Builder>(fbb.release());
1654 }
1655
1656 bool has_valid_attr = false;
1657 size_t map_start = fbb->StartMap();
1658 for (const auto& pair : node_def.attr()) {
1659 const char* key = pair.first.c_str();
1660 const auto& attr = pair.second;
1661 switch (attr.value_case()) {
1662 case ::tensorflow::AttrValue::kS:
1663 fbb->String(key, attr.s());
1664 has_valid_attr = true;
1665 break;
1666 case ::tensorflow::AttrValue::kI:
1667 fbb->Int(key, attr.i());
1668 has_valid_attr = true;
1669 break;
1670 case ::tensorflow::AttrValue::kF:
1671 fbb->Float(key, attr.f());
1672 has_valid_attr = true;
1673 break;
1674 case ::tensorflow::AttrValue::kB:
1675 fbb->Bool(key, attr.b());
1676 has_valid_attr = true;
1677 break;
1678 case tensorflow::AttrValue::kList:
1679 if (attr.list().s_size() > 0) {
1680 auto start = fbb->StartVector(key);
1681 for (const std::string& v : attr.list().s()) {
1682 fbb->Add(v);
1683 }
1684 fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
1685 has_valid_attr = true;
1686 } else if (attr.list().i_size() > 0) {
1687 auto start = fbb->StartVector(key);
1688 for (const int64_t v : attr.list().i()) {
1689 fbb->Add(v);
1690 }
1691 fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
1692 has_valid_attr = true;
1693 } else if (attr.list().f_size() > 0) {
1694 auto start = fbb->StartVector(key);
1695 for (const float v : attr.list().f()) {
1696 fbb->Add(v);
1697 }
1698 fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
1699 has_valid_attr = true;
1700 } else {
1701 LOG(WARNING)
1702 << "Ignoring unsupported type in list attribute with key '"
1703 << key << "'";
1704 }
1705 break;
1706 default:
1707 LOG(WARNING) << "Ignoring unsupported attribute type with key '"
1708 << key << "'";
1709 break;
1710 }
1711 }
1712 if (!has_valid_attr) {
1713 return std::unique_ptr<flexbuffers::Builder>();
1714 }
1715 fbb->EndMap(map_start);
1716 fbb->Finish();
1717 return std::unique_ptr<flexbuffers::Builder>(fbb.release());
1718 }
1719
1720 void ReadOptions(const flexbuffers::Map& m,
1721 TensorFlowUnsupportedOperator* op) const {
1722 ::tensorflow::NodeDef node_def;
1723 auto attr = node_def.mutable_attr();
1724
1725 const auto& keys = m.Keys();
1726 for (size_t i = 0; i < keys.size(); ++i) {
1727 const auto key = keys[i].AsKey();
1728 const auto& value = m[key];
1729 switch (value.GetType()) {
1730 case flexbuffers::FBT_STRING:
1731 (*attr)[key].set_s(value.AsString().c_str());
1732 break;
1733 case flexbuffers::FBT_INT:
1734 (*attr)[key].set_i(value.AsInt64());
1735 break;
1736 case flexbuffers::FBT_FLOAT:
1737 (*attr)[key].set_f(value.AsFloat());
1738 break;
1739 case flexbuffers::FBT_BOOL:
1740 (*attr)[key].set_b(value.AsBool());
1741 if (std::string(key) == "_output_quantized") {
1742 op->quantized = value.AsBool();
1743 }
1744 if (std::string(key) ==
1745 "_support_output_type_float_in_quantized_op") {
1746 op->support_output_type_float_in_quantized_op = value.AsBool();
1747 }
1748 break;
1749 case flexbuffers::FBT_VECTOR_INT: {
1750 auto* list = (*attr)[key].mutable_list();
1751 const auto& vector = value.AsTypedVector();
1752 for (size_t i = 0; i < vector.size(); i++) {
1753 list->add_i(vector[i].AsInt64());
1754 }
1755 break;
1756 }
1757 case flexbuffers::FBT_VECTOR_FLOAT: {
1758 auto* list = (*attr)[key].mutable_list();
1759 const auto& vector = value.AsTypedVector();
1760 for (size_t i = 0; i < vector.size(); i++) {
1761 list->add_f(vector[i].AsFloat());
1762 }
1763 break;
1764 }
1765 case 15 /* TO_DO(wvo): flexbuffers::FBT_VECTOR_STRING_DEPRECATED*/: {
1766 auto* list = (*attr)[key].mutable_list();
1767 const auto& vector = value.AsTypedVector();
1768 for (size_t i = 0; i < vector.size(); i++) {
1769 list->add_s(vector[i].AsString().str());
1770 }
1771 break;
1772 }
1773 default:
1774 LOG(WARNING) << "Ignoring unsupported attribute type with key '"
1775 << key << "'";
1776 break;
1777 }
1778 }
1779 node_def.SerializeToString(&op->tensorflow_node_def);
1780 }
1781
1782 int GetVersion(const OperatorSignature& op_signature) const override {
1783 // TODO(ycling): Design and implement a way to plumb the version of
1784 // custom ops.
1785 return 1;
1786 }
1787
1788 private:
1789 const bool enable_select_tf_ops_;
1790};
1791
1792class Dequantize
1793 : public BuiltinOperator<DequantizeOperator, ::tflite::DequantizeOptions,
1794 ::tflite::BuiltinOptions_DequantizeOptions> {
1795 public:
1796 using BuiltinOperator::BuiltinOperator;
1797
1798 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1799 const TocoOperator& op,
1800 flatbuffers::FlatBufferBuilder* builder) const override {
1801 return ::tflite::CreateDequantizeOptions(*builder);
1802 }
1803
1804 void ReadOptions(const TfLiteOptions& options,
1805 TocoOperator* op) const override {}
1806};
1807
1808class ReverseSequence
1809 : public BuiltinOperator<ReverseSequenceOperator,
1810 ::tflite::ReverseSequenceOptions,
1811 ::tflite::BuiltinOptions_ReverseSequenceOptions> {
1812 public:
1813 using BuiltinOperator::BuiltinOperator;
1814
1815 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1816 const TocoOperator& op,
1817 flatbuffers::FlatBufferBuilder* builder) const override {
1818 return ::tflite::CreateReverseSequenceOptions(*builder, op.seq_dim,
1819 op.batch_dim);
1820 }
1821
1822 void ReadOptions(const TfLiteOptions& options,
1823 TocoOperator* op) const override {
1824 op->seq_dim = options.seq_dim();
1825 op->batch_dim = options.batch_dim();
1826 }
1827};
1828
1829namespace {
1830// Build a vector containing all the known operators.
1831std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
1832 bool enable_select_tf_ops = false) {
1833 std::vector<std::unique_ptr<BaseOperator>> ops;
1834 using tensorflow::MakeUnique;
1835 // Builtin Operators.
1836 ops.push_back(
1837 MakeUnique<Add>(::tflite::BuiltinOperator_ADD, OperatorType::kAdd));
1838 ops.push_back(
1839 MakeUnique<AddN>(::tflite::BuiltinOperator_ADD_N, OperatorType::kAddN));
1840 ops.push_back(
1841 MakeUnique<Div>(::tflite::BuiltinOperator_DIV, OperatorType::kDiv));
1842 ops.push_back(
1843 MakeUnique<Sub>(::tflite::BuiltinOperator_SUB, OperatorType::kSub));
1844 ops.push_back(MakeUnique<AveragePool>(
1845 ::tflite::BuiltinOperator_AVERAGE_POOL_2D, OperatorType::kAveragePool));
1846 ops.push_back(
1847 MakeUnique<SpaceToBatchND>(::tflite::BuiltinOperator_SPACE_TO_BATCH_ND,
1848 OperatorType::kSpaceToBatchND));
1849 ops.push_back(
1850 MakeUnique<BatchToSpaceND>(::tflite::BuiltinOperator_BATCH_TO_SPACE_ND,
1851 OperatorType::kBatchToSpaceND));
1852 ops.push_back(MakeUnique<Concatenation>(
1853 ::tflite::BuiltinOperator_CONCATENATION, OperatorType::kConcatenation));
1854 ops.push_back(MakeUnique<Convolution>(::tflite::BuiltinOperator_CONV_2D,
1855 OperatorType::kConv));
1856 ops.push_back(MakeUnique<DepthwiseConvolution>(
1857 ::tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
1858 OperatorType::kDepthwiseConv));
1859 ops.push_back(MakeUnique<Dequantize>(::tflite::BuiltinOperator_DEQUANTIZE,
1860 OperatorType::kDequantize));
1861 ops.push_back(
1862 MakeUnique<FullyConnected>(::tflite::BuiltinOperator_FULLY_CONNECTED,
1863 OperatorType::kFullyConnected));
1864 ops.push_back(MakeUnique<Gather>(::tflite::BuiltinOperator_GATHER,
1865 OperatorType::kGather));
1866 ops.push_back(MakeUnique<GatherNd>(::tflite::BuiltinOperator_GATHER_ND,
1867 OperatorType::kGatherNd));
1868 ops.push_back(
1869 MakeUnique<L2Normalization>(::tflite::BuiltinOperator_L2_NORMALIZATION,
1870 OperatorType::kL2Normalization));
1871 ops.push_back(MakeUnique<L2Pool>(::tflite::BuiltinOperator_L2_POOL_2D,
1872 OperatorType::kL2Pool));
1873 ops.push_back(MakeUnique<LocalResponseNormalization>(
1874 ::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
1875 OperatorType::kLocalResponseNormalization));
1876 ops.push_back(MakeUnique<MaxPool>(::tflite::BuiltinOperator_MAX_POOL_2D,
1877 OperatorType::kMaxPool));
1878 ops.push_back(
1879 MakeUnique<Mul>(::tflite::BuiltinOperator_MUL, OperatorType::kMul));
1880
1881 ops.push_back(
1882 MakeUnique<Pad>(::tflite::BuiltinOperator_PAD, OperatorType::kPad));
1883 ops.push_back(
1884 MakeUnique<PadV2>(::tflite::BuiltinOperator_PADV2, OperatorType::kPadV2));
1885 ops.push_back(MakeUnique<Reshape>(::tflite::BuiltinOperator_RESHAPE,
1886 OperatorType::kReshape));
1887 ops.push_back(MakeUnique<Softmax>(::tflite::BuiltinOperator_SOFTMAX,
1888 OperatorType::kSoftmax));
1889 ops.push_back(MakeUnique<SpaceToDepth>(
1890 ::tflite::BuiltinOperator_SPACE_TO_DEPTH, OperatorType::kSpaceToDepth));
1891 ops.push_back(MakeUnique<DepthToSpace>(
1892 ::tflite::BuiltinOperator_DEPTH_TO_SPACE, OperatorType::kDepthToSpace));
1893 ops.push_back(
1894 MakeUnique<Svdf>(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf));
1895 ops.push_back(MakeUnique<Transpose>(::tflite::BuiltinOperator_TRANSPOSE,
1896 OperatorType::kTranspose));
1897 ops.push_back(
1898 MakeUnique<Mean>(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
1899 ops.push_back(
1900 MakeUnique<Sum>(::tflite::BuiltinOperator_SUM, OperatorType::kSum));
1901 ops.push_back(MakeUnique<ReduceProd>(::tflite::BuiltinOperator_REDUCE_PROD,
1902 OperatorType::kReduceProd));
1903 ops.push_back(MakeUnique<ReduceMax>(::tflite::BuiltinOperator_REDUCE_MAX,
1904 OperatorType::kReduceMax));
1905 ops.push_back(MakeUnique<ReduceMin>(::tflite::BuiltinOperator_REDUCE_MIN,
1906 OperatorType::kReduceMin));
1907 ops.push_back(MakeUnique<ReduceAny>(::tflite::BuiltinOperator_REDUCE_ANY,
1908 OperatorType::kAny));
1909 ops.push_back(
1910 MakeUnique<ResizeBilinear>(::tflite::BuiltinOperator_RESIZE_BILINEAR,
1911 OperatorType::kResizeBilinear));
1912 ops.push_back(MakeUnique<ResizeNearestNeighbor>(
1913 ::tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
1914 OperatorType::kResizeNearestNeighbor));
1915 ops.push_back(MakeUnique<Squeeze>(::tflite::BuiltinOperator_SQUEEZE,
1916 OperatorType::kSqueeze));
1917 ops.push_back(
1918 MakeUnique<Split>(::tflite::BuiltinOperator_SPLIT, OperatorType::kSplit));
1919 ops.push_back(MakeUnique<SplitV>(::tflite::BuiltinOperator_SPLIT_V,
1920 OperatorType::kSplitV));
1921 ops.push_back(MakeUnique<StridedSlice>(
1922 ::tflite::BuiltinOperator_STRIDED_SLICE, OperatorType::kStridedSlice));
1923 ops.push_back(MakeUnique<TopK_V2>(::tflite::BuiltinOperator_TOPK_V2,
1924 OperatorType::kTopK_V2));
1925 ops.push_back(MakeUnique<Lstm>(::tflite::BuiltinOperator_LSTM,
1926 OperatorType::kLstmCell));
1927 ops.push_back(
1928 MakeUnique<Cast>(::tflite::BuiltinOperator_CAST, OperatorType::kCast));
1929 ops.push_back(MakeUnique<ArgMax>(::tflite::BuiltinOperator_ARG_MAX,
1930 OperatorType::kArgMax));
1931 ops.push_back(MakeUnique<ArgMin>(::tflite::BuiltinOperator_ARG_MIN,
1932 OperatorType::kArgMin));
1933 ops.push_back(
1934 MakeUnique<Tile>(::tflite::BuiltinOperator_TILE, OperatorType::kTile));
1935 ops.push_back(MakeUnique<ExpandDims>(::tflite::BuiltinOperator_EXPAND_DIMS,
1936 OperatorType::kExpandDims));
1937 ops.push_back(MakeUnique<TransposeConv>(
1938 ::tflite::BuiltinOperator_TRANSPOSE_CONV, OperatorType::kTransposeConv));
1939 ops.push_back(MakeUnique<SparseToDense>(
1940 ::tflite::BuiltinOperator_SPARSE_TO_DENSE, OperatorType::kSparseToDense));
1941 ops.push_back(
1942 MakeUnique<Shape>(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape));
1943 ops.push_back(MakeUnique<FakeQuant>(::tflite::BuiltinOperator_FAKE_QUANT,
1944 OperatorType::kFakeQuant));
1945 ops.push_back(
1946 MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
1947 ops.emplace_back(MakeUnique<UnidirectionalSequenceLstm>(
1948 ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
1949 OperatorType::kUnidirectionalSequenceLstm));
1950 ops.emplace_back(MakeUnique<BidirectionalSequenceLstm>(
1951 ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
1952 OperatorType::kBidirectionalSequenceLstm));
1953 ops.emplace_back(MakeUnique<BidirectionalSequenceRnn>(
1954 ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
1955 OperatorType::kBidirectionalSequenceRnn));
1956 ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT,
1957 OperatorType::kOneHot));
1958 ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK,
1959 OperatorType::kUnpack));
1960 ops.push_back(MakeUnique<LeakyRelu>(::tflite::BuiltinOperator_LEAKY_RELU,
1961 OperatorType::kLeakyRelu));
1962 ops.push_back(MakeUnique<SquaredDifference>(
1963 ::tflite::BuiltinOperator_SQUARED_DIFFERENCE,
1964 OperatorType::kSquaredDifference));
1965 ops.push_back(MakeUnique<MirrorPad>(::tflite::BuiltinOperator_MIRROR_PAD,
1966 OperatorType::kMirrorPad));
1967 ops.push_back(MakeUnique<Unique>(::tflite::BuiltinOperator_UNIQUE,
1968 OperatorType::kUnique));
1969 ops.push_back(MakeUnique<UnidirectionalSequenceRnn>(
1970 ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
1971 OperatorType::kUnidirectionalSequenceRnn));
1972 ops.push_back(
1973 MakeUnique<Where>(::tflite::BuiltinOperator_WHERE, OperatorType::kWhere));
1974 ops.push_back(
1975 MakeUnique<ReverseSequence>(::tflite::BuiltinOperator_REVERSE_SEQUENCE,
1976 OperatorType::kReverseSequence));
1977 ops.push_back(MakeUnique<SimpleOperator<MatrixDiagOperator>>(
1978 ::tflite::BuiltinOperator_MATRIX_DIAG, OperatorType::kMatrixDiag));
1979 ops.push_back(MakeUnique<SimpleOperator<MatrixSetDiagOperator>>(
1980 ::tflite::BuiltinOperator_MATRIX_SET_DIAG, OperatorType::kMatrixSetDiag));
1981 // Custom Operators.
1982 ops.push_back(MakeUnique<CTCBeamSearchDecoder>(
1983 "CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder));
1984 ops.push_back(MakeUnique<TensorFlowUnsupported>("TENSORFLOW_UNSUPPORTED",
1985 OperatorType::kUnsupported,
1986 enable_select_tf_ops));
1987
1988 // SimpleOperator was designed to export CUSTOM TF Lite ops, but has since
1989 // been modified to also export builtins. As TOCO evolved we added warnings
1990 // when custom ops are exported but SimpleOperator bypasses thoses. To
1991 // prevent user confusion we are settling on using SimpleOperator only for
1992 // builtins.
1993 ops.push_back(MakeUnique<SimpleOperator<FloorOperator>>(
1994 ::tflite::BuiltinOperator_FLOOR, OperatorType::kFloor));
1995 ops.push_back(MakeUnique<SimpleOperator<CeilOperator>>(
1996 ::tflite::BuiltinOperator_CEIL, OperatorType::kCeil));
1997 ops.push_back(MakeUnique<SimpleOperator<EluOperator>>(
1998 ::tflite::BuiltinOperator_ELU, OperatorType::kElu));
1999 ops.push_back(MakeUnique<SimpleOperator<RoundOperator>>(
2000 ::tflite::BuiltinOperator_ROUND, OperatorType::kRound));
2001 ops.push_back(MakeUnique<SimpleOperator<ReluOperator>>(
2002 ::tflite::BuiltinOperator_RELU, OperatorType::kRelu));
2003 ops.push_back(MakeUnique<SimpleOperator<Relu1Operator>>(
2004 ::tflite::BuiltinOperator_RELU_N1_TO_1, OperatorType::kRelu1));
2005 ops.push_back(MakeUnique<SimpleOperator<Relu6Operator>>(
2006 ::tflite::BuiltinOperator_RELU6, OperatorType::kRelu6));
2007 ops.push_back(MakeUnique<SimpleOperator<PReluOperator>>(
2008 ::tflite::BuiltinOperator_PRELU, OperatorType::kPRelu));
2009 ops.push_back(MakeUnique<SimpleOperator<LogisticOperator>>(
2010 ::tflite::BuiltinOperator_LOGISTIC, OperatorType::kLogistic));
2011 ops.push_back(MakeUnique<SimpleOperator<TanhOperator>>(
2012 ::tflite::BuiltinOperator_TANH, OperatorType::kTanh));
2013 ops.push_back(MakeUnique<SimpleOperator<ExpOperator>>(
2014 ::tflite::BuiltinOperator_EXP, OperatorType::kExp));
2015 ops.push_back(MakeUnique<SimpleOperator<CosOperator>>(
2016 ::tflite::BuiltinOperator_COS, OperatorType::kCos));
2017 ops.push_back(MakeUnique<SimpleOperator<LogSoftmaxOperator>>(
2018 ::tflite::BuiltinOperator_LOG_SOFTMAX, OperatorType::kLogSoftmax));
2019 ops.push_back(MakeUnique<SimpleOperator<TensorFlowMaximumOperator>>(
2020 ::tflite::BuiltinOperator_MAXIMUM, OperatorType::kMaximum));
2021 ops.push_back(MakeUnique<SimpleOperator<TensorFlowMinimumOperator>>(
2022 ::tflite::BuiltinOperator_MINIMUM, OperatorType::kMinimum));
2023 ops.push_back(MakeUnique<SimpleOperator<TensorFlowGreaterOperator>>(
2024 ::tflite::BuiltinOperator_GREATER, OperatorType::kGreater));
2025 ops.push_back(MakeUnique<SimpleOperator<TensorFlowGreaterEqualOperator>>(
2026 ::tflite::BuiltinOperator_GREATER_EQUAL, OperatorType::kGreaterEqual));
2027 ops.push_back(MakeUnique<SimpleOperator<TensorFlowLessOperator>>(
2028 ::tflite::BuiltinOperator_LESS, OperatorType::kLess));
2029 ops.push_back(MakeUnique<SimpleOperator<TensorFlowLessEqualOperator>>(
2030 ::tflite::BuiltinOperator_LESS_EQUAL, OperatorType::kLessEqual));
2031 ops.push_back(MakeUnique<SimpleOperator<TensorFlowEqualOperator>>(
2032 ::tflite::BuiltinOperator_EQUAL, OperatorType::kEqual));
2033 ops.push_back(MakeUnique<SimpleOperator<TensorFlowNotEqualOperator>>(
2034 ::tflite::BuiltinOperator_NOT_EQUAL, OperatorType::kNotEqual));
2035 ops.push_back(MakeUnique<SimpleOperator<NegOperator>>(
2036 ::tflite::BuiltinOperator_NEG, OperatorType::kNeg));
2037 ops.push_back(MakeUnique<SimpleOperator<SelectOperator>>(
2038 ::tflite::BuiltinOperator_SELECT, OperatorType::kSelect));
2039 ops.push_back(MakeUnique<SimpleOperator<SliceOperator>>(
2040 ::tflite::BuiltinOperator_SLICE, OperatorType::kSlice));
2041 ops.push_back(MakeUnique<SimpleOperator<PowOperator>>(
2042 ::tflite::BuiltinOperator_POW, OperatorType::kPow));
2043 ops.push_back(MakeUnique<SimpleOperator<LogicalOrOperator>>(
2044 ::tflite::BuiltinOperator_LOGICAL_OR, OperatorType::kLogicalOr));
2045 ops.emplace_back(new SimpleOperator<LogicalAndOperator>(
2046 ::tflite::BuiltinOperator_LOGICAL_AND, OperatorType::kLogicalAnd));
2047 ops.emplace_back(new SimpleOperator<LogicalNotOperator>(
2048 ::tflite::BuiltinOperator_LOGICAL_NOT, OperatorType::kLogicalNot));
2049 ops.emplace_back(new SimpleOperator<FloorDivOperator>(
2050 ::tflite::BuiltinOperator_FLOOR_DIV, OperatorType::kFloorDiv));
2051 ops.emplace_back(new SimpleOperator<FloorModOperator>(
2052 ::tflite::BuiltinOperator_FLOOR_MOD, OperatorType::kFloorMod));
2053 ops.emplace_back(new SimpleOperator<RangeOperator>(
2054 ::tflite::BuiltinOperator_RANGE, OperatorType::kRange));
2055 // Element-wise operator
2056 ops.push_back(MakeUnique<SimpleOperator<SinOperator>>(
2057 ::tflite::BuiltinOperator_SIN, OperatorType::kSin));
2058 ops.push_back(MakeUnique<SimpleOperator<LogOperator>>(
2059 ::tflite::BuiltinOperator_LOG, OperatorType::kLog));
2060 ops.push_back(MakeUnique<SimpleOperator<TensorFlowSqrtOperator>>(
2061 ::tflite::BuiltinOperator_SQRT, OperatorType::kSqrt));
2062 ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>(
2063 ::tflite::BuiltinOperator_RSQRT, OperatorType::kRsqrt));
2064 ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>(
2065 ::tflite::BuiltinOperator_SQUARE, OperatorType::kSquare));
2066 ops.push_back(MakeUnique<SimpleOperator<TensorFlowZerosLikeOperator>>(
2067 ::tflite::BuiltinOperator_ZEROS_LIKE, OperatorType::kZerosLike));
2068 ops.push_back(MakeUnique<SimpleOperator<AbsOperator>>(
2069 ::tflite::BuiltinOperator_ABS, OperatorType::kAbs));
2070 ops.push_back(MakeUnique<SimpleOperator<HardSwishOperator>>(
2071 ::tflite::BuiltinOperator_HARD_SWISH, OperatorType::kHardSwish));
2072 ops.push_back(MakeUnique<SimpleOperator<FillOperator>>(
2073 ::tflite::BuiltinOperator_FILL, OperatorType::kFill));
2074 ops.push_back(MakeUnique<SimpleOperator<ReverseV2Operator>>(
2075 ::tflite::BuiltinOperator_REVERSE_V2, OperatorType::kReverseV2));
2076 ops.push_back(MakeUnique<SimpleOperator<TensorFlowRankOperator>>(
2077 ::tflite::BuiltinOperator_RANK, OperatorType::kRank));
2078 ops.emplace_back(new SimpleOperator<SegmentSumOperator>(
2079 ::tflite::BuiltinOperator_SEGMENT_SUM, OperatorType::kSegmentSum));
2080 ops.emplace_back(MakeUnique<SimpleOperator<ScatterNdOperator>>(
2081 ::tflite::BuiltinOperator_SCATTER_ND, OperatorType::kScatterNd));
2082 return ops;
2083}
2084} // namespace
2085
2086// LINT.ThenChange(//tensorflow/lite/tools/versioning/op_version.cc)
2087
2088std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
2089 bool enable_select_tf_ops) {
2090 std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
2091
2092 std::vector<std::unique_ptr<BaseOperator>> ops =
2093 BuildOperatorList(enable_select_tf_ops);
2094 for (auto& op : ops) {
2095 result[op->type()] = std::move(op);
2096 }
2097
2098 return result;
2099}
2100
2101std::map<std::string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
2102 bool enable_select_tf_ops) {
2103 std::map<std::string, std::unique_ptr<BaseOperator>> result;
2104
2105 std::vector<std::unique_ptr<BaseOperator>> ops =
2106 BuildOperatorList(enable_select_tf_ops);
2107 for (auto& op : ops) {
2108 result[op->name()] = std::move(op);
2109 }
2110
2111 return result;
2112}
2113
2114bool ShouldExportAsFlexOp(bool enable_select_tf_ops,
2115 const std::string& tensorflow_op_name) {
2116 // If Flex ops aren't allow at all, simply return false.
2117 if (!enable_select_tf_ops) {
2118 return false;
2119 }
2120 // Check if we can find the `OpDef` for the TensorFlow op. If we can find
2121 // it and it has been allowlisted, export the op as an Flex op. Otherwise,
2122 // export it as a regular custom op.
2123 const tensorflow::OpDef* op_def = nullptr;
2124 if (!tensorflow::OpRegistry::Global()
2125 ->LookUpOpDef(tensorflow_op_name, &op_def)
2126 .ok()) {
2127 return false;
2128 }
2129
2130 if (!::tflite::flex::IsAllowlistedFlexOp(tensorflow_op_name)) {
2131 LOG(WARNING) << "Op " << tensorflow_op_name
2132 << " is a valid TensorFlow op but has not been allowlisted for"
2133 " the TensorFlow Lite flex op set.";
2134 return false;
2135 }
2136
2137 return true;
2138}
2139
2140} // namespace tflite
2141
2142} // namespace toco
2143