1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_LITE_TOCO_TFLITE_OPERATOR_H_
16#define TENSORFLOW_LITE_TOCO_TFLITE_OPERATOR_H_
17
18#include <string>
19
20#include "flatbuffers/flatbuffers.h"
21#include "flatbuffers/flexbuffers.h"
22#include "tensorflow/lite/schema/schema_generated.h"
23#include "tensorflow/lite/toco/model.h"
24#include "tensorflow/lite/tools/versioning/op_version.h"
25
26namespace toco {
27
28namespace tflite {
29
30class BaseOperator;
31
32// Return a map contained all know TF Lite Operators, keyed by their names.
33// TODO(ycling): The pattern to propagate parameters (e.g. enable_select_tf_ops)
34// is ugly here. Consider refactoring.
35std::map<std::string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
36 bool enable_select_tf_ops = false);
37
38// Return a map contained all know TF Lite Operators, keyed by the type of
39// their tf.mini counterparts.
40std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
41 bool enable_select_tf_ops = false);
42
43// Write the custom option FlexBuffer with a serialized TensorFlow NodeDef
44// for a Flex op.
45std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
46 const std::string& tensorflow_node_def);
47
48// These are the flatbuffer types for custom and builtin options.
49using CustomOptions = flatbuffers::Vector<uint8_t>;
50using BuiltinOptions = void;
51
52// A simple wrapper around the flatbuffer objects used to describe options that
53// configure operators.
54struct Options {
55 // Build custom options.
56 static Options Custom(flatbuffers::Offset<CustomOptions> offset) {
57 return {::tflite::BuiltinOptions_NONE, 0, offset};
58 }
59
60 // Build builtin options of the given type.
61 static Options Builtin(::tflite::BuiltinOptions type,
62 flatbuffers::Offset<BuiltinOptions> offset) {
63 return {type, offset, 0};
64 }
65
66 ::tflite::BuiltinOptions type;
67 flatbuffers::Offset<BuiltinOptions> builtin;
68 flatbuffers::Offset<CustomOptions> custom;
69};
70
71// A BaseOperator encapsulates the relationship between operators in tf.mini
72// and TF lite, and provides methods for converting between those two formats.
73class BaseOperator {
74 public:
75 // Build an operator with the given TF Lite name and tf.mini type.
76 BaseOperator(const std::string& name, OperatorType type)
77 : name_(name), type_(type) {}
78 virtual ~BaseOperator() = default;
79
80 std::string name() const { return name_; }
81 OperatorType type() const { return type_; }
82
83 // Given a tf.mini operator, create the corresponding flatbuffer options and
84 // return their offsets.
85 virtual Options Serialize(const Operator& op,
86 flatbuffers::FlatBufferBuilder* builder) const = 0;
87
88 // Read TF Lite options and create the appropriate tf.mini operator.
89 virtual std::unique_ptr<Operator> Deserialize(
90 const BuiltinOptions* builtin_options,
91 const CustomOptions* custom_options) const = 0;
92
93 // Get the op version using the OperatorSignature.
94 // The function needs to be overridden to return the op version based on the
95 // parameters. Note:
96 // * The first version for each op should be 1 (to be consistent with the
97 // default value in Flatbuffer. `return 1;` is okay for newly implemented
98 // ops.
99 // * When multiple versions are defined for an op, this function could be
100 // overridden. (See example in `operator_test.cc` and
101 // 'tools/versioning/op_version.cc`)
102 virtual int GetVersion(const OperatorSignature& op_signature) const = 0;
103
104 // Given a Toco `Operator`, return a list of booleans indicating the op
105 // mutates which input variables.
106 // * If the op mutates any input variables, it should return a list of bool
107 // with the same length as inputs.
108 // * Otherwise, it will return an empty list.
109 virtual std::vector<bool> GetMutatingInputVariables(
110 const Operator& op) const {
111 // Most ops don't have variable tensors. This function can be overridden.
112 return std::vector<bool>();
113 }
114
115 private:
116 std::string name_;
117 OperatorType type_;
118};
119
120// Helper function to create ::tflite::OpSignature from the given
121// ::tflite::BuiltinOperator and OperatorSignature.
122::tflite::OpSignature GetVersioningOpSig(const ::tflite::BuiltinOperator op,
123 const OperatorSignature& op_signature);
124
125// Helper function to determine if a unsupported TensorFlow op should be
126// exported as an Flex op or a regular custom op.
127bool ShouldExportAsFlexOp(bool enable_select_tf_ops,
128 const std::string& tensorflow_op_name);
129
130} // namespace tflite
131
132} // namespace toco
133
134#endif // TENSORFLOW_LITE_TOCO_TFLITE_OPERATOR_H_
135