1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/toco/tooling_util.h" |
16 | |
17 | #include <algorithm> |
18 | #include <functional> |
19 | #include <iterator> |
20 | #include <set> |
21 | #include <string> |
22 | #include <unordered_map> |
23 | #include <unordered_set> |
24 | #include <utility> |
25 | |
26 | #include "absl/strings/ascii.h" |
27 | #include "absl/strings/str_cat.h" |
28 | #include "absl/strings/str_join.h" |
29 | #include "absl/strings/str_replace.h" |
30 | #include "absl/strings/str_split.h" |
31 | #include "re2/re2.h" |
32 | #include "tensorflow/core/lib/core/status.h" |
33 | #include "tensorflow/core/platform/logging.h" |
34 | #include "tensorflow/lite/toco/dump_graphviz.h" |
35 | #include "tensorflow/lite/toco/model_flags.pb.h" |
36 | #include "tensorflow/lite/toco/toco_graphviz_dump_options.h" |
37 | |
38 | namespace toco { |
39 | |
40 | // Find the longest common prefix of two strings. |
41 | absl::string_view FindLongestCommonPrefix(absl::string_view a, |
42 | absl::string_view b) { |
43 | if (a.empty() || b.empty()) return absl::string_view(); |
44 | |
45 | const char* pa = a.data(); |
46 | const char* pb = b.data(); |
47 | size_t count = 0; |
48 | const size_t limit = std::min(a.size(), b.size()); |
49 | while (count < limit && *pa == *pb) { |
50 | ++pa; |
51 | ++pb; |
52 | ++count; |
53 | } |
54 | |
55 | return absl::string_view(a.data(), count); |
56 | } |
57 | |
58 | std::string LogName(const Operator& op) { |
59 | const std::string& opname = HelpfulOperatorTypeName(op); |
60 | if (op.outputs.empty()) { |
61 | return toco::port::StringF("{%s operator}" , opname); |
62 | } else { |
63 | return toco::port::StringF("{%s operator with output %s}" , opname, |
64 | op.outputs[0]); |
65 | } |
66 | } |
67 | |
68 | std::string ArrayDataTypeName(ArrayDataType data_type) { |
69 | switch (data_type) { |
70 | case ArrayDataType::kFloat: |
71 | return "float" ; |
72 | case ArrayDataType::kInt8: |
73 | return "int8" ; |
74 | case ArrayDataType::kUint8: |
75 | return "uint8" ; |
76 | case ArrayDataType::kInt16: |
77 | return "int16" ; |
78 | case ArrayDataType::kUint16: |
79 | return "uint16" ; |
80 | case ArrayDataType::kInt32: |
81 | return "int32" ; |
82 | case ArrayDataType::kUint32: |
83 | return "uint32" ; |
84 | case ArrayDataType::kInt64: |
85 | return "int64" ; |
86 | case ArrayDataType::kUint64: |
87 | return "uint64" ; |
88 | case ArrayDataType::kString: |
89 | return "string" ; |
90 | case ArrayDataType::kBool: |
91 | return "bool" ; |
92 | case ArrayDataType::kComplex64: |
93 | return "complex64" ; |
94 | case ArrayDataType::kNone: |
95 | return "None" ; |
96 | default: |
97 | LOG(FATAL) << "Unhandled array data type " << static_cast<int>(data_type); |
98 | } |
99 | } |
100 | |
101 | bool IsInputArray(const Model& model, const std::string& array_name) { |
102 | for (const auto& input_array : model.flags.input_arrays()) { |
103 | if (array_name == input_array.name()) { |
104 | return true; |
105 | } |
106 | } |
107 | return false; |
108 | } |
109 | |
110 | bool IsOutputArray(const Model& model, const std::string& array_name) { |
111 | for (const auto& output_array : model.flags.output_arrays()) { |
112 | if (array_name == output_array) { |
113 | return true; |
114 | } |
115 | } |
116 | return false; |
117 | } |
118 | |
119 | bool IsArrayConsumed(const Model& model, const std::string& name) { |
120 | if (GetOpWithInput(model, name)) { |
121 | return true; |
122 | } |
123 | if (IsOutputArray(model, name)) { |
124 | return true; |
125 | } |
126 | for (const auto& rnn_state : model.flags.rnn_states()) { |
127 | if (rnn_state.back_edge_source_array() == name) { |
128 | return true; |
129 | } |
130 | } |
131 | return false; |
132 | } |
133 | |
134 | int CountTrueOutputs(const Model& model, const Operator& op) { |
135 | int count = 0; |
136 | for (const std::string& output : op.outputs) { |
137 | if (IsArrayConsumed(model, output)) { |
138 | ++count; |
139 | } |
140 | } |
141 | return count; |
142 | } |
143 | |
144 | int CountOpsWithInput(const Model& model, const std::string& array_name) { |
145 | int count = 0; |
146 | for (const auto& op : model.operators) { |
147 | for (auto& input : op->inputs) { |
148 | if (input == array_name) { |
149 | count++; |
150 | // Breaking here is important: some graphs have ops that use the |
151 | // same array as more than one of their inputs, and in that case |
152 | // we want it counted only once. |
153 | break; |
154 | } |
155 | } |
156 | } |
157 | return count; |
158 | } |
159 | |
160 | bool DeleteArrayIfUnused(const std::string& array_name, Model* model) { |
161 | if (IsDiscardableArray(*model, array_name) && |
162 | CountOpsWithInput(*model, array_name) == 0 && |
163 | GetOpWithOutput(*model, array_name) == nullptr) { |
164 | model->EraseArray(array_name); |
165 | return true; |
166 | } |
167 | return false; |
168 | } |
169 | |
170 | bool DeleteArrayIfUnusedOutsideOfOp(const std::string& array_name, |
171 | const Operator* op, Model* model) { |
172 | if (!IsDiscardableArray(*model, array_name)) { |
173 | return false; |
174 | } |
175 | if (CountOpsWithInput(*model, array_name) > 1) { |
176 | return false; |
177 | } |
178 | const Operator* op_having_this_as_input = GetOpWithInput(*model, array_name); |
179 | if (op_having_this_as_input && op_having_this_as_input != op) { |
180 | return false; |
181 | } |
182 | const Operator* op_having_this_as_output = |
183 | GetOpWithOutput(*model, array_name); |
184 | if (op_having_this_as_output && op_having_this_as_output != op) { |
185 | return false; |
186 | } |
187 | model->EraseArray(array_name); |
188 | return true; |
189 | } |
190 | |
191 | void DeleteOpAndArrays(Model* model, const Operator* op) { |
192 | for (const std::string& array_name : op->inputs) { |
193 | DeleteArrayIfUnusedOutsideOfOp(array_name, op, model); |
194 | } |
195 | for (const std::string& array_name : op->outputs) { |
196 | DeleteArrayIfUnusedOutsideOfOp(array_name, op, model); |
197 | } |
198 | auto op_it = FindOp(*model, op); |
199 | CHECK(op_it != model->operators.end()); |
200 | model->operators.erase(op_it); |
201 | } |
202 | |
203 | std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithOutput( |
204 | const Model& model, const std::string& array_name) { |
205 | for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { |
206 | for (auto& output : it->get()->outputs) { |
207 | if (output == array_name) { |
208 | return it; |
209 | } |
210 | } |
211 | } |
212 | return model.operators.end(); |
213 | } |
214 | |
215 | std::vector<std::unique_ptr<Operator>>::iterator FindOpWithOutput( |
216 | Model& model, const std::string& array_name) { |
217 | for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { |
218 | for (auto& output : it->get()->outputs) { |
219 | if (output == array_name) { |
220 | return it; |
221 | } |
222 | } |
223 | } |
224 | return model.operators.end(); |
225 | } |
226 | |
227 | Operator* GetOpWithOutput(const Model& model, const std::string& array_name) { |
228 | auto it = FindOpWithOutput(model, array_name); |
229 | return it == model.operators.end() ? nullptr : it->get(); |
230 | } |
231 | |
232 | // GetFirstOpWithInput assumes that this finds the first op. |
233 | std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithInput( |
234 | const Model& model, const std::string& array_name) { |
235 | for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { |
236 | for (auto& input : it->get()->inputs) { |
237 | if (input == array_name) { |
238 | return it; |
239 | } |
240 | } |
241 | } |
242 | return model.operators.end(); |
243 | } |
244 | |
245 | std::vector<std::unique_ptr<Operator>>::iterator FindOpWithInput( |
246 | Model& model, const std::string& array_name) { |
247 | for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { |
248 | for (auto& input : it->get()->inputs) { |
249 | if (input == array_name) { |
250 | return it; |
251 | } |
252 | } |
253 | } |
254 | return model.operators.end(); |
255 | } |
256 | |
257 | std::vector<std::unique_ptr<Operator>>::const_iterator FindOp( |
258 | const Model& model, const Operator* op) { |
259 | for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { |
260 | if (it->get() == op) { |
261 | return it; |
262 | } |
263 | } |
264 | return model.operators.end(); |
265 | } |
266 | |
267 | std::vector<std::unique_ptr<Operator>>::iterator FindOp(Model& model, |
268 | const Operator* op) { |
269 | for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { |
270 | if (it->get() == op) { |
271 | return it; |
272 | } |
273 | } |
274 | return model.operators.end(); |
275 | } |
276 | |
277 | Operator* GetOpWithInput(const Model& model, const std::string& array_name) { |
278 | auto it = FindOpWithInput(model, array_name); |
279 | return it == model.operators.end() ? nullptr : it->get(); |
280 | } |
281 | |
282 | Operator* GetFirstOpWithInput(const Model& model, |
283 | const std::string& array_name) { |
284 | auto it = FindOpWithInput(model, array_name); |
285 | return it == model.operators.end() ? nullptr : it->get(); |
286 | } |
287 | |
288 | void ReplaceArrayUsage(Model* model, const std::string& old_array_name, |
289 | const std::string& new_array_name) { |
290 | for (auto& op_it : model->operators) { |
291 | Operator* op = op_it.get(); |
292 | for (size_t i = 0; i < op->inputs.size(); ++i) { |
293 | if (op->inputs[i] == old_array_name) { |
294 | op->inputs[i] = new_array_name; |
295 | } |
296 | } |
297 | for (size_t i = 0; i < op->outputs.size(); ++i) { |
298 | if (op->outputs[i] == old_array_name) { |
299 | op->outputs[i] = new_array_name; |
300 | } |
301 | } |
302 | } |
303 | } |
304 | |
305 | std::string FormatArraysList(const Model& model, |
306 | const std::vector<std::string>& list) { |
307 | if (list.empty()) { |
308 | return "[]" ; |
309 | } |
310 | std::string result = "" ; |
311 | if (list.size() > 1) { |
312 | result += "[ " ; |
313 | } |
314 | for (std::size_t i = 0; i < list.size(); i++) { |
315 | if (i > 0) { |
316 | result += ", " ; |
317 | } |
318 | result += list[i]; |
319 | } |
320 | if (list.size() > 1) { |
321 | result += " ]" ; |
322 | } |
323 | return result; |
324 | } |
325 | |
326 | const char* OperatorTypeName(OperatorType type) { |
327 | switch (type) { |
328 | #define HANDLE_OPERATORTYPENAME_CASE(c) \ |
329 | case OperatorType::k##c: \ |
330 | return #c; |
331 | HANDLE_OPERATORTYPENAME_CASE(Abs) |
332 | HANDLE_OPERATORTYPENAME_CASE(Add) |
333 | HANDLE_OPERATORTYPENAME_CASE(AddN) |
334 | HANDLE_OPERATORTYPENAME_CASE(AveragePool) |
335 | HANDLE_OPERATORTYPENAME_CASE(BatchMatMul) |
336 | HANDLE_OPERATORTYPENAME_CASE(BatchNormalization) |
337 | HANDLE_OPERATORTYPENAME_CASE(Conv) |
338 | HANDLE_OPERATORTYPENAME_CASE(Concatenation) |
339 | HANDLE_OPERATORTYPENAME_CASE(DepthwiseConv) |
340 | HANDLE_OPERATORTYPENAME_CASE(DepthToSpace) |
341 | HANDLE_OPERATORTYPENAME_CASE(SpaceToDepth) |
342 | HANDLE_OPERATORTYPENAME_CASE(FullyConnected) |
343 | HANDLE_OPERATORTYPENAME_CASE(HardSwish) |
344 | HANDLE_OPERATORTYPENAME_CASE(Dequantize) |
345 | HANDLE_OPERATORTYPENAME_CASE(L2Normalization) |
346 | HANDLE_OPERATORTYPENAME_CASE(LocalResponseNormalization) |
347 | HANDLE_OPERATORTYPENAME_CASE(Log) |
348 | HANDLE_OPERATORTYPENAME_CASE(Logistic) |
349 | HANDLE_OPERATORTYPENAME_CASE(LstmCell) |
350 | HANDLE_OPERATORTYPENAME_CASE(MaxPool) |
351 | HANDLE_OPERATORTYPENAME_CASE(L2Pool) |
352 | HANDLE_OPERATORTYPENAME_CASE(FakeQuant) |
353 | HANDLE_OPERATORTYPENAME_CASE(Mul) |
354 | HANDLE_OPERATORTYPENAME_CASE(RandomUniform) |
355 | HANDLE_OPERATORTYPENAME_CASE(Elu) |
356 | HANDLE_OPERATORTYPENAME_CASE(Relu) |
357 | HANDLE_OPERATORTYPENAME_CASE(Relu1) |
358 | HANDLE_OPERATORTYPENAME_CASE(Relu6) |
359 | HANDLE_OPERATORTYPENAME_CASE(PRelu) |
360 | HANDLE_OPERATORTYPENAME_CASE(ReorderAxes) |
361 | HANDLE_OPERATORTYPENAME_CASE(Softmax) |
362 | HANDLE_OPERATORTYPENAME_CASE(LogSoftmax) |
363 | HANDLE_OPERATORTYPENAME_CASE(Div) |
364 | HANDLE_OPERATORTYPENAME_CASE(Tanh) |
365 | HANDLE_OPERATORTYPENAME_CASE(Sin) |
366 | HANDLE_OPERATORTYPENAME_CASE(All) |
367 | HANDLE_OPERATORTYPENAME_CASE(Assert) |
368 | HANDLE_OPERATORTYPENAME_CASE(ExpandDims) |
369 | HANDLE_OPERATORTYPENAME_CASE(Fill) |
370 | HANDLE_OPERATORTYPENAME_CASE(FloorMod) |
371 | HANDLE_OPERATORTYPENAME_CASE(FloorDiv) |
372 | HANDLE_OPERATORTYPENAME_CASE(Greater) |
373 | HANDLE_OPERATORTYPENAME_CASE(GreaterEqual) |
374 | HANDLE_OPERATORTYPENAME_CASE(Identity) |
375 | HANDLE_OPERATORTYPENAME_CASE(Less) |
376 | HANDLE_OPERATORTYPENAME_CASE(LessEqual) |
377 | HANDLE_OPERATORTYPENAME_CASE(MatMul) |
378 | HANDLE_OPERATORTYPENAME_CASE(ReduceMax) // Reduction Max |
379 | HANDLE_OPERATORTYPENAME_CASE(Maximum) // Element-wise Maximum |
380 | HANDLE_OPERATORTYPENAME_CASE(Merge) |
381 | HANDLE_OPERATORTYPENAME_CASE(ReduceMin) // Reduction Min |
382 | HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum |
383 | HANDLE_OPERATORTYPENAME_CASE(Neg) |
384 | HANDLE_OPERATORTYPENAME_CASE(OneHot) |
385 | HANDLE_OPERATORTYPENAME_CASE(Pack) |
386 | HANDLE_OPERATORTYPENAME_CASE(Pad) |
387 | HANDLE_OPERATORTYPENAME_CASE(PadV2) |
388 | HANDLE_OPERATORTYPENAME_CASE(StridedSlice) |
389 | HANDLE_OPERATORTYPENAME_CASE(Range) |
390 | HANDLE_OPERATORTYPENAME_CASE(Rank) |
391 | HANDLE_OPERATORTYPENAME_CASE(Reshape) |
392 | HANDLE_OPERATORTYPENAME_CASE(Squeeze) |
393 | HANDLE_OPERATORTYPENAME_CASE(Rsqrt) |
394 | HANDLE_OPERATORTYPENAME_CASE(SegmentSum) |
395 | HANDLE_OPERATORTYPENAME_CASE(Shape) |
396 | HANDLE_OPERATORTYPENAME_CASE(Slice) |
397 | HANDLE_OPERATORTYPENAME_CASE(Split) |
398 | HANDLE_OPERATORTYPENAME_CASE(SplitV) |
399 | HANDLE_OPERATORTYPENAME_CASE(Sqrt) |
400 | HANDLE_OPERATORTYPENAME_CASE(Square) |
401 | HANDLE_OPERATORTYPENAME_CASE(Switch) |
402 | HANDLE_OPERATORTYPENAME_CASE(Sub) |
403 | HANDLE_OPERATORTYPENAME_CASE(Sum) |
404 | HANDLE_OPERATORTYPENAME_CASE(Tile) |
405 | HANDLE_OPERATORTYPENAME_CASE(Transpose) |
406 | HANDLE_OPERATORTYPENAME_CASE(TransposeConv) |
407 | HANDLE_OPERATORTYPENAME_CASE(Concat) |
408 | HANDLE_OPERATORTYPENAME_CASE(ConcatV2) |
409 | HANDLE_OPERATORTYPENAME_CASE(Cast) |
410 | HANDLE_OPERATORTYPENAME_CASE(Floor) |
411 | HANDLE_OPERATORTYPENAME_CASE(Ceil) |
412 | HANDLE_OPERATORTYPENAME_CASE(Round) |
413 | HANDLE_OPERATORTYPENAME_CASE(Gather) |
414 | HANDLE_OPERATORTYPENAME_CASE(GatherNd) |
415 | HANDLE_OPERATORTYPENAME_CASE(ResizeBilinear) |
416 | HANDLE_OPERATORTYPENAME_CASE(SpaceToBatchND) |
417 | HANDLE_OPERATORTYPENAME_CASE(BatchToSpaceND) |
418 | HANDLE_OPERATORTYPENAME_CASE(Mean) |
419 | HANDLE_OPERATORTYPENAME_CASE(ReduceProd) |
420 | HANDLE_OPERATORTYPENAME_CASE(Svdf) |
421 | HANDLE_OPERATORTYPENAME_CASE(ArgMax) |
422 | HANDLE_OPERATORTYPENAME_CASE(ArgMin) |
423 | HANDLE_OPERATORTYPENAME_CASE(TopK_V2) |
424 | HANDLE_OPERATORTYPENAME_CASE(Unsupported) |
425 | HANDLE_OPERATORTYPENAME_CASE(Exp) |
426 | HANDLE_OPERATORTYPENAME_CASE(DynamicPartition) |
427 | HANDLE_OPERATORTYPENAME_CASE(DynamicStitch) |
428 | HANDLE_OPERATORTYPENAME_CASE(Select) |
429 | HANDLE_OPERATORTYPENAME_CASE(SparseToDense) |
430 | HANDLE_OPERATORTYPENAME_CASE(Equal) |
431 | HANDLE_OPERATORTYPENAME_CASE(NotEqual) |
432 | HANDLE_OPERATORTYPENAME_CASE(Pow) |
433 | HANDLE_OPERATORTYPENAME_CASE(Any) |
434 | HANDLE_OPERATORTYPENAME_CASE(LogicalAnd) |
435 | HANDLE_OPERATORTYPENAME_CASE(LogicalNot) |
436 | HANDLE_OPERATORTYPENAME_CASE(LogicalOr) |
437 | HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder) |
438 | HANDLE_OPERATORTYPENAME_CASE(Unpack) |
439 | HANDLE_OPERATORTYPENAME_CASE(ZerosLike) |
440 | HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceLstm) |
441 | HANDLE_OPERATORTYPENAME_CASE(BidirectionalSequenceLstm) |
442 | HANDLE_OPERATORTYPENAME_CASE(BidirectionalSequenceRnn) |
443 | HANDLE_OPERATORTYPENAME_CASE(ResizeNearestNeighbor) |
444 | HANDLE_OPERATORTYPENAME_CASE(LeakyRelu) |
445 | HANDLE_OPERATORTYPENAME_CASE(SquaredDifference) |
446 | HANDLE_OPERATORTYPENAME_CASE(MirrorPad) |
447 | HANDLE_OPERATORTYPENAME_CASE(Unique) |
448 | HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceRnn) |
449 | HANDLE_OPERATORTYPENAME_CASE(ReverseV2) |
450 | HANDLE_OPERATORTYPENAME_CASE(Cos) |
451 | HANDLE_OPERATORTYPENAME_CASE(Where) |
452 | HANDLE_OPERATORTYPENAME_CASE(ReverseSequence) |
453 | HANDLE_OPERATORTYPENAME_CASE(MatrixDiag) |
454 | HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiag) |
455 | HANDLE_OPERATORTYPENAME_CASE(MatrixDiagV2) |
456 | HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiagV2) |
457 | HANDLE_OPERATORTYPENAME_CASE(MatrixDiagV3) |
458 | HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiagV3) |
459 | HANDLE_OPERATORTYPENAME_CASE(ScatterNd) |
460 | default: |
461 | LOG(FATAL) << "Unhandled op type" ; |
462 | #undef HANDLE_OPERATORTYPENAME_CASE |
463 | } |
464 | } |
465 | |
466 | std::string HelpfulOperatorTypeName(const Operator& op) { |
467 | if (op.type == OperatorType::kUnsupported) { |
468 | return toco::port::StringF( |
469 | "(Unsupported TensorFlow op: %s)" , |
470 | static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op); |
471 | } |
472 | return OperatorTypeName(op.type); |
473 | } |
474 | |
475 | bool OperatorSupportsFusedActivation(OperatorType type) { |
476 | switch (type) { |
477 | case OperatorType::kAdd: |
478 | case OperatorType::kAveragePool: |
479 | case OperatorType::kBatchNormalization: |
480 | case OperatorType::kConv: |
481 | case OperatorType::kDepthwiseConv: |
482 | case OperatorType::kDiv: |
483 | case OperatorType::kFullyConnected: |
484 | case OperatorType::kL2Pool: |
485 | case OperatorType::kMaxPool: |
486 | case OperatorType::kMul: |
487 | case OperatorType::kSub: |
488 | case OperatorType::kSquaredDifference: |
489 | return true; |
490 | default: |
491 | return false; |
492 | } |
493 | } |
494 | |
495 | void LogSummary(int log_level, const Model& model) { |
496 | VLOG(log_level) << "Operators summary (" << model.operators.size() |
497 | << " operators):" ; |
498 | std::unordered_multiset<OperatorType> ops_by_type; |
499 | for (const auto& op : model.operators) { |
500 | ops_by_type.insert(op->type); |
501 | } |
502 | auto it = ops_by_type.begin(); |
503 | while (it != ops_by_type.end()) { |
504 | int count = ops_by_type.count(*it); |
505 | VLOG(log_level) << " " << OperatorTypeName(*it) << ": " << count; |
506 | std::advance(it, count); |
507 | } |
508 | } |
509 | |
510 | void LogArray(int log_level, const Model& model, const std::string& name) { |
511 | VLOG(log_level) << "Array: " << name; |
512 | if (!model.HasArray(name)) { |
513 | VLOG(log_level) << " DOES NOT EXIST" ; |
514 | return; |
515 | } |
516 | const auto& array = model.GetArray(name); |
517 | VLOG(log_level) << " Data type: " << ArrayDataTypeName(array.data_type); |
518 | VLOG(log_level) << " Final type: " |
519 | << ArrayDataTypeName(array.final_data_type); |
520 | if (array.buffer) { |
521 | VLOG(log_level) << " Constant Buffer" ; |
522 | } |
523 | if (array.alloc) { |
524 | VLOG(log_level) << " Transient Alloc" ; |
525 | } |
526 | if (array.has_shape()) { |
527 | const Shape& array_shape = array.shape(); |
528 | if (array_shape.dimensions_count() == 0) { |
529 | VLOG(log_level) << " (Zero dimensions)" ; |
530 | } else { |
531 | std::string message = " Dims: " ; |
532 | bool first = true; |
533 | for (const int dim : array_shape.dims()) { |
534 | if (!first) { |
535 | message += ", " ; |
536 | } |
537 | first = false; |
538 | toco::port::AppendF(&message, "%d" , dim); |
539 | } |
540 | VLOG(log_level) << message; |
541 | } |
542 | } |
543 | if (array.minmax) { |
544 | VLOG(log_level) << " MinMax: " << array.minmax->min << " .. " |
545 | << array.minmax->max; |
546 | } |
547 | if (array.quantization_params) { |
548 | VLOG(log_level) << " QuantizationParams: zero_point=" |
549 | << static_cast<int>(array.quantization_params->zero_point) |
550 | << ", scale=" << array.quantization_params->scale; |
551 | } |
552 | } |
553 | |
554 | void DumpGraphvizVideoFrame(const Model& model) { |
555 | namespace port = toco::port; |
556 | |
557 | const auto& dump_options = *GraphVizDumpOptions::singleton(); |
558 | if (!dump_options.dump_graphviz_video) { |
559 | return; |
560 | } |
561 | CHECK(!dump_options.dump_graphviz.empty()); |
562 | // TODO(benoitjacob): the static data here means that this function |
563 | // is stateful, not reentrant, and effectively leaks memory till exit |
564 | // (since dump_hashes can only grow in size). It also means that it |
565 | // really only is intended to be called for a single model during the |
566 | // process' lifetime. So it's not great design at all. The overriding |
567 | // design aspect here is to make the video-dumping code as unintrusive |
568 | // and self-contained as possible. Eventually, we'll want to have that |
569 | // cleaned-up, but that will require some form of general statefulness |
570 | // in toco (some kind of 'tooling state' data structure) that does |
571 | // not exist at present, and would be premature to design here just for |
572 | // this new video-dumping feature. |
573 | static int dump_id = 0; |
574 | static std::unordered_set<std::size_t> dump_hashes; |
575 | std::string graphviz_dump; |
576 | DumpGraphviz(model, &graphviz_dump, |
577 | toco::port::StringF("VIDEO frame:%05d" , dump_id)); |
578 | std::size_t hash = std::hash<std::string>{}(graphviz_dump); |
579 | if (!dump_hashes.count(hash)) { |
580 | LOG(INFO) << "DUMPING GRAPHVIZ VIDEO FRAME: " << dump_id; |
581 | dump_hashes.insert(hash); |
582 | const auto result = port::file::SetContents( |
583 | port::file::JoinPath( |
584 | dump_options.dump_graphviz, |
585 | toco::port::StringF("toco_video_%05d.dot" , dump_id)), |
586 | graphviz_dump, port::file::Defaults()); |
587 | QCHECK(result.ok()) << result.error_message(); |
588 | dump_id++; |
589 | } |
590 | } |
591 | |
592 | void LogDump(int log_level, const std::string& message, const Model& model) { |
593 | namespace port = toco::port; |
594 | const auto& dump_options = *GraphVizDumpOptions::singleton(); |
595 | |
596 | DumpGraphvizVideoFrame(model); |
597 | if (!dump_options.dump_graphviz.empty()) { |
598 | std::string graphviz_dump; |
599 | |
600 | DumpGraphviz(model, &graphviz_dump, message); |
601 | const auto result = port::file::SetContents( |
602 | port::file::JoinPath( |
603 | dump_options.dump_graphviz, |
604 | absl::StrCat("toco_" , absl::StrReplaceAll(message, {{" " , "_" }}), |
605 | ".dot" )), |
606 | graphviz_dump, port::file::Defaults()); |
607 | QCHECK(result.ok()) << result.error_message(); |
608 | } |
609 | |
610 | if (!VLOG_IS_ON(log_level)) { |
611 | return; |
612 | } |
613 | VLOG(log_level) << "BEGIN DUMP OF TOCO MODEL (" << message << ")" ; |
614 | LogSummary(log_level, model); |
615 | std::unordered_set<std::string> already_printed_arrays; |
616 | for (const auto& op : model.operators) { |
617 | for (const auto& input : op->inputs) { |
618 | if (!already_printed_arrays.count(input)) { |
619 | already_printed_arrays.insert(input); |
620 | LogArray(log_level, model, input); |
621 | } |
622 | } |
623 | VLOG(log_level) << HelpfulOperatorTypeName(*op) << " :" ; |
624 | VLOG(log_level) << " " << FormatArraysList(model, op->inputs) << " -> " |
625 | << FormatArraysList(model, op->outputs); |
626 | if (op->fused_activation_function != FusedActivationFunctionType::kNone) { |
627 | VLOG(log_level) << " (with fused activation function)" ; |
628 | } |
629 | for (const auto& output : op->outputs) { |
630 | if (!already_printed_arrays.count(output)) { |
631 | already_printed_arrays.insert(output); |
632 | LogArray(log_level, model, output); |
633 | } |
634 | } |
635 | } |
636 | VLOG(log_level) << "END DUMP OF TOCO MODEL (" << message << ")" ; |
637 | } |
638 | |
639 | // Note remaining raw-array extension in ProcessTensorFlowReshapeOperator(). |
640 | void ExtendShape(Shape* shape, int new_shape_size) { |
641 | CHECK_GE(new_shape_size, shape->dimensions_count()); |
642 | const int size_increase = new_shape_size - shape->dimensions_count(); |
643 | auto* shape_dims = shape->mutable_dims(); |
644 | shape_dims->insert(shape_dims->begin(), size_increase, 1); |
645 | } |
646 | |
647 | // TODO(b/62904716) Remove along with remaining uses. |
648 | void UnextendShape(Shape* shape, int new_shape_size) { |
649 | CHECK_LE(new_shape_size, shape->dimensions_count()); |
650 | const int size_reduction = shape->dimensions_count() - new_shape_size; |
651 | for (int i = 0; i < size_reduction; i++) { |
652 | CHECK_EQ(shape->dims(i), 1); |
653 | } |
654 | std::vector<int>& shape_dims = *shape->mutable_dims(); |
655 | shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction); |
656 | } |
657 | |
658 | // In general, zero-sized dimensions are disallowed, but there are exceptions, |
659 | // e.g., if the tensor data itself represents a scalar (rank 0) shape, its |
660 | // shape will have dimensions [0]. CheckNonEmptyShapeDimensions is more |
661 | // strict, and is appropriate for ops and comparisons where an empty shape |
662 | // doesn't make sense. |
663 | template <typename Dims> |
664 | void CheckValidShapeDimensions(const Dims& dims) { |
665 | if (dims.size() == 1 && dims[0] == 0) { |
666 | return; |
667 | } |
668 | for (const auto& dim : dims) { |
669 | CHECK_GE(dim, 1); |
670 | } |
671 | } |
672 | |
673 | void CheckValidShape(const Shape& shape) { |
674 | CheckValidShapeDimensions(shape.dims()); |
675 | } |
676 | |
677 | bool IsNonEmpty(const Shape& shape) { |
678 | for (int i = 0; i < shape.dimensions_count(); ++i) { |
679 | if (shape.dims(i) < 1) return false; |
680 | } |
681 | return true; |
682 | } |
683 | |
684 | void CheckNonEmptyShapeDimensions(const Shape& shape) { |
685 | for (int i = 0; i < shape.dimensions_count(); ++i) { |
686 | CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i |
687 | << ". shape = " << ShapeToString(shape); |
688 | } |
689 | } |
690 | |
691 | bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) { |
692 | CheckNonEmptyShapeDimensions(shape0); |
693 | CheckNonEmptyShapeDimensions(shape1); |
694 | |
695 | const Shape* longer = &shape0; |
696 | const Shape* shorter = &shape1; |
697 | if (shape1.dimensions_count() > shape0.dimensions_count()) { |
698 | longer = &shape1; |
699 | shorter = &shape0; |
700 | } |
701 | |
702 | // Walk dimensions back to front until we run out of dimensions in the shorter |
703 | // shape. |
704 | int longer_index = longer->dimensions_count() - 1; |
705 | int shorter_index = shorter->dimensions_count() - 1; |
706 | while (shorter_index >= 0) { |
707 | const int d_long = longer->dims(longer_index); |
708 | const int d_short = shorter->dims(shorter_index); |
709 | // Broadcasting fails if the dimensions are different *and* neither is 1. |
710 | if ((d_long != d_short) && (d_long != 1) && (d_short != 1)) { |
711 | return false; |
712 | } |
713 | longer_index--; |
714 | shorter_index--; |
715 | } |
716 | return true; |
717 | } |
718 | |
719 | bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) { |
720 | CheckNonEmptyShapeDimensions(shape0); |
721 | CheckNonEmptyShapeDimensions(shape1); |
722 | |
723 | const Shape* longer = &shape0; |
724 | const Shape* shorter = &shape1; |
725 | if (shape1.dimensions_count() > shape0.dimensions_count()) { |
726 | longer = &shape1; |
727 | shorter = &shape0; |
728 | } |
729 | |
730 | // Walk dimensions back to front until we run out of dimensions in the shorter |
731 | // shape. |
732 | int longer_index = longer->dimensions_count() - 1; |
733 | int shorter_index = shorter->dimensions_count() - 1; |
734 | while (shorter_index >= 0) { |
735 | const int d_long = longer->dims(longer_index); |
736 | const int d_short = shorter->dims(shorter_index); |
737 | // Extending fails if the dimensions are different. |
738 | if (d_long != d_short) { |
739 | return false; |
740 | } |
741 | longer_index--; |
742 | shorter_index--; |
743 | } |
744 | |
745 | // The remaining dimensions in the longer shape must be 1. |
746 | while (longer_index >= 0) { |
747 | const int d_long = longer->dims(longer_index); |
748 | if (d_long != 1) { |
749 | return false; |
750 | } |
751 | longer_index--; |
752 | } |
753 | |
754 | return true; |
755 | } |
756 | |
757 | int RequiredBufferSizeForShape(const Shape& shape) { |
758 | CheckValidShape(shape); |
759 | int max_offset = 1; |
760 | for (const auto& dim : shape.dims()) { |
761 | max_offset *= dim; |
762 | } |
763 | return max_offset; |
764 | } |
765 | |
766 | bool IsConstantParameterArray(const Model& model, const std::string& name) { |
767 | if (!model.HasArray(name)) { |
768 | return false; |
769 | } |
770 | |
771 | return !!model.GetArray(name).buffer; |
772 | } |
773 | |
774 | namespace { |
775 | template <ArrayDataType A> |
776 | bool CompareArrayBuffers(const Array& lhs_array, const Array& rhs_array) { |
777 | CHECK(lhs_array.data_type == rhs_array.data_type) << "Data types must match" ; |
778 | CHECK(lhs_array.buffer) << "LHS must be constant" ; |
779 | CHECK(rhs_array.buffer) << "RHS must be constant" ; |
780 | const auto& lhs_data = lhs_array.GetBuffer<A>().data; |
781 | const auto& rhs_data = rhs_array.GetBuffer<A>().data; |
782 | CHECK_EQ(lhs_data.size(), rhs_data.size()) |
783 | << "Buffer sizes must match in element count" ; |
784 | for (int i = 0; i < lhs_data.size(); ++i) { |
785 | if (lhs_data[i] != rhs_data[i]) { |
786 | return false; |
787 | } |
788 | } |
789 | return true; |
790 | } |
791 | |
792 | bool HaveSameMinMax(const Array& lhs_array, const Array& rhs_array) { |
793 | if (lhs_array.minmax || rhs_array.minmax) { |
794 | if (!lhs_array.minmax || !rhs_array.minmax) { |
795 | return false; |
796 | } |
797 | if (!(*lhs_array.minmax == *rhs_array.minmax)) { |
798 | return false; |
799 | } |
800 | } |
801 | return true; |
802 | } |
803 | |
804 | bool HaveSameQuantizationParams(const Array& lhs_array, |
805 | const Array& rhs_array) { |
806 | if (lhs_array.quantization_params || rhs_array.quantization_params) { |
807 | if (!lhs_array.quantization_params || !rhs_array.quantization_params) { |
808 | return false; |
809 | } |
810 | if (!(*lhs_array.quantization_params == *rhs_array.quantization_params)) { |
811 | return false; |
812 | } |
813 | } |
814 | return true; |
815 | } |
816 | |
817 | } // namespace |
818 | |
819 | bool CompareConstantArrays(const Array& lhs_array, const Array& rhs_array) { |
820 | bool attrs_equal = lhs_array.shape() == rhs_array.shape() && |
821 | lhs_array.data_type == rhs_array.data_type && |
822 | lhs_array.final_data_type == rhs_array.final_data_type && |
823 | HaveSameMinMax(lhs_array, rhs_array) && |
824 | HaveSameQuantizationParams(lhs_array, rhs_array) && |
825 | lhs_array.narrow_range == rhs_array.narrow_range; |
826 | if (!attrs_equal) { |
827 | return false; |
828 | } |
829 | switch (lhs_array.data_type) { |
830 | case ArrayDataType::kBool: |
831 | return CompareArrayBuffers<ArrayDataType::kBool>(lhs_array, rhs_array); |
832 | case ArrayDataType::kFloat: |
833 | return CompareArrayBuffers<ArrayDataType::kFloat>(lhs_array, rhs_array); |
834 | case ArrayDataType::kInt8: |
835 | return CompareArrayBuffers<ArrayDataType::kInt8>(lhs_array, rhs_array); |
836 | case ArrayDataType::kUint8: |
837 | return CompareArrayBuffers<ArrayDataType::kUint8>(lhs_array, rhs_array); |
838 | case ArrayDataType::kInt16: |
839 | return CompareArrayBuffers<ArrayDataType::kInt16>(lhs_array, rhs_array); |
840 | case ArrayDataType::kUint16: |
841 | return CompareArrayBuffers<ArrayDataType::kUint16>(lhs_array, rhs_array); |
842 | case ArrayDataType::kInt32: |
843 | return CompareArrayBuffers<ArrayDataType::kInt32>(lhs_array, rhs_array); |
844 | case ArrayDataType::kUint32: |
845 | return CompareArrayBuffers<ArrayDataType::kUint32>(lhs_array, rhs_array); |
846 | case ArrayDataType::kInt64: |
847 | return CompareArrayBuffers<ArrayDataType::kInt64>(lhs_array, rhs_array); |
848 | case ArrayDataType::kUint64: |
849 | return CompareArrayBuffers<ArrayDataType::kUint64>(lhs_array, rhs_array); |
850 | case ArrayDataType::kString: |
851 | return CompareArrayBuffers<ArrayDataType::kString>(lhs_array, rhs_array); |
852 | case ArrayDataType::kComplex64: |
853 | return CompareArrayBuffers<ArrayDataType::kComplex64>(lhs_array, |
854 | rhs_array); |
855 | default: |
856 | LOG(FATAL) << "Unsupported data type: " |
857 | << ArrayDataTypeName(lhs_array.data_type); |
858 | return false; |
859 | } |
860 | } |
861 | |
862 | namespace { |
863 | // Take an array name, which may be something like "name:3_5" and make it |
864 | // acceptable as a TF node name, say "name_3_5"; |
865 | std::string SanitizeNameForTFNode(const std::string& array_name) { |
866 | auto node_name = array_name; |
867 | std::replace(node_name.begin(), node_name.end(), ':', '_'); |
868 | return node_name; |
869 | } |
870 | |
871 | void CheckInputArraysAreNotOutputArrays(const ModelFlags& model_flags) { |
872 | for (const auto& input_array : model_flags.input_arrays()) { |
873 | for (const std::string& output_array : model_flags.output_arrays()) { |
874 | QCHECK_NE(input_array.name(), output_array) |
875 | << "The array " << output_array |
876 | << " is listed in both --input_arrays and --output_arrays." ; |
877 | } |
878 | } |
879 | } |
880 | |
881 | bool IsAsciiPrintable(const std::string& name) { |
882 | for (char c : name) { |
883 | if (!absl::ascii_isprint(c)) { |
884 | return false; |
885 | } |
886 | } |
887 | return true; |
888 | } |
889 | |
890 | std::string DumpAscii(const std::string& name) { |
891 | std::string result; |
892 | port::AppendF(&result, "ASCII | Hex\n" ); |
893 | port::AppendF(&result, "------+----\n" ); |
894 | for (char c : name) { |
895 | if (absl::ascii_isprint(c)) { |
896 | port::AppendF(&result, "%c | %x\n" , c, c); |
897 | } else { |
898 | port::AppendF(&result, " | %x Not ASCII printable!\n" , c); |
899 | } |
900 | } |
901 | return result; |
902 | } |
903 | |
904 | void CheckNonAsciiIOArrays(const ModelFlags& model_flags) { |
905 | if (model_flags.allow_nonascii_arrays()) { |
906 | return; |
907 | } |
908 | for (const auto& input_array : model_flags.input_arrays()) { |
909 | QCHECK(IsAsciiPrintable(input_array.name())) |
910 | << "Non-ASCII-printable character found in --input_arrays: " |
911 | << input_array.name() |
912 | << ". Pass --allow_nonascii_arrays to allow that. " |
913 | << "Here is a dump of the string:\n\n" |
914 | << DumpAscii(input_array.name()); |
915 | } |
916 | for (const std::string& output_array : model_flags.output_arrays()) { |
917 | QCHECK(IsAsciiPrintable(output_array)) |
918 | << "Non-ASCII-printable character found in --output_arrays: " |
919 | << output_array << ". Pass --allow_nonascii_arrays to allow that. " |
920 | << "Here is a dump of the string:\n\n" |
921 | << DumpAscii(output_array); |
922 | } |
923 | } |
924 | |
925 | void CheckNonExistentIOArrays(const Model& model) { |
926 | // "non-existent" is interpreted in the stronger sense of |
927 | // "not actually produced/consumed by an op". |
928 | // Rationale: we have to artificially fix up TensorFlow graphs by creating |
929 | // any array that it refers to, so just checking that arrays exist isn't |
930 | // sufficient. The real invariant here is whether arrays are produced/consumed |
931 | // by something. |
932 | if (model.flags.allow_nonexistent_arrays()) { |
933 | return; |
934 | } |
935 | static constexpr char [] = |
936 | "Is it a typo? This should not happen. If you trigger this error " |
937 | "please send a bug report (with code to reproduce this error), to the " |
938 | "TensorFlow Lite team." ; |
939 | for (const std::string& output_array : model.flags.output_arrays()) { |
940 | if (IsConstantParameterArray(model, output_array)) { |
941 | continue; // It is OK to request that a constant be an output. |
942 | } |
943 | QCHECK(GetOpWithOutput(model, output_array)) |
944 | << "Specified output array \"" << output_array |
945 | << "\" is not produced by any op in this graph. " << general_comment; |
946 | } |
947 | for (const auto& rnn_state : model.flags.rnn_states()) { |
948 | if (!rnn_state.discardable()) { |
949 | // Check that all RNN states are consumed |
950 | QCHECK(GetOpWithInput(model, rnn_state.state_array())) |
951 | << "Specified RNN state \"" << rnn_state.state_array() |
952 | << "\" is not consumed by any op in this graph. " << general_comment; |
953 | // Check that all RNN back-edge source arrays are produced |
954 | QCHECK(GetOpWithOutput(model, rnn_state.back_edge_source_array())) |
955 | << "Specified RNN back-edge source array \"" |
956 | << rnn_state.back_edge_source_array() |
957 | << "\" is not produced by any op in this graph. " << general_comment; |
958 | } |
959 | } |
960 | } |
961 | |
962 | } // namespace |
963 | |
964 | void CheckNoMissingArray(const Model& model) { |
965 | for (const auto& op : model.operators) { |
966 | for (const auto& input : op->inputs) { |
967 | CHECK(model.HasArray(input) || model.optional_arrays.count(input)) |
968 | << "Input: " << input << " missing for op: " << op->outputs[0] << "." ; |
969 | } |
970 | for (const auto& output : op->outputs) { |
971 | CHECK(model.HasArray(output)) << "Output: " << output << " missing." ; |
972 | } |
973 | } |
974 | CheckNonExistentIOArrays(model); |
975 | } |
976 | |
977 | void FixNoMissingArray(Model* model) { |
978 | for (const auto& op : model->operators) { |
979 | for (const auto& input : op->inputs) { |
980 | if (!model->HasArray(input) && !model->IsOptionalArray(input)) { |
981 | model->GetOrCreateArray(input); |
982 | } |
983 | } |
984 | for (const auto& output : op->outputs) { |
985 | if (!model->HasArray(output) && !model->IsOptionalArray(output)) { |
986 | model->GetOrCreateArray(output); |
987 | } |
988 | } |
989 | } |
990 | if (model->flags.allow_nonexistent_arrays()) { |
991 | for (const std::string& output_array : model->flags.output_arrays()) { |
992 | model->GetOrCreateArray(output_array); |
993 | } |
994 | for (const auto& rnn_state : model->flags.rnn_states()) { |
995 | model->GetOrCreateArray(rnn_state.state_array()); |
996 | model->GetOrCreateArray(rnn_state.back_edge_source_array()); |
997 | } |
998 | } |
999 | } |
1000 | |
1001 | void CheckNoOrphanedArray(const Model& model) { |
1002 | std::unordered_set<std::string> arrays_without_known_use; |
1003 | for (const auto& array : model.GetArrayMap()) { |
1004 | if (IsDiscardableArray(model, array.first)) { |
1005 | arrays_without_known_use.insert(array.first); |
1006 | } |
1007 | } |
1008 | for (const auto& op : model.operators) { |
1009 | for (const auto& input : op->inputs) { |
1010 | arrays_without_known_use.erase(input); |
1011 | } |
1012 | for (const auto& output : op->outputs) { |
1013 | arrays_without_known_use.erase(output); |
1014 | } |
1015 | } |
1016 | for (const auto& rnn_state : model.flags.rnn_states()) { |
1017 | arrays_without_known_use.erase(rnn_state.state_array()); |
1018 | arrays_without_known_use.erase(rnn_state.back_edge_source_array()); |
1019 | } |
1020 | if (!arrays_without_known_use.empty()) { |
1021 | for (const auto& array : arrays_without_known_use) { |
1022 | LOG(INFO) << "Error: Orphaned array: " << array; |
1023 | } |
1024 | } |
1025 | CHECK(arrays_without_known_use.empty()); |
1026 | } |
1027 | |
1028 | void FixNoOrphanedArray(Model* model) { |
1029 | std::unordered_set<std::string> arrays_without_known_use; |
1030 | for (const auto& array : model->GetArrayMap()) { |
1031 | arrays_without_known_use.insert(array.first); |
1032 | } |
1033 | for (const auto& op : model->operators) { |
1034 | for (const auto& input : op->inputs) { |
1035 | arrays_without_known_use.erase(input); |
1036 | } |
1037 | for (const auto& output : op->outputs) { |
1038 | arrays_without_known_use.erase(output); |
1039 | } |
1040 | } |
1041 | for (const auto& rnn_state : model->flags.rnn_states()) { |
1042 | arrays_without_known_use.erase(rnn_state.state_array()); |
1043 | arrays_without_known_use.erase(rnn_state.back_edge_source_array()); |
1044 | } |
1045 | for (const auto& array : arrays_without_known_use) { |
1046 | if (IsDiscardableArray(*model, array)) { |
1047 | model->EraseArray(array); |
1048 | } |
1049 | } |
1050 | } |
1051 | |
1052 | // Apply checks to arrays individually (for-each fashion). |
1053 | // |
1054 | // Check consistency of array fields, check name. |
1055 | void CheckEachArray(const Model& model) { |
1056 | for (const auto& array_entry : model.GetArrayMap()) { |
1057 | const auto& array = array_entry.second; |
1058 | // It's OK to have a buffer or an alloc, but not both. |
1059 | // (Since allocs are for transient arrays without a buffer). |
1060 | CHECK(!array->buffer || !array->alloc) << "Tensor: " << array_entry.first; |
1061 | if (array->buffer) { |
1062 | // If there is a buffer, its type should be consistent with data_type. |
1063 | CHECK(array->buffer->type == array->data_type) |
1064 | << "Tensor: " << array_entry.first; |
1065 | // The presence of a fixed buffer should imply the presence of a fixed |
1066 | // shape. |
1067 | CHECK(array->has_shape()) << array_entry.first; |
1068 | // Constant buffer should has a valid shape. |
1069 | CheckValidShape(array->shape()); |
1070 | // The shape flat-size should agree with the buffer length. |
1071 | CHECK_EQ(array->buffer->Length(), |
1072 | RequiredBufferSizeForShape(array->shape())) |
1073 | << "Tensor: " << array_entry.first; |
1074 | } |
1075 | |
1076 | // Check name. Either "name_with_suffix_8", "name_with_port:3", but not |
1077 | // "name_with_both:3_8". |
1078 | const std::string& name = array_entry.first; |
1079 | auto colon_pos = name.find_first_of(':'); |
1080 | if (colon_pos != std::string::npos) { |
1081 | CHECK_EQ(name.substr(colon_pos + 1).find_first_not_of("0123456789" ), |
1082 | std::string::npos) |
1083 | << "Array '" << name << "' has non-digit characters after colon." ; |
1084 | } |
1085 | CHECK_GT(colon_pos, 0) << "Array '" << name |
1086 | << "' must not start with a colon." ; |
1087 | } |
1088 | } |
1089 | |
1090 | void CheckOperatorOrdering(const Model& model) { |
1091 | std::unordered_set<std::string> arrays_behind_us; |
1092 | for (const auto& array_entry : model.GetArrayMap()) { |
1093 | if (!GetOpWithOutput(model, array_entry.first)) { |
1094 | arrays_behind_us.insert(array_entry.first); |
1095 | } |
1096 | } |
1097 | arrays_behind_us.insert(model.optional_arrays.begin(), |
1098 | model.optional_arrays.end()); |
1099 | for (const auto& op : model.operators) { |
1100 | for (const auto& input : op->inputs) { |
1101 | if (!IsConstantParameterArray(model, input)) { |
1102 | CHECK(arrays_behind_us.count(input)); |
1103 | } |
1104 | } |
1105 | for (const auto& output : op->outputs) { |
1106 | CHECK(!arrays_behind_us.count(output)); |
1107 | arrays_behind_us.insert(output); |
1108 | } |
1109 | } |
1110 | for (const std::string& output_array : model.flags.output_arrays()) { |
1111 | CHECK(arrays_behind_us.count(output_array)); |
1112 | } |
1113 | } |
1114 | |
1115 | void FixOperatorOrdering(Model* model) { |
1116 | std::unordered_set<std::string> arrays_behind_us; |
1117 | for (const auto& array_entry : model->GetArrayMap()) { |
1118 | if (!GetOpWithOutput(*model, array_entry.first)) { |
1119 | arrays_behind_us.insert(array_entry.first); |
1120 | } |
1121 | } |
1122 | arrays_behind_us.insert(model->optional_arrays.begin(), |
1123 | model->optional_arrays.end()); |
1124 | std::vector<std::unique_ptr<Operator>> old_operators; |
1125 | std::swap(old_operators, model->operators); |
1126 | std::set<std::size_t> remaining; |
1127 | for (std::size_t i = 0; i < old_operators.size(); i++) { |
1128 | remaining.insert(i); |
1129 | } |
1130 | std::unordered_map<std::string, std::string> reason_why_leftover; |
1131 | while (true) { |
1132 | bool inserted_something = false; |
1133 | for (const auto& i : remaining) { |
1134 | bool can_insert = true; |
1135 | auto& op = old_operators[i]; |
1136 | CHECK(op); |
1137 | for (const auto& input : op->inputs) { |
1138 | if (!IsConstantParameterArray(*model, input) && |
1139 | !arrays_behind_us.count(input)) { |
1140 | for (const std::string& output : op->outputs) { |
1141 | reason_why_leftover[output] = input; |
1142 | } |
1143 | can_insert = false; |
1144 | break; |
1145 | } |
1146 | } |
1147 | if (can_insert) { |
1148 | model->operators.emplace_back(nullptr); |
1149 | for (const auto& output : op->outputs) { |
1150 | arrays_behind_us.insert(output); |
1151 | } |
1152 | std::swap(op, model->operators.back()); |
1153 | remaining.erase(i); |
1154 | inserted_something = true; |
1155 | break; |
1156 | } |
1157 | } |
1158 | if (!inserted_something) { |
1159 | break; |
1160 | } |
1161 | } |
1162 | if (!remaining.empty()) { |
1163 | LOG(ERROR) |
1164 | << "No viable ordering of operators was found. " |
1165 | << "Here is a 'backtrace' of at least one part of the graph that is " |
1166 | << "problematic. It starts with the first operator that has as " |
1167 | << "problematic input array, and then walks back the graph to " |
1168 | << "the operator that produced that input array, etc., until we find " |
1169 | << "the root cause:" ; |
1170 | LOG(ERROR) << "BEGIN TRACE OF OPERATOR WITH BAD INPUT" ; |
1171 | LOG(ERROR) << "Here is the first-encountered operator with a bad input: " ; |
1172 | const Operator* bad_op = old_operators[*remaining.begin()].get(); |
1173 | std::unordered_set<std::string> bad_inputs_already_traced; |
1174 | // The following while(true) loop should always end with a LOG(FATAL). |
1175 | while (true) { |
1176 | LOG(ERROR) << HelpfulOperatorTypeName(*bad_op) << " : " |
1177 | << FormatArraysList(*model, bad_op->inputs) << " -> " |
1178 | << FormatArraysList(*model, bad_op->outputs); |
1179 | bool found_bad_output = false; |
1180 | std::string bad_output; |
1181 | for (const std::string& output : bad_op->outputs) { |
1182 | if (reason_why_leftover.count(output)) { |
1183 | found_bad_output = true; |
1184 | bad_output = output; |
1185 | break; |
1186 | } |
1187 | } |
1188 | CHECK(found_bad_output); |
1189 | const std::string& bad_input = reason_why_leftover[bad_output]; |
1190 | LOG(ERROR) << "The bad input here is: " << bad_input; |
1191 | if (bad_inputs_already_traced.count(bad_input)) { |
1192 | LOG(FATAL) |
1193 | << "Cycle found! We already encountered that " |
1194 | << "input array, " << bad_input << ", earlier in the " |
1195 | << "above trace! We expect graphs to be acyclic, even " |
1196 | << "RNNs. Let us know if some graph actually needs to have " |
1197 | << "cycles, but first, please check if it really is " |
1198 | << "an *inference* graph. *Training* graphs are out-of-scope " |
1199 | << "for toco." ; |
1200 | } |
1201 | bad_inputs_already_traced.insert(bad_input); |
1202 | bad_op = nullptr; |
1203 | for (const auto& i : remaining) { |
1204 | const Operator* op = old_operators[i].get(); |
1205 | for (const std::string& output : op->outputs) { |
1206 | if (bad_input == output) { |
1207 | bad_op = op; |
1208 | break; |
1209 | } |
1210 | } |
1211 | if (bad_op) { |
1212 | break; |
1213 | } |
1214 | } |
1215 | if (!bad_op) { |
1216 | LOG(ERROR) << "And that's the root cause: " |
1217 | << "that array, " << bad_input << ", isn't produced by any " |
1218 | << "operator, or provided in any other way." ; |
1219 | LOG(ERROR) << "END TRACE OF OPERATOR WITH BAD INPUT" ; |
1220 | LOG(FATAL) << "(The above was a multi-line fatal error)" ; |
1221 | } |
1222 | LOG(ERROR) << "And that array is the output of the following operator:" ; |
1223 | } |
1224 | } |
1225 | CHECK(remaining.empty()) |
1226 | << "Should never get here! In case of bad graph, " |
1227 | << "the above code should have generated a FATAL error already!" ; |
1228 | } |
1229 | |
1230 | void CheckInvariants(const Model& model) { |
1231 | CheckInputArraysAreNotOutputArrays(model.flags); |
1232 | CheckNonAsciiIOArrays(model.flags); |
1233 | CheckNoMissingArray(model); |
1234 | CheckNoOrphanedArray(model); |
1235 | CheckEachArray(model); |
1236 | CheckOperatorOrdering(model); |
1237 | } |
1238 | |
1239 | void CheckCountInRange(const ::toco::ModelFlags::ModelCheck& model_check, |
1240 | const int count, const std::string& count_description) { |
1241 | if (model_check.count_min() >= 0) { |
1242 | CHECK_GE(count, model_check.count_min()) |
1243 | << "Mismatch in " << count_description << ": count was " << count |
1244 | << ", but the specified " |
1245 | << (model_check.count_max() > model_check.count_min() ? "minimum" |
1246 | : "value" ) |
1247 | << " was " << model_check.count_min() << "." ; |
1248 | } |
1249 | if (model_check.count_max() > model_check.count_min()) { |
1250 | CHECK_LE(count, model_check.count_max()) |
1251 | << "Mismatch in " << count_description << ": count was " << count |
1252 | << ", but the specified maximum was " << model_check.count_max() << "." ; |
1253 | } |
1254 | } |
1255 | |
1256 | void CheckModelCounts(const Model& model) { |
1257 | std::unordered_multiset<OperatorType> ops_by_type; |
1258 | std::unordered_map<std::string, OperatorType> op_type_by_name; |
1259 | if (model.flags.model_checks_size() == 0) { |
1260 | return; |
1261 | } |
1262 | |
1263 | for (const auto& op : model.operators) { |
1264 | ops_by_type.insert(op->type); |
1265 | op_type_by_name[OperatorTypeName(op->type)] = op->type; |
1266 | } |
1267 | for (const auto& model_check : model.flags.model_checks()) { |
1268 | std::string count_type = model_check.count_type(); |
1269 | if (count_type == "None" ) { |
1270 | continue; |
1271 | } else if (count_type == "Arrays" ) { |
1272 | CheckCountInRange(model_check, model.GetArrayMap().size(), |
1273 | "count of arrays" ); |
1274 | } else if (count_type == "Total" ) { |
1275 | CheckCountInRange(model_check, model.operators.size(), |
1276 | "count of all operator instances" ); |
1277 | } else { |
1278 | // The check type is not itself checked against the set of valid |
1279 | // operators, mainly because the enum set cannot be iterated in C++. |
1280 | const int found_count = |
1281 | op_type_by_name.count(count_type) > 0 |
1282 | ? ops_by_type.count(op_type_by_name[count_type]) |
1283 | : 0; |
1284 | CheckCountInRange(model_check, found_count, |
1285 | "count of instances of " + count_type + " operator" ); |
1286 | } |
1287 | } |
1288 | } |
1289 | |
1290 | void FixEdgeArrays(Model* model) { |
1291 | for (const std::string& output_array_name : model->flags.output_arrays()) { |
1292 | if (!GetOpWithOutput(*model, output_array_name)) { |
1293 | // Output has no operator producing it. Change that by inserting a copy. |
1294 | LOG(WARNING) << "Fixing constant output array " << output_array_name |
1295 | << " by inserting a copy. This is not optimal." ; |
1296 | std::string intermediate_array_name = |
1297 | AvailableArrayName(*model, output_array_name + "_copy" ); |
1298 | CloneArray(model, output_array_name, intermediate_array_name); |
1299 | InsertCopyOperator(model, intermediate_array_name, output_array_name); |
1300 | } |
1301 | } |
1302 | } |
1303 | |
1304 | void DedupeConstantArrays(Model* model, size_t min_size) { |
1305 | // Walk all 0..N and compare with the remaining n+1..N. |
1306 | // This lets us avoid N^2 comparisons and erase duplicate arrays while |
1307 | // iterating. |
1308 | const auto& array_map = model->GetArrayMap(); |
1309 | for (auto lhs_array_it = array_map.begin(); lhs_array_it != array_map.end(); |
1310 | ++lhs_array_it) { |
1311 | const auto& lhs_array_name = lhs_array_it->first; |
1312 | const auto& lhs_array = *lhs_array_it->second; |
1313 | if (!IsConstantParameterArray(*model, lhs_array_name)) { |
1314 | // Not a constant array; skip. |
1315 | continue; |
1316 | } |
1317 | ArrayDataType final_data_type = |
1318 | lhs_array.final_data_type != ArrayDataType::kNone |
1319 | ? lhs_array.final_data_type |
1320 | : lhs_array.data_type; |
1321 | // Ignore small arrays, don't check string arrays because it is not possible |
1322 | // to estimate its size. |
1323 | if (final_data_type != ArrayDataType::kString) { |
1324 | size_t array_byte_size = |
1325 | lhs_array.buffer->Length() * ElementSize(final_data_type); |
1326 | if (array_byte_size < min_size) { |
1327 | // Too small; skip. |
1328 | continue; |
1329 | } |
1330 | } |
1331 | |
1332 | auto next_lhs_array_it = lhs_array_it; |
1333 | ++next_lhs_array_it; |
1334 | for (auto rhs_array_it = next_lhs_array_it; |
1335 | rhs_array_it != array_map.end();) { |
1336 | const auto& rhs_array_name = rhs_array_it->first; |
1337 | const auto& rhs_array = *rhs_array_it->second; |
1338 | ++rhs_array_it; |
1339 | if (!IsConstantParameterArray(*model, rhs_array_name)) { |
1340 | // Not a constant array; skip. |
1341 | continue; |
1342 | } |
1343 | if (!IsDiscardableArray(*model, rhs_array_name)) { |
1344 | // Can't remove the array as it's not discardable (such as an IO edge). |
1345 | continue; |
1346 | } |
1347 | if (!CompareConstantArrays(lhs_array, rhs_array)) { |
1348 | // Arrays aren't equal; skip. |
1349 | continue; |
1350 | } |
1351 | |
1352 | // Arrays can be deduped! |
1353 | VLOG(1) << "Deduplicating arrays; using " << lhs_array_name |
1354 | << " in place of " << rhs_array_name; |
1355 | ReplaceArrayUsage(model, rhs_array_name, lhs_array_name); |
1356 | // Note: rhs_array_it above is already incremented so this is safe. |
1357 | model->EraseArray(rhs_array_name); |
1358 | } |
1359 | } |
1360 | } |
1361 | |
1362 | namespace { |
1363 | void CopyArrayAttribs(const Array& source_array, Array* target_array) { |
1364 | target_array->data_type = source_array.data_type; |
1365 | target_array->final_data_type = source_array.final_data_type; |
1366 | if (source_array.has_shape()) { |
1367 | target_array->copy_shape(source_array.shape()); |
1368 | } |
1369 | |
1370 | if (source_array.minmax) { |
1371 | target_array->GetOrCreateMinMax() = source_array.GetMinMax(); |
1372 | } else { |
1373 | target_array->minmax.reset(); |
1374 | } |
1375 | |
1376 | if (source_array.quantization_params) { |
1377 | target_array->GetOrCreateQuantizationParams() = |
1378 | source_array.GetQuantizationParams(); |
1379 | } else { |
1380 | target_array->quantization_params.reset(); |
1381 | } |
1382 | } |
1383 | } // namespace |
1384 | |
1385 | void InsertCopyOperator(Model* model, const std::string& source_array_name, |
1386 | const std::string& target_array_name) { |
1387 | // Reshape to the same size. This should be a no-op. |
1388 | const Array& source_array = model->GetArray(source_array_name); |
1389 | std::vector<int> shape = source_array.shape().dims(); |
1390 | |
1391 | // Drop constant data from the target array as the copy will be done at |
1392 | // runtime. |
1393 | Array& target_array = model->GetOrCreateArray(target_array_name); |
1394 | target_array.buffer.reset(); |
1395 | CopyArrayAttribs(source_array, &target_array); |
1396 | |
1397 | // Insert copy operator. |
1398 | auto* copy_op = new TensorFlowReshapeOperator; |
1399 | copy_op->inputs = { |
1400 | source_array_name, |
1401 | CreateInt32Array( |
1402 | model, AvailableArrayName(*model, target_array_name + "_copy_shape" ), |
1403 | shape)}; |
1404 | copy_op->outputs = {target_array_name}; |
1405 | if (target_array.has_shape()) { |
1406 | copy_op->shape = target_array.shape().dims(); |
1407 | } |
1408 | model->operators.emplace_back(copy_op); |
1409 | } |
1410 | |
1411 | void CloneArray(Model* model, const std::string& source_array_name, |
1412 | const std::string& target_array_name) { |
1413 | CHECK(!model->HasArray(target_array_name)); |
1414 | const Array& source_array = model->GetArray(source_array_name); |
1415 | Array& target_array = model->GetOrCreateArray(target_array_name); |
1416 | CopyArrayAttribs(source_array, &target_array); |
1417 | |
1418 | if (!source_array.buffer) { |
1419 | return; |
1420 | } |
1421 | |
1422 | switch (source_array.data_type) { |
1423 | case ArrayDataType::kBool: |
1424 | CopyArrayBuffer<ArrayDataType::kBool>(source_array, &target_array); |
1425 | break; |
1426 | case ArrayDataType::kFloat: |
1427 | CopyArrayBuffer<ArrayDataType::kFloat>(source_array, &target_array); |
1428 | break; |
1429 | case ArrayDataType::kInt8: |
1430 | CopyArrayBuffer<ArrayDataType::kInt8>(source_array, &target_array); |
1431 | break; |
1432 | case ArrayDataType::kUint8: |
1433 | CopyArrayBuffer<ArrayDataType::kUint8>(source_array, &target_array); |
1434 | break; |
1435 | case ArrayDataType::kInt16: |
1436 | CopyArrayBuffer<ArrayDataType::kInt16>(source_array, &target_array); |
1437 | break; |
1438 | case ArrayDataType::kUint16: |
1439 | CopyArrayBuffer<ArrayDataType::kUint16>(source_array, &target_array); |
1440 | break; |
1441 | case ArrayDataType::kInt32: |
1442 | CopyArrayBuffer<ArrayDataType::kInt32>(source_array, &target_array); |
1443 | break; |
1444 | case ArrayDataType::kUint32: |
1445 | CopyArrayBuffer<ArrayDataType::kUint32>(source_array, &target_array); |
1446 | break; |
1447 | case ArrayDataType::kInt64: |
1448 | CopyArrayBuffer<ArrayDataType::kInt64>(source_array, &target_array); |
1449 | break; |
1450 | case ArrayDataType::kUint64: |
1451 | CopyArrayBuffer<ArrayDataType::kUint64>(source_array, &target_array); |
1452 | break; |
1453 | case ArrayDataType::kString: |
1454 | CopyArrayBuffer<ArrayDataType::kString>(source_array, &target_array); |
1455 | break; |
1456 | case ArrayDataType::kComplex64: |
1457 | CopyArrayBuffer<ArrayDataType::kComplex64>(source_array, &target_array); |
1458 | break; |
1459 | default: |
1460 | LOG(FATAL) << "Unsupported data type: " |
1461 | << ArrayDataTypeName(source_array.data_type); |
1462 | return; |
1463 | } |
1464 | } |
1465 | |
1466 | void MakeArrayDims(int num_dims, int batch, int height, int width, int depth, |
1467 | std::vector<int>* out_dims) { |
1468 | CHECK(out_dims->empty()); |
1469 | if (num_dims == 0) { |
1470 | return; |
1471 | } else if (num_dims == 1) { |
1472 | CHECK_EQ(batch, 1); |
1473 | *out_dims = {depth}; |
1474 | } else if (num_dims == 2) { |
1475 | *out_dims = {batch, depth}; |
1476 | } else if (num_dims == 3) { |
1477 | CHECK_EQ(batch, 1); |
1478 | *out_dims = {height, width, depth}; |
1479 | } else if (num_dims == 4) { |
1480 | *out_dims = {batch, height, width, depth}; |
1481 | } else { |
1482 | LOG(FATAL) << "Should not get here: " << num_dims; |
1483 | } |
1484 | } |
1485 | |
1486 | void CreateOrCheckRnnStateArray(const std::string& name, int size, |
1487 | int state_num_dims, Model* model) { |
1488 | int batch = 1; |
1489 | int num_dims = -1; |
1490 | if (state_num_dims > 0) { |
1491 | num_dims = state_num_dims; |
1492 | } else { |
1493 | // state_num_dims is not given. We will infer it from an input tensor. |
1494 | for (const auto& input_array : model->flags.input_arrays()) { |
1495 | // Pick 'num_dims' and 'batch' from the first input_arrays, unless we find |
1496 | // a better match by name. |
1497 | if (input_array.name() == name || num_dims == -1) { |
1498 | num_dims = input_array.shape().dims_size(); |
1499 | if (num_dims > 0) { |
1500 | batch = input_array.shape().dims(0); |
1501 | } |
1502 | } |
1503 | } |
1504 | } |
1505 | Array& array = model->GetOrCreateArray(name); |
1506 | if (array.has_shape()) { |
1507 | num_dims = array.shape().dimensions_count(); |
1508 | } |
1509 | if (!array.has_shape() && num_dims >= 0) { |
1510 | Shape* shape = array.mutable_shape(); |
1511 | std::vector<int> dims; |
1512 | MakeArrayDims(num_dims, batch, 1, 1, size, &dims); |
1513 | *shape->mutable_dims() = dims; |
1514 | } |
1515 | } |
1516 | |
1517 | void ResolveModelFlags(const ModelFlags& model_flags, Model* model) { |
1518 | // Merge info about input_arrays from model_flags into model->flags |
1519 | for (const auto& specified_input_array : model_flags.input_arrays()) { |
1520 | toco::InputArray* dst_input_array = nullptr; |
1521 | for (int i = 0; i < model->flags.input_arrays_size(); i++) { |
1522 | toco::InputArray* candidate_dst_input_array = |
1523 | model->flags.mutable_input_arrays(i); |
1524 | if (candidate_dst_input_array->name() == specified_input_array.name()) { |
1525 | // specified_input_array from model_flags maps to dst_input_array |
1526 | // in model->flags |
1527 | dst_input_array = candidate_dst_input_array; |
1528 | break; |
1529 | } |
1530 | } |
1531 | if (!dst_input_array) { |
1532 | // Specified_input_array from model_flags is not found in model->flags. |
1533 | // Match a name-less specified input array when there can be no ambiguity |
1534 | // as there is only 1 input array. |
1535 | if (model->flags.input_arrays_size() == 1 && |
1536 | model_flags.input_arrays_size() == 1 && |
1537 | !specified_input_array.has_name()) { |
1538 | dst_input_array = model->flags.mutable_input_arrays(0); |
1539 | } |
1540 | } |
1541 | if (!dst_input_array) { |
1542 | // Still no match, so create a new input array to copy |
1543 | // specified_input_array into. |
1544 | dst_input_array = model->flags.add_input_arrays(); |
1545 | dst_input_array->set_name(specified_input_array.name()); |
1546 | } |
1547 | |
1548 | #define RESOLVE_MODEL_FLAG(field_name) \ |
1549 | if (specified_input_array.has_##field_name()) { \ |
1550 | if (dst_input_array->has_##field_name()) { \ |
1551 | QCHECK_EQ(dst_input_array->field_name(), \ |
1552 | specified_input_array.field_name()) \ |
1553 | << "For input array '" << dst_input_array->name() << "', " \ |
1554 | << "specified " #field_name " flag with value: " \ |
1555 | << specified_input_array.field_name() \ |
1556 | << " does not agree with already defined " #field_name \ |
1557 | " of this model, with value: " \ |
1558 | << specified_input_array.field_name(); \ |
1559 | } else { \ |
1560 | dst_input_array->set_##field_name(specified_input_array.field_name()); \ |
1561 | } \ |
1562 | } |
1563 | RESOLVE_MODEL_FLAG(std_value); |
1564 | RESOLVE_MODEL_FLAG(mean_value); |
1565 | #undef RESOLVE_MODEL_FLAG |
1566 | |
1567 | if (specified_input_array.has_shape()) { |
1568 | if (dst_input_array->has_shape()) { |
1569 | QCHECK_EQ(specified_input_array.shape().dims_size(), |
1570 | dst_input_array->shape().dims_size()) |
1571 | << "For input array '" << specified_input_array.name() << "', " |
1572 | << "size of specified input shape flag with size: " |
1573 | << specified_input_array.shape().dims_size() |
1574 | << " does not agree with already defined input shape" |
1575 | " of this model, with size: " |
1576 | << dst_input_array->shape().dims_size(); |
1577 | // We treat the first dimension as a special case, since it is often |
1578 | // a batch size and the input_shape flag is effectively overriding |
1579 | // the model. |
1580 | for (int i = 1; i < specified_input_array.shape().dims_size(); i++) { |
1581 | QCHECK_EQ(specified_input_array.shape().dims(i), |
1582 | dst_input_array->shape().dims(i)) |
1583 | << "At dimension number " << i << " of input array " |
1584 | << specified_input_array.name() << ", the specified shape's " |
1585 | << "dimension flag with dimension: " |
1586 | << specified_input_array.shape().dims(i) |
1587 | << " does not agree with already defined shape" |
1588 | << " of this model, with dimension: " |
1589 | << dst_input_array->shape().dims(i); |
1590 | } |
1591 | } else { |
1592 | *dst_input_array->mutable_shape() = specified_input_array.shape(); |
1593 | } |
1594 | } |
1595 | |
1596 | if (specified_input_array.has_data_type()) { |
1597 | QCHECK(!dst_input_array->has_data_type()); |
1598 | dst_input_array->set_data_type(specified_input_array.data_type()); |
1599 | } |
1600 | } |
1601 | |
1602 | if (model_flags.output_arrays_size() > 0) { |
1603 | model->flags.mutable_output_arrays()->CopyFrom(model_flags.output_arrays()); |
1604 | } |
1605 | |
1606 | #define RESOLVE_MODEL_FLAG(name) \ |
1607 | if (model_flags.has_##name()) { \ |
1608 | if (model->flags.has_##name()) { \ |
1609 | QCHECK_EQ(model_flags.name(), model->flags.name()) \ |
1610 | << "Specified " #name " flag with value: " << model_flags.name() \ |
1611 | << " does not agree with already defined " #name \ |
1612 | " of this model, with value: " \ |
1613 | << model->flags.name(); \ |
1614 | } else { \ |
1615 | model->flags.set_##name(model_flags.name()); \ |
1616 | } \ |
1617 | } |
1618 | |
1619 | RESOLVE_MODEL_FLAG(variable_batch) |
1620 | |
1621 | #undef RESOLVE_MODEL_FLAG |
1622 | |
1623 | if (!model_flags.rnn_states().empty()) { |
1624 | model->flags.mutable_rnn_states()->CopyFrom(model_flags.rnn_states()); |
1625 | } |
1626 | |
1627 | if (model->flags.model_checks_size() == 0) { |
1628 | model->flags.mutable_model_checks()->CopyFrom(model_flags.model_checks()); |
1629 | } |
1630 | |
1631 | QCHECK_GT(model->flags.output_arrays_size(), 0) |
1632 | << "This model does not define output arrays, so a " |
1633 | "--output_arrays flag must be given on the command-line." ; |
1634 | |
1635 | for (auto& input_array_proto : *model->flags.mutable_input_arrays()) { |
1636 | auto& input_array = model->GetOrCreateArray(input_array_proto.name()); |
1637 | if (input_array_proto.has_data_type()) { |
1638 | const ArrayDataType specified_type = |
1639 | ConvertIODataTypeToArrayDataType(input_array_proto.data_type()); |
1640 | QCHECK(specified_type != ArrayDataType::kNone); |
1641 | if (input_array.data_type != ArrayDataType::kNone) { |
1642 | QCHECK(specified_type == input_array.data_type) |
1643 | << "For input array " << input_array_proto.name() |
1644 | << " the specified input data type " |
1645 | << IODataType_Name(input_array_proto.data_type()) |
1646 | << " conflicts with the existing type." ; |
1647 | } |
1648 | input_array.data_type = specified_type; |
1649 | } |
1650 | |
1651 | if (input_array.data_type == ArrayDataType::kNone) { |
1652 | // We start out with a float input array; |
1653 | // that may get replaced by a uint8 array later, by |
1654 | // MakeInitialDequantizeOp. |
1655 | input_array.data_type = ArrayDataType::kFloat; |
1656 | } |
1657 | |
1658 | // Compare/merge the model->flags describing the input_shape with |
1659 | // the actual input array's shape. |
1660 | if (!input_array.has_shape()) { |
1661 | if (input_array_proto.has_shape()) { |
1662 | auto& input_array_dims = *input_array.mutable_shape()->mutable_dims(); |
1663 | CheckValidShapeDimensions(input_array_proto.shape().dims()); |
1664 | for (const auto& dim : input_array_proto.shape().dims()) { |
1665 | input_array_dims.push_back(dim); |
1666 | } |
1667 | } |
1668 | } else { |
1669 | if (input_array_proto.has_shape()) { |
1670 | // If an input shape was specified on the flags ensure that it matches |
1671 | // the actual shape in the model. |
1672 | const auto& input_array_dims = |
1673 | *input_array.mutable_shape()->mutable_dims(); |
1674 | CHECK_EQ(input_array_dims.size(), |
1675 | input_array_proto.shape().dims_size()); |
1676 | for (int i = 0; i < input_array_dims.size(); i++) { |
1677 | CHECK_EQ(input_array_dims[i], input_array_proto.shape().dims(i)); |
1678 | } |
1679 | } else { |
1680 | for (int i = 0; i < input_array.shape().dimensions_count(); i++) { |
1681 | input_array_proto.mutable_shape()->add_dims( |
1682 | input_array.shape().dims(i)); |
1683 | } |
1684 | } |
1685 | } |
1686 | |
1687 | const float mean_value = input_array_proto.mean_value(); |
1688 | const float std_value = input_array_proto.std_value(); |
1689 | MinMax input_minmax; |
1690 | float qmin = 0, qmax = 255; |
1691 | if (input_array.data_type == ArrayDataType::kInt16) { |
1692 | qmin = -32768; |
1693 | qmax = 32767; |
1694 | } |
1695 | input_minmax.min = (qmin - mean_value) / std_value; |
1696 | input_minmax.max = (qmax - mean_value) / std_value; |
1697 | if (!input_array.minmax) { |
1698 | input_array.GetOrCreateMinMax() = input_minmax; |
1699 | } |
1700 | } |
1701 | |
1702 | // Creation of the RNN state arrays |
1703 | for (const auto& rnn_state : model->flags.rnn_states()) { |
1704 | CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(), |
1705 | rnn_state.num_dims(), model); |
1706 | } |
1707 | |
1708 | model->flags.set_change_concat_input_ranges( |
1709 | model_flags.change_concat_input_ranges()); |
1710 | model->flags.set_allow_nonascii_arrays(model_flags.allow_nonascii_arrays()); |
1711 | model->flags.set_allow_nonexistent_arrays( |
1712 | model_flags.allow_nonexistent_arrays()); |
1713 | |
1714 | CHECK(!model->flags.has_arrays_extra_info()); |
1715 | *model->flags.mutable_arrays_extra_info() = model_flags.arrays_extra_info(); |
1716 | } |
1717 | |
1718 | void CheckIsReadyForQuantization(const Model& model) { |
1719 | for (const auto& op : model.operators) { |
1720 | for (const auto& input : op->inputs) { |
1721 | const auto& input_array = model.GetArray(input); |
1722 | if (input_array.data_type != ArrayDataType::kFloat) { |
1723 | // The array is not floats, no quantization needed. |
1724 | continue; |
1725 | } |
1726 | if (input_array.minmax) { |
1727 | // The array has minmax, we're good. |
1728 | continue; |
1729 | } |
1730 | if (input_array.buffer) { |
1731 | // The array has a constant buffer, so we can |
1732 | // fall back to computing the minmax from actual array entries |
1733 | // (with a WARNING about possible accuracy implications). |
1734 | continue; |
1735 | } |
1736 | LOG(FATAL) |
1737 | << "Array " << input << ", which is an input to the " |
1738 | << HelpfulOperatorTypeName(*op) << " operator producing the output " |
1739 | << "array " << op->outputs[0] << ", is lacking min/max data, " |
1740 | << "which is necessary for quantization. If accuracy matters, either " |
1741 | << "target a non-quantized output format, or run quantized training " |
1742 | << "with your model from a floating point checkpoint to change the " |
1743 | << "input graph to contain min/max information. If you don't care " |
1744 | << "about accuracy, you can pass --default_ranges_min= and " |
1745 | << "--default_ranges_max= for easy experimentation." ; |
1746 | } |
1747 | } |
1748 | } |
1749 | |
1750 | int ElementSize(ArrayDataType data_type) { |
1751 | switch (data_type) { |
1752 | case ArrayDataType::kBool: |
1753 | return sizeof(bool); |
1754 | case ArrayDataType::kFloat: |
1755 | return 4; |
1756 | case ArrayDataType::kInt8: |
1757 | return 1; |
1758 | case ArrayDataType::kUint8: |
1759 | return 1; |
1760 | case ArrayDataType::kInt16: |
1761 | return 2; |
1762 | case ArrayDataType::kUint16: |
1763 | return 2; |
1764 | case ArrayDataType::kInt32: |
1765 | return 4; |
1766 | case ArrayDataType::kUint32: |
1767 | return 4; |
1768 | case ArrayDataType::kInt64: |
1769 | return 8; |
1770 | case ArrayDataType::kUint64: |
1771 | return 8; |
1772 | case ArrayDataType::kComplex64: |
1773 | return 8; |
1774 | case ArrayDataType::kComplex128: |
1775 | return 16; |
1776 | case ArrayDataType::kFloat64: |
1777 | return 8; |
1778 | |
1779 | // Usually not critical limitation because strings are only input and/or |
1780 | // output. |
1781 | case ArrayDataType::kString: |
1782 | LOG(FATAL) << "Transient arrays with strings are not supported yet" ; |
1783 | return 0; |
1784 | default: |
1785 | LOG(FATAL) << "Unknown data_type = " << static_cast<int>(data_type); |
1786 | return 0; |
1787 | } |
1788 | } |
1789 | |
1790 | void DropMinMax(Model* model, const std::string& array_name) { |
1791 | auto& array = model->GetArray(array_name); |
1792 | if (!!array.minmax) { |
1793 | LOG(WARNING) << "Dropping MinMax information in array " << array_name |
1794 | << ". Expect inaccuracy in quantized inference." ; |
1795 | array.minmax = nullptr; |
1796 | } |
1797 | } |
1798 | |
1799 | bool IsAllocatableTransientArray(const Model& model, |
1800 | const std::string& array_name) { |
1801 | // Optional array is not transient |
1802 | if (model.IsOptionalArray(array_name)) return false; |
1803 | // The model's input and output arrays are externally allocated. |
1804 | // They are not transient arrays. |
1805 | if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) { |
1806 | return false; |
1807 | } |
1808 | const auto& array = &model.GetArray(array_name); |
1809 | // An array with a constant buffer isn't a transient array. |
1810 | if (!!array->buffer) { |
1811 | return false; |
1812 | } |
1813 | // An array without shape isn't allocatable. |
1814 | if (!array->has_shape()) { |
1815 | return false; |
1816 | } |
1817 | |
1818 | // The size of string tensors is rarely known ahead of time, so all transient |
1819 | // tensors of this type will need to be dynamically allocated. |
1820 | if (array->final_data_type == ArrayDataType::kString || |
1821 | array->data_type == ArrayDataType::kString) { |
1822 | return false; |
1823 | } |
1824 | |
1825 | return true; |
1826 | } |
1827 | |
1828 | std::string AvailableArrayName(const Model& model, const std::string& name) { |
1829 | std::string sanitized_name = SanitizeNameForTFNode(name); |
1830 | if (!model.HasArray(sanitized_name) && |
1831 | !model.IsOptionalArray(sanitized_name)) { |
1832 | return sanitized_name; |
1833 | } |
1834 | const int kNumSuffixesToTry = 1000; |
1835 | for (int i = 0; i < kNumSuffixesToTry; i++) { |
1836 | const std::string& name_with_suffix = |
1837 | toco::port::StringF("%s_%d" , sanitized_name, i); |
1838 | if (!model.HasArray(name_with_suffix) && |
1839 | !model.IsOptionalArray(name_with_suffix)) { |
1840 | return name_with_suffix; |
1841 | } |
1842 | } |
1843 | LOG(FATAL) << "Could not find an available array name starting with " |
1844 | << sanitized_name << ". Tried " << kNumSuffixesToTry |
1845 | << " suffixes, all were taken!" ; |
1846 | return "" ; |
1847 | } |
1848 | |
1849 | std::string ShapeToString(const Shape& shape) { |
1850 | if (shape.dimensions_count() == 0) { |
1851 | return "[]" ; |
1852 | } |
1853 | |
1854 | return absl::StrCat("[ " , absl::StrJoin(shape.dims(), ", " ), " ]" ); |
1855 | } |
1856 | |
1857 | void PrintArrayShape(Model* model, const std::string& name) { |
1858 | if (!model->GetArray(name).has_shape()) { |
1859 | LOG(INFO) << name << " has no shape" ; |
1860 | return; |
1861 | } |
1862 | LOG(INFO) << name |
1863 | << " has shape: " << ShapeToString(model->GetArray(name).shape()); |
1864 | } |
1865 | |
1866 | bool IsArrayFullyConnectedWeights(const Model& model, const std::string& name) { |
1867 | bool is_fc_weights = false; |
1868 | bool is_something_else = false; |
1869 | for (const auto& op : model.operators) { |
1870 | for (int input_index = 0; input_index < op->inputs.size(); input_index++) { |
1871 | if (op->inputs[input_index] == name) { |
1872 | if (op->type == OperatorType::kFullyConnected && input_index == 1) { |
1873 | is_fc_weights = true; |
1874 | } else { |
1875 | is_something_else = true; |
1876 | } |
1877 | } |
1878 | } |
1879 | } |
1880 | CHECK(!(is_fc_weights && is_something_else)); |
1881 | return is_fc_weights; |
1882 | } |
1883 | |
1884 | std::string CreateInt32Array(Model* model, const std::string& param_name, |
1885 | const std::vector<int>& value) { |
1886 | auto param_array_name = AvailableArrayName(*model, param_name); |
1887 | auto& param_array = model->GetOrCreateArray(param_array_name); |
1888 | param_array.mutable_shape()->ReplaceDims({static_cast<int>(value.size())}); |
1889 | param_array.data_type = ArrayDataType::kInt32; |
1890 | auto& param_array_data = |
1891 | param_array.GetMutableBuffer<ArrayDataType::kInt32>().data; |
1892 | param_array_data.resize(RequiredBufferSizeForShape(param_array.shape())); |
1893 | for (int i = 0; i < value.size(); ++i) { |
1894 | param_array_data[i] = value[i]; |
1895 | } |
1896 | return param_array_name; |
1897 | } |
1898 | |
1899 | bool EstimateArithmeticOpsCount(const Model& model, const Operator& op, |
1900 | int64_t* result) { |
1901 | switch (op.type) { |
1902 | case OperatorType::kFullyConnected: |
1903 | case OperatorType::kConv: |
1904 | case OperatorType::kDepthwiseConv: { |
1905 | const auto& output_array = model.GetArray(op.outputs[0]); |
1906 | const auto& weights_array = model.GetArray(op.inputs[1]); |
1907 | if (!output_array.has_shape() || !weights_array.has_shape()) { |
1908 | return false; |
1909 | } |
1910 | int64_t cols = 1; |
1911 | for (int i = 0; i < output_array.shape().dimensions_count() - 1; i++) { |
1912 | cols *= output_array.shape().dims(i); |
1913 | } |
1914 | const int64_t cost_per_col = |
1915 | 2 * RequiredBufferSizeForShape(weights_array.shape()); |
1916 | *result = cost_per_col * cols; |
1917 | if (op.inputs.size() > 2) { |
1918 | // There is a bias vector. One more op per output value. |
1919 | *result += RequiredBufferSizeForShape(output_array.shape()); |
1920 | } |
1921 | break; |
1922 | } |
1923 | case OperatorType::kTransposeConv: { |
1924 | const auto& input_array = model.GetArray(op.inputs[2]); |
1925 | const auto& weights_array = model.GetArray(op.inputs[1]); |
1926 | if (!input_array.has_shape() || !weights_array.has_shape()) { |
1927 | return false; |
1928 | } |
1929 | const Shape& input = input_array.shape(); |
1930 | const Shape& weights = weights_array.shape(); |
1931 | // Compute op count from the seven nested loops of |
1932 | // tflite::reference_ops::TransposeConv(): |
1933 | *result = 2 * input.dims(0) * input.dims(1) * input.dims(2) * |
1934 | input.dims(3) * weights.dims(1) * weights.dims(2) * |
1935 | weights.dims(0); |
1936 | // Note that tflite::optimized_ops::TransposeConv() uses an im2col matrix |
1937 | // and has a higher op count, by a factor of (output_height*output_width) |
1938 | // vs. (input_height*input_width). Yet it generally performs better |
1939 | // because of coherent memory access. (At least for 2x2 striding. But not |
1940 | // likely for all cases.) |
1941 | break; |
1942 | } |
1943 | case OperatorType::kAdd: |
1944 | case OperatorType::kSub: |
1945 | case OperatorType::kMul: { |
1946 | const auto& output_array = model.GetArray(op.outputs[0]); |
1947 | if (!output_array.has_shape()) { |
1948 | return false; |
1949 | } |
1950 | *result = RequiredBufferSizeForShape(output_array.shape()); |
1951 | break; |
1952 | } |
1953 | case OperatorType::kAddN: { |
1954 | const auto& output_array = model.GetArray(op.outputs[0]); |
1955 | if (!output_array.has_shape()) { |
1956 | return false; |
1957 | } |
1958 | // AddN cost is roughly the same cost as N-1 Adds. |
1959 | const int64_t num_adds = op.inputs.size() - 1; |
1960 | *result = num_adds * RequiredBufferSizeForShape(output_array.shape()); |
1961 | break; |
1962 | } |
1963 | case OperatorType::kLogistic: |
1964 | case OperatorType::kSoftmax: |
1965 | case OperatorType::kLogSoftmax: |
1966 | case OperatorType::kTanh: { |
1967 | const auto& output_array = model.GetArray(op.outputs[0]); |
1968 | if (!output_array.has_shape()) { |
1969 | return false; |
1970 | } |
1971 | // As a very rough ballpark, the cost of evaluating a math function |
1972 | // such as tanh or logistic is about 32 multiplications, and about as |
1973 | // many additions/subtractions. (Just a power-of-two order-of-magnitude |
1974 | // from looking at actual implementations that we use in runtime/ code). |
1975 | *result = 64 * RequiredBufferSizeForShape(output_array.shape()); |
1976 | break; |
1977 | } |
1978 | case OperatorType::kMaxPool: { |
1979 | const auto& maxpool = *static_cast<const MaxPoolOperator*>(&op); |
1980 | const auto& output_array = model.GetArray(op.outputs[0]); |
1981 | if (!output_array.has_shape()) { |
1982 | return false; |
1983 | } |
1984 | *result = RequiredBufferSizeForShape(output_array.shape()) * |
1985 | maxpool.kheight * maxpool.kwidth; |
1986 | break; |
1987 | } |
1988 | case OperatorType::kAveragePool: { |
1989 | const auto& avgpool = *static_cast<const AveragePoolOperator*>(&op); |
1990 | const auto& output_array = model.GetArray(op.outputs[0]); |
1991 | if (!output_array.has_shape()) { |
1992 | return false; |
1993 | } |
1994 | *result = RequiredBufferSizeForShape(output_array.shape()) * |
1995 | avgpool.kheight * avgpool.kwidth; |
1996 | break; |
1997 | } |
1998 | case OperatorType::kL2Pool: { |
1999 | const auto* maxpool = static_cast<const MaxPoolOperator*>(&op); |
2000 | const auto& output_array = model.GetArray(op.outputs[0]); |
2001 | if (!output_array.has_shape()) { |
2002 | return false; |
2003 | } |
2004 | // The sum of squares requires (kheight*kwidth) multiply-adds, |
2005 | // and then there is the sqrt which we ballpark at 32 ops. |
2006 | const int64_t cost_per_val = 2 * maxpool->kheight * maxpool->kwidth + 32; |
2007 | *result = RequiredBufferSizeForShape(output_array.shape()) * cost_per_val; |
2008 | break; |
2009 | } |
2010 | case OperatorType::kL2Normalization: { |
2011 | const auto& output_array = model.GetArray(op.outputs[0]); |
2012 | if (!output_array.has_shape()) { |
2013 | return false; |
2014 | } |
2015 | // Computing the squared L2 norm is N multiply-adds so 2N ops, |
2016 | // then the single inverse-sqrt is negligible, then we multiply each |
2017 | // value by the resulting multiplier, so an extra N ops. count 3N ops. |
2018 | *result = 3 * RequiredBufferSizeForShape(output_array.shape()); |
2019 | break; |
2020 | } |
2021 | default: |
2022 | *result = 0; |
2023 | break; |
2024 | } |
2025 | return true; |
2026 | } |
2027 | |
2028 | bool EstimateArithmeticOpsCount(const Model& model, int64_t* result) { |
2029 | int64_t total = 0; |
2030 | for (const auto& op : model.operators) { |
2031 | int64_t num_ops; |
2032 | if (!EstimateArithmeticOpsCount(model, *op, &num_ops)) { |
2033 | return false; |
2034 | } |
2035 | total += num_ops; |
2036 | } |
2037 | *result = total; |
2038 | return true; |
2039 | } |
2040 | |
2041 | std::string FormattedNumber(int64_t x) { |
2042 | const int64_t million = 1000000; |
2043 | const int64_t billion = 1000000000; |
2044 | if (x < 10000) { |
2045 | return toco::port::StringF("%d " , x); |
2046 | } else if (x < billion) { |
2047 | return toco::port::StringF("%.3f M" , static_cast<double>(x) / million); |
2048 | } else { |
2049 | return toco::port::StringF("%.3f G" , static_cast<double>(x) / billion); |
2050 | } |
2051 | } |
2052 | |
2053 | void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order, |
2054 | std::vector<int>* shuffle) { |
2055 | CHECK_EQ(AxesCount(input_axes_order), AxesCount(output_axes_order)); |
2056 | shuffle->resize(4); |
2057 | for (int i = 0; i < 4; i++) { |
2058 | (*shuffle)[i] = i; |
2059 | } |
2060 | if (input_axes_order == output_axes_order) { |
2061 | // nothing to do |
2062 | } else if (AxesCount(input_axes_order) == 2) { |
2063 | shuffle->resize(2); |
2064 | (*shuffle)[0] = 1; |
2065 | (*shuffle)[1] = 0; |
2066 | } else if (input_axes_order == AxesOrder::kOHWI && |
2067 | output_axes_order == AxesOrder::kHWIO) { |
2068 | // 3210 <- 3210 |
2069 | // HWIO <- OHWI |
2070 | *shuffle = {1, 2, 3, 0}; |
2071 | } else if (input_axes_order == AxesOrder::kHWIO && |
2072 | output_axes_order == AxesOrder::kOHWI) { |
2073 | // 3210 <- 3210 |
2074 | // OHWI <- HWIO |
2075 | *shuffle = {3, 0, 1, 2}; |
2076 | } else if (input_axes_order == AxesOrder::kOHWI && |
2077 | output_axes_order == AxesOrder::kHWOI) { |
2078 | *shuffle = {1, 2, 0, 3}; |
2079 | } else { |
2080 | LOG(FATAL) << "Bad shuffle" ; |
2081 | } |
2082 | } |
2083 | |
2084 | void ExtendShuffle(const std::vector<int>& input_shuffle, int newdim, |
2085 | std::vector<int>* extended_shuffle) { |
2086 | *extended_shuffle = input_shuffle; |
2087 | CHECK(newdim >= input_shuffle.size()); |
2088 | const int pad_size = newdim - input_shuffle.size(); |
2089 | extended_shuffle->resize(newdim); |
2090 | for (int i = 0; i < pad_size; i++) { |
2091 | (*extended_shuffle)[i] = i; |
2092 | } |
2093 | for (int i = pad_size; i < newdim; i++) { |
2094 | (*extended_shuffle)[i] = input_shuffle[i - pad_size] + pad_size; |
2095 | } |
2096 | } |
2097 | |
2098 | void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order, |
2099 | AxesOrder output_axes_order, Shape* output_shape) { |
2100 | if (input_axes_order == AxesOrder::kHWIM && |
2101 | output_axes_order == AxesOrder::k1HWO) { |
2102 | // This special case isn't just a permutation, the IM pair of dims get |
2103 | // merged into the 3 dim, so we have to special-case it. |
2104 | *output_shape = Shape({1, input_shape.dims(0), input_shape.dims(1), |
2105 | input_shape.dims(3) * input_shape.dims(2)}); |
2106 | } else { |
2107 | std::vector<int> shuffle; |
2108 | GetShuffleShape(input_axes_order, output_axes_order, &shuffle); |
2109 | std::vector<int>* output_dims = output_shape->mutable_dims(); |
2110 | output_dims->resize(input_shape.dimensions_count()); |
2111 | for (int i = 0; i < input_shape.dimensions_count(); i++) { |
2112 | (*output_dims)[i] = input_shape.dims(shuffle[i]); |
2113 | } |
2114 | } |
2115 | } |
2116 | |
2117 | template <typename T> |
2118 | void ShuffleArrayTemplate(const Shape& input_shape, AxesOrder input_axes_order, |
2119 | AxesOrder output_axes_order, |
2120 | const Shape& output_shape, const T* input_data, |
2121 | T* output_data) { |
2122 | if (input_axes_order == AxesOrder::kHWIM && |
2123 | output_axes_order == AxesOrder::k1HWO) { |
2124 | // This special case isn't just a permutation, the IM pair of dims get |
2125 | // merged into the O dim, so we have to special-case it. Fortunately, |
2126 | // as far as array shuffling is concerned, it's just the identity |
2127 | // transformation. |
2128 | memcpy(output_data, input_data, |
2129 | RequiredBufferSizeForShape(input_shape) * sizeof(output_data[0])); |
2130 | return; |
2131 | } |
2132 | CHECK(input_shape.dimensions_count() == output_shape.dimensions_count()); |
2133 | const int dim = input_shape.dimensions_count(); |
2134 | CHECK_LE(dim, 4); |
2135 | std::vector<int> shuffle; |
2136 | GetShuffleShape(input_axes_order, output_axes_order, &shuffle); |
2137 | CHECK(shuffle.size() >= dim); |
2138 | for (int i = 0; i < dim; i++) { |
2139 | CHECK(shuffle[i] >= 0 && shuffle[i] < dim); |
2140 | CHECK(input_shape.dims(shuffle[i]) == output_shape.dims(i)); |
2141 | } |
2142 | Shape extended_input_shape = input_shape; |
2143 | ExtendShape(&extended_input_shape, 4); |
2144 | Shape extended_output_shape = output_shape; |
2145 | ExtendShape(&extended_output_shape, 4); |
2146 | std::vector<int> extended_shuffle; |
2147 | ExtendShuffle(shuffle, 4, &extended_shuffle); |
2148 | |
2149 | const std::vector<int>& extended_input_dims = extended_input_shape.dims(); |
2150 | const std::vector<int>& extended_output_dims = extended_output_shape.dims(); |
2151 | |
2152 | // TODO(starka): Rework to handle different numbers of dimensions. |
2153 | int input_strides[4]; |
2154 | input_strides[3] = 1; |
2155 | input_strides[2] = extended_input_dims[3]; |
2156 | input_strides[1] = input_strides[2] * extended_input_dims[2]; |
2157 | input_strides[0] = input_strides[1] * extended_input_dims[1]; |
2158 | const int input_stride_0 = input_strides[extended_shuffle[3]]; |
2159 | const int input_stride_1 = input_strides[extended_shuffle[2]]; |
2160 | const int input_stride_2 = input_strides[extended_shuffle[1]]; |
2161 | const int input_stride_3 = input_strides[extended_shuffle[0]]; |
2162 | |
2163 | const int output_size_0 = extended_output_dims[3]; |
2164 | const int output_size_1 = extended_output_dims[2]; |
2165 | const int output_size_2 = extended_output_dims[1]; |
2166 | const int output_size_3 = extended_output_dims[0]; |
2167 | const int output_stride_0 = 1; |
2168 | const int output_stride_1 = output_size_0; |
2169 | const int output_stride_2 = output_stride_1 * output_size_1; |
2170 | const int output_stride_3 = output_stride_2 * output_size_2; |
2171 | |
2172 | for (int i3 = 0; i3 < output_size_3; i3++) { |
2173 | const T* const input_ptr_3 = input_data + i3 * input_stride_3; |
2174 | T* const output_ptr_3 = output_data + i3 * output_stride_3; |
2175 | for (int i2 = 0; i2 < output_size_2; i2++) { |
2176 | const T* const input_ptr_2 = input_ptr_3 + i2 * input_stride_2; |
2177 | T* const output_ptr_2 = output_ptr_3 + i2 * output_stride_2; |
2178 | for (int i1 = 0; i1 < output_size_1; i1++) { |
2179 | const T* input_ptr = input_ptr_2 + i1 * input_stride_1; |
2180 | T* output_ptr = output_ptr_2 + i1 * output_stride_1; |
2181 | T* const output_ptr_end = output_ptr + output_size_0 * output_stride_0; |
2182 | while (output_ptr != output_ptr_end) { |
2183 | *output_ptr = *input_ptr; |
2184 | input_ptr += input_stride_0; |
2185 | output_ptr += output_stride_0; |
2186 | } |
2187 | } |
2188 | } |
2189 | } |
2190 | } |
2191 | |
2192 | void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order, |
2193 | AxesOrder output_axes_order, const Shape& output_shape, |
2194 | const uint8* input_data, uint8* output_data) { |
2195 | ShuffleArrayTemplate<uint8>(input_shape, input_axes_order, output_axes_order, |
2196 | output_shape, input_data, output_data); |
2197 | } |
2198 | |
2199 | void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order, |
2200 | AxesOrder output_axes_order, const Shape& output_shape, |
2201 | const float* input_data, float* output_data) { |
2202 | ShuffleArrayTemplate<float>(input_shape, input_axes_order, output_axes_order, |
2203 | output_shape, input_data, output_data); |
2204 | } |
2205 | |
2206 | int AxesCount(AxesOrder axes_order) { |
2207 | switch (axes_order) { |
2208 | case AxesOrder::kOneAxis: |
2209 | return 1; |
2210 | case AxesOrder::kRC: |
2211 | return 2; |
2212 | case AxesOrder::kCR: |
2213 | return 2; |
2214 | case AxesOrder::kHWIO: |
2215 | return 4; |
2216 | case AxesOrder::kOHWI: |
2217 | return 4; |
2218 | case AxesOrder::kHWIM: |
2219 | return 4; |
2220 | case AxesOrder::k1HWO: |
2221 | return 4; |
2222 | case AxesOrder::kNHWC: |
2223 | return 4; |
2224 | case AxesOrder::kHWOI: |
2225 | return 4; |
2226 | default: |
2227 | LOG(FATAL) << "Bad AxesOrder" ; |
2228 | return 0; |
2229 | } |
2230 | } |
2231 | |
2232 | bool IsDiscardableArray(const Model& model, const std::string& array_name) { |
2233 | if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) { |
2234 | return false; |
2235 | } |
2236 | for (const auto& rnn_state : model.flags.rnn_states()) { |
2237 | if (!rnn_state.discardable()) { |
2238 | if (array_name == rnn_state.state_array()) { |
2239 | return false; |
2240 | } |
2241 | if (array_name == rnn_state.back_edge_source_array()) { |
2242 | return false; |
2243 | } |
2244 | } |
2245 | } |
2246 | return true; |
2247 | } |
2248 | |
2249 | bool ReshapeIsEquivalentToTranspose(const Model& model, |
2250 | const TensorFlowReshapeOperator* op, |
2251 | bool ) { |
2252 | CHECK(!op->shape.empty()); |
2253 | CHECK(model.HasArray(op->inputs[0])); |
2254 | CHECK(model.HasArray(op->outputs[0])); |
2255 | |
2256 | const auto& input_array = model.GetArray(op->inputs[0]); |
2257 | const auto& output_array = model.GetArray(op->outputs[0]); |
2258 | |
2259 | CHECK(input_array.has_shape()); |
2260 | CHECK(output_array.has_shape()); |
2261 | |
2262 | std::vector<int> in_shape = input_array.shape().dims(); |
2263 | std::vector<int> out_shape = output_array.shape().dims(); |
2264 | |
2265 | // If the reshape changes the number of dimensions so it cannot be interpreted |
2266 | // as a transpose. |
2267 | if (!allow_extra_unary_dims && in_shape.size() != out_shape.size()) { |
2268 | return false; |
2269 | } |
2270 | |
2271 | in_shape.erase(std::remove(in_shape.begin(), in_shape.end(), 1), |
2272 | in_shape.end()); |
2273 | out_shape.erase(std::remove(out_shape.begin(), out_shape.end(), 1), |
2274 | out_shape.end()); |
2275 | return in_shape == out_shape; |
2276 | } |
2277 | |
2278 | void CheckFinalDataTypesSatisfied(const Model& model) { |
2279 | for (const auto& array_entry : model.GetArrayMap()) { |
2280 | const auto& array = *array_entry.second; |
2281 | if (array.data_type == ArrayDataType::kBool) { |
2282 | // Boolean values are never quantized. |
2283 | continue; |
2284 | } |
2285 | |
2286 | // If the final data type is int16, the data type may be float, for example |
2287 | // after dequantization. |
2288 | if (array.final_data_type != ArrayDataType::kNone && |
2289 | array.final_data_type != ArrayDataType::kInt16) { |
2290 | CHECK(array.data_type == array.final_data_type) |
2291 | << "Array \"" << array_entry.first |
2292 | << "\" has mis-matching actual and final data types (data_type=" |
2293 | << ArrayDataTypeName(array.data_type) |
2294 | << ", final_data_type=" << ArrayDataTypeName(array.final_data_type) |
2295 | << ")." ; |
2296 | } |
2297 | } |
2298 | } |
2299 | |
2300 | ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) { |
2301 | switch (type) { |
2302 | case FLOAT: |
2303 | return ArrayDataType::kFloat; |
2304 | case UINT8: |
2305 | case QUANTIZED_UINT8: |
2306 | return ArrayDataType::kUint8; |
2307 | case INT8: |
2308 | case QUANTIZED_INT8: |
2309 | return ArrayDataType::kInt8; |
2310 | case INT16: |
2311 | case QUANTIZED_INT16: |
2312 | return ArrayDataType::kInt16; |
2313 | case UINT16: |
2314 | return ArrayDataType::kUint16; |
2315 | case INT32: |
2316 | return ArrayDataType::kInt32; |
2317 | case UINT32: |
2318 | return ArrayDataType::kUint32; |
2319 | case INT64: |
2320 | return ArrayDataType::kInt64; |
2321 | case UINT64: |
2322 | return ArrayDataType::kUint64; |
2323 | case BOOL: |
2324 | return ArrayDataType::kBool; |
2325 | case STRING: |
2326 | return ArrayDataType::kString; |
2327 | case COMPLEX64: |
2328 | return ArrayDataType::kComplex64; |
2329 | case COMPLEX128: |
2330 | return ArrayDataType::kComplex128; |
2331 | case FLOAT16: |
2332 | return ArrayDataType::kFloat16; |
2333 | case FLOAT64: |
2334 | return ArrayDataType::kFloat64; |
2335 | case RESOURCE: |
2336 | case VARIANT: |
2337 | default: |
2338 | return ArrayDataType::kNone; |
2339 | } |
2340 | } |
2341 | |
2342 | void FinishBuildingRNNStates(Model* model) { |
2343 | for (const auto& rnn_state : model->flags.rnn_states()) { |
2344 | if (!model->HasArray(rnn_state.back_edge_source_array()) || |
2345 | !model->HasArray(rnn_state.state_array())) { |
2346 | CHECK(model->HasArray(rnn_state.back_edge_source_array())); |
2347 | CHECK(model->HasArray(rnn_state.state_array())); |
2348 | continue; |
2349 | } |
2350 | const auto& src_array = model->GetArray(rnn_state.back_edge_source_array()); |
2351 | auto& dst_array = model->GetArray(rnn_state.state_array()); |
2352 | if (src_array.data_type == ArrayDataType::kNone && |
2353 | dst_array.data_type == ArrayDataType::kNone) { |
2354 | dst_array.data_type = ArrayDataType::kFloat; |
2355 | } |
2356 | } |
2357 | } |
2358 | |
2359 | // Returns the array names that match the ArraysExtraInfo's name and |
2360 | // name_regexp. The regexp match is for a full match. |
2361 | std::unordered_set<std::string> ( |
2362 | const Model& model, const toco::ArraysExtraInfo_Entry& entry) { |
2363 | std::unordered_set<std::string> matches; |
2364 | if (model.HasArray(entry.name())) { |
2365 | matches.insert(entry.name()); |
2366 | } |
2367 | if (!entry.name_regexp().empty()) { |
2368 | const auto& arrays = model.GetArrayMap(); |
2369 | const RE2 name_regexp = {entry.name_regexp()}; |
2370 | for (auto it = arrays.begin(); it != arrays.end(); ++it) { |
2371 | if (RE2::FullMatch(it->first, name_regexp)) { |
2372 | matches.insert(it->first); |
2373 | } |
2374 | } |
2375 | } |
2376 | return matches; |
2377 | } |
2378 | |
2379 | void (Model* model, bool quantize_output) { |
2380 | for (const auto& entry : model->flags.arrays_extra_info().entries()) { |
2381 | const auto matches = ScanArrayNames(*model, entry); |
2382 | if (matches.empty()) { |
2383 | LOG(ERROR) << "arrays_extra_info_file: No matching arrays found for " |
2384 | << (entry.has_name() ? entry.name() : "" ) |
2385 | << (entry.has_name_regexp() ? entry.name_regexp() : "" ); |
2386 | continue; |
2387 | } |
2388 | for (const auto& matched_name : matches) { |
2389 | auto& array = model->GetArray(matched_name); |
2390 | if (entry.has_min() || entry.has_max()) { |
2391 | CHECK_EQ(entry.has_min(), entry.has_max()); |
2392 | auto& minmax = array.GetOrCreateMinMax(); |
2393 | minmax.min = entry.min(); |
2394 | minmax.max = entry.max(); |
2395 | } |
2396 | if (entry.has_data_type() && quantize_output) { |
2397 | array.final_data_type = |
2398 | ConvertIODataTypeToArrayDataType(entry.data_type()); |
2399 | } |
2400 | if (entry.has_shape()) { |
2401 | array.clear_shape(); |
2402 | // Make sure to create the shape even if there are no dims, to |
2403 | // correctly record 0-D shapes. |
2404 | array.mutable_shape(); |
2405 | for (const auto& dim : entry.shape().dims()) { |
2406 | array.mutable_shape()->mutable_dims()->push_back(dim); |
2407 | } |
2408 | } |
2409 | if (entry.has_constant_float_value()) { |
2410 | CHECK(array.has_shape()); |
2411 | if (array.data_type == ArrayDataType::kFloat) { |
2412 | auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data; |
2413 | data.resize(RequiredBufferSizeForShape(array.shape())); |
2414 | for (float& f : data) { |
2415 | f = entry.constant_float_value(); |
2416 | } |
2417 | } |
2418 | } |
2419 | } |
2420 | } |
2421 | } |
2422 | |
2423 | void UndoWeightsShuffling(Model* model) { |
2424 | for (const auto& op : model->operators) { |
2425 | if (op->type != toco::OperatorType::kFullyConnected) { |
2426 | continue; |
2427 | } |
2428 | const auto& fc_op = static_cast<toco::FullyConnectedOperator&>(*op); |
2429 | if (fc_op.weights_format == FullyConnectedWeightsFormat::kDefault) { |
2430 | continue; |
2431 | } |
2432 | const std::string& weights_name = fc_op.inputs[1]; |
2433 | QCHECK_EQ(CountOpsWithInput(*model, weights_name), 1); |
2434 | auto& weights_array = model->GetArray(weights_name); |
2435 | QCHECK(weights_array.data_type == ArrayDataType::kUint8); |
2436 | auto& weights_data = |
2437 | weights_array.GetMutableBuffer<toco::ArrayDataType::kUint8>().data; |
2438 | const auto& weights_shape = weights_array.shape(); |
2439 | QCHECK_EQ(weights_shape.dimensions_count(), 2); |
2440 | const int rows = weights_shape.dims(0); |
2441 | const int cols = weights_shape.dims(1); |
2442 | QCHECK_EQ(rows % 4, 0); |
2443 | QCHECK_EQ(cols % 16, 0); |
2444 | CHECK_EQ(rows * cols, weights_data.size()); |
2445 | // Compute the de-shuffled weights |
2446 | std::vector<uint8> deshuffled_data(weights_data.size()); |
2447 | uint8* shuffled_data_ptr = weights_data.data(); |
2448 | for (int r = 0; r < rows; r += 4) { |
2449 | for (int c = 0; c < cols; c += 16) { |
2450 | for (int i = 0; i < 4; i++) { |
2451 | uint8* deshuffled_data_ptr = |
2452 | deshuffled_data.data() + (r + i) * cols + c; |
2453 | for (int j = 0; j < 16; j++) { |
2454 | uint8 shuffled_val = *shuffled_data_ptr++; |
2455 | // Deshuffling isn't only about deshuffling the storage layout, |
2456 | // it's also about undoing the flipping of the sign bit, which is |
2457 | // performed on the shuffled weights. |
2458 | uint8 deshuffled_val = shuffled_val ^ 0x80; |
2459 | *deshuffled_data_ptr++ = deshuffled_val; |
2460 | } |
2461 | } |
2462 | } |
2463 | } |
2464 | CHECK_EQ(shuffled_data_ptr, weights_data.data() + rows * cols); |
2465 | // Switch this FC op to using the deshuffled weights. |
2466 | weights_data = std::move(deshuffled_data); |
2467 | } |
2468 | } |
2469 | |
2470 | void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst) { |
2471 | if (src.minmax) { |
2472 | dst->GetOrCreateMinMax() = src.GetMinMax(); |
2473 | } |
2474 | if (src.quantization_params) { |
2475 | dst->GetOrCreateQuantizationParams() = src.GetQuantizationParams(); |
2476 | } |
2477 | dst->narrow_range = src.narrow_range; |
2478 | } |
2479 | |
2480 | } // namespace toco |
2481 | |