1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/lite/mutable_op_resolver.h"
17
18#include <string>
19#include <unordered_map>
20#include <utility>
21
22#include "tensorflow/lite/c/common.h"
23#include "tensorflow/lite/core/api/op_resolver_internal.h"
24#include "tensorflow/lite/schema/schema_generated.h"
25
26namespace tflite {
27
28const TfLiteRegistration* MutableOpResolver::FindOp(tflite::BuiltinOperator op,
29 int version) const {
30 auto it = builtins_.find(std::make_pair(op, version));
31 if (it != builtins_.end()) {
32 return &it->second;
33 }
34 for (const OpResolver* other : other_op_resolvers_) {
35 const TfLiteRegistration* result = other->FindOp(op, version);
36 if (result != nullptr) {
37 return result;
38 }
39 }
40 return nullptr;
41}
42
43const TfLiteRegistration* MutableOpResolver::FindOp(const char* op,
44 int version) const {
45 auto it = custom_ops_.find(std::make_pair(op, version));
46 if (it != custom_ops_.end()) {
47 return &it->second;
48 }
49 for (const OpResolver* other : other_op_resolvers_) {
50 const TfLiteRegistration* result = other->FindOp(op, version);
51 if (result != nullptr) {
52 return result;
53 }
54 }
55 return nullptr;
56}
57
58void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op,
59 const TfLiteRegistration* registration,
60 int version) {
61 if (registration == nullptr) {
62 // Under certain conditions, builtin TfLiteRegistration factory methods may
63 // return null in the client library. This is generally benign, and we
64 // silently suppress resulting AddBuiltin calls here.
65 return;
66 }
67 TfLiteRegistration new_registration = *registration;
68 new_registration.custom_name = nullptr;
69 new_registration.builtin_code = op;
70 new_registration.version = version;
71 auto op_key = std::make_pair(op, version);
72 builtins_[op_key] = new_registration;
73 // The builtin op that is being added may be one that is not supported by
74 // tflite::ops::builtin::BuiltinOpResolver. Or the TfLiteRegistration for this
75 // builtin may be different than the one that BuiltinOpResolver would use,
76 // which could lead to different semantics. Both of those cases are considered
77 // "user defined ops".
78 may_directly_contain_user_defined_ops_ = true;
79}
80
81void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op,
82 const TfLiteRegistration* registration,
83 int min_version, int max_version) {
84 for (int version = min_version; version <= max_version; ++version) {
85 AddBuiltin(op, registration, version);
86 }
87}
88
89void MutableOpResolver::AddCustom(const char* name,
90 const TfLiteRegistration* registration,
91 int version) {
92 TfLiteRegistration new_registration = *registration;
93 new_registration.builtin_code = BuiltinOperator_CUSTOM;
94 new_registration.custom_name = name;
95 new_registration.version = version;
96 auto op_key = std::make_pair(name, version);
97 custom_ops_[op_key] = new_registration;
98 may_directly_contain_user_defined_ops_ = true;
99}
100
101void MutableOpResolver::AddCustom(const char* name,
102 const TfLiteRegistration* registration,
103 int min_version, int max_version) {
104 for (int version = min_version; version <= max_version; ++version) {
105 AddCustom(name, registration, version);
106 }
107}
108
109void MutableOpResolver::AddAll(const MutableOpResolver& other) {
110 // map::insert does not replace existing elements, and map::insert_or_assign
111 // wasn't added until C++17.
112 for (const auto& other_builtin : other.builtins_) {
113 builtins_[other_builtin.first] = other_builtin.second;
114 }
115 for (const auto& other_custom_op : other.custom_ops_) {
116 custom_ops_[other_custom_op.first] = other_custom_op.second;
117 }
118 other_op_resolvers_.insert(other_op_resolvers_.begin(),
119 other.other_op_resolvers_.begin(),
120 other.other_op_resolvers_.end());
121}
122
123void MutableOpResolver::ChainOpResolver(const OpResolver* other) {
124 other_op_resolvers_.push_back(other);
125}
126
127bool MutableOpResolver::MayContainUserDefinedOps() const {
128 if (may_directly_contain_user_defined_ops_) {
129 return true;
130 }
131 for (const OpResolver* other : other_op_resolvers_) {
132 if (OpResolverInternal::MayContainUserDefinedOps(*other)) {
133 return true;
134 }
135 }
136 return false;
137}
138
139} // namespace tflite
140