1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_LITE_TOCO_TFLITE_OPERATOR_H_ |
16 | #define TENSORFLOW_LITE_TOCO_TFLITE_OPERATOR_H_ |
17 | |
18 | #include <string> |
19 | |
20 | #include "flatbuffers/flatbuffers.h" |
21 | #include "flatbuffers/flexbuffers.h" |
22 | #include "tensorflow/lite/schema/schema_generated.h" |
23 | #include "tensorflow/lite/toco/model.h" |
24 | #include "tensorflow/lite/tools/versioning/op_version.h" |
25 | |
26 | namespace toco { |
27 | |
28 | namespace tflite { |
29 | |
30 | class BaseOperator; |
31 | |
32 | // Return a map contained all know TF Lite Operators, keyed by their names. |
33 | // TODO(ycling): The pattern to propagate parameters (e.g. enable_select_tf_ops) |
34 | // is ugly here. Consider refactoring. |
35 | std::map<std::string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap( |
36 | bool enable_select_tf_ops = false); |
37 | |
38 | // Return a map contained all know TF Lite Operators, keyed by the type of |
39 | // their tf.mini counterparts. |
40 | std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap( |
41 | bool enable_select_tf_ops = false); |
42 | |
43 | // Write the custom option FlexBuffer with a serialized TensorFlow NodeDef |
44 | // for a Flex op. |
45 | std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions( |
46 | const std::string& tensorflow_node_def); |
47 | |
48 | // These are the flatbuffer types for custom and builtin options. |
49 | using CustomOptions = flatbuffers::Vector<uint8_t>; |
50 | using BuiltinOptions = void; |
51 | |
52 | // A simple wrapper around the flatbuffer objects used to describe options that |
53 | // configure operators. |
54 | struct Options { |
55 | // Build custom options. |
56 | static Options Custom(flatbuffers::Offset<CustomOptions> offset) { |
57 | return {::tflite::BuiltinOptions_NONE, 0, offset}; |
58 | } |
59 | |
60 | // Build builtin options of the given type. |
61 | static Options Builtin(::tflite::BuiltinOptions type, |
62 | flatbuffers::Offset<BuiltinOptions> offset) { |
63 | return {type, offset, 0}; |
64 | } |
65 | |
66 | ::tflite::BuiltinOptions type; |
67 | flatbuffers::Offset<BuiltinOptions> builtin; |
68 | flatbuffers::Offset<CustomOptions> custom; |
69 | }; |
70 | |
71 | // A BaseOperator encapsulates the relationship between operators in tf.mini |
72 | // and TF lite, and provides methods for converting between those two formats. |
73 | class BaseOperator { |
74 | public: |
75 | // Build an operator with the given TF Lite name and tf.mini type. |
76 | BaseOperator(const std::string& name, OperatorType type) |
77 | : name_(name), type_(type) {} |
78 | virtual ~BaseOperator() = default; |
79 | |
80 | std::string name() const { return name_; } |
81 | OperatorType type() const { return type_; } |
82 | |
83 | // Given a tf.mini operator, create the corresponding flatbuffer options and |
84 | // return their offsets. |
85 | virtual Options Serialize(const Operator& op, |
86 | flatbuffers::FlatBufferBuilder* builder) const = 0; |
87 | |
88 | // Read TF Lite options and create the appropriate tf.mini operator. |
89 | virtual std::unique_ptr<Operator> Deserialize( |
90 | const BuiltinOptions* builtin_options, |
91 | const CustomOptions* custom_options) const = 0; |
92 | |
93 | // Get the op version using the OperatorSignature. |
94 | // The function needs to be overridden to return the op version based on the |
95 | // parameters. Note: |
96 | // * The first version for each op should be 1 (to be consistent with the |
97 | // default value in Flatbuffer. `return 1;` is okay for newly implemented |
98 | // ops. |
99 | // * When multiple versions are defined for an op, this function could be |
100 | // overridden. (See example in `operator_test.cc` and |
101 | // 'tools/versioning/op_version.cc`) |
102 | virtual int GetVersion(const OperatorSignature& op_signature) const = 0; |
103 | |
104 | // Given a Toco `Operator`, return a list of booleans indicating the op |
105 | // mutates which input variables. |
106 | // * If the op mutates any input variables, it should return a list of bool |
107 | // with the same length as inputs. |
108 | // * Otherwise, it will return an empty list. |
109 | virtual std::vector<bool> GetMutatingInputVariables( |
110 | const Operator& op) const { |
111 | // Most ops don't have variable tensors. This function can be overridden. |
112 | return std::vector<bool>(); |
113 | } |
114 | |
115 | private: |
116 | std::string name_; |
117 | OperatorType type_; |
118 | }; |
119 | |
120 | // Helper function to create ::tflite::OpSignature from the given |
121 | // ::tflite::BuiltinOperator and OperatorSignature. |
122 | ::tflite::OpSignature GetVersioningOpSig(const ::tflite::BuiltinOperator op, |
123 | const OperatorSignature& op_signature); |
124 | |
125 | // Helper function to determine if a unsupported TensorFlow op should be |
126 | // exported as an Flex op or a regular custom op. |
127 | bool ShouldExportAsFlexOp(bool enable_select_tf_ops, |
128 | const std::string& tensorflow_op_name); |
129 | |
130 | } // namespace tflite |
131 | |
132 | } // namespace toco |
133 | |
134 | #endif // TENSORFLOW_LITE_TOCO_TFLITE_OPERATOR_H_ |
135 | |