1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_LITE_MUTABLE_OP_RESOLVER_H_ |
16 | #define TENSORFLOW_LITE_MUTABLE_OP_RESOLVER_H_ |
17 | |
18 | #include <stddef.h> |
19 | |
20 | #include <string> |
21 | #include <unordered_map> |
22 | #include <utility> |
23 | #include <vector> |
24 | |
25 | #include "tensorflow/lite/c/common.h" |
26 | #include "tensorflow/lite/core/api/op_resolver.h" |
27 | #include "tensorflow/lite/schema/schema_generated.h" |
28 | #include "tensorflow/lite/util.h" |
29 | |
30 | namespace tflite { |
31 | |
32 | // Some versions of gcc don't support partial specialization in class scope, |
33 | // so these are defined in a namescope. |
34 | namespace op_resolver_hasher { |
35 | template <typename V> |
36 | struct ValueHasher { |
37 | size_t operator()(const V& v) const { return std::hash<V>()(v); } |
38 | }; |
39 | |
40 | template <> |
41 | struct ValueHasher<tflite::BuiltinOperator> { |
42 | size_t operator()(const tflite::BuiltinOperator& v) const { |
43 | return std::hash<int>()(static_cast<int>(v)); |
44 | } |
45 | }; |
46 | |
47 | template <typename T> |
48 | struct OperatorKeyHasher { |
49 | size_t operator()(const T& x) const { |
50 | size_t a = ValueHasher<typename T::first_type>()(x.first); |
51 | size_t b = ValueHasher<typename T::second_type>()(x.second); |
52 | return CombineHashes({a, b}); |
53 | } |
54 | }; |
55 | } // namespace op_resolver_hasher |
56 | |
57 | /// An OpResolver that is mutable, also used as the op in gen_op_registration. |
58 | /// A typical usage: |
59 | /// MutableOpResolver resolver; |
60 | /// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD()); |
61 | /// resolver.AddCustom("CustomOp", Register_CUSTOM_OP()); |
62 | /// InterpreterBuilder(model, resolver)(&interpreter); |
63 | class MutableOpResolver : public OpResolver { |
64 | public: |
65 | const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, |
66 | int version) const override; |
67 | const TfLiteRegistration* FindOp(const char* op, int version) const override; |
68 | |
69 | /// Registers the specified `version` of the specified builtin operator `op`. |
70 | /// Replaces any previous registration for the same operator version. |
71 | void AddBuiltin(tflite::BuiltinOperator op, |
72 | const TfLiteRegistration* registration, int version = 1); |
73 | |
74 | /// Registers the specified version range (versions `min_version` to |
75 | /// `max_version`, inclusive) of the specified builtin operator `op`. |
76 | /// Replaces any previous registration for the same operator version. |
77 | void AddBuiltin(tflite::BuiltinOperator op, |
78 | const TfLiteRegistration* registration, int min_version, |
79 | int max_version); |
80 | |
81 | /// Registers the specified `version` of the specified builtin operator `op`. |
82 | /// Replaces any previous registration for the same operator version. |
83 | void AddCustom(const char* name, const TfLiteRegistration* registration, |
84 | int version = 1); |
85 | |
86 | /// Registers the specified version range (versions `min_version` to |
87 | /// `max_version`, inclusive) of the specified custom operator `name`. |
88 | /// Replaces any previous registration for the same operator version. |
89 | void AddCustom(const char* name, const TfLiteRegistration* registration, |
90 | int min_version, int max_version); |
91 | |
92 | /// Registers all operator versions supported by another MutableOpResolver. |
93 | /// Replaces any previous registrations for the same operator versions, |
94 | /// except that registrations made with `AddBuiltin` or `AddCustom` always |
95 | /// take precedence over registrations made with `ChainOpResolver`. |
96 | void AddAll(const MutableOpResolver& other); |
97 | |
98 | OpResolver::TfLiteDelegateCreators GetDelegateCreators() const final { |
99 | return delegate_creators_; |
100 | } |
101 | |
102 | OpResolver::TfLiteOpaqueDelegateCreators GetOpaqueDelegateCreators() |
103 | const final { |
104 | return opaque_delegate_creators_; |
105 | } |
106 | |
107 | protected: |
108 | /// Registers all operator versions supported by another OpResolver, |
109 | /// except any already registered in this MutableOpResolver. |
110 | /// `other` must point to an OpResolver whose lifetime is at least as long |
111 | /// as the lifetime of the MutableOpResolver pointed to by `this`. |
112 | /// The OpResolver pointed to by `other` should not be modified during the |
113 | /// lifetime of this MutableOpResolver. |
114 | void ChainOpResolver(const OpResolver* other); |
115 | |
116 | /// True if this OpResolver itself (as opposed to chained op resolvers |
117 | /// registed with ChainOpResolver) may contain user defined ops. |
118 | /// |
119 | /// By "user defined" ops, we mean any op definitions other than those |
120 | /// contained in tflite::ops::builtin::BuiltinOpResolver. |
121 | bool may_directly_contain_user_defined_ops_ = false; |
122 | |
123 | /// A vector of delegate creators to create optional delegates for resolving |
124 | /// and handling ops in the flatbuffer model. This may be used in addition to |
125 | /// the standard TfLiteRegistration lookup for graph resolution. |
126 | TfLiteDelegateCreators delegate_creators_; |
127 | |
128 | /// A vector of opaque delegate creators to create optional opaque delegates |
129 | /// for resolving and handling ops in the flatbuffer model. This may be used |
130 | /// in addition to the standard TfLiteRegistration lookup for graph |
131 | /// resolution. This is used for TF Lite in Google Play Services. |
132 | TfLiteOpaqueDelegateCreators opaque_delegate_creators_; |
133 | |
134 | private: |
135 | bool MayContainUserDefinedOps() const override; |
136 | |
137 | typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey; |
138 | typedef std::pair<std::string, int> CustomOperatorKey; |
139 | |
140 | std::unordered_map<BuiltinOperatorKey, TfLiteRegistration, |
141 | op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey> > |
142 | builtins_; |
143 | std::unordered_map<CustomOperatorKey, TfLiteRegistration, |
144 | op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey> > |
145 | custom_ops_; |
146 | std::vector<const OpResolver*> other_op_resolvers_; |
147 | }; |
148 | |
149 | } // namespace tflite |
150 | |
151 | #endif // TENSORFLOW_LITE_MUTABLE_OP_RESOLVER_H_ |
152 | |