1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/toco/tflite/import.h"
16
17#include <memory>
18#include <string>
19
20#include "flatbuffers/flexbuffers.h"
21#include "tensorflow/lite/model.h"
22#include "tensorflow/lite/schema/schema_generated.h"
23#include "tensorflow/lite/schema/schema_utils.h"
24#include "tensorflow/lite/toco/tflite/operator.h"
25#include "tensorflow/lite/toco/tflite/types.h"
26#include "tensorflow/lite/toco/tooling_util.h"
27#include "tensorflow/lite/tools/verifier.h"
28
29namespace toco {
30
31namespace tflite {
32
33namespace details {
34void LoadTensorsTable(const ::tflite::Model& input_model,
35 TensorsTable* tensors_table) {
36 // TODO(aselle): add support to toco for multiple subgraphs.
37 auto tensors = (*input_model.subgraphs())[0]->tensors();
38 if (!tensors) return;
39 for (const auto* tensor : *tensors) {
40 tensors_table->push_back(tensor->name()->c_str());
41 }
42}
43
44void LoadOperatorsTable(const ::tflite::Model& input_model,
45 OperatorsTable* operators_table) {
46 auto opcodes = input_model.operator_codes();
47 if (!opcodes) return;
48 for (const auto* opcode : *opcodes) {
49 auto builtin_code = GetBuiltinCode(opcode);
50 if (builtin_code != ::tflite::BuiltinOperator_CUSTOM) {
51 operators_table->push_back(EnumNameBuiltinOperator(builtin_code));
52 } else {
53 operators_table->push_back(opcode->custom_code()->c_str());
54 }
55 }
56}
57} // namespace details
58
59void ImportTensors(const ::tflite::Model& input_model, Model* model) {
60 auto tensors = (*input_model.subgraphs())[0]->tensors();
61 auto* buffers = input_model.buffers();
62 // auto tensors = input_model.tensors();
63 if (!tensors) return;
64 for (const auto* input_tensor : *tensors) {
65 Array& array = model->GetOrCreateArray(input_tensor->name()->c_str());
66 array.data_type = DataType::Deserialize(input_tensor->type());
67 int buffer_index = input_tensor->buffer();
68 auto* buffer = buffers->Get(buffer_index);
69 DataBuffer::Deserialize(*input_tensor, *buffer, &array);
70
71 auto shape = input_tensor->shape();
72 if (shape) {
73 // If the shape is 0-dimensional, make sure to record it as such,
74 // as oppose to leaving the array without a shape.
75 array.mutable_shape()->mutable_dims()->clear();
76 for (uint32_t i = 0; i < shape->Length(); ++i) {
77 auto d = shape->Get(i);
78 array.mutable_shape()->mutable_dims()->push_back(d);
79 }
80 }
81
82 auto quantization = input_tensor->quantization();
83 if (quantization) {
84 // Note that tf.mini only supports a single quantization parameters for
85 // the whole array.
86 if (quantization->min() && quantization->max()) {
87 CHECK_EQ(1, quantization->min()->Length());
88 CHECK_EQ(1, quantization->max()->Length());
89 MinMax& minmax = array.GetOrCreateMinMax();
90 minmax.min = quantization->min()->Get(0);
91 minmax.max = quantization->max()->Get(0);
92 }
93 if (quantization->scale() && quantization->zero_point()) {
94 CHECK_EQ(1, quantization->scale()->Length());
95 CHECK_EQ(1, quantization->zero_point()->Length());
96 QuantizationParams& q = array.GetOrCreateQuantizationParams();
97 q.scale = quantization->scale()->Get(0);
98 q.zero_point = quantization->zero_point()->Get(0);
99 }
100 }
101 }
102}
103
104void ImportOperators(
105 const ::tflite::Model& input_model,
106 const std::map<std::string, std::unique_ptr<BaseOperator>>& ops_by_name,
107 const details::TensorsTable& tensors_table,
108 const details::OperatorsTable& operators_table, Model* model) {
109 // TODO(aselle): add support for multiple subgraphs.
110 auto ops = (*input_model.subgraphs())[0]->operators();
111
112 if (!ops) return;
113 for (const auto* input_op : *ops) {
114 uint32_t index = input_op->opcode_index();
115 if (index > operators_table.size()) {
116 LOG(FATAL) << "Index " << index << " must be between zero and "
117 << operators_table.size();
118 }
119 std::string opname = operators_table.at(index);
120
121 // Find and use the appropriate operator deserialization factory.
122 std::unique_ptr<Operator> new_op = nullptr;
123 if (ops_by_name.count(opname) == 0) {
124 std::string effective_opname = "TENSORFLOW_UNSUPPORTED";
125 if (ops_by_name.count(effective_opname) == 0) {
126 LOG(FATAL) << "Internal logic error: TENSORFLOW_UNSUPPORTED not found.";
127 }
128 new_op = ops_by_name.at(effective_opname)
129 ->Deserialize(input_op->builtin_options(),
130 input_op->custom_options());
131 if (new_op->type == OperatorType::kUnsupported) {
132 auto* unsupported_op =
133 static_cast<TensorFlowUnsupportedOperator*>(new_op.get());
134 unsupported_op->tensorflow_op = opname;
135 // TODO(b/109932940): Remove this when quantized is removed.
136 // For now, we assume all ops are quantized.
137 unsupported_op->quantized = true;
138 } else {
139 LOG(FATAL) << "Expected a TensorFlowUnsupportedOperator";
140 }
141 } else {
142 new_op = ops_by_name.at(opname)->Deserialize(input_op->builtin_options(),
143 input_op->custom_options());
144 }
145 model->operators.emplace_back(new_op.release());
146 auto* op = model->operators.back().get();
147
148 // Make sure all the inputs and outputs are hooked up.
149 auto inputs = input_op->inputs();
150 for (uint32_t i = 0; i < inputs->Length(); i++) {
151 auto input_index = inputs->Get(i);
152 // input_index == -1 indicates optional tensor.
153 if (input_index != -1) {
154 const std::string& input_name = tensors_table.at(input_index);
155 op->inputs.push_back(input_name);
156 } else {
157 const std::string& tensor_name =
158 toco::AvailableArrayName(*model, "OptionalTensor");
159 model->CreateOptionalArray(tensor_name);
160 op->inputs.push_back(tensor_name);
161 }
162 }
163 auto outputs = input_op->outputs();
164 for (int i = 0, end = outputs->Length(); i < end; i++) {
165 auto output_index = outputs->Get(i);
166 const std::string& output_name = tensors_table.at(output_index);
167 op->outputs.push_back(output_name);
168 }
169 }
170}
171
172void ImportIOTensors(const ModelFlags& model_flags,
173 const ::tflite::Model& input_model,
174 const details::TensorsTable& tensors_table, Model* model) {
175 // Import from the first subgraph if input arrays have not been specified.
176 if (model_flags.input_arrays().empty()) {
177 auto inputs = (*input_model.subgraphs())[0]->inputs();
178 if (inputs) {
179 for (int input : *inputs) {
180 const std::string& input_name = tensors_table.at(input);
181 model->flags.add_input_arrays()->set_name(input_name);
182 }
183 }
184 }
185
186 // Import from the first subgraph if output arrays have not been specified.
187 if (model_flags.output_arrays().empty()) {
188 auto outputs = (*input_model.subgraphs())[0]->outputs();
189 if (outputs) {
190 for (int output : *outputs) {
191 const std::string& output_name = tensors_table.at(output);
192 model->flags.add_output_arrays(output_name);
193 }
194 }
195 }
196}
197
198namespace {
199bool Verify(const void* buf, size_t len) {
200 ::flatbuffers::Verifier verifier(static_cast<const uint8_t*>(buf), len);
201 return ::tflite::VerifyModelBuffer(verifier);
202}
203} // namespace
204
205std::unique_ptr<Model> Import(const ModelFlags& model_flags,
206 const std::string& input_file_contents) {
207 ::tflite::AlwaysTrueResolver r;
208 if (!::tflite::Verify(input_file_contents.data(), input_file_contents.size(),
209 r, ::tflite::DefaultErrorReporter())) {
210 LOG(FATAL) << "Invalid flatbuffer.";
211 }
212 const ::tflite::Model* input_model =
213 ::tflite::GetModel(input_file_contents.data());
214
215 // Full list of all known operators.
216 const auto ops_by_name = BuildOperatorByNameMap();
217
218 if (!input_model->subgraphs() || input_model->subgraphs()->size() != 1) {
219 LOG(FATAL) << "Number of subgraphs in tflite should be exactly 1.";
220 }
221 std::unique_ptr<Model> model;
222 model = std::make_unique<Model>();
223
224 details::TensorsTable tensors_table;
225 details::LoadTensorsTable(*input_model, &tensors_table);
226
227 details::OperatorsTable operators_table;
228 details::LoadOperatorsTable(*input_model, &operators_table);
229
230 ImportTensors(*input_model, model.get());
231 ImportOperators(*input_model, ops_by_name, tensors_table, operators_table,
232 model.get());
233
234 ImportIOTensors(model_flags, *input_model, tensors_table, model.get());
235
236 UndoWeightsShuffling(model.get());
237
238 return model;
239}
240
241} // namespace tflite
242
243} // namespace toco
244