1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/lite/mutable_op_resolver.h" |
17 | |
18 | #include <string> |
19 | #include <unordered_map> |
20 | #include <utility> |
21 | |
22 | #include "tensorflow/lite/c/common.h" |
23 | #include "tensorflow/lite/core/api/op_resolver_internal.h" |
24 | #include "tensorflow/lite/schema/schema_generated.h" |
25 | |
26 | namespace tflite { |
27 | |
28 | const TfLiteRegistration* MutableOpResolver::FindOp(tflite::BuiltinOperator op, |
29 | int version) const { |
30 | auto it = builtins_.find(std::make_pair(op, version)); |
31 | if (it != builtins_.end()) { |
32 | return &it->second; |
33 | } |
34 | for (const OpResolver* other : other_op_resolvers_) { |
35 | const TfLiteRegistration* result = other->FindOp(op, version); |
36 | if (result != nullptr) { |
37 | return result; |
38 | } |
39 | } |
40 | return nullptr; |
41 | } |
42 | |
43 | const TfLiteRegistration* MutableOpResolver::FindOp(const char* op, |
44 | int version) const { |
45 | auto it = custom_ops_.find(std::make_pair(op, version)); |
46 | if (it != custom_ops_.end()) { |
47 | return &it->second; |
48 | } |
49 | for (const OpResolver* other : other_op_resolvers_) { |
50 | const TfLiteRegistration* result = other->FindOp(op, version); |
51 | if (result != nullptr) { |
52 | return result; |
53 | } |
54 | } |
55 | return nullptr; |
56 | } |
57 | |
58 | void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op, |
59 | const TfLiteRegistration* registration, |
60 | int version) { |
61 | if (registration == nullptr) { |
62 | // Under certain conditions, builtin TfLiteRegistration factory methods may |
63 | // return null in the client library. This is generally benign, and we |
64 | // silently suppress resulting AddBuiltin calls here. |
65 | return; |
66 | } |
67 | TfLiteRegistration new_registration = *registration; |
68 | new_registration.custom_name = nullptr; |
69 | new_registration.builtin_code = op; |
70 | new_registration.version = version; |
71 | auto op_key = std::make_pair(op, version); |
72 | builtins_[op_key] = new_registration; |
73 | // The builtin op that is being added may be one that is not supported by |
74 | // tflite::ops::builtin::BuiltinOpResolver. Or the TfLiteRegistration for this |
75 | // builtin may be different than the one that BuiltinOpResolver would use, |
76 | // which could lead to different semantics. Both of those cases are considered |
77 | // "user defined ops". |
78 | may_directly_contain_user_defined_ops_ = true; |
79 | } |
80 | |
81 | void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op, |
82 | const TfLiteRegistration* registration, |
83 | int min_version, int max_version) { |
84 | for (int version = min_version; version <= max_version; ++version) { |
85 | AddBuiltin(op, registration, version); |
86 | } |
87 | } |
88 | |
89 | void MutableOpResolver::AddCustom(const char* name, |
90 | const TfLiteRegistration* registration, |
91 | int version) { |
92 | TfLiteRegistration new_registration = *registration; |
93 | new_registration.builtin_code = BuiltinOperator_CUSTOM; |
94 | new_registration.custom_name = name; |
95 | new_registration.version = version; |
96 | auto op_key = std::make_pair(name, version); |
97 | custom_ops_[op_key] = new_registration; |
98 | may_directly_contain_user_defined_ops_ = true; |
99 | } |
100 | |
101 | void MutableOpResolver::AddCustom(const char* name, |
102 | const TfLiteRegistration* registration, |
103 | int min_version, int max_version) { |
104 | for (int version = min_version; version <= max_version; ++version) { |
105 | AddCustom(name, registration, version); |
106 | } |
107 | } |
108 | |
109 | void MutableOpResolver::AddAll(const MutableOpResolver& other) { |
110 | // map::insert does not replace existing elements, and map::insert_or_assign |
111 | // wasn't added until C++17. |
112 | for (const auto& other_builtin : other.builtins_) { |
113 | builtins_[other_builtin.first] = other_builtin.second; |
114 | } |
115 | for (const auto& other_custom_op : other.custom_ops_) { |
116 | custom_ops_[other_custom_op.first] = other_custom_op.second; |
117 | } |
118 | other_op_resolvers_.insert(other_op_resolvers_.begin(), |
119 | other.other_op_resolvers_.begin(), |
120 | other.other_op_resolvers_.end()); |
121 | } |
122 | |
123 | void MutableOpResolver::ChainOpResolver(const OpResolver* other) { |
124 | other_op_resolvers_.push_back(other); |
125 | } |
126 | |
127 | bool MutableOpResolver::MayContainUserDefinedOps() const { |
128 | if (may_directly_contain_user_defined_ops_) { |
129 | return true; |
130 | } |
131 | for (const OpResolver* other : other_op_resolvers_) { |
132 | if (OpResolverInternal::MayContainUserDefinedOps(*other)) { |
133 | return true; |
134 | } |
135 | } |
136 | return false; |
137 | } |
138 | |
139 | } // namespace tflite |
140 | |