1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/toco/tflite/import.h" |
16 | |
17 | #include <memory> |
18 | #include <string> |
19 | |
20 | #include "flatbuffers/flexbuffers.h" |
21 | #include "tensorflow/lite/model.h" |
22 | #include "tensorflow/lite/schema/schema_generated.h" |
23 | #include "tensorflow/lite/schema/schema_utils.h" |
24 | #include "tensorflow/lite/toco/tflite/operator.h" |
25 | #include "tensorflow/lite/toco/tflite/types.h" |
26 | #include "tensorflow/lite/toco/tooling_util.h" |
27 | #include "tensorflow/lite/tools/verifier.h" |
28 | |
29 | namespace toco { |
30 | |
31 | namespace tflite { |
32 | |
33 | namespace details { |
34 | void LoadTensorsTable(const ::tflite::Model& input_model, |
35 | TensorsTable* tensors_table) { |
36 | // TODO(aselle): add support to toco for multiple subgraphs. |
37 | auto tensors = (*input_model.subgraphs())[0]->tensors(); |
38 | if (!tensors) return; |
39 | for (const auto* tensor : *tensors) { |
40 | tensors_table->push_back(tensor->name()->c_str()); |
41 | } |
42 | } |
43 | |
44 | void LoadOperatorsTable(const ::tflite::Model& input_model, |
45 | OperatorsTable* operators_table) { |
46 | auto opcodes = input_model.operator_codes(); |
47 | if (!opcodes) return; |
48 | for (const auto* opcode : *opcodes) { |
49 | auto builtin_code = GetBuiltinCode(opcode); |
50 | if (builtin_code != ::tflite::BuiltinOperator_CUSTOM) { |
51 | operators_table->push_back(EnumNameBuiltinOperator(builtin_code)); |
52 | } else { |
53 | operators_table->push_back(opcode->custom_code()->c_str()); |
54 | } |
55 | } |
56 | } |
57 | } // namespace details |
58 | |
59 | void ImportTensors(const ::tflite::Model& input_model, Model* model) { |
60 | auto tensors = (*input_model.subgraphs())[0]->tensors(); |
61 | auto* buffers = input_model.buffers(); |
62 | // auto tensors = input_model.tensors(); |
63 | if (!tensors) return; |
64 | for (const auto* input_tensor : *tensors) { |
65 | Array& array = model->GetOrCreateArray(input_tensor->name()->c_str()); |
66 | array.data_type = DataType::Deserialize(input_tensor->type()); |
67 | int buffer_index = input_tensor->buffer(); |
68 | auto* buffer = buffers->Get(buffer_index); |
69 | DataBuffer::Deserialize(*input_tensor, *buffer, &array); |
70 | |
71 | auto shape = input_tensor->shape(); |
72 | if (shape) { |
73 | // If the shape is 0-dimensional, make sure to record it as such, |
74 | // as oppose to leaving the array without a shape. |
75 | array.mutable_shape()->mutable_dims()->clear(); |
76 | for (uint32_t i = 0; i < shape->Length(); ++i) { |
77 | auto d = shape->Get(i); |
78 | array.mutable_shape()->mutable_dims()->push_back(d); |
79 | } |
80 | } |
81 | |
82 | auto quantization = input_tensor->quantization(); |
83 | if (quantization) { |
84 | // Note that tf.mini only supports a single quantization parameters for |
85 | // the whole array. |
86 | if (quantization->min() && quantization->max()) { |
87 | CHECK_EQ(1, quantization->min()->Length()); |
88 | CHECK_EQ(1, quantization->max()->Length()); |
89 | MinMax& minmax = array.GetOrCreateMinMax(); |
90 | minmax.min = quantization->min()->Get(0); |
91 | minmax.max = quantization->max()->Get(0); |
92 | } |
93 | if (quantization->scale() && quantization->zero_point()) { |
94 | CHECK_EQ(1, quantization->scale()->Length()); |
95 | CHECK_EQ(1, quantization->zero_point()->Length()); |
96 | QuantizationParams& q = array.GetOrCreateQuantizationParams(); |
97 | q.scale = quantization->scale()->Get(0); |
98 | q.zero_point = quantization->zero_point()->Get(0); |
99 | } |
100 | } |
101 | } |
102 | } |
103 | |
104 | void ImportOperators( |
105 | const ::tflite::Model& input_model, |
106 | const std::map<std::string, std::unique_ptr<BaseOperator>>& ops_by_name, |
107 | const details::TensorsTable& tensors_table, |
108 | const details::OperatorsTable& operators_table, Model* model) { |
109 | // TODO(aselle): add support for multiple subgraphs. |
110 | auto ops = (*input_model.subgraphs())[0]->operators(); |
111 | |
112 | if (!ops) return; |
113 | for (const auto* input_op : *ops) { |
114 | uint32_t index = input_op->opcode_index(); |
115 | if (index > operators_table.size()) { |
116 | LOG(FATAL) << "Index " << index << " must be between zero and " |
117 | << operators_table.size(); |
118 | } |
119 | std::string opname = operators_table.at(index); |
120 | |
121 | // Find and use the appropriate operator deserialization factory. |
122 | std::unique_ptr<Operator> new_op = nullptr; |
123 | if (ops_by_name.count(opname) == 0) { |
124 | std::string effective_opname = "TENSORFLOW_UNSUPPORTED" ; |
125 | if (ops_by_name.count(effective_opname) == 0) { |
126 | LOG(FATAL) << "Internal logic error: TENSORFLOW_UNSUPPORTED not found." ; |
127 | } |
128 | new_op = ops_by_name.at(effective_opname) |
129 | ->Deserialize(input_op->builtin_options(), |
130 | input_op->custom_options()); |
131 | if (new_op->type == OperatorType::kUnsupported) { |
132 | auto* unsupported_op = |
133 | static_cast<TensorFlowUnsupportedOperator*>(new_op.get()); |
134 | unsupported_op->tensorflow_op = opname; |
135 | // TODO(b/109932940): Remove this when quantized is removed. |
136 | // For now, we assume all ops are quantized. |
137 | unsupported_op->quantized = true; |
138 | } else { |
139 | LOG(FATAL) << "Expected a TensorFlowUnsupportedOperator" ; |
140 | } |
141 | } else { |
142 | new_op = ops_by_name.at(opname)->Deserialize(input_op->builtin_options(), |
143 | input_op->custom_options()); |
144 | } |
145 | model->operators.emplace_back(new_op.release()); |
146 | auto* op = model->operators.back().get(); |
147 | |
148 | // Make sure all the inputs and outputs are hooked up. |
149 | auto inputs = input_op->inputs(); |
150 | for (uint32_t i = 0; i < inputs->Length(); i++) { |
151 | auto input_index = inputs->Get(i); |
152 | // input_index == -1 indicates optional tensor. |
153 | if (input_index != -1) { |
154 | const std::string& input_name = tensors_table.at(input_index); |
155 | op->inputs.push_back(input_name); |
156 | } else { |
157 | const std::string& tensor_name = |
158 | toco::AvailableArrayName(*model, "OptionalTensor" ); |
159 | model->CreateOptionalArray(tensor_name); |
160 | op->inputs.push_back(tensor_name); |
161 | } |
162 | } |
163 | auto outputs = input_op->outputs(); |
164 | for (int i = 0, end = outputs->Length(); i < end; i++) { |
165 | auto output_index = outputs->Get(i); |
166 | const std::string& output_name = tensors_table.at(output_index); |
167 | op->outputs.push_back(output_name); |
168 | } |
169 | } |
170 | } |
171 | |
172 | void ImportIOTensors(const ModelFlags& model_flags, |
173 | const ::tflite::Model& input_model, |
174 | const details::TensorsTable& tensors_table, Model* model) { |
175 | // Import from the first subgraph if input arrays have not been specified. |
176 | if (model_flags.input_arrays().empty()) { |
177 | auto inputs = (*input_model.subgraphs())[0]->inputs(); |
178 | if (inputs) { |
179 | for (int input : *inputs) { |
180 | const std::string& input_name = tensors_table.at(input); |
181 | model->flags.add_input_arrays()->set_name(input_name); |
182 | } |
183 | } |
184 | } |
185 | |
186 | // Import from the first subgraph if output arrays have not been specified. |
187 | if (model_flags.output_arrays().empty()) { |
188 | auto outputs = (*input_model.subgraphs())[0]->outputs(); |
189 | if (outputs) { |
190 | for (int output : *outputs) { |
191 | const std::string& output_name = tensors_table.at(output); |
192 | model->flags.add_output_arrays(output_name); |
193 | } |
194 | } |
195 | } |
196 | } |
197 | |
198 | namespace { |
199 | bool Verify(const void* buf, size_t len) { |
200 | ::flatbuffers::Verifier verifier(static_cast<const uint8_t*>(buf), len); |
201 | return ::tflite::VerifyModelBuffer(verifier); |
202 | } |
203 | } // namespace |
204 | |
205 | std::unique_ptr<Model> Import(const ModelFlags& model_flags, |
206 | const std::string& input_file_contents) { |
207 | ::tflite::AlwaysTrueResolver r; |
208 | if (!::tflite::Verify(input_file_contents.data(), input_file_contents.size(), |
209 | r, ::tflite::DefaultErrorReporter())) { |
210 | LOG(FATAL) << "Invalid flatbuffer." ; |
211 | } |
212 | const ::tflite::Model* input_model = |
213 | ::tflite::GetModel(input_file_contents.data()); |
214 | |
215 | // Full list of all known operators. |
216 | const auto ops_by_name = BuildOperatorByNameMap(); |
217 | |
218 | if (!input_model->subgraphs() || input_model->subgraphs()->size() != 1) { |
219 | LOG(FATAL) << "Number of subgraphs in tflite should be exactly 1." ; |
220 | } |
221 | std::unique_ptr<Model> model; |
222 | model = std::make_unique<Model>(); |
223 | |
224 | details::TensorsTable tensors_table; |
225 | details::LoadTensorsTable(*input_model, &tensors_table); |
226 | |
227 | details::OperatorsTable operators_table; |
228 | details::LoadOperatorsTable(*input_model, &operators_table); |
229 | |
230 | ImportTensors(*input_model, model.get()); |
231 | ImportOperators(*input_model, ops_by_name, tensors_table, operators_table, |
232 | model.get()); |
233 | |
234 | ImportIOTensors(model_flags, *input_model, tensors_table, model.get()); |
235 | |
236 | UndoWeightsShuffling(model.get()); |
237 | |
238 | return model; |
239 | } |
240 | |
241 | } // namespace tflite |
242 | |
243 | } // namespace toco |
244 | |