1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_LITE_MUTABLE_OP_RESOLVER_H_
16#define TENSORFLOW_LITE_MUTABLE_OP_RESOLVER_H_
17
18#include <stddef.h>
19
20#include <string>
21#include <unordered_map>
22#include <utility>
23#include <vector>
24
25#include "tensorflow/lite/c/common.h"
26#include "tensorflow/lite/core/api/op_resolver.h"
27#include "tensorflow/lite/schema/schema_generated.h"
28#include "tensorflow/lite/util.h"
29
30namespace tflite {
31
32// Some versions of gcc don't support partial specialization in class scope,
33// so these are defined in a namescope.
34namespace op_resolver_hasher {
35template <typename V>
36struct ValueHasher {
37 size_t operator()(const V& v) const { return std::hash<V>()(v); }
38};
39
40template <>
41struct ValueHasher<tflite::BuiltinOperator> {
42 size_t operator()(const tflite::BuiltinOperator& v) const {
43 return std::hash<int>()(static_cast<int>(v));
44 }
45};
46
47template <typename T>
48struct OperatorKeyHasher {
49 size_t operator()(const T& x) const {
50 size_t a = ValueHasher<typename T::first_type>()(x.first);
51 size_t b = ValueHasher<typename T::second_type>()(x.second);
52 return CombineHashes({a, b});
53 }
54};
55} // namespace op_resolver_hasher
56
57/// An OpResolver that is mutable, also used as the op in gen_op_registration.
58/// A typical usage:
59/// MutableOpResolver resolver;
60/// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD());
61/// resolver.AddCustom("CustomOp", Register_CUSTOM_OP());
62/// InterpreterBuilder(model, resolver)(&interpreter);
63class MutableOpResolver : public OpResolver {
64 public:
65 const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
66 int version) const override;
67 const TfLiteRegistration* FindOp(const char* op, int version) const override;
68
69 /// Registers the specified `version` of the specified builtin operator `op`.
70 /// Replaces any previous registration for the same operator version.
71 void AddBuiltin(tflite::BuiltinOperator op,
72 const TfLiteRegistration* registration, int version = 1);
73
74 /// Registers the specified version range (versions `min_version` to
75 /// `max_version`, inclusive) of the specified builtin operator `op`.
76 /// Replaces any previous registration for the same operator version.
77 void AddBuiltin(tflite::BuiltinOperator op,
78 const TfLiteRegistration* registration, int min_version,
79 int max_version);
80
81 /// Registers the specified `version` of the specified builtin operator `op`.
82 /// Replaces any previous registration for the same operator version.
83 void AddCustom(const char* name, const TfLiteRegistration* registration,
84 int version = 1);
85
86 /// Registers the specified version range (versions `min_version` to
87 /// `max_version`, inclusive) of the specified custom operator `name`.
88 /// Replaces any previous registration for the same operator version.
89 void AddCustom(const char* name, const TfLiteRegistration* registration,
90 int min_version, int max_version);
91
92 /// Registers all operator versions supported by another MutableOpResolver.
93 /// Replaces any previous registrations for the same operator versions,
94 /// except that registrations made with `AddBuiltin` or `AddCustom` always
95 /// take precedence over registrations made with `ChainOpResolver`.
96 void AddAll(const MutableOpResolver& other);
97
98 OpResolver::TfLiteDelegateCreators GetDelegateCreators() const final {
99 return delegate_creators_;
100 }
101
102 OpResolver::TfLiteOpaqueDelegateCreators GetOpaqueDelegateCreators()
103 const final {
104 return opaque_delegate_creators_;
105 }
106
107 protected:
108 /// Registers all operator versions supported by another OpResolver,
109 /// except any already registered in this MutableOpResolver.
110 /// `other` must point to an OpResolver whose lifetime is at least as long
111 /// as the lifetime of the MutableOpResolver pointed to by `this`.
112 /// The OpResolver pointed to by `other` should not be modified during the
113 /// lifetime of this MutableOpResolver.
114 void ChainOpResolver(const OpResolver* other);
115
116 /// True if this OpResolver itself (as opposed to chained op resolvers
117 /// registed with ChainOpResolver) may contain user defined ops.
118 ///
119 /// By "user defined" ops, we mean any op definitions other than those
120 /// contained in tflite::ops::builtin::BuiltinOpResolver.
121 bool may_directly_contain_user_defined_ops_ = false;
122
123 /// A vector of delegate creators to create optional delegates for resolving
124 /// and handling ops in the flatbuffer model. This may be used in addition to
125 /// the standard TfLiteRegistration lookup for graph resolution.
126 TfLiteDelegateCreators delegate_creators_;
127
128 /// A vector of opaque delegate creators to create optional opaque delegates
129 /// for resolving and handling ops in the flatbuffer model. This may be used
130 /// in addition to the standard TfLiteRegistration lookup for graph
131 /// resolution. This is used for TF Lite in Google Play Services.
132 TfLiteOpaqueDelegateCreators opaque_delegate_creators_;
133
134 private:
135 bool MayContainUserDefinedOps() const override;
136
137 typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey;
138 typedef std::pair<std::string, int> CustomOperatorKey;
139
140 std::unordered_map<BuiltinOperatorKey, TfLiteRegistration,
141 op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey> >
142 builtins_;
143 std::unordered_map<CustomOperatorKey, TfLiteRegistration,
144 op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey> >
145 custom_ops_;
146 std::vector<const OpResolver*> other_op_resolvers_;
147};
148
149} // namespace tflite
150
151#endif // TENSORFLOW_LITE_MUTABLE_OP_RESOLVER_H_
152