1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_LITE_TOCO_MODEL_H_
16#define TENSORFLOW_LITE_TOCO_MODEL_H_
17
18#include <algorithm>
19#include <complex>
20#include <functional>
21#include <initializer_list>
22#include <memory>
23#include <optional>
24#include <string>
25#include <unordered_map>
26#include <unordered_set>
27#include <vector>
28
29#include "absl/types/optional.h"
30#include "tensorflow/core/platform/logging.h"
31#include "tensorflow/lite/toco/model_flags.pb.h"
32#include "tensorflow/lite/toco/runtime/types.h"
33#include "tensorflow/lite/toco/toco_port.h"
34#include "tensorflow/lite/toco/toco_types.h"
35
36namespace toco {
37
38using tflite::QuantizationParams;
39
40enum class OperatorType : uint8 {
41 kNone,
42 // General-purpose neural network operators.
43 kAdd,
44 kAddN,
45 kAveragePool,
46 kBatchMatMul,
47 kBatchNormalization,
48 kCeil,
49 kConv,
50 kConcatenation,
51 kCos,
52 kDepthwiseConv,
53 kDepthToSpace,
54 kSpaceToDepth,
55 kDequantize,
56 kDiv,
57 kExp,
58 kExpandDims,
59 kFill,
60 kFloorDiv,
61 kFloorMod,
62 kFullyConnected,
63 kL2Normalization,
64 kL2Pool,
65 kLstmCell,
66 kUnidirectionalSequenceLstm,
67 kLocalResponseNormalization,
68 kLog,
69 kLogistic,
70 kMaxPool,
71 kFakeQuant,
72 kMul,
73 kOneHot,
74 kRandomUniform,
75 kRange,
76 kRank,
77 kRelu,
78 kRelu1,
79 kRelu6,
80 kPRelu,
81 kHardSwish,
82 kSoftmax,
83 kLogSoftmax,
84 kSub,
85 kTanh,
86 kTransposeConv,
87 kCast,
88 kFloor,
89 kRound,
90 kGather,
91 kResizeBilinear,
92 kSin,
93 kSpaceToBatchND,
94 kPack,
95 kBatchToSpaceND,
96 kPad,
97 kPadV2,
98 kReduceProd, // Reduction product
99 kStridedSlice,
100 kSlice,
101 kSqueeze,
102 kMean,
103 kArgMax,
104 // The SVDF Op is a decomposition of a densely connected Op into
105 // low rank filters. For details:
106 // https://research.google.com/pubs/pub43813.html
107 kSvdf,
108 // Special operators used for importing TensorFlow nodes.
109 // The general intent is to have some graph transformation either
110 // drop them or rewrite them as general-purpose operators.
111 kAll,
112 kAssert,
113 kConcat,
114 kConcatV2,
115 kGreater,
116 kGreaterEqual,
117 kIdentity,
118 kLess,
119 kLessEqual,
120 kReduceMax, // Reduction Max
121 kMaximum, // Element-wise Maximum
122 kReduceMin, // Reduction Min
123 kMinimum, // Element-wise Minimum
124 kMatMul,
125 kMerge,
126 kNeg,
127 kReshape,
128 kRsqrt,
129 kShape,
130 kSplit,
131 kSplitV,
132 kSqrt,
133 kSquare,
134 kSquaredDifference,
135 kSum,
136 kSwitch,
137 kTile,
138 kTranspose,
139 kTopK_V2,
140 kDynamicPartition,
141 kDynamicStitch,
142 // An unsupported TF operation. It's only needed to be able to represent TF
143 // graph internally and is expected to be dropped by graph transformations.
144 kUnsupported,
145 // Finally, TensorFlow uses different conventions for axes ordering,
146 // see AxesOrder, and this cannot always be resolved at the time of importing
147 // nodes, as TensorFlow parameters may be constant-expression subgraphs
148 // instead of being given as plain constant arrays. So we need to insert
149 // special nodes in the graph to shuffle axes.
150 kReorderAxes,
151 kSegmentSum,
152 kSelect,
153 kSelectV2,
154 kSparseToDense,
155 kEqual,
156 kNotEqual,
157 kPow,
158 kArgMin,
159 kAny,
160 kLogicalAnd,
161 kLogicalNot,
162 kLogicalOr,
163 kCTCBeamSearchDecoder,
164 kUnpack,
165 kZerosLike,
166 kResizeNearestNeighbor,
167 kLeakyRelu,
168 kAbs,
169 kMirrorPad,
170 kUnique,
171 kUnidirectionalSequenceRnn,
172 kBidirectionalSequenceLstm,
173 kReverseV2,
174 kBidirectionalSequenceRnn,
175 kGatherNd,
176 kWhere,
177 kElu,
178 kReverseSequence,
179 kMatrixDiag,
180 kMatrixSetDiag,
181 kMatrixDiagV2,
182 kMatrixSetDiagV2,
183 kMatrixDiagV3,
184 kMatrixSetDiagV3,
185 kScatterNd,
186 // Debugging operators.
187 kNumericVerify
188};
189
190// Helper to deal with TensorFlow arrays using a different ordering of
191// dimensions
192// ("axes") than our own.
193// TODO(benoitjacob): Ultimately, we shouldn't have any "ordering" of axes,
194// we should have associative arrays mapping symbolic axes identifiers (like
195// "output_depth") to dimensions. We would then not need this anymore.
196enum class AxesOrder {
197 kOneAxis, // one-dimensional array, one unique axis.
198 kCR, // column-major matrix storage order. Our standard.
199 kRC, // row-major matrix storage order. TensorFlow default.
200 kOHWI, // Our standard for conv weights
201 kHWIO, // TensorFlow conv weights
202 k1HWO, // Our standard for DepthwiseConv weights
203 kHWIM, // TensorFlow DepthwiseConv weights
204 kNHWC, // TensorFlow activations
205 kHWOI, // TensorFlow back-prop conv weights
206};
207
208// The type of the scalars in an array.
209// Note that the type does not by itself tell whether the values in the array
210// are non-quantized (can be accessed directly) or quantized (must be
211// interpreted in conjunction with QuantizationParams).
212//
213// In practice though:
214// float values are never quantized
215// uint8 values are always quantized
216// int32 values are sometimes quantized (depending on whether
217// QuantizationParams are present).
218// complex values are never quantized
219// other types are never quantized at the moment.
220//
221// kNone means that we don't know the data type yet, or that we don't care
222// because we'll be dropping the array anyway (e.g. some exotic array types
223// may be involved only in debug-only subgraphs that we may not be interested
224// in actually supporting).
225enum class ArrayDataType : uint8 {
226 kNone, // 0
227 kBool,
228 kFloat,
229 kInt8,
230 kUint8,
231 kInt16, // 5
232 kUint16,
233 kInt32,
234 kUint32,
235 kInt64,
236 kUint64, // 10
237 kString,
238 kComplex64,
239 kFloat16,
240 kFloat64,
241 kComplex128,
242};
243
244// Compile-time logic to map ArrayDataType to the corresponding C++ scalar type
245template <ArrayDataType A>
246struct DataTypeImpl {};
247template <>
248struct DataTypeImpl<ArrayDataType::kNone> {
249 typedef int Type;
250};
251template <>
252struct DataTypeImpl<ArrayDataType::kBool> {
253 typedef bool Type;
254};
255template <>
256struct DataTypeImpl<ArrayDataType::kFloat> {
257 typedef float Type;
258};
259template <>
260struct DataTypeImpl<ArrayDataType::kInt8> {
261 typedef int8 Type;
262};
263template <>
264struct DataTypeImpl<ArrayDataType::kUint8> {
265 typedef uint8 Type;
266};
267template <>
268struct DataTypeImpl<ArrayDataType::kInt16> {
269 typedef int16 Type;
270};
271template <>
272struct DataTypeImpl<ArrayDataType::kUint16> {
273 typedef uint16 Type;
274};
275template <>
276struct DataTypeImpl<ArrayDataType::kInt32> {
277 typedef int32 Type;
278};
279template <>
280struct DataTypeImpl<ArrayDataType::kUint32> {
281 typedef uint32 Type;
282};
283template <>
284struct DataTypeImpl<ArrayDataType::kInt64> {
285 typedef int64_t Type;
286};
287template <>
288struct DataTypeImpl<ArrayDataType::kUint64> {
289 typedef uint64 Type;
290};
291template <>
292struct DataTypeImpl<ArrayDataType::kString> {
293 typedef std::string Type;
294};
295template <>
296struct DataTypeImpl<ArrayDataType::kComplex64> {
297 typedef std::complex<float> Type;
298};
299
300template <ArrayDataType A>
301using DataType = typename DataTypeImpl<A>::Type;
302
303// Base class for type-specific buffer types.
304struct GenericBuffer {
305 // Non-default-constructible: only ArrayDataType-specific subclass
306 // objects may be constructed.
307 GenericBuffer() = delete;
308 // Non-copyable-or-movable: we should only store pointers-to-Buffer
309 // in containers, not Operators themselves, so there should be no
310 // copy or move.
311 GenericBuffer(const GenericBuffer&) = delete;
312 GenericBuffer(const GenericBuffer&&) = delete;
313
314 // We need a virtual destructor so we can store pointers-to-Buffer
315 // in containers and have the containers call the right subclass destructor.
316 virtual ~GenericBuffer() {}
317
318 virtual int Length() const = 0;
319
320 const ArrayDataType type;
321
322 protected:
323 // Constructor used by subclasses for specific ArrayDataType's.
324 explicit GenericBuffer(ArrayDataType t) : type(t) {}
325};
326
327// Type-specific buffer, containing type-specific storage.
328template <ArrayDataType A>
329struct Buffer : GenericBuffer {
330 Buffer() : GenericBuffer(A) {}
331
332 int Length() const override { return data.size(); }
333
334 std::vector<DataType<A>> data;
335};
336
337class Shape {
338 public:
339 // For Shape, we stick to half-way encapsulation for now:
340 // we hide the raw dims_ member, but expose it raw by accessors
341 // because from some brainstorming, it's not at all easy to
342 // anticipate which flavor of more hermetic encapsulation would
343 // actually buy us future-proof-ness without being needlessly
344 // cumbersome.
345 Shape() {}
346 Shape(std::initializer_list<int> dim_list) : dims_(dim_list) {}
347
348 void ReplaceDims(std::initializer_list<int> dim_list) {
349 dims_ = std::vector<int>(dim_list);
350 }
351
352 const std::vector<int>& dims() const { return dims_; }
353 std::vector<int>* mutable_dims() { return &dims_; }
354 const int dimensions_count() const { return dims_.size(); }
355
356 // We still have that one convenience accessor to avoid
357 // the awkward double bracket issue: shape.dims()[i].
358 int dims(int i) const {
359 // Always check for out-of-bounds accesses, even in optimized builds where
360 // standard assertions are disabled. Out-of-bounds access here is a common
361 // occurrence.
362 CHECK_GE(i, 0);
363 CHECK_GT(dims_.size(), i);
364 return dims_[i];
365 }
366
367 bool operator==(const Shape& comp) const {
368 return (this->dims_ == comp.dims());
369 }
370
371 bool operator!=(const Shape& comp) const { return !((*this) == comp); }
372
373 private:
374 std::vector<int> dims_;
375};
376
377// Base class for all operator classes.
378struct Operator {
379 // Non-default-constructible: only OperatorType-specific subclass
380 // objects may be constructed.
381 Operator() = delete;
382 // Non-copyable-or-movable: we should only store pointers-to-Operator
383 // in containers, not Operators themselves, so there should be no
384 // copy or move.
385 Operator(const Operator&) = delete;
386 Operator(const Operator&&) = delete;
387
388 // We need a virtual destructor so we can store pointers-to-Operator
389 // in containers and have the containers call the right subclass destructor.
390 virtual ~Operator() {}
391
392 // The specific type of operator. Corresponds 1:1 to subclasses.
393 const OperatorType type;
394
395 // The activation function that may be fused into this operator,
396 // or None if no activation function is fused.
397 FusedActivationFunctionType fused_activation_function;
398
399 // Input arrays: either activation arrays or constant array parameters.
400 // We refer to them by their name, not by their address; the mapping of
401 // names to addresses is given by the Model, which owns both Operator's and
402 // Array's. Thus, an Operator on its own doesn't contain much information,
403 // it is meant to be used in conjunction with the Model that owns it.
404 std::vector<std::string> inputs;
405
406 // Output activation arrays. Same comments as for inputs apply here too.
407 std::vector<std::string> outputs;
408
409 // If true, the operator has more outputs than are listed in the 'outputs'
410 // member. These need to be resolved by some graph transformation.
411 // This flag is only here to indicate that an operator should not be
412 // discarded as unused, even if from its 'outputs' member alone it
413 // looks unused.
414 bool unresolved_outputs = false;
415
416 // A serialized tensorflow::NodeDef string.
417 // The field is filled only when importing from TensorFlow.
418 // It's guaranteed to be filled for `TensorFlowUnsupportedOperator`.
419 // It's not guaranteed to be filled for other ops. Ops created by graph
420 // transformations won't have TensorFlow NodeDef.
421 std::string tensorflow_node_def;
422
423 protected:
424 // Constructor used by subclasses for specific OperatorType's.
425 explicit Operator(OperatorType t)
426 : type(t),
427 fused_activation_function(FusedActivationFunctionType::kNone) {}
428};
429
430// Padding types for Conv-like operators. This is how padding is typically
431// specified in model files. But for inference, we will need to resolve this
432// to a FixedPadding, see below.
433enum class PaddingType { kNone, kSame, kValid };
434
435// Padding as resolved for a specific layer shape, as needed for inference.
436// For a given layer shape, a given padding type will resolve to a choice of
437// a number of padding rows and columns, which we call the padding height and
438// width respectively.
439struct FixedPadding {
440 int width = 0;
441 int height = 0;
442};
443
444// "Universal" padding struct containing both a generic PaddingType (as
445// represented in a model file), and a FixedPadding (as needed for inference).
446// The latter is resolved during the PropagateFixedSizes pass.
447struct Padding {
448 FixedPadding& GetOrCreateFixedPadding() {
449 if (!fixed) {
450 FixedPadding* ptr = new FixedPadding;
451 fixed = std::unique_ptr<FixedPadding>(ptr);
452 }
453 return *fixed;
454 }
455
456 Padding() : type(PaddingType::kNone) {}
457 PaddingType type;
458 std::unique_ptr<FixedPadding> fixed;
459};
460
461// "Convolutional" layer, as represented in model files.
462//
463// Inputs:
464// inputs[0]: required: the input activations array
465// inputs[1]: required: the Conv weights
466// inputs[2]: optional: the bias vector, specifying the biases for each output
467// channel.
468//
469// Outputs:
470// outputs[0]: required: the output activations array
471// outputs[1]: optional: the intermediate array of im2col-replicated input
472// activations. Present when targeting implementations
473// of Conv layers as Im2col+GEMM.
474//
475// TensorFlow equivalent: Conv2D
476struct ConvOperator : Operator {
477 ConvOperator() : Operator(OperatorType::kConv) {}
478 Padding padding;
479 int stride_width = 0;
480 int stride_height = 0;
481 // A dilation_rate of 0 is invalid and this field is an optional attribute.
482 // Thus initializing it to 1 to allow default conv behavior when the
483 // attribute is not present.
484 int dilation_width_factor = 1;
485 int dilation_height_factor = 1;
486};
487
488// CTCBeamSearchDecoder operator:
489//
490// Inputs:
491// inputs[0]: required: the logits.
492// inputs[1]: required: sequence length.
493// inputs[2]: optional: beam width.
494// inputs[3]: optional: top paths.
495// inputs[4]: optional: merge repeated.
496//
497// Outputs:
498// outputs[0]: decoded.
499// outputs[1]: log probability.
500//
501// TensorFlow equivalent: CTCBeamSearchDecoder
502struct CTCBeamSearchDecoderOperator : Operator {
503 CTCBeamSearchDecoderOperator()
504 : Operator(OperatorType::kCTCBeamSearchDecoder) {}
505 int beam_width;
506 int top_paths;
507 bool merge_repeated = true;
508};
509
510// Depthwise-separable convolution operator.
511//
512// Inputs:
513// inputs[0]: required: the input activations array
514// inputs[1]: required: the DepthwiseConv weights
515// inputs[2]: optional: the bias vector, specifying the biases for each output
516// channel.
517//
518// TensorFlow equivalent: DepthwiseConv2dNative
519struct DepthwiseConvOperator : Operator {
520 DepthwiseConvOperator() : Operator(OperatorType::kDepthwiseConv) {}
521 Padding padding;
522 int stride_height = 0;
523 int stride_width = 0;
524 int depth_multiplier = 0;
525 // A dilation_rate of 0 is invalid and this field is an optional attribute.
526 // Thus initializing it to 1 to allow default conv behavior when the
527 // attribute is not present.
528 int dilation_width_factor = 1;
529 int dilation_height_factor = 1;
530};
531
532// Depth-to-space transform operator.
533//
534// Inputs:
535// inputs[0]: required: the input activations array
536//
537// TensorFlow equivalent: DepthToSpace
538struct DepthToSpaceOperator : Operator {
539 DepthToSpaceOperator() : Operator(OperatorType::kDepthToSpace) {}
540 int block_size = 0;
541};
542
543// Space-to-depth transform operator.
544//
545// Inputs:
546// inputs[0]: required: the input activations array
547//
548// TensorFlow equivalent: SpaceToDepth
549struct SpaceToDepthOperator : Operator {
550 SpaceToDepthOperator() : Operator(OperatorType::kSpaceToDepth) {}
551 int block_size = 0;
552};
553
554// Fully-connected operator.
555//
556// Inputs:
557// inputs[0]: required: the input activations array
558// inputs[1]: required: the FullyConnected weights
559// inputs[2]: optional: the bias vector, specifying the biases for each output
560// channel.
561//
562// TensorFlow equivalent: a pair consisting of a Reshape node reshaping the
563// input activations as a matrix, followed by a MatMul node.
564struct FullyConnectedOperator : Operator {
565 FullyConnectedOperator() : Operator(OperatorType::kFullyConnected) {}
566 FullyConnectedWeightsFormat weights_format =
567 FullyConnectedWeightsFormat::kDefault;
568
569 // `keep_num_dims` is supported in the FullyConnected kernel version 5, but
570 // it's never supported by Toco.
571 bool keep_num_dims = false;
572};
573
574// Dequantization operator, converting a quantized array of integers with
575// quantization parameters specifying how these integers correspond to real
576// numbers
577// (see QuantizationParams) to an output activations array of floating-point
578// values.
579//
580// In floating-point image models, there is typically a Dequantization operator
581// at the very beginning, converting the input image RGB data, consisting of
582// uint8 integer values, to floating-point input activations. That is where
583// image model parameters such as "mean_value" and "std_value" are typically
584// handled.
585//
586// This is the only operator type that converts from quantized to
587// floating-point,
588// and there is at the moment no operator type at all to convert from
589// floating-point
590// to quantized. Every other operator does either float->float or
591// quantized->quantized.
592//
593// Inputs:
594// inputs[0]: required: the input quantized activations array
595//
596// TensorFlow equivalent: Dequantize
597struct DequantizeOperator : Operator {
598 DequantizeOperator() : Operator(OperatorType::kDequantize) {}
599};
600
601// Numeric verification operator, converting a quantized array of integers with
602// quantization parameters specifying how these integers correspond to real
603// numbers
604// (see QuantizationParams) and verify them with an array of floating-point
605// values.
606
607// Inputs:
608// inputs[0]: required: the input quantized activations array
609// inputs[1]: required: the input reference activations array
610//
611// TensorFlow equivalent: Dequantize
612struct NumericVerifyOperator : Operator {
613 NumericVerifyOperator() : Operator(OperatorType::kNumericVerify) {}
614};
615
616// Batch-normalization operator.
617//
618// We only support batch-normalization using pre-learned moments, so this is
619// just
620// computing (input - mean) * multiplier + offset. As such, this can be
621// expressed as a combination of Add and Mul nodes, and indeed this is how
622// we break it down during tooling for the purpose of fusing it into
623// other operators.
624//
625// Inputs:
626// inputs[0]: required: the input activations array
627// inputs[1]: required: the learned mean array
628// inputs[2]: required: the learned multiplier array
629// inputs[3]: required: the learned offset array
630//
631// TensorFlow equivalent: a combination of Add and Mul nodes
632struct BatchNormalizationOperator : Operator {
633 BatchNormalizationOperator()
634 : Operator(OperatorType::kBatchNormalization),
635 global_normalization(false) {}
636 bool global_normalization;
637};
638
639// L2-normalization operator.
640//
641// Inputs:
642// inputs[0]: required: the input activations array
643//
644// TensorFlow equivalent: none. In TensorFlow, L2 normalization is implemented
645// by a sub-graph of operators implementing L2-normalization
646// from lower-level arithmetic nodes; during tooling, we identify such
647// sub-graphs
648// and replace them by L2NormalizationOperator's. See IdentifyL2Normalization.
649struct L2NormalizationOperator : Operator {
650 L2NormalizationOperator() : Operator(OperatorType::kL2Normalization) {}
651};
652
653// LSTM Cell operator.
654//
655// Inputs:
656// inputs[0]: required: the input data array
657// inputs[1]: required: the previous output activations array
658// inputs[2]: required: the learned weights array
659// inputs[3]: required: the learned biases array
660// inputs[4]: required: the previous output state
661// outputs[0]: required: the output activations array
662// outputs[1]: required: the new state array
663//
664// TensorFlow equivalent: none. In TensorFlow, an LSTM is implemented
665// with a sub-graph of lower-level arithmetic nodes; during tooling, we identify
666// such sub-graphs and replace them with LstmCells. See IdentifyLstmCell().
667struct LstmCellOperator : Operator {
668 enum Inputs {
669 DATA_INPUT = 0,
670 PREV_ACTIV_INPUT = 1,
671 WEIGHTS_INPUT = 2,
672 BIASES_INPUT = 3,
673 PREV_STATE_INPUT = 4,
674 NUM_INPUTS = 5
675 };
676 enum Outputs {
677 ACTIV_OUTPUT = 0,
678 STATE_OUTPUT = 1,
679 CONCAT_TEMP = 2,
680 ACTIV_TEMP = 3,
681 NUM_OUTPUTS = 4
682 };
683 enum KernelType {
684 KERNEL_BASIC = 0,
685 KERNEL_FULL = 1,
686 };
687
688 LstmCellOperator()
689 : Operator(OperatorType::kLstmCell), kernel_type(KERNEL_BASIC) {}
690
691 KernelType kernel_type;
692};
693
694struct UnidirectionalSequenceLstmOperator : Operator {
695 UnidirectionalSequenceLstmOperator()
696 : Operator(OperatorType::kUnidirectionalSequenceLstm) {}
697};
698
699struct BidirectionalSequenceLstmOperator : Operator {
700 BidirectionalSequenceLstmOperator()
701 : Operator(OperatorType::kBidirectionalSequenceLstm) {}
702 bool merge_outputs;
703};
704
705struct BidirectionalSequenceRnnOperator : Operator {
706 BidirectionalSequenceRnnOperator()
707 : Operator(OperatorType::kBidirectionalSequenceRnn) {}
708 bool merge_outputs;
709};
710
711// Element-wise multiplication operator.
712//
713// Inputs:
714// inputs[0]: required: the left-hand side array
715// inputs[1]: required: the right-hand side array
716//
717// TensorFlow equivalent: Mul
718struct MulOperator : Operator {
719 MulOperator() : Operator(OperatorType::kMul) {}
720};
721
722// Element-wise Abs operator:
723// x -> abs(x)
724//
725// Inputs:
726// inputs[0]: required: the input array
727//
728// TensorFlow equivalent: abs
729struct AbsOperator : Operator {
730 AbsOperator() : Operator(OperatorType::kAbs) {}
731};
732
733// Element-wise HardSwish operator:
734// x -> x * relu6(x+3)/6
735//
736// Inputs:
737// inputs[0]: required: the input array
738//
739// TensorFlow equivalent: hard_swish
740struct HardSwishOperator : Operator {
741 HardSwishOperator() : Operator(OperatorType::kHardSwish) {}
742};
743
744// Elu
745// f(x) -> exp(x) - 1 for x < 0, x for x >= 0.
746//
747// Inputs:
748// inputs[0]: required: the input array
749//
750// TensorFlow equivalent: Elu
751struct EluOperator : Operator {
752 EluOperator() : Operator(OperatorType::kElu) {}
753};
754
755// Element-wise Relu operator:
756// x -> max(0, x)
757//
758// Inputs:
759// inputs[0]: required: the input array
760//
761// TensorFlow equivalent: Relu
762struct ReluOperator : Operator {
763 ReluOperator() : Operator(OperatorType::kRelu) {}
764};
765
766// Element-wise Relu1 operator:
767// x -> min(max(x, -1), 1)
768//
769// Inputs:
770// inputs[0]: required: the input array
771//
772// TensorFlow equivalent: none. We can construct the operator with Minimum
773// and Maximum operations
774struct Relu1Operator : Operator {
775 Relu1Operator() : Operator(OperatorType::kRelu1) {}
776};
777
778// Element-wise Relu6 operator:
779// x -> max(0, min(6, x))
780//
781// Inputs:
782// inputs[0]: required: the input array
783//
784// TensorFlow equivalent: Relu6
785struct Relu6Operator : Operator {
786 Relu6Operator() : Operator(OperatorType::kRelu6) {}
787};
788
789// PRelu
790// f(x) = alpha * x for x < 0, f(x) = x for x >= 0.
791//
792// Inputs:
793// inputs[0]: required: the input array
794// inputs[1]: required: the alpha array
795//
796// Equivalent to keras.layers.PReLU.
797struct PReluOperator : Operator {
798 PReluOperator() : Operator(OperatorType::kPRelu) {}
799};
800
801// LeakyRelu
802// x -> max(x, alpha * x)
803//
804// Inputs:
805// inputs[0]: required: the input array
806//
807// TensorFlow equivalent: LeakyRelu
808struct LeakyReluOperator : Operator {
809 LeakyReluOperator() : Operator(OperatorType::kLeakyRelu) {}
810
811 float alpha = 0.2f; // 0.2 matches the default value for the TF op attribute.
812};
813
814// Element-wise Logistic operator:
815// x -> Logistic(x) = 1 / (1 + exp(-x))
816//
817// Inputs:
818// inputs[0]: required: the input array
819//
820// TensorFlow equivalent: Sigmoid
821struct LogisticOperator : Operator {
822 LogisticOperator() : Operator(OperatorType::kLogistic) {}
823};
824
825// Element-wise natural log operator:
826// x -> ln(x)
827//
828// Inputs:
829// inputs[0]: required: the input array
830//
831// TensorFlow equivalent: Log
832struct LogOperator : Operator {
833 LogOperator() : Operator(OperatorType::kLog) {}
834};
835
836// Element-wise Tanh operator:
837// x -> Tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
838//
839// Inputs:
840// inputs[0]: required: the input array
841//
842// TensorFlow equivalent: Tanh
843struct TanhOperator : Operator {
844 TanhOperator() : Operator(OperatorType::kTanh) {}
845};
846
847// Element-wise Sin operator:
848// x -> Sin(x) = sin(x)
849//
850// Inputs:
851// inputs[0]: required: the input array
852//
853// TensorFlow equivalent: Sin
854struct SinOperator : Operator {
855 SinOperator() : Operator(OperatorType::kSin) {}
856};
857
858// Element-wise addition operator.
859//
860// Inputs:
861// inputs[0]: required: the left-hand side array
862// inputs[1]: required: the right-hand side array
863//
864// TensorFlow equivalent: Add
865struct AddOperator : Operator {
866 AddOperator() : Operator(OperatorType::kAdd) {}
867};
868
869// Element-wise addition operator for N inputs.
870//
871// Inputs:
872// inputs[i]: The i-th array to add together to form the output.
873//
874// TensorFlow equivalent: AddN
875struct AddNOperator : Operator {
876 AddNOperator() : Operator(OperatorType::kAddN) {}
877};
878
879// Concatenation operator: concatenates its inputs
880// along the axis.
881//
882// Inputs: this operator accepts any number >= 1 of inputs.
883// inputs[i]: the i-th array to concatenate.
884//
885// TensorFlow equivalent: Concat.
886struct ConcatenationOperator : Operator {
887 ConcatenationOperator() : Operator(OperatorType::kConcatenation) {}
888 int axis = 0;
889};
890
891// Reordering dimensions. Used only during tooling to transform graphs from
892// the TensorFlow format.
893//
894// Inputs:
895// inputs[0]: required: the input array
896//
897// TensorFlow equivalent: none. This is only useful to convert between formats.
898struct ReorderAxesOperator : Operator {
899 ReorderAxesOperator() : Operator(OperatorType::kReorderAxes) {}
900 AxesOrder input_axes_order;
901 AxesOrder output_axes_order;
902};
903
904// Average-pooling operator.
905//
906// Inputs:
907// inputs[0]: required: the input array
908//
909// TensorFlow equivalent: AveragePool
910struct AveragePoolOperator : Operator {
911 AveragePoolOperator() : Operator(OperatorType::kAveragePool) {}
912 Padding padding;
913 int stride_height = 0;
914 int stride_width = 0;
915 int kheight = 0;
916 int kwidth = 0;
917};
918
919// Local response normalization operator.
920//
921// Inputs:
922// inputs[0]: required: the input array
923//
924// TensorFlow equivalent: LRN
925struct LocalResponseNormalizationOperator : Operator {
926 LocalResponseNormalizationOperator()
927 : Operator(OperatorType::kLocalResponseNormalization) {}
928
929 int range = 0;
930 float bias = 0.f;
931 float alpha = 0.f;
932 float beta = 0.f;
933};
934
935// Max-pooling operator.
936//
937// Inputs:
938// inputs[0]: required: the input array
939//
940// TensorFlow equivalent: MaxPool
941struct MaxPoolOperator : Operator {
942 MaxPoolOperator() : Operator(OperatorType::kMaxPool) {}
943 Padding padding;
944 int stride_height = 0;
945 int stride_width = 0;
946 int kheight = 0;
947 int kwidth = 0;
948};
949
950// L2-pooling operator.
951//
952// Inputs:
953// inputs[0]: required: the input array
954//
955// TensorFlow equivalent: none. Can be shimmed by squaring+avgpool+sqrt.
956struct L2PoolOperator : Operator {
957 L2PoolOperator() : Operator(OperatorType::kL2Pool) {}
958 Padding padding;
959 int stride_height = 0;
960 int stride_width = 0;
961 int kheight = 0;
962 int kwidth = 0;
963};
964
965// The expected [min, max] range of values in a given array.
966// Used for quantization only.
967// This information typically comes from special nodes found in quantized
968// models, see FakeQuantOperator, and is used during quantization to resolve
969// actual quantization parameters (see QuantizationParams).
970struct MinMax {
971 double min = 0.;
972 double max = 0.;
973};
974
975inline bool operator==(const MinMax& m1, const MinMax& m2) {
976 return m1.min == m2.min && m1.max == m2.max;
977}
978
979inline bool operator!=(const MinMax& m1, const MinMax& m2) {
980 return m1.min != m2.min || m1.max != m2.max;
981}
982
983// Fake-quantization operator. This does two things:
984// - Annotate its input and output arrays with MinMax information,
985// - Arithmetic-wise, this operator rounds incoming activation values
986// to the nearest representable value on the scale of 256
987// values from the min to the max value dictated by its MinMax info.
988//
989// Inputs:
990// inputs[0]: required: the input array
991// inputs[1]: optional: the 'min' value, if it has not yet been resolved
992// to a constant.
993// inputs[2]: optional: the 'max' value, if it has not yet been resolved
994// to a constant.
995//
996// TensorFlow equivalent: FakeQuantWithMinMaxVars, FakeQuantWithMinMaxArgs.
997struct FakeQuantOperator : Operator {
998 FakeQuantOperator() : Operator(OperatorType::kFakeQuant) {}
999 std::unique_ptr<MinMax> minmax;
1000 int num_bits = 8;
1001 bool narrow_range = false;
1002};
1003
1004// Element-wise division operator.
1005//
1006// Inputs:
1007// inputs[0]: required: the left-hand side array
1008// inputs[1]: required: the right-hand side array
1009//
1010// TensorFlow equivalent: Div
1011struct DivOperator : Operator {
1012 DivOperator() : Operator(OperatorType::kDiv) {}
1013};
1014
1015// Element-wise identity (x->x) operator.
1016//
1017// Inputs:
1018// inputs[0]: required: the input array
1019//
1020// TensorFlow equivalent: Identity
1021struct TensorFlowIdentityOperator : Operator {
1022 TensorFlowIdentityOperator() : Operator(OperatorType::kIdentity) {}
1023};
1024
1025// Batch matrix multiplication operator. This comes from a tf.matmul where one
1026// of the operands has rank 3 or more.
1027//
1028// Inputs:
1029// inputs[0]: required: the left-hand side matrix
1030// inputs[1]: required: the right-hand side matrix
1031//
1032// TensorFlow equivalent: MatMul
1033struct BatchMatMulOperator : Operator {
1034 BatchMatMulOperator() : Operator(OperatorType::kBatchMatMul) {}
1035 bool adj_x = false;
1036 bool adj_y = false;
1037};
1038
1039// General matrix multiplication operator. We don't want to support general
1040// matrix multiplication at inference time, so we resolve it during tooling
1041// to more specific operator types, namely, FullyConnected.
1042//
1043// Inputs:
1044// inputs[0]: required: the left-hand side matrix
1045// inputs[1]: required: the right-hand side matrix
1046//
1047// TensorFlow equivalent: MatMul
1048struct TensorFlowMatMulOperator : Operator {
1049 TensorFlowMatMulOperator() : Operator(OperatorType::kMatMul) {}
1050 bool transpose_a = false;
1051 bool transpose_b = false;
1052};
1053
1054// Padding operator. Pads a tensor with zeros.
1055//
1056// Inputs:
1057// inputs[0]: required: the input array
1058// inputs[1]: required: the padding array
1059//
1060// This operation pads a `input` with zeros according to the `paddings` you
1061// specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the
1062// rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates
1063// how many zeros to add before the contents of `input` in that dimension, and
1064// `paddings[D, 1]` indicates how many zeros to add after the contents of
1065// `input` in that dimension.
1066//
1067// TensorFlow equivalent: Pad
1068struct PadOperator : Operator {
1069 PadOperator() : Operator(OperatorType::kPad) {}
1070
1071 std::vector<int> left_padding;
1072 std::vector<int> right_padding;
1073};
1074
1075// PaddingV2 operator. Pads a tensor with the given constant value.
1076//
1077// Inputs:
1078// inputs[0]: required: the input array
1079// inputs[1]: required: the padding array
1080// inputs[2]: required: the scalar constant_values
1081//
1082// This operation pads input according to the paddings and constant_values you
1083// specify. paddings is an integer tensor with shape [Dn, 2], where n is the
1084// rank of input. For each dimension D of input, paddings[D, 0] indicates how
1085// many padding values to add before the contents of input in that dimension,
1086// and paddings[D, 1] indicates how many padding values to add after the
1087// contents of input in that dimension. constant_values is a scalar tensor of
1088// the same type as input that indicates the value to use for padding input.
1089//
1090// TensorFlow equivalent: PadV2
1091struct PadV2Operator : Operator {
1092 PadV2Operator() : Operator(OperatorType::kPadV2) {}
1093
1094 std::vector<int> left_padding;
1095 std::vector<int> right_padding;
1096};
1097
1098// Strided slice operator.
1099//
1100// Inputs:
1101// inputs[0]: required: the input array
1102// inputs[1]: required: the begin array
1103// inputs[2]: required: the end array
1104// inputs[3]: optional: the strides array
1105//
1106// TensorFlow equivalent: StridedSlice
1107struct StridedSliceOperator : Operator {
1108 StridedSliceOperator() : Operator(OperatorType::kStridedSlice) {}
1109
1110 std::vector<int> start_indices;
1111 std::vector<int> stop_indices;
1112 std::vector<int> strides;
1113
1114 int begin_mask;
1115 int ellipsis_mask;
1116 int end_mask;
1117 int new_axis_mask;
1118 int shrink_axis_mask;
1119
1120 StridedSliceOperator(const StridedSliceOperator& other)
1121 : Operator(OperatorType::kStridedSlice) {
1122 inputs = other.inputs;
1123 outputs = other.outputs;
1124
1125 start_indices = other.start_indices;
1126 stop_indices = other.stop_indices;
1127 strides = other.strides;
1128
1129 begin_mask = other.begin_mask;
1130 ellipsis_mask = other.ellipsis_mask;
1131 end_mask = other.end_mask;
1132 new_axis_mask = other.new_axis_mask;
1133 shrink_axis_mask = other.shrink_axis_mask;
1134 }
1135
1136 void PadIndices(int dim_count) {
1137 // Add indices and mask bits to fully include extra dimensions
1138 CHECK_GE(dim_count, start_indices.size());
1139 CHECK_EQ(start_indices.size(), stop_indices.size());
1140 CHECK_EQ(stop_indices.size(), strides.size());
1141
1142 for (int i = start_indices.size(); i < dim_count; i++) {
1143 start_indices.push_back(0);
1144 stop_indices.push_back(0);
1145 strides.push_back(1);
1146 begin_mask |= 1 << i;
1147 end_mask |= 1 << i;
1148 }
1149 }
1150
1151 void ReverseIndices() {
1152 CHECK_EQ(start_indices.size(), stop_indices.size());
1153 CHECK_EQ(stop_indices.size(), strides.size());
1154
1155 std::reverse(start_indices.begin(), start_indices.end());
1156 std::reverse(stop_indices.begin(), stop_indices.end());
1157 std::reverse(strides.begin(), strides.end());
1158
1159 begin_mask = toco::port::ReverseBits32(static_cast<uint32>(begin_mask)) >>
1160 (32 - start_indices.size());
1161 ellipsis_mask =
1162 toco::port::ReverseBits32(static_cast<uint32>(ellipsis_mask)) >>
1163 (32 - start_indices.size());
1164 end_mask = toco::port::ReverseBits32(static_cast<uint32>(end_mask)) >>
1165 (32 - start_indices.size());
1166 new_axis_mask =
1167 toco::port::ReverseBits32(static_cast<uint32>(new_axis_mask)) >>
1168 (32 - start_indices.size());
1169 shrink_axis_mask =
1170 toco::port::ReverseBits32(static_cast<uint32>(shrink_axis_mask)) >>
1171 (32 - start_indices.size());
1172 }
1173};
1174
1175// Reshaping operator, reshaping its input array to a two-dimensional shape
1176// (a "matrix"). This is used in the TensorFlow format, in conjunction with
1177// MatMul nodes, to implement fully-connected layers.
1178//
1179// Inputs:
1180// inputs[0]: required: the input array
1181// inputs[1]: optional: the output tensor shape
1182//
1183// TensorFlow equivalent: Reshape --- except that we only support a special case
1184// here, where the output shape is a matrix (2D) shape.
1185struct TensorFlowReshapeOperator : Operator {
1186 TensorFlowReshapeOperator() : Operator(OperatorType::kReshape) {}
1187 std::vector<int> shape;
1188};
1189
1190// Removes dimensions of size 1 from the shape of a tensor.
1191// https://www.tensorflow.org/api_docs/python/tf/squeeze
1192//
1193// Inputs:
1194// inputs[0]: required: the input array
1195//
1196// TensorFlow equivalent: Squeeze
1197struct SqueezeOperator : Operator {
1198 SqueezeOperator() : Operator(OperatorType::kSqueeze) {}
1199
1200 std::vector<int> squeeze_dims;
1201};
1202
1203// Inputs:
1204// inputs[0]: required: the output shape
1205// inputs[1]: required: the weights
1206// inputs[2]: required: the input activations array
1207// inputs[3]: optional: the bias vector, specifying the biases for each output
1208// channel.
1209// NOTE: The input activations is NOT the first input.
1210//
1211//
1212// Outputs:
1213// outputs[0]: required: the output activations array
1214//
1215// TensorFlow equivalent: Conv2DBackpropInput
1216struct TransposeConvOperator : Operator {
1217 enum Inputs {
1218 OUTPUT_SHAPE = 0,
1219 WEIGHTS = 1,
1220 DATA_INPUT = 2,
1221 BIAS = 3,
1222 };
1223
1224 TransposeConvOperator() : Operator(OperatorType::kTransposeConv) {}
1225 Padding padding;
1226 int stride_width = 0;
1227 int stride_height = 0;
1228 // Dilation is possible with transpose convolution, but Tensorflow does not
1229 // currently support it, so we omit it.
1230};
1231
1232// Given a tensor input, this operation calculates element-wise exponential
1233// (y = e^x).
1234//
1235// Inputs:
1236// inputs[0]: required: input tensor
1237//
1238// TensorFlow equivalent: Exp
1239struct ExpOperator : Operator {
1240 ExpOperator() : Operator(OperatorType::kExp) {}
1241};
1242
1243// Given a tensor input, this operation calculates element-wise exponential
1244// (y = cos(x)).
1245//
1246// Inputs:
1247// inputs[0]: required: input tensor
1248//
1249// TensorFlow equivalent: Cos
1250struct CosOperator : Operator {
1251 CosOperator() : Operator(OperatorType::kCos) {}
1252};
1253
1254// Given a tensor input, this operation inserts a dimension of 1 at the
1255// dimension index axis of input's shape. The dimension index axis starts at
1256// zero; if you specify a negative number for axis it is counted backward from
1257// the end.
1258//
1259// Inputs:
1260// inputs[0]: required: input tensor
1261// inputs[1]: required: 0-D (scalar). Specifies the dimension index at which
1262// to expand the shape of input
1263//
1264// TensorFlow equivalent: ExpandDims
1265struct ExpandDimsOperator : Operator {
1266 ExpandDimsOperator() : Operator(OperatorType::kExpandDims) {}
1267};
1268
1269// Creates a tensor of shape dims and fills it with the given scalar value.
1270// Output type will be the same as the given scalar value.
1271//
1272// Inputs:
1273// inputs[0]: required: 1-D (int32) - the shape of the output tensor
1274// inputs[1]: required: 0-D (scalar) - value to fill the tensor with
1275//
1276// TensorFlow equivalent: Fill
1277struct FillOperator : Operator {
1278 FillOperator() : Operator(OperatorType::kFill) {}
1279};
1280
1281// Element-wise floor division operator.
1282//
1283// Inputs:
1284// inputs[0]: required: the left-hand side array
1285// inputs[1]: required: the right-hand side array
1286//
1287// TensorFlow equivalent: FloorDiv
1288struct FloorDivOperator : Operator {
1289 FloorDivOperator() : Operator(OperatorType::kFloorDiv) {}
1290};
1291
1292// Element-wise floor mod operator.
1293//
1294// Inputs:
1295// inputs[0]: required: the left-hand side array
1296// inputs[1]: required: the right-hand side array
1297//
1298// TensorFlow equivalent: FloorMod
1299struct FloorModOperator : Operator {
1300 FloorModOperator() : Operator(OperatorType::kFloorMod) {}
1301};
1302
1303struct RandomUniformOperator : Operator {
1304 RandomUniformOperator() : Operator(OperatorType::kRandomUniform) {}
1305 ArrayDataType dtype = ArrayDataType::kNone;
1306 int64_t seed;
1307 int64_t seed2;
1308};
1309
1310// Creates a sequence of numbers that begins at start and extends by increments
1311// of delta up to but not including limit.
1312//
1313// The dtype of the resulting tensor is inferred from the inputs unless it is
1314// provided explicitly.
1315//
1316// Inputs:
1317// inputs[0]: required: the start
1318// inputs[1]: required: the limit
1319// inputs[2]: required: the delta
1320//
1321// TensorFlow equivalent: Range
1322struct RangeOperator : Operator {
1323 RangeOperator() : Operator(OperatorType::kRange) {}
1324 ArrayDataType dtype = ArrayDataType::kNone;
1325};
1326
1327// Rank operator. Extracts the rank of the tensor.
1328//
1329// Inputs:
1330// inputs[0]: required: the input array
1331//
1332// This operation outputs a 0-D int32 Tensor representing the rank of input.
1333//
1334// TensorFlow equivalent: Rank.
1335struct TensorFlowRankOperator : Operator {
1336 TensorFlowRankOperator() : Operator(OperatorType::kRank) {}
1337 ArrayDataType output_data_type = ArrayDataType::kInt32;
1338};
1339
1340// Element-wise negation (-x) operator.
1341//
1342// Inputs:
1343// inputs[0]: required: the input array
1344//
1345// TensorFlow equivalent: Neg
1346struct NegOperator : Operator {
1347 NegOperator() : Operator(OperatorType::kNeg) {}
1348};
1349
1350// Element-wise select operator choosing elements from inputs[1] or input[2]
1351//
1352// Inputs:
1353// inputs[0]: required: boolean mask per index
1354// inputs[1]: required: tensor of values if true
1355// inputs[2]: required: tensor of values if false
1356//
1357// TensorFlow equivalent: Select
1358struct SelectOperator : Operator {
1359 SelectOperator() : Operator(OperatorType::kSelect) {}
1360};
1361
1362// Element-wise reciprocal-square-root (x^-0.5) operator.
1363//
1364// Inputs:
1365// inputs[0]: required: the input array
1366//
1367// TensorFlow equivalent: Rsqrt
1368struct TensorFlowRsqrtOperator : Operator {
1369 TensorFlowRsqrtOperator() : Operator(OperatorType::kRsqrt) {}
1370};
1371
1372// Stacks a list of rank-R tensors into one rank-(R+1) tensor.
1373//
1374// Packs the list of tensors in values into a tensor with rank one higher than
1375// each tensor in values, by packing them along the axis dimension. Given a list
1376// of length N of tensors of shape (A, B, C);.
1377//
1378// Inputs: this operator accepts any number >= 1 of inputs.
1379// inputs[i]: the i-th array to merge.
1380//
1381// TensorFlow equivalent: Pack
1382struct PackOperator : Operator {
1383 PackOperator() : Operator(OperatorType::kPack) {}
1384 int values_count;
1385 int axis = 0;
1386 ArrayDataType dtype = ArrayDataType::kNone;
1387};
1388
1389// Shape operator. Extracts the shape of the tensor.
1390//
1391// Inputs:
1392// inputs[0]: required: the input array
1393//
1394// This operation outputs a 1-D integer tensor representing the shape of
1395// the input.
1396//
1397// TensorFlow equivalent: Shape.
1398struct TensorFlowShapeOperator : Operator {
1399 TensorFlowShapeOperator() : Operator(OperatorType::kShape) {}
1400 ArrayDataType output_data_type = ArrayDataType::kInt32;
1401};
1402
1403// Element-wise square-root (x^0.5) operator.
1404//
1405// Inputs:
1406// inputs[0]: required: the input array
1407//
1408// TensorFlow equivalent: Sqrt
1409struct TensorFlowSqrtOperator : Operator {
1410 TensorFlowSqrtOperator() : Operator(OperatorType::kSqrt) {}
1411};
1412
1413// Element-wise square (x*x) operator.
1414//
1415// Inputs:
1416// inputs[0]: required: the input array
1417//
1418// TensorFlow equivalent: Square
1419struct TensorFlowSquareOperator : Operator {
1420 TensorFlowSquareOperator() : Operator(OperatorType::kSquare) {}
1421};
1422
1423// Element-wise squared difference ((x-y)*(x-y)) operator.
1424//
1425// Inputs:
1426// inputs[0]: required: the left-hand side array
1427// inputs[1]: required: the right-hand side array
1428//
1429// TensorFlow equivalent: SquaredDifference
1430struct SquaredDifferenceOperator : Operator {
1431 SquaredDifferenceOperator() : Operator(OperatorType::kSquaredDifference) {}
1432};
1433
1434// Transposes a tensor.
1435//
1436// By default, this operation performs a regular matrix transpose on 2-D input
1437// tensors.
1438//
1439// Inputs:
1440// inputs[0]: required: the input array
1441//
1442// TensorFlow equivalent: Transpose
1443struct TransposeOperator : Operator {
1444 TransposeOperator() : Operator(OperatorType::kTranspose) {}
1445 std::vector<int> perm;
1446};
1447
1448// Element-wise subtraction operator.
1449//
1450// Inputs:
1451// inputs[0]: required: the left-hand side array
1452// inputs[1]: required: the right-hand side array
1453//
1454// TensorFlow equivalent: Sub
1455struct SubOperator : Operator {
1456 SubOperator() : Operator(OperatorType::kSub) {}
1457};
1458
1459// Sum reduction: computes the sum of all of entries across the axes.
1460//
1461// Inputs:
1462// inputs[0]: required: the input array
1463//
1464// TensorFlow equivalent: Sum
1465struct TensorFlowSumOperator : Operator {
1466 TensorFlowSumOperator() : Operator(OperatorType::kSum) {}
1467 std::vector<int> axis;
1468 bool keep_dims = false;
1469};
1470
1471// Prod reduction: computes the product of all of entries across the axes.
1472//
1473// Inputs:
1474// inputs[0]: required: the input array
1475//
1476// TensorFlow equivalent: Prod
1477struct TensorFlowProdOperator : Operator {
1478 TensorFlowProdOperator() : Operator(OperatorType::kReduceProd) {}
1479 std::vector<int> axis;
1480 bool keep_dims = false;
1481};
1482
1483// TensorFlow Tile equivalent. Refer to TensorFlow documentation for details.
1484//
1485// Inputs:
1486// inputs[0]: required: the input array
1487// inputs[1]: required: int array with length of rank(input[0])
1488struct TensorFlowTileOperator : Operator {
1489 TensorFlowTileOperator() : Operator(OperatorType::kTile) {}
1490};
1491
1492// TensorFlow Slice equivalent. Refer to TensorFlow documentation for details.
1493struct SliceOperator : Operator {
1494 SliceOperator() : Operator(OperatorType::kSlice) {}
1495
1496 std::vector<int> begin;
1497 std::vector<int> size;
1498};
1499
1500// TensorFlow Split equivalent. Refer to TensorFlow documentation for details.
1501// Not fully supported, just a placeholder to handle TensorFlow graphs and
1502// support graph transformations to other operator types by matching sub-graphs.
1503struct TensorFlowSplitOperator : Operator {
1504 TensorFlowSplitOperator() : Operator(OperatorType::kSplit) {}
1505 int num_split = 0;
1506};
1507
1508// TensorFlow SplitV equivalent. Refer to TensorFlow documentation for details.
1509struct TensorFlowSplitVOperator : Operator {
1510 TensorFlowSplitVOperator() : Operator(OperatorType::kSplitV) {}
1511 int num_split = 0;
1512};
1513
1514// TensorFlow Concat equivalent. Refer to TensorFlow documentation for details.
1515// Not fully supported, just a placeholder to handle TensorFlow graphs and
1516// support graph transformations to other operator types by matching sub-graphs.
1517// Concretely, once the concat dim becomes known, if it is the depth
1518// dimension then we can change this op into a DepthConcatenation op.
1519// Otherwise, we hope for some other graph transformation to drop this node.
1520struct TensorFlowConcatOperator : Operator {
1521 TensorFlowConcatOperator() : Operator(OperatorType::kConcat) {}
1522};
1523
1524// TensorFlow ConcatV2 equivalent. Refer to TensorFlow documentation for
1525// details.
1526// Not fully supported, just a placeholder to handle TensorFlow graphs and
1527// support graph transformations to other operator types by matching sub-graphs.
1528// Concretely, once the concat dim becomes known, if it is the depth
1529// dimension then we can change this op into a DepthConcatenation op.
1530// Otherwise, we hope for some other graph transformation to drop this node.
1531struct TensorFlowConcatV2Operator : Operator {
1532 TensorFlowConcatV2Operator() : Operator(OperatorType::kConcatV2) {}
1533};
1534
1535// TensorFlow Merge equivalent. Refer to TensorFlow documentation for details.
1536//
1537// Inputs: this operator accepts any number >= 1 of inputs.
1538// inputs[i]: the i-th array to merge.
1539//
1540// It is expected that graph transformations will drop all but exactly one
1541// of the inputs, at which point the Merge node will be equivalent to an
1542// Identity node forwarding the remaining input.
1543//
1544// Note: We do not currently support runtime control flow: we only support
1545// control flow that can be resolved at tooling time (independently of input
1546// activations).
1547struct TensorFlowMergeOperator : Operator {
1548 TensorFlowMergeOperator() : Operator(OperatorType::kMerge) {}
1549};
1550
1551// TensorFlow Switch equivalent. Refer to TensorFlow documentation for details.
1552//
1553// Inputs:
1554// inputs[0]: required: the input array
1555// inputs[1]: required: the boolean predicate, given as an array of size 1
1556// and of type kBool, will determine which output gets selected.
1557//
1558// Outputs: a TensorFlow Switch node always has exactly two outputs. Depending
1559// on the boolean value that the input predicate resolves to (see note below),
1560// one or the other of the outputs will be 'selected': the input array will be
1561// forwarded to the 'selected output' as if by a Identity node, while the other
1562// output will be discarded, and any graph edge connecting that discarded output
1563// will be dropped. The rule for selecting outputs is as follows:
1564// outputs[0] will be selected if the input predicate resolves to 'true'.
1565// outputs[1] will be selected if the input predicate resolves to 'false'.
1566//
1567// Note: We do not currently support runtime control flow: we only support
1568// control flow that can be resolved at tooling time (independently of input
1569// activations).
1570struct TensorFlowSwitchOperator : Operator {
1571 TensorFlowSwitchOperator() : Operator(OperatorType::kSwitch) {}
1572};
1573
1574// TensorFlow All equivalent. Refer to TensorFlow documentation for details.
1575// Not fully supported, just a placeholder to handle TensorFlow graphs and
1576// support graph transformations to other operator types by matching sub-graphs.
1577// Typically, this is only used as an input to an Assert node, so can be
1578// removed as an unused node as we drop Assert nodes.
1579struct TensorFlowAllOperator : Operator {
1580 TensorFlowAllOperator() : Operator(OperatorType::kAll) {}
1581};
1582
1583// TensorFlow Assert equivalent. Refer to TensorFlow documentation for details.
1584// Not fully supported, just a placeholder to handle TensorFlow graphs and
1585// support graph transformations to other operator types by matching sub-graphs.
1586// Typically, we just drop Assert nodes.
1587struct TensorFlowAssertOperator : Operator {
1588 TensorFlowAssertOperator() : Operator(OperatorType::kAssert) {}
1589};
1590
1591// TensorFlow Less equivalent. Refer to TensorFlow documentation for details.
1592// Not fully supported, just a placeholder to handle TensorFlow graphs and
1593// support graph transformations to other operator types by matching sub-graphs.
1594// Typically, this is only used as an input to an Assert node, so can be
1595// removed as an unused node as we drop Assert nodes.
1596struct TensorFlowLessOperator : Operator {
1597 TensorFlowLessOperator() : Operator(OperatorType::kLess) {}
1598};
1599
1600// TensorFlow LessEqual equivalent. Refer to TensorFlow documentation for
1601// details.
1602// Not fully supported, just a placeholder to handle TensorFlow graphs and
1603// support graph transformations to other operator types by matching sub-graphs.
1604// Typically, this is only used as an input to an Assert node, so can be
1605// removed as an unused node as we drop Assert nodes.
1606struct TensorFlowLessEqualOperator : Operator {
1607 TensorFlowLessEqualOperator() : Operator(OperatorType::kLessEqual) {}
1608};
1609
1610// TensorFlow Less equivalent. Refer to TensorFlow documentation for details.
1611// Not fully supported, just a placeholder to handle TensorFlow graphs and
1612// support graph transformations to other operator types by matching sub-graphs.
1613// Typically, this is only used as an input to an Assert node, so can be
1614// removed as an unused node as we drop Assert nodes.
1615struct TensorFlowGreaterOperator : Operator {
1616 TensorFlowGreaterOperator() : Operator(OperatorType::kGreater) {}
1617};
1618
1619// TensorFlow GreaterEqual equivalent. Refer to TensorFlow documentation for
1620// details.
1621// Not fully supported, just a placeholder to handle TensorFlow graphs and
1622// support graph transformations to other operator types by matching sub-graphs.
1623// Typically, this is only used as an input to an Assert node, so can be
1624// removed as an unused node as we drop Assert nodes.
1625struct TensorFlowGreaterEqualOperator : Operator {
1626 TensorFlowGreaterEqualOperator() : Operator(OperatorType::kGreaterEqual) {}
1627};
1628
1629// TensorFlow Equal equivalent. Refer to TensorFlow documentation for
1630// details.
1631// Not fully supported, just a placeholder to handle TensorFlow graphs and
1632// support graph transformations to other operator types by matching sub-graphs.
1633// Typically, this is only used as an input to an Assert node, so can be
1634// removed as an unused node as we drop Assert nodes.
1635struct TensorFlowEqualOperator : Operator {
1636 TensorFlowEqualOperator() : Operator(OperatorType::kEqual) {}
1637};
1638
1639// TensorFlow Not Equal equivalent. Refer to TensorFlow documentation for
1640// details.
1641struct TensorFlowNotEqualOperator : Operator {
1642 TensorFlowNotEqualOperator() : Operator(OperatorType::kNotEqual) {}
1643};
1644
1645// Max reduction: computes the max of all of entries across the axes.
1646//
1647// Inputs:
1648// inputs[0]: required: the input array
1649//
1650// TensorFlow equivalent: Max
1651struct TensorFlowMaxOperator : Operator {
1652 TensorFlowMaxOperator() : Operator(OperatorType::kReduceMax) {}
1653 std::vector<int> axis;
1654 bool keep_dims = false;
1655};
1656
1657// Min reduction: computes the min of all of entries across the axes.
1658//
1659// Inputs:
1660// inputs[0]: required: the input array
1661//
1662// TensorFlow equivalent: Min
1663struct TensorFlowMinOperator : Operator {
1664 TensorFlowMinOperator() : Operator(OperatorType::kReduceMin) {}
1665 std::vector<int> axis;
1666 bool keep_dims = false;
1667};
1668
1669// Element-wise maximum operator. Currently it only supports scalar as
1670// the second operand.
1671//
1672// Inputs:
1673// inputs[0]: required: the left-hand side array
1674// inputs[1]: required: the right-hand side array
1675//
1676// TensorFlow equivalent: Maximum
1677struct TensorFlowMaximumOperator : Operator {
1678 TensorFlowMaximumOperator() : Operator(OperatorType::kMaximum) {}
1679};
1680
1681// Element-wise minimum operator. Currently it only supports scalar as
1682// the second operand.
1683//
1684// Inputs:
1685// inputs[0]: required: the left-hand side array
1686// inputs[1]: required: the right-hand side array
1687//
1688// TensorFlow equivalent: Minimum
1689struct TensorFlowMinimumOperator : Operator {
1690 TensorFlowMinimumOperator() : Operator(OperatorType::kMinimum) {}
1691};
1692
1693// General TF operation, unsupported by tf.mini. Expected to be dropped by
1694// graph transformations.
1695struct TensorFlowUnsupportedOperator : Operator {
1696 TensorFlowUnsupportedOperator() : Operator(OperatorType::kUnsupported) {}
1697
1698 // The original TF operation type. Used for diagnostic purposes.
1699 std::string tensorflow_op;
1700 // A boolean indicating if the unsupported op should be treated as quantized.
1701 bool quantized = false;
1702 // A boolean indicating if the unsupported op output should allow float values
1703 // in quantized mode.
1704 bool support_output_type_float_in_quantized_op = false;
1705 // Output data types
1706 std::vector<ArrayDataType> output_data_types;
1707 // Output shapes.
1708 std::vector<Shape> output_shapes;
1709};
1710
1711// Softmax activation function.
1712//
1713// Inputs:
1714// inputs[0]: required: the input array
1715//
1716// TensorFlow equivalent: Softmax
1717struct SoftmaxOperator : Operator {
1718 SoftmaxOperator() : Operator(OperatorType::kSoftmax) {}
1719 float beta = 0.f;
1720};
1721
1722// LogSoftmax activation function.
1723//
1724// Inputs:
1725// inputs[0]: required: the logits input array
1726//
1727// TensorFlow equivalent: LogSoftmax
1728struct LogSoftmaxOperator : Operator {
1729 LogSoftmaxOperator() : Operator(OperatorType::kLogSoftmax) {}
1730
1731 // LogSoftmax can in principal have very large negative output, depending on
1732 // the input size. However, input x_i that is less than x_max-10 is
1733 // accumulated as exp(x_i-x_max), which is truncated to zero.
1734 //
1735 // Since we effectively disregard smallish inputs in the normalizing factor,
1736 // we also drop them in the output (set to minimum output), and in doing so
1737 // make better use of the quantization range / resolution.
1738 static constexpr float kOutputRangeMin = -16.0;
1739};
1740
1741// Cast operator.
1742//
1743// Inputs:
1744// inputs[0]: required: the input array
1745//
1746// TensorFlow equivalent: Cast
1747struct CastOperator : Operator {
1748 CastOperator() : Operator(OperatorType::kCast) {}
1749 ArrayDataType src_data_type = ArrayDataType::kNone;
1750 ArrayDataType dst_data_type = ArrayDataType::kNone;
1751};
1752
1753// Floor operator.
1754//
1755// Inputs:
1756// inputs[0]: required: the input array
1757//
1758// TensorFlow equivalent: Floor
1759struct FloorOperator : Operator {
1760 FloorOperator() : Operator(OperatorType::kFloor) {}
1761};
1762
1763// Ceil operator.
1764//
1765// Inputs:
1766// inputs[0]: required: the input array
1767//
1768// TensorFlow equivalent: Ceil
1769struct CeilOperator : Operator {
1770 CeilOperator() : Operator(OperatorType::kCeil) {}
1771};
1772
1773// Round operator.
1774//
1775// Inputs:
1776// inputs[0]: required: the input array
1777//
1778// TensorFlow equivalent: Round
1779struct RoundOperator : Operator {
1780 RoundOperator() : Operator(OperatorType::kRound) {}
1781};
1782
1783// Gather operator. It gathers slices from params according to indices.
1784// Only 1-D indices are supported at the moment.
1785//
1786// Inputs:
1787// inputs[0]: required: the params array
1788// inputs[1]: required: the indices to gather
1789// inputs[2]: optional: axis
1790//
1791// TensorFlow equivalent: Gather
1792struct GatherOperator : Operator {
1793 GatherOperator() : Operator(OperatorType::kGather) {}
1794 // Axis is populated explicitly or implicitly from the axis input by
1795 // ResolveGatherAttributes. An empty axis indicates that the axis has not yet
1796 // be resolved.
1797 std::optional<int> axis;
1798
1799 // This field is not used by the standard TF Lite export but it is still need
1800 // for legacy Gather implementations.
1801 int input_rank = 0;
1802};
1803
1804// GatherNd operator. It gathers slices from params according to indices.
1805//
1806// Inputs:
1807// inputs[0]: required: the params array
1808// inputs[1]: required: the indices to gather
1809//
1810// TensorFlow equivalent: GatherNd
1811struct GatherNdOperator : Operator {
1812 GatherNdOperator() : Operator(OperatorType::kGatherNd) {}
1813};
1814
1815// ArgMax operator. It returns the index of the maximum value along axis.
1816//
1817// Inputs:
1818// inputs[0]: required: the input tensor
1819// inputs[1]: optional: 0-D (scalar) axis
1820//
1821// TensorFlow equivalent: ArgMax
1822struct ArgMaxOperator : Operator {
1823 ArgMaxOperator() : Operator(OperatorType::kArgMax) {}
1824 ArrayDataType output_data_type = ArrayDataType::kInt64;
1825};
1826
1827// ArgMin operator. It returns the index of the minimum value along axis.
1828//
1829// Inputs:
1830// inputs[0]: required: the input tensor
1831// inputs[1]: optional: 0-D (scalar) axis
1832//
1833// TensorFlow equivalent: ArgMin
1834struct ArgMinOperator : Operator {
1835 ArgMinOperator() : Operator(OperatorType::kArgMin) {}
1836 ArrayDataType output_data_type = ArrayDataType::kInt64;
1837};
1838
1839// ResizeBilinear operator. It resizes input images with bilinear interpolation.
1840// It does not support align_corners at the moment.
1841//
1842// Inputs:
1843// inputs[0]: required: the input array
1844// inputs[1]: required: the new image size
1845//
1846// TensorFlow equivalent: ResizeBilinear
1847struct ResizeBilinearOperator : Operator {
1848 ResizeBilinearOperator() : Operator(OperatorType::kResizeBilinear) {}
1849
1850 bool align_corners = false;
1851 bool half_pixel_centers = false;
1852};
1853
1854// ResizeNearestNeighborOperator operator. It resizes input images with nearest
1855// neighbor interpolation. It does not support align_corners at the moment.
1856//
1857// Inputs:
1858// inputs[0]: required: the input array
1859// inputs[1]: required: the new image size
1860//
1861// TensorFlow equivalent: ResizeNearestNeighbor
1862struct ResizeNearestNeighborOperator : Operator {
1863 ResizeNearestNeighborOperator()
1864 : Operator(OperatorType::kResizeNearestNeighbor) {}
1865
1866 bool align_corners = false;
1867 bool half_pixel_centers = false;
1868};
1869
1870// SpaceToBatchND operator. It divides spatial dimensions into a grid of
1871// blocks and interleaves these blocks with the batch dimension. Currently,
1872// only 2-d blocks are supported.
1873//
1874// Inputs:
1875// inputs[0]: required: the input array
1876// inputs[1]: required: the block shape
1877// inputs[2]: required: the paddings
1878//
1879// TensorFlow equivalent: SpaceToBatchND
1880struct SpaceToBatchNDOperator : Operator {
1881 SpaceToBatchNDOperator() : Operator(OperatorType::kSpaceToBatchND) {}
1882
1883 std::vector<int> block_shape;
1884 std::vector<int> before_paddings;
1885 std::vector<int> after_paddings;
1886};
1887
1888// BatchToSpaceND operator. Rearranges data from batch into blocks of
1889// spatial data. Currently, only 2-d blocks are supported.
1890//
1891// Inputs:
1892// inputs[0]: required: the input array
1893// inputs[1]: required: the block shape
1894// inputs[2]: required: the crops
1895//
1896// TensorFlow equivalent: BatchToSpaceND
1897struct BatchToSpaceNDOperator : Operator {
1898 BatchToSpaceNDOperator() : Operator(OperatorType::kBatchToSpaceND) {}
1899
1900 std::vector<int> block_shape;
1901 std::vector<int> before_crops;
1902 std::vector<int> after_crops;
1903};
1904
1905// Mean operator.
1906//
1907// Inputs:
1908// inputs[0]: required: the input array
1909//
1910// TensorFlow equivalent: Mean
1911struct MeanOperator : Operator {
1912 MeanOperator() : Operator(OperatorType::kMean) {}
1913
1914 std::vector<int> axis;
1915 bool keep_dims = false;
1916};
1917
1918// Svdf operator:
1919//
1920// Inputs:
1921// inputs[0]: required: the input array
1922// inputs[1]: required: weights_feature
1923// inputs[2]: required: weights_time
1924// inputs[3]: optional: bias
1925struct SvdfOperator : Operator {
1926 SvdfOperator() : Operator(OperatorType::kSvdf) {}
1927 int rank;
1928};
1929
1930// TopKV2 operator.
1931//
1932// Inputs:
1933// input tensor and top_k scalar.
1934struct TopKV2Operator : Operator {
1935 TopKV2Operator() : Operator(OperatorType::kTopK_V2) {}
1936};
1937
1938// DynamicPartition operator:
1939//
1940// Inputs:
1941// inputs[0]: required: data.
1942// inputs[1]: required: partitions.
1943//
1944// TensorFlow equivalent: DynamicPartition
1945struct DynamicPartitionOperator : Operator {
1946 DynamicPartitionOperator() : Operator(OperatorType::kDynamicPartition) {}
1947 int num_partitions;
1948};
1949
1950// DynamicStitch operator:
1951//
1952// Inputs:
1953// inputs[0,N): required: indices.
1954// inputs[N,2N): required: data.
1955//
1956// TensorFlow equivalent: DynamicStitch/ParallelDynamicStitch
1957struct DynamicStitchOperator : Operator {
1958 DynamicStitchOperator() : Operator(OperatorType::kDynamicStitch) {}
1959 int num_partitions;
1960};
1961
1962// SparseToDense operator:
1963//
1964// Inputs:
1965// Inputs[0]: required: sparse_indices.
1966// Inputs[1]: required: output_shape.
1967// Inputs[2]: required: sparse_values.
1968//
1969// TensorFlow equivalent: SparseToDense.
1970struct SparseToDenseOperator : Operator {
1971 SparseToDenseOperator() : Operator(OperatorType::kSparseToDense) {}
1972 bool validate_indices;
1973};
1974
1975// Pow operator:
1976//
1977// Inputs:
1978// Inputs[0]: required: A tensor.
1979// Inputs[1]: required: A tensor.
1980//
1981// TensorFlow equivalent: Pow.
1982struct PowOperator : Operator {
1983 PowOperator() : Operator(OperatorType::kPow) {}
1984};
1985
1986// Any operator:
1987//
1988// Inputs:
1989// Inputs[0]: required: A boolean input tensor.
1990// Inputs[1]: required: reduction_indices.
1991//
1992// TensorFlow equivalent: tf.reduce_any.
1993struct TensorFlowAnyOperator : Operator {
1994 TensorFlowAnyOperator() : Operator(OperatorType::kAny) {}
1995 std::vector<int> axis;
1996 bool keep_dims = false;
1997};
1998
1999// LogicalAnd operator:
2000//
2001// Inputs:
2002// Inputs[0]: required: A boolean tensor.
2003// Inputs[1]: required: A boolean tensor.
2004//
2005// TensorFlow equivalent: tf.logical_and.
2006struct LogicalAndOperator : Operator {
2007 LogicalAndOperator() : Operator(OperatorType::kLogicalAnd) {}
2008};
2009
2010// LogicalNot operator:
2011//
2012// Inputs:
2013// Inputs[0]: required: A boolean tensor.
2014//
2015// TensorFlow equivalent: tf.logical_not.
2016struct LogicalNotOperator : Operator {
2017 LogicalNotOperator() : Operator(OperatorType::kLogicalNot) {}
2018};
2019
2020// OneHot operator:
2021//
2022// Inputs:
2023// Inputs[0]: required: indices.
2024// Inputs[1]: required: depth.
2025// Inputs[2]: required: on_value.
2026// Inputs[3]: required: off_value.
2027//
2028// TensorFlow equivalent: OneHot.
2029struct OneHotOperator : Operator {
2030 enum Inputs {
2031 INDICES_INPUT = 0,
2032 DEPTH_INPUT = 1,
2033 ON_VALUE_INPUT = 2,
2034 OFF_VALUE_INPUT = 3,
2035 };
2036
2037 OneHotOperator() : Operator(OperatorType::kOneHot) {}
2038 int axis = -1;
2039};
2040
2041// LogicalOr operator:
2042//
2043// Inputs:
2044// Inputs[0]: required: A Bool tensor.
2045// Inputs[1]: required: A Bool tensor.
2046//
2047// TensorFlow equivalent: LogicalOr.
2048struct LogicalOrOperator : Operator {
2049 LogicalOrOperator() : Operator(OperatorType::kLogicalOr) {}
2050};
2051
2052// Unpack operator:
2053//
2054// Inputs:
2055// Inputs[0]: required: A boolean input tensor.
2056// Inputs[1]: required: reduction_indices.
2057//
2058// TensorFlow equivalent: tf.unstack.
2059struct UnpackOperator : Operator {
2060 UnpackOperator() : Operator(OperatorType::kUnpack) {}
2061 int num;
2062 int axis;
2063 ArrayDataType dtype = ArrayDataType::kNone;
2064};
2065
2066// ZerosLike operator:
2067//
2068// Inputs:
2069// inputs[0]: required: the input array
2070//
2071// TensorFlow equivalent: tf.zeros_like
2072struct TensorFlowZerosLikeOperator : Operator {
2073 TensorFlowZerosLikeOperator() : Operator(OperatorType::kZerosLike) {}
2074};
2075
2076// ReverseV2 operator:
2077//
2078// Inputs:
2079// Inputs[0]: required: the input array.
2080//
2081// TensorFlow equivalent: ReverseV2.
2082struct ReverseV2Operator : Operator {
2083 ReverseV2Operator() : Operator(OperatorType::kReverseV2) {}
2084};
2085
2086enum class MirrorPadMode { kNone, kSymmetric, kReflect };
2087
2088// MirrorPad Operator:
2089//
2090// Inputs:
2091// Inputs[0]: required: input tensor to be padded.
2092// Inputs[1]: required: 2 Column matrix specifying padding sizes. The number of
2093// rows must be the same as the rank of the input.
2094// Inputs[2]: required: REFLECT or SYMMETRIC.
2095//
2096// TensorFlow equivalent: MirrorPad.
2097struct MirrorPadOperator : Operator {
2098 MirrorPadOperator() : Operator(OperatorType::kMirrorPad) {}
2099 // mode is either SYMMETRIC or REFLECT.
2100 MirrorPadMode mode;
2101};
2102
2103// ReverseSequence operator:
2104//
2105// Inputs:
2106// Inputs[0]: required: the input array.
2107// Inputs[1]: required: the lengths of the elements to be reversed.
2108//
2109// TensorFlow equivalent: tf.reverse_sequence.
2110struct ReverseSequenceOperator : Operator {
2111 ReverseSequenceOperator() : Operator(OperatorType::kReverseSequence) {}
2112 int seq_dim;
2113 int batch_dim = 0;
2114};
2115
2116// Unique Operator:
2117//
2118// Inputs:
2119// inputs[0]: required: the input array
2120//
2121// TensorFlow equivalent: Unique
2122struct UniqueOperator : Operator {
2123 UniqueOperator() : Operator(OperatorType::kUnique) {}
2124 ArrayDataType idx_out_type = ArrayDataType::kInt32;
2125};
2126
2127struct UnidirectionalSequenceRnnOperator : Operator {
2128 UnidirectionalSequenceRnnOperator()
2129 : Operator(OperatorType::kUnidirectionalSequenceRnn) {}
2130 bool time_major;
2131 FusedActivationFunctionType fused_activation_function;
2132};
2133
2134// Where Operator:
2135// Return the coordinates of the true values in condition tensor in row-major
2136// order.
2137//
2138// Inputs:
2139// inputs[0]: required: boolean condition tensor
2140//
2141// TensorFlow equivalent: Where
2142struct WhereOperator : Operator {
2143 WhereOperator() : Operator(OperatorType::kWhere) {}
2144};
2145
2146// Matrix Diag Operator:
2147// Construct a batched diagonal tensor with given batched diagonal values.
2148// Inputs: A tensor of values that will be on the diagonal of the returned
2149// tensor.
2150struct MatrixDiagOperator : Operator {
2151 MatrixDiagOperator() : Operator(OperatorType::kMatrixDiag) {}
2152};
2153
2154// Matrix Diag Operator V2:
2155// Construct a batched diagonal tensor with given batched diagonal values.
2156// Not fully supported, contains 4 extra inputs compared to MatrixDiag. Behave
2157// like MatrixDiag when default parameters are used.
2158struct MatrixDiagV2Operator : Operator {
2159 MatrixDiagV2Operator() : Operator(OperatorType::kMatrixDiagV2) {}
2160};
2161
2162// Matrix Diag Operator V3:
2163// Construct a batched diagonal tensor with given batched diagonal values.
2164// Not fully supported, contains 5 extra inputs compared to MatrixDiag. Behave
2165// like MatrixDiag when default parameters are used.
2166// V3 is only different from V2 because it has an extra attribute (align) which
2167// controls the alignment of diagonals in the band matrix (compact) format.
2168// The alignment in V2 contradicts with the default alignment in V3 so V2 is
2169// skipped. (It has never been, and should never be, exposed in the public API.)
2170struct MatrixDiagV3Operator : Operator {
2171 MatrixDiagV3Operator() : Operator(OperatorType::kMatrixDiagV3) {}
2172};
2173
2174// Matrix Set Diag Operator:
2175// Construct a batched diagonal tensor with given input and diagonal values.
2176// Input is a rank (k+1) tensor of values.
2177// diagonal is a rank (k) tensor of values that will be on the diagonal
2178// of the returned output. Output is rank k+1.
2179// tensor.
2180struct MatrixSetDiagOperator : Operator {
2181 MatrixSetDiagOperator() : Operator(OperatorType::kMatrixSetDiag) {}
2182};
2183
2184// Matrix Set Diag Operator V2:
2185// Construct a batched diagonal tensor with given input and diagonal values.
2186// Not fully supported, contains 1 extra inputs compared to MatrixSetDiag.
2187// Behave like MatrixSetDiag when default parameters are used.
2188struct MatrixSetDiagV2Operator : Operator {
2189 MatrixSetDiagV2Operator() : Operator(OperatorType::kMatrixSetDiagV2) {}
2190};
2191
2192// Matrix Set Diag Operator V3:
2193// Construct a batched diagonal tensor with given input and diagonal values.
2194// Not fully supported, contains 2 extra inputs compared to MatrixSetDiag.
2195// Behave like MatrixSetDiag when default parameters are used.
2196// V3 is only different from V2 because it has an extra attribute (align) which
2197// controls the alignment of diagonals in the band matrix (compact) format.
2198// The alignment in V2 contradicts with the default alignment in V3 so V2 is
2199// skipped. (It has never been, and should never be, exposed in the public API.)
2200struct MatrixSetDiagV3Operator : Operator {
2201 MatrixSetDiagV3Operator() : Operator(OperatorType::kMatrixSetDiagV3) {}
2202};
2203
2204struct ScatterNdOperator : Operator {
2205 ScatterNdOperator() : Operator(OperatorType::kScatterNd) {}
2206};
2207
2208struct SegmentSumOperator : Operator {
2209 SegmentSumOperator() : Operator(OperatorType::kSegmentSum) {}
2210};
2211
2212// Alloc's are used for transient arrays only. An Alloc specifies which interval
2213// of the "transient_data" workspace buffer passed to inference functions, is to
2214// be used for the transient array at hand. The 'start' and 'end' values are
2215// offsets from the start of the workspace buffer, expressed in bytes.
2216struct Alloc {
2217 int64_t start = 0;
2218 int64_t end = 0;
2219};
2220
2221inline bool operator<(const Alloc& a, const Alloc& b) {
2222 return a.start < b.start;
2223}
2224
2225// Array represents an array (either a constant parameter array or an
2226// activations array) in a Model.
2227struct Array {
2228 template <ArrayDataType A>
2229 const Buffer<A>& GetBuffer() const {
2230 DCHECK(buffer);
2231 DCHECK(buffer->type == A);
2232 return *static_cast<const Buffer<A>*>(buffer.get());
2233 }
2234 template <ArrayDataType A>
2235 Buffer<A>& GetMutableBuffer() {
2236 if (!buffer) {
2237 Buffer<A>* ptr = new Buffer<A>;
2238 buffer = std::unique_ptr<GenericBuffer>(ptr);
2239 }
2240 DCHECK(buffer);
2241 DCHECK(buffer->type == A);
2242 return *static_cast<Buffer<A>*>(buffer.get());
2243 }
2244 Alloc& GetOrCreateAlloc() {
2245 if (!alloc) {
2246 alloc = std::make_unique<Alloc>();
2247 }
2248 return *alloc;
2249 }
2250 MinMax& GetOrCreateMinMax() {
2251 if (!minmax) {
2252 minmax = std::make_unique<MinMax>();
2253 }
2254 return *minmax;
2255 }
2256 MinMax& GetMinMax() const {
2257 DCHECK(minmax);
2258 return *minmax;
2259 }
2260 QuantizationParams& GetOrCreateQuantizationParams() {
2261 if (!quantization_params) {
2262 quantization_params = std::make_unique<QuantizationParams>();
2263 }
2264 return *quantization_params;
2265 }
2266 QuantizationParams& GetQuantizationParams() const {
2267 DCHECK(quantization_params);
2268 return *quantization_params;
2269 }
2270
2271 // The data type of the actual elements of this array, that is:
2272 // - If there is a buffer (see 'buffer' member), it must be of the same
2273 // type.
2274 // - If there is no buffer, meaning that this is a runtime (i.e. activations)
2275 // array, then this specifies the type of elements that there will be
2276 // at runtime.
2277 //
2278 // Note that this only specifies the storage type of elements; this does
2279 // not specify whether these are to be treated as 'real' or 'quantized'
2280 // values.
2281 // That is decided by whether the 'quantization_params' member is null.
2282 ArrayDataType data_type = ArrayDataType::kNone;
2283 // The final value that data_type should have at the end of graph
2284 // transformations
2285 ArrayDataType final_data_type = ArrayDataType::kNone;
2286 // The dimensions of this array --- this specifies both sizes and strides
2287 // (the storage layout).
2288 //
2289 // Issues with shape handling that remain include:
2290 // - No way to distinguish between 0-dimensional dims and missing dims.
2291 // - No way to describe dims that may be runtime-variable.
2292 // - Addressing of dims by integer index differs in different graph formats
2293 // (TensorFlow vs. other frameworks vs. what we have informally grown
2294 // within toco).
2295 // This is currently quite messy; see ReorderAxesOperator which is how we
2296 // bridge some of these discrepancies at the moment. This is overdue for
2297 // a redesign; I'm thinking that it would be nice to have more flexible
2298 // dims that allow mapping 1:1, cleanly, dims as they are in various
2299 // formats,
2300 // then explicitly convert between different conventions.
2301
2302 // Proto-style accessors
2303 bool has_shape() const { return array_shape != nullptr; }
2304 const Shape& shape() const {
2305 CHECK(has_shape());
2306 return *array_shape;
2307 }
2308 Shape* mutable_shape() {
2309 if (!array_shape) {
2310 array_shape = std::make_unique<Shape>();
2311 }
2312 return array_shape.get();
2313 }
2314 void copy_shape(const Shape& src_shape) { *mutable_shape() = src_shape; }
2315 void clear_shape() { array_shape = nullptr; }
2316
2317 // The constant buffer backing this array. This is non-null if and only if
2318 // this is a constant parameter array. Conversely, this is null for
2319 // activations arrays.
2320 //
2321 // Note that this buffer is pure storage. In the case of quantized values,
2322 // it only stores the quantized values, it does not know by itself about the
2323 // quantization parameters necessary to interprete these values, that is
2324 // in the separate 'quantization_params' field. In fact, this 'buffer' field
2325 // does no even know whether values are quantized. It only has a data_type,
2326 // which must equal the 'data_type' member here, and which only describes
2327 // the storage type of element, does not tell whether they are quantized i.e.
2328 // whether they are to be interpreted with quantization_params.
2329 std::unique_ptr<GenericBuffer> buffer;
2330 // Only for activation arrays (i.e. when 'buffer' is null).
2331 // Only for code generation.
2332 //
2333 // Describes the allocation of this array within the workspace buffer
2334 // allocated
2335 // for all transient arrays.
2336 std::unique_ptr<Alloc> alloc;
2337 // Describes the [min, max] range of values
2338 // to be assumed when determining quantization_params.
2339 //
2340 // Only used for quantization. In fact, only used for determining
2341 // quantization_params.
2342 //
2343 // Used for both constant arrays (those having a 'buffer') and non-constant
2344 // arrays (activations). Indeed, it is important to use the same min-max range
2345 // as was used during training, even if that min-max range is slightly wrong
2346 // w.r.t. actual buffer elements. Doing otherwise would defeat the point of
2347 // re-training for quantization.
2348 std::unique_ptr<MinMax> minmax;
2349 // Quantization parameters. The non-null-ness of this pointer is what
2350 // defines whether this array is quantized or not.
2351 //
2352 // If this is non-null, then these quantization parameters are to be used
2353 // to assign a meaning as real numbers to the elements of this array.
2354 std::unique_ptr<QuantizationParams> quantization_params;
2355 // narrow_range is a detail of how toco handles FakeQuant operators with
2356 // narrow_range, see
2357 // https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_vars
2358 //
2359 // For more context about what that is useful for, see the big comment in
2360 // graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
2361 //
2362 // The narrow_range flag applies only to quantized arrays, and changes
2363 // their quantization in the following way when it is set to 'true':
2364 // 1. The computation of {zero_point, scale} from {min, max} needs to be
2365 // amended so that the real min value will get quantized to
2366 // (min_quantized_value + 1) instead of just (min_quantized_value).
2367 // E.g. for uint8 quantization, the real min value should get quantized to
2368 // the uint8 value 1, not 0.
2369 // 2. Quantized values should get clamped to the interval
2370 // [min_quantized_value + 1, max_value]. Equivalently, the
2371 // min_quantized_value should get nudged to (min_quantized_value + 1).
2372 // The reason why 1. does not imply 2. is that real values may not belong to
2373 // the stated [min, max] interval. Concretely, weights recorded at the last
2374 // learning step may not fall in the [min, max] interval recorded over
2375 // previous learning steps, as the values evolve across learning steps.
2376 //
2377 // Rationale why this is directly a field on Array:
2378 // - This can't be just a field on FakeQuantOperator, because
2379 // FakeQuantOperators are gone (DropFakeQuant) before we get to using that
2380 // information (Quantize). We need a place to store that bit in the interim.
2381 // - This can't be in QuantizationParams because we need to record this
2382 // ahead of quantization, and QuantizationParams are only created during
2383 // quantization.
2384 // - This could be in MinMax, but that would be an abuse of what MinMax is
2385 // about, and would break existing code that assumes that a MinMax is just
2386 // a min and a max. Unlike MinMax which is agnostic as to the quantized
2387 // data type, narrow_range refers to values in the quantized data type.
2388 bool narrow_range = false;
2389
2390 private:
2391 std::unique_ptr<Shape> array_shape;
2392};
2393
2394// Our Model struct, represents an entire model (our "top-level" struct).
2395// Owns everything.
2396class Model {
2397 public:
2398 using ArrayMap = std::unordered_map<std::string, std::unique_ptr<Array>>;
2399
2400 bool HasArray(const std::string& name) const {
2401 return arrays.count(name) > 0;
2402 }
2403 Array& GetArray(const std::string& name) const {
2404 DCHECK(HasArray(name)) << "Array not found: " << name;
2405 return *arrays.at(name);
2406 }
2407 Array& GetOrCreateArray(const std::string& name) {
2408 // Make sure name is not used by an optional array
2409 DCHECK(!optional_arrays.count(name));
2410 if (!HasArray(name)) {
2411 Array* ptr = new Array;
2412 arrays[name] = std::unique_ptr<Array>(ptr);
2413 }
2414 Array& result = GetArray(name);
2415 return result;
2416 }
2417 void CreateOptionalArray(const std::string& name) {
2418 DCHECK(!arrays.count(name) && !optional_arrays.count(name));
2419 optional_arrays.insert(name);
2420 }
2421 bool IsOptionalArray(const std::string& name) const {
2422 return optional_arrays.count(name);
2423 }
2424
2425 // Note that this invalidates all array iterators.
2426 void EraseArray(const std::string& name) { arrays.erase(name); }
2427 void EraseArrays(std::function<bool(const std::string&)> discardable) {
2428 for (auto it = arrays.begin(); it != arrays.end();) {
2429 if (discardable(it->first)) {
2430 it = arrays.erase(it);
2431 } else {
2432 ++it;
2433 }
2434 }
2435 }
2436 const ArrayMap& GetArrayMap() const { return arrays; }
2437 ArrayMap& GetMutableArrayMap() { return arrays; }
2438
2439 int64_t ArithmeticOpsCount() const { return ops_count; }
2440
2441 void AddInvalidInputArray(std::string invalid_input_array) {
2442 invalid_input_arrays_.insert(invalid_input_array);
2443 }
2444
2445 const std::unordered_set<std::string>& GetInvalidInputArrays() const {
2446 return invalid_input_arrays_;
2447 }
2448
2449 // Optional arrays are used for optional tensors,
2450 // these tensors do not have data, but with reserved names as op inputs.
2451 std::set<std::string> optional_arrays;
2452
2453 // The list of operators. Notice how it's a list of unique_ptr's, implying
2454 // that the Model is what owns Operator's and keeps them alive.
2455 std::vector<std::unique_ptr<Operator>> operators;
2456
2457 // Generic flags, a place where we combine information passed to us via
2458 // command-line parameters (e.g. --input_width=N) with information that
2459 // we may or may not find in the input model file.
2460 ModelFlags flags;
2461 // For code-generation only: required size of the transient_data buffer
2462 std::size_t transient_data_size = 0;
2463 // For code-generation only: required alignment of the transient_data buffer
2464 std::size_t transient_data_alignment = 0;
2465 // Arithmetic operations performed in the model.
2466 int64_t ops_count = 0;
2467
2468 private:
2469 // The associative array mapping names to Array's.
2470 // Notice how it's a container of unique_ptr's, implying
2471 // that the Model is what owns Array's and keeps them alive.
2472 // The Operator's refer to these Array's by their name strings, not by their
2473 // addresses. See Operator::inputs, Operator::outputs.
2474 std::unordered_map<std::string, std::unique_ptr<Array>> arrays;
2475
2476 // Invalid input arrays.
2477 std::unordered_set<std::string> invalid_input_arrays_;
2478};
2479
2480// OperatorSignature contains the information required to making versioning
2481// decisions.
2482struct OperatorSignature {
2483 // The operator.
2484 const Operator* op;
2485
2486 // The model in which the operator resides.
2487 const Model* model;
2488};
2489} // namespace toco
2490
2491#endif // TENSORFLOW_LITE_TOCO_MODEL_H_
2492