1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/toco/tooling_util.h"
16
17#include <algorithm>
18#include <functional>
19#include <iterator>
20#include <set>
21#include <string>
22#include <unordered_map>
23#include <unordered_set>
24#include <utility>
25
26#include "absl/strings/ascii.h"
27#include "absl/strings/str_cat.h"
28#include "absl/strings/str_join.h"
29#include "absl/strings/str_replace.h"
30#include "absl/strings/str_split.h"
31#include "re2/re2.h"
32#include "tensorflow/core/lib/core/status.h"
33#include "tensorflow/core/platform/logging.h"
34#include "tensorflow/lite/toco/dump_graphviz.h"
35#include "tensorflow/lite/toco/model_flags.pb.h"
36#include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
37
38namespace toco {
39
40// Find the longest common prefix of two strings.
41absl::string_view FindLongestCommonPrefix(absl::string_view a,
42 absl::string_view b) {
43 if (a.empty() || b.empty()) return absl::string_view();
44
45 const char* pa = a.data();
46 const char* pb = b.data();
47 size_t count = 0;
48 const size_t limit = std::min(a.size(), b.size());
49 while (count < limit && *pa == *pb) {
50 ++pa;
51 ++pb;
52 ++count;
53 }
54
55 return absl::string_view(a.data(), count);
56}
57
58std::string LogName(const Operator& op) {
59 const std::string& opname = HelpfulOperatorTypeName(op);
60 if (op.outputs.empty()) {
61 return toco::port::StringF("{%s operator}", opname);
62 } else {
63 return toco::port::StringF("{%s operator with output %s}", opname,
64 op.outputs[0]);
65 }
66}
67
68std::string ArrayDataTypeName(ArrayDataType data_type) {
69 switch (data_type) {
70 case ArrayDataType::kFloat:
71 return "float";
72 case ArrayDataType::kInt8:
73 return "int8";
74 case ArrayDataType::kUint8:
75 return "uint8";
76 case ArrayDataType::kInt16:
77 return "int16";
78 case ArrayDataType::kUint16:
79 return "uint16";
80 case ArrayDataType::kInt32:
81 return "int32";
82 case ArrayDataType::kUint32:
83 return "uint32";
84 case ArrayDataType::kInt64:
85 return "int64";
86 case ArrayDataType::kUint64:
87 return "uint64";
88 case ArrayDataType::kString:
89 return "string";
90 case ArrayDataType::kBool:
91 return "bool";
92 case ArrayDataType::kComplex64:
93 return "complex64";
94 case ArrayDataType::kNone:
95 return "None";
96 default:
97 LOG(FATAL) << "Unhandled array data type " << static_cast<int>(data_type);
98 }
99}
100
101bool IsInputArray(const Model& model, const std::string& array_name) {
102 for (const auto& input_array : model.flags.input_arrays()) {
103 if (array_name == input_array.name()) {
104 return true;
105 }
106 }
107 return false;
108}
109
110bool IsOutputArray(const Model& model, const std::string& array_name) {
111 for (const auto& output_array : model.flags.output_arrays()) {
112 if (array_name == output_array) {
113 return true;
114 }
115 }
116 return false;
117}
118
119bool IsArrayConsumed(const Model& model, const std::string& name) {
120 if (GetOpWithInput(model, name)) {
121 return true;
122 }
123 if (IsOutputArray(model, name)) {
124 return true;
125 }
126 for (const auto& rnn_state : model.flags.rnn_states()) {
127 if (rnn_state.back_edge_source_array() == name) {
128 return true;
129 }
130 }
131 return false;
132}
133
134int CountTrueOutputs(const Model& model, const Operator& op) {
135 int count = 0;
136 for (const std::string& output : op.outputs) {
137 if (IsArrayConsumed(model, output)) {
138 ++count;
139 }
140 }
141 return count;
142}
143
144int CountOpsWithInput(const Model& model, const std::string& array_name) {
145 int count = 0;
146 for (const auto& op : model.operators) {
147 for (auto& input : op->inputs) {
148 if (input == array_name) {
149 count++;
150 // Breaking here is important: some graphs have ops that use the
151 // same array as more than one of their inputs, and in that case
152 // we want it counted only once.
153 break;
154 }
155 }
156 }
157 return count;
158}
159
160bool DeleteArrayIfUnused(const std::string& array_name, Model* model) {
161 if (IsDiscardableArray(*model, array_name) &&
162 CountOpsWithInput(*model, array_name) == 0 &&
163 GetOpWithOutput(*model, array_name) == nullptr) {
164 model->EraseArray(array_name);
165 return true;
166 }
167 return false;
168}
169
170bool DeleteArrayIfUnusedOutsideOfOp(const std::string& array_name,
171 const Operator* op, Model* model) {
172 if (!IsDiscardableArray(*model, array_name)) {
173 return false;
174 }
175 if (CountOpsWithInput(*model, array_name) > 1) {
176 return false;
177 }
178 const Operator* op_having_this_as_input = GetOpWithInput(*model, array_name);
179 if (op_having_this_as_input && op_having_this_as_input != op) {
180 return false;
181 }
182 const Operator* op_having_this_as_output =
183 GetOpWithOutput(*model, array_name);
184 if (op_having_this_as_output && op_having_this_as_output != op) {
185 return false;
186 }
187 model->EraseArray(array_name);
188 return true;
189}
190
191void DeleteOpAndArrays(Model* model, const Operator* op) {
192 for (const std::string& array_name : op->inputs) {
193 DeleteArrayIfUnusedOutsideOfOp(array_name, op, model);
194 }
195 for (const std::string& array_name : op->outputs) {
196 DeleteArrayIfUnusedOutsideOfOp(array_name, op, model);
197 }
198 auto op_it = FindOp(*model, op);
199 CHECK(op_it != model->operators.end());
200 model->operators.erase(op_it);
201}
202
203std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithOutput(
204 const Model& model, const std::string& array_name) {
205 for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
206 for (auto& output : it->get()->outputs) {
207 if (output == array_name) {
208 return it;
209 }
210 }
211 }
212 return model.operators.end();
213}
214
215std::vector<std::unique_ptr<Operator>>::iterator FindOpWithOutput(
216 Model& model, const std::string& array_name) {
217 for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
218 for (auto& output : it->get()->outputs) {
219 if (output == array_name) {
220 return it;
221 }
222 }
223 }
224 return model.operators.end();
225}
226
227Operator* GetOpWithOutput(const Model& model, const std::string& array_name) {
228 auto it = FindOpWithOutput(model, array_name);
229 return it == model.operators.end() ? nullptr : it->get();
230}
231
232// GetFirstOpWithInput assumes that this finds the first op.
233std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithInput(
234 const Model& model, const std::string& array_name) {
235 for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
236 for (auto& input : it->get()->inputs) {
237 if (input == array_name) {
238 return it;
239 }
240 }
241 }
242 return model.operators.end();
243}
244
245std::vector<std::unique_ptr<Operator>>::iterator FindOpWithInput(
246 Model& model, const std::string& array_name) {
247 for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
248 for (auto& input : it->get()->inputs) {
249 if (input == array_name) {
250 return it;
251 }
252 }
253 }
254 return model.operators.end();
255}
256
257std::vector<std::unique_ptr<Operator>>::const_iterator FindOp(
258 const Model& model, const Operator* op) {
259 for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
260 if (it->get() == op) {
261 return it;
262 }
263 }
264 return model.operators.end();
265}
266
267std::vector<std::unique_ptr<Operator>>::iterator FindOp(Model& model,
268 const Operator* op) {
269 for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
270 if (it->get() == op) {
271 return it;
272 }
273 }
274 return model.operators.end();
275}
276
277Operator* GetOpWithInput(const Model& model, const std::string& array_name) {
278 auto it = FindOpWithInput(model, array_name);
279 return it == model.operators.end() ? nullptr : it->get();
280}
281
282Operator* GetFirstOpWithInput(const Model& model,
283 const std::string& array_name) {
284 auto it = FindOpWithInput(model, array_name);
285 return it == model.operators.end() ? nullptr : it->get();
286}
287
288void ReplaceArrayUsage(Model* model, const std::string& old_array_name,
289 const std::string& new_array_name) {
290 for (auto& op_it : model->operators) {
291 Operator* op = op_it.get();
292 for (size_t i = 0; i < op->inputs.size(); ++i) {
293 if (op->inputs[i] == old_array_name) {
294 op->inputs[i] = new_array_name;
295 }
296 }
297 for (size_t i = 0; i < op->outputs.size(); ++i) {
298 if (op->outputs[i] == old_array_name) {
299 op->outputs[i] = new_array_name;
300 }
301 }
302 }
303}
304
305std::string FormatArraysList(const Model& model,
306 const std::vector<std::string>& list) {
307 if (list.empty()) {
308 return "[]";
309 }
310 std::string result = "";
311 if (list.size() > 1) {
312 result += "[ ";
313 }
314 for (std::size_t i = 0; i < list.size(); i++) {
315 if (i > 0) {
316 result += ", ";
317 }
318 result += list[i];
319 }
320 if (list.size() > 1) {
321 result += " ]";
322 }
323 return result;
324}
325
326const char* OperatorTypeName(OperatorType type) {
327 switch (type) {
328#define HANDLE_OPERATORTYPENAME_CASE(c) \
329 case OperatorType::k##c: \
330 return #c;
331 HANDLE_OPERATORTYPENAME_CASE(Abs)
332 HANDLE_OPERATORTYPENAME_CASE(Add)
333 HANDLE_OPERATORTYPENAME_CASE(AddN)
334 HANDLE_OPERATORTYPENAME_CASE(AveragePool)
335 HANDLE_OPERATORTYPENAME_CASE(BatchMatMul)
336 HANDLE_OPERATORTYPENAME_CASE(BatchNormalization)
337 HANDLE_OPERATORTYPENAME_CASE(Conv)
338 HANDLE_OPERATORTYPENAME_CASE(Concatenation)
339 HANDLE_OPERATORTYPENAME_CASE(DepthwiseConv)
340 HANDLE_OPERATORTYPENAME_CASE(DepthToSpace)
341 HANDLE_OPERATORTYPENAME_CASE(SpaceToDepth)
342 HANDLE_OPERATORTYPENAME_CASE(FullyConnected)
343 HANDLE_OPERATORTYPENAME_CASE(HardSwish)
344 HANDLE_OPERATORTYPENAME_CASE(Dequantize)
345 HANDLE_OPERATORTYPENAME_CASE(L2Normalization)
346 HANDLE_OPERATORTYPENAME_CASE(LocalResponseNormalization)
347 HANDLE_OPERATORTYPENAME_CASE(Log)
348 HANDLE_OPERATORTYPENAME_CASE(Logistic)
349 HANDLE_OPERATORTYPENAME_CASE(LstmCell)
350 HANDLE_OPERATORTYPENAME_CASE(MaxPool)
351 HANDLE_OPERATORTYPENAME_CASE(L2Pool)
352 HANDLE_OPERATORTYPENAME_CASE(FakeQuant)
353 HANDLE_OPERATORTYPENAME_CASE(Mul)
354 HANDLE_OPERATORTYPENAME_CASE(RandomUniform)
355 HANDLE_OPERATORTYPENAME_CASE(Elu)
356 HANDLE_OPERATORTYPENAME_CASE(Relu)
357 HANDLE_OPERATORTYPENAME_CASE(Relu1)
358 HANDLE_OPERATORTYPENAME_CASE(Relu6)
359 HANDLE_OPERATORTYPENAME_CASE(PRelu)
360 HANDLE_OPERATORTYPENAME_CASE(ReorderAxes)
361 HANDLE_OPERATORTYPENAME_CASE(Softmax)
362 HANDLE_OPERATORTYPENAME_CASE(LogSoftmax)
363 HANDLE_OPERATORTYPENAME_CASE(Div)
364 HANDLE_OPERATORTYPENAME_CASE(Tanh)
365 HANDLE_OPERATORTYPENAME_CASE(Sin)
366 HANDLE_OPERATORTYPENAME_CASE(All)
367 HANDLE_OPERATORTYPENAME_CASE(Assert)
368 HANDLE_OPERATORTYPENAME_CASE(ExpandDims)
369 HANDLE_OPERATORTYPENAME_CASE(Fill)
370 HANDLE_OPERATORTYPENAME_CASE(FloorMod)
371 HANDLE_OPERATORTYPENAME_CASE(FloorDiv)
372 HANDLE_OPERATORTYPENAME_CASE(Greater)
373 HANDLE_OPERATORTYPENAME_CASE(GreaterEqual)
374 HANDLE_OPERATORTYPENAME_CASE(Identity)
375 HANDLE_OPERATORTYPENAME_CASE(Less)
376 HANDLE_OPERATORTYPENAME_CASE(LessEqual)
377 HANDLE_OPERATORTYPENAME_CASE(MatMul)
378 HANDLE_OPERATORTYPENAME_CASE(ReduceMax) // Reduction Max
379 HANDLE_OPERATORTYPENAME_CASE(Maximum) // Element-wise Maximum
380 HANDLE_OPERATORTYPENAME_CASE(Merge)
381 HANDLE_OPERATORTYPENAME_CASE(ReduceMin) // Reduction Min
382 HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum
383 HANDLE_OPERATORTYPENAME_CASE(Neg)
384 HANDLE_OPERATORTYPENAME_CASE(OneHot)
385 HANDLE_OPERATORTYPENAME_CASE(Pack)
386 HANDLE_OPERATORTYPENAME_CASE(Pad)
387 HANDLE_OPERATORTYPENAME_CASE(PadV2)
388 HANDLE_OPERATORTYPENAME_CASE(StridedSlice)
389 HANDLE_OPERATORTYPENAME_CASE(Range)
390 HANDLE_OPERATORTYPENAME_CASE(Rank)
391 HANDLE_OPERATORTYPENAME_CASE(Reshape)
392 HANDLE_OPERATORTYPENAME_CASE(Squeeze)
393 HANDLE_OPERATORTYPENAME_CASE(Rsqrt)
394 HANDLE_OPERATORTYPENAME_CASE(SegmentSum)
395 HANDLE_OPERATORTYPENAME_CASE(Shape)
396 HANDLE_OPERATORTYPENAME_CASE(Slice)
397 HANDLE_OPERATORTYPENAME_CASE(Split)
398 HANDLE_OPERATORTYPENAME_CASE(SplitV)
399 HANDLE_OPERATORTYPENAME_CASE(Sqrt)
400 HANDLE_OPERATORTYPENAME_CASE(Square)
401 HANDLE_OPERATORTYPENAME_CASE(Switch)
402 HANDLE_OPERATORTYPENAME_CASE(Sub)
403 HANDLE_OPERATORTYPENAME_CASE(Sum)
404 HANDLE_OPERATORTYPENAME_CASE(Tile)
405 HANDLE_OPERATORTYPENAME_CASE(Transpose)
406 HANDLE_OPERATORTYPENAME_CASE(TransposeConv)
407 HANDLE_OPERATORTYPENAME_CASE(Concat)
408 HANDLE_OPERATORTYPENAME_CASE(ConcatV2)
409 HANDLE_OPERATORTYPENAME_CASE(Cast)
410 HANDLE_OPERATORTYPENAME_CASE(Floor)
411 HANDLE_OPERATORTYPENAME_CASE(Ceil)
412 HANDLE_OPERATORTYPENAME_CASE(Round)
413 HANDLE_OPERATORTYPENAME_CASE(Gather)
414 HANDLE_OPERATORTYPENAME_CASE(GatherNd)
415 HANDLE_OPERATORTYPENAME_CASE(ResizeBilinear)
416 HANDLE_OPERATORTYPENAME_CASE(SpaceToBatchND)
417 HANDLE_OPERATORTYPENAME_CASE(BatchToSpaceND)
418 HANDLE_OPERATORTYPENAME_CASE(Mean)
419 HANDLE_OPERATORTYPENAME_CASE(ReduceProd)
420 HANDLE_OPERATORTYPENAME_CASE(Svdf)
421 HANDLE_OPERATORTYPENAME_CASE(ArgMax)
422 HANDLE_OPERATORTYPENAME_CASE(ArgMin)
423 HANDLE_OPERATORTYPENAME_CASE(TopK_V2)
424 HANDLE_OPERATORTYPENAME_CASE(Unsupported)
425 HANDLE_OPERATORTYPENAME_CASE(Exp)
426 HANDLE_OPERATORTYPENAME_CASE(DynamicPartition)
427 HANDLE_OPERATORTYPENAME_CASE(DynamicStitch)
428 HANDLE_OPERATORTYPENAME_CASE(Select)
429 HANDLE_OPERATORTYPENAME_CASE(SparseToDense)
430 HANDLE_OPERATORTYPENAME_CASE(Equal)
431 HANDLE_OPERATORTYPENAME_CASE(NotEqual)
432 HANDLE_OPERATORTYPENAME_CASE(Pow)
433 HANDLE_OPERATORTYPENAME_CASE(Any)
434 HANDLE_OPERATORTYPENAME_CASE(LogicalAnd)
435 HANDLE_OPERATORTYPENAME_CASE(LogicalNot)
436 HANDLE_OPERATORTYPENAME_CASE(LogicalOr)
437 HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder)
438 HANDLE_OPERATORTYPENAME_CASE(Unpack)
439 HANDLE_OPERATORTYPENAME_CASE(ZerosLike)
440 HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceLstm)
441 HANDLE_OPERATORTYPENAME_CASE(BidirectionalSequenceLstm)
442 HANDLE_OPERATORTYPENAME_CASE(BidirectionalSequenceRnn)
443 HANDLE_OPERATORTYPENAME_CASE(ResizeNearestNeighbor)
444 HANDLE_OPERATORTYPENAME_CASE(LeakyRelu)
445 HANDLE_OPERATORTYPENAME_CASE(SquaredDifference)
446 HANDLE_OPERATORTYPENAME_CASE(MirrorPad)
447 HANDLE_OPERATORTYPENAME_CASE(Unique)
448 HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceRnn)
449 HANDLE_OPERATORTYPENAME_CASE(ReverseV2)
450 HANDLE_OPERATORTYPENAME_CASE(Cos)
451 HANDLE_OPERATORTYPENAME_CASE(Where)
452 HANDLE_OPERATORTYPENAME_CASE(ReverseSequence)
453 HANDLE_OPERATORTYPENAME_CASE(MatrixDiag)
454 HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiag)
455 HANDLE_OPERATORTYPENAME_CASE(MatrixDiagV2)
456 HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiagV2)
457 HANDLE_OPERATORTYPENAME_CASE(MatrixDiagV3)
458 HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiagV3)
459 HANDLE_OPERATORTYPENAME_CASE(ScatterNd)
460 default:
461 LOG(FATAL) << "Unhandled op type";
462#undef HANDLE_OPERATORTYPENAME_CASE
463 }
464}
465
466std::string HelpfulOperatorTypeName(const Operator& op) {
467 if (op.type == OperatorType::kUnsupported) {
468 return toco::port::StringF(
469 "(Unsupported TensorFlow op: %s)",
470 static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op);
471 }
472 return OperatorTypeName(op.type);
473}
474
475bool OperatorSupportsFusedActivation(OperatorType type) {
476 switch (type) {
477 case OperatorType::kAdd:
478 case OperatorType::kAveragePool:
479 case OperatorType::kBatchNormalization:
480 case OperatorType::kConv:
481 case OperatorType::kDepthwiseConv:
482 case OperatorType::kDiv:
483 case OperatorType::kFullyConnected:
484 case OperatorType::kL2Pool:
485 case OperatorType::kMaxPool:
486 case OperatorType::kMul:
487 case OperatorType::kSub:
488 case OperatorType::kSquaredDifference:
489 return true;
490 default:
491 return false;
492 }
493}
494
495void LogSummary(int log_level, const Model& model) {
496 VLOG(log_level) << "Operators summary (" << model.operators.size()
497 << " operators):";
498 std::unordered_multiset<OperatorType> ops_by_type;
499 for (const auto& op : model.operators) {
500 ops_by_type.insert(op->type);
501 }
502 auto it = ops_by_type.begin();
503 while (it != ops_by_type.end()) {
504 int count = ops_by_type.count(*it);
505 VLOG(log_level) << " " << OperatorTypeName(*it) << ": " << count;
506 std::advance(it, count);
507 }
508}
509
510void LogArray(int log_level, const Model& model, const std::string& name) {
511 VLOG(log_level) << "Array: " << name;
512 if (!model.HasArray(name)) {
513 VLOG(log_level) << " DOES NOT EXIST";
514 return;
515 }
516 const auto& array = model.GetArray(name);
517 VLOG(log_level) << " Data type: " << ArrayDataTypeName(array.data_type);
518 VLOG(log_level) << " Final type: "
519 << ArrayDataTypeName(array.final_data_type);
520 if (array.buffer) {
521 VLOG(log_level) << " Constant Buffer";
522 }
523 if (array.alloc) {
524 VLOG(log_level) << " Transient Alloc";
525 }
526 if (array.has_shape()) {
527 const Shape& array_shape = array.shape();
528 if (array_shape.dimensions_count() == 0) {
529 VLOG(log_level) << " (Zero dimensions)";
530 } else {
531 std::string message = " Dims: ";
532 bool first = true;
533 for (const int dim : array_shape.dims()) {
534 if (!first) {
535 message += ", ";
536 }
537 first = false;
538 toco::port::AppendF(&message, "%d", dim);
539 }
540 VLOG(log_level) << message;
541 }
542 }
543 if (array.minmax) {
544 VLOG(log_level) << " MinMax: " << array.minmax->min << " .. "
545 << array.minmax->max;
546 }
547 if (array.quantization_params) {
548 VLOG(log_level) << " QuantizationParams: zero_point="
549 << static_cast<int>(array.quantization_params->zero_point)
550 << ", scale=" << array.quantization_params->scale;
551 }
552}
553
554void DumpGraphvizVideoFrame(const Model& model) {
555 namespace port = toco::port;
556
557 const auto& dump_options = *GraphVizDumpOptions::singleton();
558 if (!dump_options.dump_graphviz_video) {
559 return;
560 }
561 CHECK(!dump_options.dump_graphviz.empty());
562 // TODO(benoitjacob): the static data here means that this function
563 // is stateful, not reentrant, and effectively leaks memory till exit
564 // (since dump_hashes can only grow in size). It also means that it
565 // really only is intended to be called for a single model during the
566 // process' lifetime. So it's not great design at all. The overriding
567 // design aspect here is to make the video-dumping code as unintrusive
568 // and self-contained as possible. Eventually, we'll want to have that
569 // cleaned-up, but that will require some form of general statefulness
570 // in toco (some kind of 'tooling state' data structure) that does
571 // not exist at present, and would be premature to design here just for
572 // this new video-dumping feature.
573 static int dump_id = 0;
574 static std::unordered_set<std::size_t> dump_hashes;
575 std::string graphviz_dump;
576 DumpGraphviz(model, &graphviz_dump,
577 toco::port::StringF("VIDEO frame:%05d", dump_id));
578 std::size_t hash = std::hash<std::string>{}(graphviz_dump);
579 if (!dump_hashes.count(hash)) {
580 LOG(INFO) << "DUMPING GRAPHVIZ VIDEO FRAME: " << dump_id;
581 dump_hashes.insert(hash);
582 const auto result = port::file::SetContents(
583 port::file::JoinPath(
584 dump_options.dump_graphviz,
585 toco::port::StringF("toco_video_%05d.dot", dump_id)),
586 graphviz_dump, port::file::Defaults());
587 QCHECK(result.ok()) << result.error_message();
588 dump_id++;
589 }
590}
591
592void LogDump(int log_level, const std::string& message, const Model& model) {
593 namespace port = toco::port;
594 const auto& dump_options = *GraphVizDumpOptions::singleton();
595
596 DumpGraphvizVideoFrame(model);
597 if (!dump_options.dump_graphviz.empty()) {
598 std::string graphviz_dump;
599
600 DumpGraphviz(model, &graphviz_dump, message);
601 const auto result = port::file::SetContents(
602 port::file::JoinPath(
603 dump_options.dump_graphviz,
604 absl::StrCat("toco_", absl::StrReplaceAll(message, {{" ", "_"}}),
605 ".dot")),
606 graphviz_dump, port::file::Defaults());
607 QCHECK(result.ok()) << result.error_message();
608 }
609
610 if (!VLOG_IS_ON(log_level)) {
611 return;
612 }
613 VLOG(log_level) << "BEGIN DUMP OF TOCO MODEL (" << message << ")";
614 LogSummary(log_level, model);
615 std::unordered_set<std::string> already_printed_arrays;
616 for (const auto& op : model.operators) {
617 for (const auto& input : op->inputs) {
618 if (!already_printed_arrays.count(input)) {
619 already_printed_arrays.insert(input);
620 LogArray(log_level, model, input);
621 }
622 }
623 VLOG(log_level) << HelpfulOperatorTypeName(*op) << " :";
624 VLOG(log_level) << " " << FormatArraysList(model, op->inputs) << " -> "
625 << FormatArraysList(model, op->outputs);
626 if (op->fused_activation_function != FusedActivationFunctionType::kNone) {
627 VLOG(log_level) << " (with fused activation function)";
628 }
629 for (const auto& output : op->outputs) {
630 if (!already_printed_arrays.count(output)) {
631 already_printed_arrays.insert(output);
632 LogArray(log_level, model, output);
633 }
634 }
635 }
636 VLOG(log_level) << "END DUMP OF TOCO MODEL (" << message << ")";
637}
638
639// Note remaining raw-array extension in ProcessTensorFlowReshapeOperator().
640void ExtendShape(Shape* shape, int new_shape_size) {
641 CHECK_GE(new_shape_size, shape->dimensions_count());
642 const int size_increase = new_shape_size - shape->dimensions_count();
643 auto* shape_dims = shape->mutable_dims();
644 shape_dims->insert(shape_dims->begin(), size_increase, 1);
645}
646
647// TODO(b/62904716) Remove along with remaining uses.
648void UnextendShape(Shape* shape, int new_shape_size) {
649 CHECK_LE(new_shape_size, shape->dimensions_count());
650 const int size_reduction = shape->dimensions_count() - new_shape_size;
651 for (int i = 0; i < size_reduction; i++) {
652 CHECK_EQ(shape->dims(i), 1);
653 }
654 std::vector<int>& shape_dims = *shape->mutable_dims();
655 shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction);
656}
657
658// In general, zero-sized dimensions are disallowed, but there are exceptions,
659// e.g., if the tensor data itself represents a scalar (rank 0) shape, its
660// shape will have dimensions [0]. CheckNonEmptyShapeDimensions is more
661// strict, and is appropriate for ops and comparisons where an empty shape
662// doesn't make sense.
663template <typename Dims>
664void CheckValidShapeDimensions(const Dims& dims) {
665 if (dims.size() == 1 && dims[0] == 0) {
666 return;
667 }
668 for (const auto& dim : dims) {
669 CHECK_GE(dim, 1);
670 }
671}
672
673void CheckValidShape(const Shape& shape) {
674 CheckValidShapeDimensions(shape.dims());
675}
676
677bool IsNonEmpty(const Shape& shape) {
678 for (int i = 0; i < shape.dimensions_count(); ++i) {
679 if (shape.dims(i) < 1) return false;
680 }
681 return true;
682}
683
684void CheckNonEmptyShapeDimensions(const Shape& shape) {
685 for (int i = 0; i < shape.dimensions_count(); ++i) {
686 CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i
687 << ". shape = " << ShapeToString(shape);
688 }
689}
690
691bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) {
692 CheckNonEmptyShapeDimensions(shape0);
693 CheckNonEmptyShapeDimensions(shape1);
694
695 const Shape* longer = &shape0;
696 const Shape* shorter = &shape1;
697 if (shape1.dimensions_count() > shape0.dimensions_count()) {
698 longer = &shape1;
699 shorter = &shape0;
700 }
701
702 // Walk dimensions back to front until we run out of dimensions in the shorter
703 // shape.
704 int longer_index = longer->dimensions_count() - 1;
705 int shorter_index = shorter->dimensions_count() - 1;
706 while (shorter_index >= 0) {
707 const int d_long = longer->dims(longer_index);
708 const int d_short = shorter->dims(shorter_index);
709 // Broadcasting fails if the dimensions are different *and* neither is 1.
710 if ((d_long != d_short) && (d_long != 1) && (d_short != 1)) {
711 return false;
712 }
713 longer_index--;
714 shorter_index--;
715 }
716 return true;
717}
718
719bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) {
720 CheckNonEmptyShapeDimensions(shape0);
721 CheckNonEmptyShapeDimensions(shape1);
722
723 const Shape* longer = &shape0;
724 const Shape* shorter = &shape1;
725 if (shape1.dimensions_count() > shape0.dimensions_count()) {
726 longer = &shape1;
727 shorter = &shape0;
728 }
729
730 // Walk dimensions back to front until we run out of dimensions in the shorter
731 // shape.
732 int longer_index = longer->dimensions_count() - 1;
733 int shorter_index = shorter->dimensions_count() - 1;
734 while (shorter_index >= 0) {
735 const int d_long = longer->dims(longer_index);
736 const int d_short = shorter->dims(shorter_index);
737 // Extending fails if the dimensions are different.
738 if (d_long != d_short) {
739 return false;
740 }
741 longer_index--;
742 shorter_index--;
743 }
744
745 // The remaining dimensions in the longer shape must be 1.
746 while (longer_index >= 0) {
747 const int d_long = longer->dims(longer_index);
748 if (d_long != 1) {
749 return false;
750 }
751 longer_index--;
752 }
753
754 return true;
755}
756
757int RequiredBufferSizeForShape(const Shape& shape) {
758 CheckValidShape(shape);
759 int max_offset = 1;
760 for (const auto& dim : shape.dims()) {
761 max_offset *= dim;
762 }
763 return max_offset;
764}
765
766bool IsConstantParameterArray(const Model& model, const std::string& name) {
767 if (!model.HasArray(name)) {
768 return false;
769 }
770
771 return !!model.GetArray(name).buffer;
772}
773
774namespace {
775template <ArrayDataType A>
776bool CompareArrayBuffers(const Array& lhs_array, const Array& rhs_array) {
777 CHECK(lhs_array.data_type == rhs_array.data_type) << "Data types must match";
778 CHECK(lhs_array.buffer) << "LHS must be constant";
779 CHECK(rhs_array.buffer) << "RHS must be constant";
780 const auto& lhs_data = lhs_array.GetBuffer<A>().data;
781 const auto& rhs_data = rhs_array.GetBuffer<A>().data;
782 CHECK_EQ(lhs_data.size(), rhs_data.size())
783 << "Buffer sizes must match in element count";
784 for (int i = 0; i < lhs_data.size(); ++i) {
785 if (lhs_data[i] != rhs_data[i]) {
786 return false;
787 }
788 }
789 return true;
790}
791
792bool HaveSameMinMax(const Array& lhs_array, const Array& rhs_array) {
793 if (lhs_array.minmax || rhs_array.minmax) {
794 if (!lhs_array.minmax || !rhs_array.minmax) {
795 return false;
796 }
797 if (!(*lhs_array.minmax == *rhs_array.minmax)) {
798 return false;
799 }
800 }
801 return true;
802}
803
804bool HaveSameQuantizationParams(const Array& lhs_array,
805 const Array& rhs_array) {
806 if (lhs_array.quantization_params || rhs_array.quantization_params) {
807 if (!lhs_array.quantization_params || !rhs_array.quantization_params) {
808 return false;
809 }
810 if (!(*lhs_array.quantization_params == *rhs_array.quantization_params)) {
811 return false;
812 }
813 }
814 return true;
815}
816
817} // namespace
818
819bool CompareConstantArrays(const Array& lhs_array, const Array& rhs_array) {
820 bool attrs_equal = lhs_array.shape() == rhs_array.shape() &&
821 lhs_array.data_type == rhs_array.data_type &&
822 lhs_array.final_data_type == rhs_array.final_data_type &&
823 HaveSameMinMax(lhs_array, rhs_array) &&
824 HaveSameQuantizationParams(lhs_array, rhs_array) &&
825 lhs_array.narrow_range == rhs_array.narrow_range;
826 if (!attrs_equal) {
827 return false;
828 }
829 switch (lhs_array.data_type) {
830 case ArrayDataType::kBool:
831 return CompareArrayBuffers<ArrayDataType::kBool>(lhs_array, rhs_array);
832 case ArrayDataType::kFloat:
833 return CompareArrayBuffers<ArrayDataType::kFloat>(lhs_array, rhs_array);
834 case ArrayDataType::kInt8:
835 return CompareArrayBuffers<ArrayDataType::kInt8>(lhs_array, rhs_array);
836 case ArrayDataType::kUint8:
837 return CompareArrayBuffers<ArrayDataType::kUint8>(lhs_array, rhs_array);
838 case ArrayDataType::kInt16:
839 return CompareArrayBuffers<ArrayDataType::kInt16>(lhs_array, rhs_array);
840 case ArrayDataType::kUint16:
841 return CompareArrayBuffers<ArrayDataType::kUint16>(lhs_array, rhs_array);
842 case ArrayDataType::kInt32:
843 return CompareArrayBuffers<ArrayDataType::kInt32>(lhs_array, rhs_array);
844 case ArrayDataType::kUint32:
845 return CompareArrayBuffers<ArrayDataType::kUint32>(lhs_array, rhs_array);
846 case ArrayDataType::kInt64:
847 return CompareArrayBuffers<ArrayDataType::kInt64>(lhs_array, rhs_array);
848 case ArrayDataType::kUint64:
849 return CompareArrayBuffers<ArrayDataType::kUint64>(lhs_array, rhs_array);
850 case ArrayDataType::kString:
851 return CompareArrayBuffers<ArrayDataType::kString>(lhs_array, rhs_array);
852 case ArrayDataType::kComplex64:
853 return CompareArrayBuffers<ArrayDataType::kComplex64>(lhs_array,
854 rhs_array);
855 default:
856 LOG(FATAL) << "Unsupported data type: "
857 << ArrayDataTypeName(lhs_array.data_type);
858 return false;
859 }
860}
861
862namespace {
863// Take an array name, which may be something like "name:3_5" and make it
864// acceptable as a TF node name, say "name_3_5";
865std::string SanitizeNameForTFNode(const std::string& array_name) {
866 auto node_name = array_name;
867 std::replace(node_name.begin(), node_name.end(), ':', '_');
868 return node_name;
869}
870
871void CheckInputArraysAreNotOutputArrays(const ModelFlags& model_flags) {
872 for (const auto& input_array : model_flags.input_arrays()) {
873 for (const std::string& output_array : model_flags.output_arrays()) {
874 QCHECK_NE(input_array.name(), output_array)
875 << "The array " << output_array
876 << " is listed in both --input_arrays and --output_arrays.";
877 }
878 }
879}
880
881bool IsAsciiPrintable(const std::string& name) {
882 for (char c : name) {
883 if (!absl::ascii_isprint(c)) {
884 return false;
885 }
886 }
887 return true;
888}
889
890std::string DumpAscii(const std::string& name) {
891 std::string result;
892 port::AppendF(&result, "ASCII | Hex\n");
893 port::AppendF(&result, "------+----\n");
894 for (char c : name) {
895 if (absl::ascii_isprint(c)) {
896 port::AppendF(&result, "%c | %x\n", c, c);
897 } else {
898 port::AppendF(&result, " | %x Not ASCII printable!\n", c);
899 }
900 }
901 return result;
902}
903
904void CheckNonAsciiIOArrays(const ModelFlags& model_flags) {
905 if (model_flags.allow_nonascii_arrays()) {
906 return;
907 }
908 for (const auto& input_array : model_flags.input_arrays()) {
909 QCHECK(IsAsciiPrintable(input_array.name()))
910 << "Non-ASCII-printable character found in --input_arrays: "
911 << input_array.name()
912 << ". Pass --allow_nonascii_arrays to allow that. "
913 << "Here is a dump of the string:\n\n"
914 << DumpAscii(input_array.name());
915 }
916 for (const std::string& output_array : model_flags.output_arrays()) {
917 QCHECK(IsAsciiPrintable(output_array))
918 << "Non-ASCII-printable character found in --output_arrays: "
919 << output_array << ". Pass --allow_nonascii_arrays to allow that. "
920 << "Here is a dump of the string:\n\n"
921 << DumpAscii(output_array);
922 }
923}
924
925void CheckNonExistentIOArrays(const Model& model) {
926 // "non-existent" is interpreted in the stronger sense of
927 // "not actually produced/consumed by an op".
928 // Rationale: we have to artificially fix up TensorFlow graphs by creating
929 // any array that it refers to, so just checking that arrays exist isn't
930 // sufficient. The real invariant here is whether arrays are produced/consumed
931 // by something.
932 if (model.flags.allow_nonexistent_arrays()) {
933 return;
934 }
935 static constexpr char general_comment[] =
936 "Is it a typo? This should not happen. If you trigger this error "
937 "please send a bug report (with code to reproduce this error), to the "
938 "TensorFlow Lite team.";
939 for (const std::string& output_array : model.flags.output_arrays()) {
940 if (IsConstantParameterArray(model, output_array)) {
941 continue; // It is OK to request that a constant be an output.
942 }
943 QCHECK(GetOpWithOutput(model, output_array))
944 << "Specified output array \"" << output_array
945 << "\" is not produced by any op in this graph. " << general_comment;
946 }
947 for (const auto& rnn_state : model.flags.rnn_states()) {
948 if (!rnn_state.discardable()) {
949 // Check that all RNN states are consumed
950 QCHECK(GetOpWithInput(model, rnn_state.state_array()))
951 << "Specified RNN state \"" << rnn_state.state_array()
952 << "\" is not consumed by any op in this graph. " << general_comment;
953 // Check that all RNN back-edge source arrays are produced
954 QCHECK(GetOpWithOutput(model, rnn_state.back_edge_source_array()))
955 << "Specified RNN back-edge source array \""
956 << rnn_state.back_edge_source_array()
957 << "\" is not produced by any op in this graph. " << general_comment;
958 }
959 }
960}
961
962} // namespace
963
964void CheckNoMissingArray(const Model& model) {
965 for (const auto& op : model.operators) {
966 for (const auto& input : op->inputs) {
967 CHECK(model.HasArray(input) || model.optional_arrays.count(input))
968 << "Input: " << input << " missing for op: " << op->outputs[0] << ".";
969 }
970 for (const auto& output : op->outputs) {
971 CHECK(model.HasArray(output)) << "Output: " << output << " missing.";
972 }
973 }
974 CheckNonExistentIOArrays(model);
975}
976
977void FixNoMissingArray(Model* model) {
978 for (const auto& op : model->operators) {
979 for (const auto& input : op->inputs) {
980 if (!model->HasArray(input) && !model->IsOptionalArray(input)) {
981 model->GetOrCreateArray(input);
982 }
983 }
984 for (const auto& output : op->outputs) {
985 if (!model->HasArray(output) && !model->IsOptionalArray(output)) {
986 model->GetOrCreateArray(output);
987 }
988 }
989 }
990 if (model->flags.allow_nonexistent_arrays()) {
991 for (const std::string& output_array : model->flags.output_arrays()) {
992 model->GetOrCreateArray(output_array);
993 }
994 for (const auto& rnn_state : model->flags.rnn_states()) {
995 model->GetOrCreateArray(rnn_state.state_array());
996 model->GetOrCreateArray(rnn_state.back_edge_source_array());
997 }
998 }
999}
1000
1001void CheckNoOrphanedArray(const Model& model) {
1002 std::unordered_set<std::string> arrays_without_known_use;
1003 for (const auto& array : model.GetArrayMap()) {
1004 if (IsDiscardableArray(model, array.first)) {
1005 arrays_without_known_use.insert(array.first);
1006 }
1007 }
1008 for (const auto& op : model.operators) {
1009 for (const auto& input : op->inputs) {
1010 arrays_without_known_use.erase(input);
1011 }
1012 for (const auto& output : op->outputs) {
1013 arrays_without_known_use.erase(output);
1014 }
1015 }
1016 for (const auto& rnn_state : model.flags.rnn_states()) {
1017 arrays_without_known_use.erase(rnn_state.state_array());
1018 arrays_without_known_use.erase(rnn_state.back_edge_source_array());
1019 }
1020 if (!arrays_without_known_use.empty()) {
1021 for (const auto& array : arrays_without_known_use) {
1022 LOG(INFO) << "Error: Orphaned array: " << array;
1023 }
1024 }
1025 CHECK(arrays_without_known_use.empty());
1026}
1027
1028void FixNoOrphanedArray(Model* model) {
1029 std::unordered_set<std::string> arrays_without_known_use;
1030 for (const auto& array : model->GetArrayMap()) {
1031 arrays_without_known_use.insert(array.first);
1032 }
1033 for (const auto& op : model->operators) {
1034 for (const auto& input : op->inputs) {
1035 arrays_without_known_use.erase(input);
1036 }
1037 for (const auto& output : op->outputs) {
1038 arrays_without_known_use.erase(output);
1039 }
1040 }
1041 for (const auto& rnn_state : model->flags.rnn_states()) {
1042 arrays_without_known_use.erase(rnn_state.state_array());
1043 arrays_without_known_use.erase(rnn_state.back_edge_source_array());
1044 }
1045 for (const auto& array : arrays_without_known_use) {
1046 if (IsDiscardableArray(*model, array)) {
1047 model->EraseArray(array);
1048 }
1049 }
1050}
1051
1052// Apply checks to arrays individually (for-each fashion).
1053//
1054// Check consistency of array fields, check name.
1055void CheckEachArray(const Model& model) {
1056 for (const auto& array_entry : model.GetArrayMap()) {
1057 const auto& array = array_entry.second;
1058 // It's OK to have a buffer or an alloc, but not both.
1059 // (Since allocs are for transient arrays without a buffer).
1060 CHECK(!array->buffer || !array->alloc) << "Tensor: " << array_entry.first;
1061 if (array->buffer) {
1062 // If there is a buffer, its type should be consistent with data_type.
1063 CHECK(array->buffer->type == array->data_type)
1064 << "Tensor: " << array_entry.first;
1065 // The presence of a fixed buffer should imply the presence of a fixed
1066 // shape.
1067 CHECK(array->has_shape()) << array_entry.first;
1068 // Constant buffer should has a valid shape.
1069 CheckValidShape(array->shape());
1070 // The shape flat-size should agree with the buffer length.
1071 CHECK_EQ(array->buffer->Length(),
1072 RequiredBufferSizeForShape(array->shape()))
1073 << "Tensor: " << array_entry.first;
1074 }
1075
1076 // Check name. Either "name_with_suffix_8", "name_with_port:3", but not
1077 // "name_with_both:3_8".
1078 const std::string& name = array_entry.first;
1079 auto colon_pos = name.find_first_of(':');
1080 if (colon_pos != std::string::npos) {
1081 CHECK_EQ(name.substr(colon_pos + 1).find_first_not_of("0123456789"),
1082 std::string::npos)
1083 << "Array '" << name << "' has non-digit characters after colon.";
1084 }
1085 CHECK_GT(colon_pos, 0) << "Array '" << name
1086 << "' must not start with a colon.";
1087 }
1088}
1089
1090void CheckOperatorOrdering(const Model& model) {
1091 std::unordered_set<std::string> arrays_behind_us;
1092 for (const auto& array_entry : model.GetArrayMap()) {
1093 if (!GetOpWithOutput(model, array_entry.first)) {
1094 arrays_behind_us.insert(array_entry.first);
1095 }
1096 }
1097 arrays_behind_us.insert(model.optional_arrays.begin(),
1098 model.optional_arrays.end());
1099 for (const auto& op : model.operators) {
1100 for (const auto& input : op->inputs) {
1101 if (!IsConstantParameterArray(model, input)) {
1102 CHECK(arrays_behind_us.count(input));
1103 }
1104 }
1105 for (const auto& output : op->outputs) {
1106 CHECK(!arrays_behind_us.count(output));
1107 arrays_behind_us.insert(output);
1108 }
1109 }
1110 for (const std::string& output_array : model.flags.output_arrays()) {
1111 CHECK(arrays_behind_us.count(output_array));
1112 }
1113}
1114
1115void FixOperatorOrdering(Model* model) {
1116 std::unordered_set<std::string> arrays_behind_us;
1117 for (const auto& array_entry : model->GetArrayMap()) {
1118 if (!GetOpWithOutput(*model, array_entry.first)) {
1119 arrays_behind_us.insert(array_entry.first);
1120 }
1121 }
1122 arrays_behind_us.insert(model->optional_arrays.begin(),
1123 model->optional_arrays.end());
1124 std::vector<std::unique_ptr<Operator>> old_operators;
1125 std::swap(old_operators, model->operators);
1126 std::set<std::size_t> remaining;
1127 for (std::size_t i = 0; i < old_operators.size(); i++) {
1128 remaining.insert(i);
1129 }
1130 std::unordered_map<std::string, std::string> reason_why_leftover;
1131 while (true) {
1132 bool inserted_something = false;
1133 for (const auto& i : remaining) {
1134 bool can_insert = true;
1135 auto& op = old_operators[i];
1136 CHECK(op);
1137 for (const auto& input : op->inputs) {
1138 if (!IsConstantParameterArray(*model, input) &&
1139 !arrays_behind_us.count(input)) {
1140 for (const std::string& output : op->outputs) {
1141 reason_why_leftover[output] = input;
1142 }
1143 can_insert = false;
1144 break;
1145 }
1146 }
1147 if (can_insert) {
1148 model->operators.emplace_back(nullptr);
1149 for (const auto& output : op->outputs) {
1150 arrays_behind_us.insert(output);
1151 }
1152 std::swap(op, model->operators.back());
1153 remaining.erase(i);
1154 inserted_something = true;
1155 break;
1156 }
1157 }
1158 if (!inserted_something) {
1159 break;
1160 }
1161 }
1162 if (!remaining.empty()) {
1163 LOG(ERROR)
1164 << "No viable ordering of operators was found. "
1165 << "Here is a 'backtrace' of at least one part of the graph that is "
1166 << "problematic. It starts with the first operator that has as "
1167 << "problematic input array, and then walks back the graph to "
1168 << "the operator that produced that input array, etc., until we find "
1169 << "the root cause:";
1170 LOG(ERROR) << "BEGIN TRACE OF OPERATOR WITH BAD INPUT";
1171 LOG(ERROR) << "Here is the first-encountered operator with a bad input: ";
1172 const Operator* bad_op = old_operators[*remaining.begin()].get();
1173 std::unordered_set<std::string> bad_inputs_already_traced;
1174 // The following while(true) loop should always end with a LOG(FATAL).
1175 while (true) {
1176 LOG(ERROR) << HelpfulOperatorTypeName(*bad_op) << " : "
1177 << FormatArraysList(*model, bad_op->inputs) << " -> "
1178 << FormatArraysList(*model, bad_op->outputs);
1179 bool found_bad_output = false;
1180 std::string bad_output;
1181 for (const std::string& output : bad_op->outputs) {
1182 if (reason_why_leftover.count(output)) {
1183 found_bad_output = true;
1184 bad_output = output;
1185 break;
1186 }
1187 }
1188 CHECK(found_bad_output);
1189 const std::string& bad_input = reason_why_leftover[bad_output];
1190 LOG(ERROR) << "The bad input here is: " << bad_input;
1191 if (bad_inputs_already_traced.count(bad_input)) {
1192 LOG(FATAL)
1193 << "Cycle found! We already encountered that "
1194 << "input array, " << bad_input << ", earlier in the "
1195 << "above trace! We expect graphs to be acyclic, even "
1196 << "RNNs. Let us know if some graph actually needs to have "
1197 << "cycles, but first, please check if it really is "
1198 << "an *inference* graph. *Training* graphs are out-of-scope "
1199 << "for toco.";
1200 }
1201 bad_inputs_already_traced.insert(bad_input);
1202 bad_op = nullptr;
1203 for (const auto& i : remaining) {
1204 const Operator* op = old_operators[i].get();
1205 for (const std::string& output : op->outputs) {
1206 if (bad_input == output) {
1207 bad_op = op;
1208 break;
1209 }
1210 }
1211 if (bad_op) {
1212 break;
1213 }
1214 }
1215 if (!bad_op) {
1216 LOG(ERROR) << "And that's the root cause: "
1217 << "that array, " << bad_input << ", isn't produced by any "
1218 << "operator, or provided in any other way.";
1219 LOG(ERROR) << "END TRACE OF OPERATOR WITH BAD INPUT";
1220 LOG(FATAL) << "(The above was a multi-line fatal error)";
1221 }
1222 LOG(ERROR) << "And that array is the output of the following operator:";
1223 }
1224 }
1225 CHECK(remaining.empty())
1226 << "Should never get here! In case of bad graph, "
1227 << "the above code should have generated a FATAL error already!";
1228}
1229
1230void CheckInvariants(const Model& model) {
1231 CheckInputArraysAreNotOutputArrays(model.flags);
1232 CheckNonAsciiIOArrays(model.flags);
1233 CheckNoMissingArray(model);
1234 CheckNoOrphanedArray(model);
1235 CheckEachArray(model);
1236 CheckOperatorOrdering(model);
1237}
1238
1239void CheckCountInRange(const ::toco::ModelFlags::ModelCheck& model_check,
1240 const int count, const std::string& count_description) {
1241 if (model_check.count_min() >= 0) {
1242 CHECK_GE(count, model_check.count_min())
1243 << "Mismatch in " << count_description << ": count was " << count
1244 << ", but the specified "
1245 << (model_check.count_max() > model_check.count_min() ? "minimum"
1246 : "value")
1247 << " was " << model_check.count_min() << ".";
1248 }
1249 if (model_check.count_max() > model_check.count_min()) {
1250 CHECK_LE(count, model_check.count_max())
1251 << "Mismatch in " << count_description << ": count was " << count
1252 << ", but the specified maximum was " << model_check.count_max() << ".";
1253 }
1254}
1255
1256void CheckModelCounts(const Model& model) {
1257 std::unordered_multiset<OperatorType> ops_by_type;
1258 std::unordered_map<std::string, OperatorType> op_type_by_name;
1259 if (model.flags.model_checks_size() == 0) {
1260 return;
1261 }
1262
1263 for (const auto& op : model.operators) {
1264 ops_by_type.insert(op->type);
1265 op_type_by_name[OperatorTypeName(op->type)] = op->type;
1266 }
1267 for (const auto& model_check : model.flags.model_checks()) {
1268 std::string count_type = model_check.count_type();
1269 if (count_type == "None") {
1270 continue;
1271 } else if (count_type == "Arrays") {
1272 CheckCountInRange(model_check, model.GetArrayMap().size(),
1273 "count of arrays");
1274 } else if (count_type == "Total") {
1275 CheckCountInRange(model_check, model.operators.size(),
1276 "count of all operator instances");
1277 } else {
1278 // The check type is not itself checked against the set of valid
1279 // operators, mainly because the enum set cannot be iterated in C++.
1280 const int found_count =
1281 op_type_by_name.count(count_type) > 0
1282 ? ops_by_type.count(op_type_by_name[count_type])
1283 : 0;
1284 CheckCountInRange(model_check, found_count,
1285 "count of instances of " + count_type + " operator");
1286 }
1287 }
1288}
1289
1290void FixEdgeArrays(Model* model) {
1291 for (const std::string& output_array_name : model->flags.output_arrays()) {
1292 if (!GetOpWithOutput(*model, output_array_name)) {
1293 // Output has no operator producing it. Change that by inserting a copy.
1294 LOG(WARNING) << "Fixing constant output array " << output_array_name
1295 << " by inserting a copy. This is not optimal.";
1296 std::string intermediate_array_name =
1297 AvailableArrayName(*model, output_array_name + "_copy");
1298 CloneArray(model, output_array_name, intermediate_array_name);
1299 InsertCopyOperator(model, intermediate_array_name, output_array_name);
1300 }
1301 }
1302}
1303
1304void DedupeConstantArrays(Model* model, size_t min_size) {
1305 // Walk all 0..N and compare with the remaining n+1..N.
1306 // This lets us avoid N^2 comparisons and erase duplicate arrays while
1307 // iterating.
1308 const auto& array_map = model->GetArrayMap();
1309 for (auto lhs_array_it = array_map.begin(); lhs_array_it != array_map.end();
1310 ++lhs_array_it) {
1311 const auto& lhs_array_name = lhs_array_it->first;
1312 const auto& lhs_array = *lhs_array_it->second;
1313 if (!IsConstantParameterArray(*model, lhs_array_name)) {
1314 // Not a constant array; skip.
1315 continue;
1316 }
1317 ArrayDataType final_data_type =
1318 lhs_array.final_data_type != ArrayDataType::kNone
1319 ? lhs_array.final_data_type
1320 : lhs_array.data_type;
1321 // Ignore small arrays, don't check string arrays because it is not possible
1322 // to estimate its size.
1323 if (final_data_type != ArrayDataType::kString) {
1324 size_t array_byte_size =
1325 lhs_array.buffer->Length() * ElementSize(final_data_type);
1326 if (array_byte_size < min_size) {
1327 // Too small; skip.
1328 continue;
1329 }
1330 }
1331
1332 auto next_lhs_array_it = lhs_array_it;
1333 ++next_lhs_array_it;
1334 for (auto rhs_array_it = next_lhs_array_it;
1335 rhs_array_it != array_map.end();) {
1336 const auto& rhs_array_name = rhs_array_it->first;
1337 const auto& rhs_array = *rhs_array_it->second;
1338 ++rhs_array_it;
1339 if (!IsConstantParameterArray(*model, rhs_array_name)) {
1340 // Not a constant array; skip.
1341 continue;
1342 }
1343 if (!IsDiscardableArray(*model, rhs_array_name)) {
1344 // Can't remove the array as it's not discardable (such as an IO edge).
1345 continue;
1346 }
1347 if (!CompareConstantArrays(lhs_array, rhs_array)) {
1348 // Arrays aren't equal; skip.
1349 continue;
1350 }
1351
1352 // Arrays can be deduped!
1353 VLOG(1) << "Deduplicating arrays; using " << lhs_array_name
1354 << " in place of " << rhs_array_name;
1355 ReplaceArrayUsage(model, rhs_array_name, lhs_array_name);
1356 // Note: rhs_array_it above is already incremented so this is safe.
1357 model->EraseArray(rhs_array_name);
1358 }
1359 }
1360}
1361
1362namespace {
1363void CopyArrayAttribs(const Array& source_array, Array* target_array) {
1364 target_array->data_type = source_array.data_type;
1365 target_array->final_data_type = source_array.final_data_type;
1366 if (source_array.has_shape()) {
1367 target_array->copy_shape(source_array.shape());
1368 }
1369
1370 if (source_array.minmax) {
1371 target_array->GetOrCreateMinMax() = source_array.GetMinMax();
1372 } else {
1373 target_array->minmax.reset();
1374 }
1375
1376 if (source_array.quantization_params) {
1377 target_array->GetOrCreateQuantizationParams() =
1378 source_array.GetQuantizationParams();
1379 } else {
1380 target_array->quantization_params.reset();
1381 }
1382}
1383} // namespace
1384
1385void InsertCopyOperator(Model* model, const std::string& source_array_name,
1386 const std::string& target_array_name) {
1387 // Reshape to the same size. This should be a no-op.
1388 const Array& source_array = model->GetArray(source_array_name);
1389 std::vector<int> shape = source_array.shape().dims();
1390
1391 // Drop constant data from the target array as the copy will be done at
1392 // runtime.
1393 Array& target_array = model->GetOrCreateArray(target_array_name);
1394 target_array.buffer.reset();
1395 CopyArrayAttribs(source_array, &target_array);
1396
1397 // Insert copy operator.
1398 auto* copy_op = new TensorFlowReshapeOperator;
1399 copy_op->inputs = {
1400 source_array_name,
1401 CreateInt32Array(
1402 model, AvailableArrayName(*model, target_array_name + "_copy_shape"),
1403 shape)};
1404 copy_op->outputs = {target_array_name};
1405 if (target_array.has_shape()) {
1406 copy_op->shape = target_array.shape().dims();
1407 }
1408 model->operators.emplace_back(copy_op);
1409}
1410
1411void CloneArray(Model* model, const std::string& source_array_name,
1412 const std::string& target_array_name) {
1413 CHECK(!model->HasArray(target_array_name));
1414 const Array& source_array = model->GetArray(source_array_name);
1415 Array& target_array = model->GetOrCreateArray(target_array_name);
1416 CopyArrayAttribs(source_array, &target_array);
1417
1418 if (!source_array.buffer) {
1419 return;
1420 }
1421
1422 switch (source_array.data_type) {
1423 case ArrayDataType::kBool:
1424 CopyArrayBuffer<ArrayDataType::kBool>(source_array, &target_array);
1425 break;
1426 case ArrayDataType::kFloat:
1427 CopyArrayBuffer<ArrayDataType::kFloat>(source_array, &target_array);
1428 break;
1429 case ArrayDataType::kInt8:
1430 CopyArrayBuffer<ArrayDataType::kInt8>(source_array, &target_array);
1431 break;
1432 case ArrayDataType::kUint8:
1433 CopyArrayBuffer<ArrayDataType::kUint8>(source_array, &target_array);
1434 break;
1435 case ArrayDataType::kInt16:
1436 CopyArrayBuffer<ArrayDataType::kInt16>(source_array, &target_array);
1437 break;
1438 case ArrayDataType::kUint16:
1439 CopyArrayBuffer<ArrayDataType::kUint16>(source_array, &target_array);
1440 break;
1441 case ArrayDataType::kInt32:
1442 CopyArrayBuffer<ArrayDataType::kInt32>(source_array, &target_array);
1443 break;
1444 case ArrayDataType::kUint32:
1445 CopyArrayBuffer<ArrayDataType::kUint32>(source_array, &target_array);
1446 break;
1447 case ArrayDataType::kInt64:
1448 CopyArrayBuffer<ArrayDataType::kInt64>(source_array, &target_array);
1449 break;
1450 case ArrayDataType::kUint64:
1451 CopyArrayBuffer<ArrayDataType::kUint64>(source_array, &target_array);
1452 break;
1453 case ArrayDataType::kString:
1454 CopyArrayBuffer<ArrayDataType::kString>(source_array, &target_array);
1455 break;
1456 case ArrayDataType::kComplex64:
1457 CopyArrayBuffer<ArrayDataType::kComplex64>(source_array, &target_array);
1458 break;
1459 default:
1460 LOG(FATAL) << "Unsupported data type: "
1461 << ArrayDataTypeName(source_array.data_type);
1462 return;
1463 }
1464}
1465
1466void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
1467 std::vector<int>* out_dims) {
1468 CHECK(out_dims->empty());
1469 if (num_dims == 0) {
1470 return;
1471 } else if (num_dims == 1) {
1472 CHECK_EQ(batch, 1);
1473 *out_dims = {depth};
1474 } else if (num_dims == 2) {
1475 *out_dims = {batch, depth};
1476 } else if (num_dims == 3) {
1477 CHECK_EQ(batch, 1);
1478 *out_dims = {height, width, depth};
1479 } else if (num_dims == 4) {
1480 *out_dims = {batch, height, width, depth};
1481 } else {
1482 LOG(FATAL) << "Should not get here: " << num_dims;
1483 }
1484}
1485
1486void CreateOrCheckRnnStateArray(const std::string& name, int size,
1487 int state_num_dims, Model* model) {
1488 int batch = 1;
1489 int num_dims = -1;
1490 if (state_num_dims > 0) {
1491 num_dims = state_num_dims;
1492 } else {
1493 // state_num_dims is not given. We will infer it from an input tensor.
1494 for (const auto& input_array : model->flags.input_arrays()) {
1495 // Pick 'num_dims' and 'batch' from the first input_arrays, unless we find
1496 // a better match by name.
1497 if (input_array.name() == name || num_dims == -1) {
1498 num_dims = input_array.shape().dims_size();
1499 if (num_dims > 0) {
1500 batch = input_array.shape().dims(0);
1501 }
1502 }
1503 }
1504 }
1505 Array& array = model->GetOrCreateArray(name);
1506 if (array.has_shape()) {
1507 num_dims = array.shape().dimensions_count();
1508 }
1509 if (!array.has_shape() && num_dims >= 0) {
1510 Shape* shape = array.mutable_shape();
1511 std::vector<int> dims;
1512 MakeArrayDims(num_dims, batch, 1, 1, size, &dims);
1513 *shape->mutable_dims() = dims;
1514 }
1515}
1516
1517void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
1518 // Merge info about input_arrays from model_flags into model->flags
1519 for (const auto& specified_input_array : model_flags.input_arrays()) {
1520 toco::InputArray* dst_input_array = nullptr;
1521 for (int i = 0; i < model->flags.input_arrays_size(); i++) {
1522 toco::InputArray* candidate_dst_input_array =
1523 model->flags.mutable_input_arrays(i);
1524 if (candidate_dst_input_array->name() == specified_input_array.name()) {
1525 // specified_input_array from model_flags maps to dst_input_array
1526 // in model->flags
1527 dst_input_array = candidate_dst_input_array;
1528 break;
1529 }
1530 }
1531 if (!dst_input_array) {
1532 // Specified_input_array from model_flags is not found in model->flags.
1533 // Match a name-less specified input array when there can be no ambiguity
1534 // as there is only 1 input array.
1535 if (model->flags.input_arrays_size() == 1 &&
1536 model_flags.input_arrays_size() == 1 &&
1537 !specified_input_array.has_name()) {
1538 dst_input_array = model->flags.mutable_input_arrays(0);
1539 }
1540 }
1541 if (!dst_input_array) {
1542 // Still no match, so create a new input array to copy
1543 // specified_input_array into.
1544 dst_input_array = model->flags.add_input_arrays();
1545 dst_input_array->set_name(specified_input_array.name());
1546 }
1547
1548#define RESOLVE_MODEL_FLAG(field_name) \
1549 if (specified_input_array.has_##field_name()) { \
1550 if (dst_input_array->has_##field_name()) { \
1551 QCHECK_EQ(dst_input_array->field_name(), \
1552 specified_input_array.field_name()) \
1553 << "For input array '" << dst_input_array->name() << "', " \
1554 << "specified " #field_name " flag with value: " \
1555 << specified_input_array.field_name() \
1556 << " does not agree with already defined " #field_name \
1557 " of this model, with value: " \
1558 << specified_input_array.field_name(); \
1559 } else { \
1560 dst_input_array->set_##field_name(specified_input_array.field_name()); \
1561 } \
1562 }
1563 RESOLVE_MODEL_FLAG(std_value);
1564 RESOLVE_MODEL_FLAG(mean_value);
1565#undef RESOLVE_MODEL_FLAG
1566
1567 if (specified_input_array.has_shape()) {
1568 if (dst_input_array->has_shape()) {
1569 QCHECK_EQ(specified_input_array.shape().dims_size(),
1570 dst_input_array->shape().dims_size())
1571 << "For input array '" << specified_input_array.name() << "', "
1572 << "size of specified input shape flag with size: "
1573 << specified_input_array.shape().dims_size()
1574 << " does not agree with already defined input shape"
1575 " of this model, with size: "
1576 << dst_input_array->shape().dims_size();
1577 // We treat the first dimension as a special case, since it is often
1578 // a batch size and the input_shape flag is effectively overriding
1579 // the model.
1580 for (int i = 1; i < specified_input_array.shape().dims_size(); i++) {
1581 QCHECK_EQ(specified_input_array.shape().dims(i),
1582 dst_input_array->shape().dims(i))
1583 << "At dimension number " << i << " of input array "
1584 << specified_input_array.name() << ", the specified shape's "
1585 << "dimension flag with dimension: "
1586 << specified_input_array.shape().dims(i)
1587 << " does not agree with already defined shape"
1588 << " of this model, with dimension: "
1589 << dst_input_array->shape().dims(i);
1590 }
1591 } else {
1592 *dst_input_array->mutable_shape() = specified_input_array.shape();
1593 }
1594 }
1595
1596 if (specified_input_array.has_data_type()) {
1597 QCHECK(!dst_input_array->has_data_type());
1598 dst_input_array->set_data_type(specified_input_array.data_type());
1599 }
1600 }
1601
1602 if (model_flags.output_arrays_size() > 0) {
1603 model->flags.mutable_output_arrays()->CopyFrom(model_flags.output_arrays());
1604 }
1605
1606#define RESOLVE_MODEL_FLAG(name) \
1607 if (model_flags.has_##name()) { \
1608 if (model->flags.has_##name()) { \
1609 QCHECK_EQ(model_flags.name(), model->flags.name()) \
1610 << "Specified " #name " flag with value: " << model_flags.name() \
1611 << " does not agree with already defined " #name \
1612 " of this model, with value: " \
1613 << model->flags.name(); \
1614 } else { \
1615 model->flags.set_##name(model_flags.name()); \
1616 } \
1617 }
1618
1619 RESOLVE_MODEL_FLAG(variable_batch)
1620
1621#undef RESOLVE_MODEL_FLAG
1622
1623 if (!model_flags.rnn_states().empty()) {
1624 model->flags.mutable_rnn_states()->CopyFrom(model_flags.rnn_states());
1625 }
1626
1627 if (model->flags.model_checks_size() == 0) {
1628 model->flags.mutable_model_checks()->CopyFrom(model_flags.model_checks());
1629 }
1630
1631 QCHECK_GT(model->flags.output_arrays_size(), 0)
1632 << "This model does not define output arrays, so a "
1633 "--output_arrays flag must be given on the command-line.";
1634
1635 for (auto& input_array_proto : *model->flags.mutable_input_arrays()) {
1636 auto& input_array = model->GetOrCreateArray(input_array_proto.name());
1637 if (input_array_proto.has_data_type()) {
1638 const ArrayDataType specified_type =
1639 ConvertIODataTypeToArrayDataType(input_array_proto.data_type());
1640 QCHECK(specified_type != ArrayDataType::kNone);
1641 if (input_array.data_type != ArrayDataType::kNone) {
1642 QCHECK(specified_type == input_array.data_type)
1643 << "For input array " << input_array_proto.name()
1644 << " the specified input data type "
1645 << IODataType_Name(input_array_proto.data_type())
1646 << " conflicts with the existing type.";
1647 }
1648 input_array.data_type = specified_type;
1649 }
1650
1651 if (input_array.data_type == ArrayDataType::kNone) {
1652 // We start out with a float input array;
1653 // that may get replaced by a uint8 array later, by
1654 // MakeInitialDequantizeOp.
1655 input_array.data_type = ArrayDataType::kFloat;
1656 }
1657
1658 // Compare/merge the model->flags describing the input_shape with
1659 // the actual input array's shape.
1660 if (!input_array.has_shape()) {
1661 if (input_array_proto.has_shape()) {
1662 auto& input_array_dims = *input_array.mutable_shape()->mutable_dims();
1663 CheckValidShapeDimensions(input_array_proto.shape().dims());
1664 for (const auto& dim : input_array_proto.shape().dims()) {
1665 input_array_dims.push_back(dim);
1666 }
1667 }
1668 } else {
1669 if (input_array_proto.has_shape()) {
1670 // If an input shape was specified on the flags ensure that it matches
1671 // the actual shape in the model.
1672 const auto& input_array_dims =
1673 *input_array.mutable_shape()->mutable_dims();
1674 CHECK_EQ(input_array_dims.size(),
1675 input_array_proto.shape().dims_size());
1676 for (int i = 0; i < input_array_dims.size(); i++) {
1677 CHECK_EQ(input_array_dims[i], input_array_proto.shape().dims(i));
1678 }
1679 } else {
1680 for (int i = 0; i < input_array.shape().dimensions_count(); i++) {
1681 input_array_proto.mutable_shape()->add_dims(
1682 input_array.shape().dims(i));
1683 }
1684 }
1685 }
1686
1687 const float mean_value = input_array_proto.mean_value();
1688 const float std_value = input_array_proto.std_value();
1689 MinMax input_minmax;
1690 float qmin = 0, qmax = 255;
1691 if (input_array.data_type == ArrayDataType::kInt16) {
1692 qmin = -32768;
1693 qmax = 32767;
1694 }
1695 input_minmax.min = (qmin - mean_value) / std_value;
1696 input_minmax.max = (qmax - mean_value) / std_value;
1697 if (!input_array.minmax) {
1698 input_array.GetOrCreateMinMax() = input_minmax;
1699 }
1700 }
1701
1702 // Creation of the RNN state arrays
1703 for (const auto& rnn_state : model->flags.rnn_states()) {
1704 CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(),
1705 rnn_state.num_dims(), model);
1706 }
1707
1708 model->flags.set_change_concat_input_ranges(
1709 model_flags.change_concat_input_ranges());
1710 model->flags.set_allow_nonascii_arrays(model_flags.allow_nonascii_arrays());
1711 model->flags.set_allow_nonexistent_arrays(
1712 model_flags.allow_nonexistent_arrays());
1713
1714 CHECK(!model->flags.has_arrays_extra_info());
1715 *model->flags.mutable_arrays_extra_info() = model_flags.arrays_extra_info();
1716}
1717
1718void CheckIsReadyForQuantization(const Model& model) {
1719 for (const auto& op : model.operators) {
1720 for (const auto& input : op->inputs) {
1721 const auto& input_array = model.GetArray(input);
1722 if (input_array.data_type != ArrayDataType::kFloat) {
1723 // The array is not floats, no quantization needed.
1724 continue;
1725 }
1726 if (input_array.minmax) {
1727 // The array has minmax, we're good.
1728 continue;
1729 }
1730 if (input_array.buffer) {
1731 // The array has a constant buffer, so we can
1732 // fall back to computing the minmax from actual array entries
1733 // (with a WARNING about possible accuracy implications).
1734 continue;
1735 }
1736 LOG(FATAL)
1737 << "Array " << input << ", which is an input to the "
1738 << HelpfulOperatorTypeName(*op) << " operator producing the output "
1739 << "array " << op->outputs[0] << ", is lacking min/max data, "
1740 << "which is necessary for quantization. If accuracy matters, either "
1741 << "target a non-quantized output format, or run quantized training "
1742 << "with your model from a floating point checkpoint to change the "
1743 << "input graph to contain min/max information. If you don't care "
1744 << "about accuracy, you can pass --default_ranges_min= and "
1745 << "--default_ranges_max= for easy experimentation.";
1746 }
1747 }
1748}
1749
1750int ElementSize(ArrayDataType data_type) {
1751 switch (data_type) {
1752 case ArrayDataType::kBool:
1753 return sizeof(bool);
1754 case ArrayDataType::kFloat:
1755 return 4;
1756 case ArrayDataType::kInt8:
1757 return 1;
1758 case ArrayDataType::kUint8:
1759 return 1;
1760 case ArrayDataType::kInt16:
1761 return 2;
1762 case ArrayDataType::kUint16:
1763 return 2;
1764 case ArrayDataType::kInt32:
1765 return 4;
1766 case ArrayDataType::kUint32:
1767 return 4;
1768 case ArrayDataType::kInt64:
1769 return 8;
1770 case ArrayDataType::kUint64:
1771 return 8;
1772 case ArrayDataType::kComplex64:
1773 return 8;
1774 case ArrayDataType::kComplex128:
1775 return 16;
1776 case ArrayDataType::kFloat64:
1777 return 8;
1778
1779 // Usually not critical limitation because strings are only input and/or
1780 // output.
1781 case ArrayDataType::kString:
1782 LOG(FATAL) << "Transient arrays with strings are not supported yet";
1783 return 0;
1784 default:
1785 LOG(FATAL) << "Unknown data_type = " << static_cast<int>(data_type);
1786 return 0;
1787 }
1788}
1789
1790void DropMinMax(Model* model, const std::string& array_name) {
1791 auto& array = model->GetArray(array_name);
1792 if (!!array.minmax) {
1793 LOG(WARNING) << "Dropping MinMax information in array " << array_name
1794 << ". Expect inaccuracy in quantized inference.";
1795 array.minmax = nullptr;
1796 }
1797}
1798
1799bool IsAllocatableTransientArray(const Model& model,
1800 const std::string& array_name) {
1801 // Optional array is not transient
1802 if (model.IsOptionalArray(array_name)) return false;
1803 // The model's input and output arrays are externally allocated.
1804 // They are not transient arrays.
1805 if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) {
1806 return false;
1807 }
1808 const auto& array = &model.GetArray(array_name);
1809 // An array with a constant buffer isn't a transient array.
1810 if (!!array->buffer) {
1811 return false;
1812 }
1813 // An array without shape isn't allocatable.
1814 if (!array->has_shape()) {
1815 return false;
1816 }
1817
1818 // The size of string tensors is rarely known ahead of time, so all transient
1819 // tensors of this type will need to be dynamically allocated.
1820 if (array->final_data_type == ArrayDataType::kString ||
1821 array->data_type == ArrayDataType::kString) {
1822 return false;
1823 }
1824
1825 return true;
1826}
1827
1828std::string AvailableArrayName(const Model& model, const std::string& name) {
1829 std::string sanitized_name = SanitizeNameForTFNode(name);
1830 if (!model.HasArray(sanitized_name) &&
1831 !model.IsOptionalArray(sanitized_name)) {
1832 return sanitized_name;
1833 }
1834 const int kNumSuffixesToTry = 1000;
1835 for (int i = 0; i < kNumSuffixesToTry; i++) {
1836 const std::string& name_with_suffix =
1837 toco::port::StringF("%s_%d", sanitized_name, i);
1838 if (!model.HasArray(name_with_suffix) &&
1839 !model.IsOptionalArray(name_with_suffix)) {
1840 return name_with_suffix;
1841 }
1842 }
1843 LOG(FATAL) << "Could not find an available array name starting with "
1844 << sanitized_name << ". Tried " << kNumSuffixesToTry
1845 << " suffixes, all were taken!";
1846 return "";
1847}
1848
1849std::string ShapeToString(const Shape& shape) {
1850 if (shape.dimensions_count() == 0) {
1851 return "[]";
1852 }
1853
1854 return absl::StrCat("[ ", absl::StrJoin(shape.dims(), ", "), " ]");
1855}
1856
1857void PrintArrayShape(Model* model, const std::string& name) {
1858 if (!model->GetArray(name).has_shape()) {
1859 LOG(INFO) << name << " has no shape";
1860 return;
1861 }
1862 LOG(INFO) << name
1863 << " has shape: " << ShapeToString(model->GetArray(name).shape());
1864}
1865
1866bool IsArrayFullyConnectedWeights(const Model& model, const std::string& name) {
1867 bool is_fc_weights = false;
1868 bool is_something_else = false;
1869 for (const auto& op : model.operators) {
1870 for (int input_index = 0; input_index < op->inputs.size(); input_index++) {
1871 if (op->inputs[input_index] == name) {
1872 if (op->type == OperatorType::kFullyConnected && input_index == 1) {
1873 is_fc_weights = true;
1874 } else {
1875 is_something_else = true;
1876 }
1877 }
1878 }
1879 }
1880 CHECK(!(is_fc_weights && is_something_else));
1881 return is_fc_weights;
1882}
1883
1884std::string CreateInt32Array(Model* model, const std::string& param_name,
1885 const std::vector<int>& value) {
1886 auto param_array_name = AvailableArrayName(*model, param_name);
1887 auto& param_array = model->GetOrCreateArray(param_array_name);
1888 param_array.mutable_shape()->ReplaceDims({static_cast<int>(value.size())});
1889 param_array.data_type = ArrayDataType::kInt32;
1890 auto& param_array_data =
1891 param_array.GetMutableBuffer<ArrayDataType::kInt32>().data;
1892 param_array_data.resize(RequiredBufferSizeForShape(param_array.shape()));
1893 for (int i = 0; i < value.size(); ++i) {
1894 param_array_data[i] = value[i];
1895 }
1896 return param_array_name;
1897}
1898
1899bool EstimateArithmeticOpsCount(const Model& model, const Operator& op,
1900 int64_t* result) {
1901 switch (op.type) {
1902 case OperatorType::kFullyConnected:
1903 case OperatorType::kConv:
1904 case OperatorType::kDepthwiseConv: {
1905 const auto& output_array = model.GetArray(op.outputs[0]);
1906 const auto& weights_array = model.GetArray(op.inputs[1]);
1907 if (!output_array.has_shape() || !weights_array.has_shape()) {
1908 return false;
1909 }
1910 int64_t cols = 1;
1911 for (int i = 0; i < output_array.shape().dimensions_count() - 1; i++) {
1912 cols *= output_array.shape().dims(i);
1913 }
1914 const int64_t cost_per_col =
1915 2 * RequiredBufferSizeForShape(weights_array.shape());
1916 *result = cost_per_col * cols;
1917 if (op.inputs.size() > 2) {
1918 // There is a bias vector. One more op per output value.
1919 *result += RequiredBufferSizeForShape(output_array.shape());
1920 }
1921 break;
1922 }
1923 case OperatorType::kTransposeConv: {
1924 const auto& input_array = model.GetArray(op.inputs[2]);
1925 const auto& weights_array = model.GetArray(op.inputs[1]);
1926 if (!input_array.has_shape() || !weights_array.has_shape()) {
1927 return false;
1928 }
1929 const Shape& input = input_array.shape();
1930 const Shape& weights = weights_array.shape();
1931 // Compute op count from the seven nested loops of
1932 // tflite::reference_ops::TransposeConv():
1933 *result = 2 * input.dims(0) * input.dims(1) * input.dims(2) *
1934 input.dims(3) * weights.dims(1) * weights.dims(2) *
1935 weights.dims(0);
1936 // Note that tflite::optimized_ops::TransposeConv() uses an im2col matrix
1937 // and has a higher op count, by a factor of (output_height*output_width)
1938 // vs. (input_height*input_width). Yet it generally performs better
1939 // because of coherent memory access. (At least for 2x2 striding. But not
1940 // likely for all cases.)
1941 break;
1942 }
1943 case OperatorType::kAdd:
1944 case OperatorType::kSub:
1945 case OperatorType::kMul: {
1946 const auto& output_array = model.GetArray(op.outputs[0]);
1947 if (!output_array.has_shape()) {
1948 return false;
1949 }
1950 *result = RequiredBufferSizeForShape(output_array.shape());
1951 break;
1952 }
1953 case OperatorType::kAddN: {
1954 const auto& output_array = model.GetArray(op.outputs[0]);
1955 if (!output_array.has_shape()) {
1956 return false;
1957 }
1958 // AddN cost is roughly the same cost as N-1 Adds.
1959 const int64_t num_adds = op.inputs.size() - 1;
1960 *result = num_adds * RequiredBufferSizeForShape(output_array.shape());
1961 break;
1962 }
1963 case OperatorType::kLogistic:
1964 case OperatorType::kSoftmax:
1965 case OperatorType::kLogSoftmax:
1966 case OperatorType::kTanh: {
1967 const auto& output_array = model.GetArray(op.outputs[0]);
1968 if (!output_array.has_shape()) {
1969 return false;
1970 }
1971 // As a very rough ballpark, the cost of evaluating a math function
1972 // such as tanh or logistic is about 32 multiplications, and about as
1973 // many additions/subtractions. (Just a power-of-two order-of-magnitude
1974 // from looking at actual implementations that we use in runtime/ code).
1975 *result = 64 * RequiredBufferSizeForShape(output_array.shape());
1976 break;
1977 }
1978 case OperatorType::kMaxPool: {
1979 const auto& maxpool = *static_cast<const MaxPoolOperator*>(&op);
1980 const auto& output_array = model.GetArray(op.outputs[0]);
1981 if (!output_array.has_shape()) {
1982 return false;
1983 }
1984 *result = RequiredBufferSizeForShape(output_array.shape()) *
1985 maxpool.kheight * maxpool.kwidth;
1986 break;
1987 }
1988 case OperatorType::kAveragePool: {
1989 const auto& avgpool = *static_cast<const AveragePoolOperator*>(&op);
1990 const auto& output_array = model.GetArray(op.outputs[0]);
1991 if (!output_array.has_shape()) {
1992 return false;
1993 }
1994 *result = RequiredBufferSizeForShape(output_array.shape()) *
1995 avgpool.kheight * avgpool.kwidth;
1996 break;
1997 }
1998 case OperatorType::kL2Pool: {
1999 const auto* maxpool = static_cast<const MaxPoolOperator*>(&op);
2000 const auto& output_array = model.GetArray(op.outputs[0]);
2001 if (!output_array.has_shape()) {
2002 return false;
2003 }
2004 // The sum of squares requires (kheight*kwidth) multiply-adds,
2005 // and then there is the sqrt which we ballpark at 32 ops.
2006 const int64_t cost_per_val = 2 * maxpool->kheight * maxpool->kwidth + 32;
2007 *result = RequiredBufferSizeForShape(output_array.shape()) * cost_per_val;
2008 break;
2009 }
2010 case OperatorType::kL2Normalization: {
2011 const auto& output_array = model.GetArray(op.outputs[0]);
2012 if (!output_array.has_shape()) {
2013 return false;
2014 }
2015 // Computing the squared L2 norm is N multiply-adds so 2N ops,
2016 // then the single inverse-sqrt is negligible, then we multiply each
2017 // value by the resulting multiplier, so an extra N ops. count 3N ops.
2018 *result = 3 * RequiredBufferSizeForShape(output_array.shape());
2019 break;
2020 }
2021 default:
2022 *result = 0;
2023 break;
2024 }
2025 return true;
2026}
2027
2028bool EstimateArithmeticOpsCount(const Model& model, int64_t* result) {
2029 int64_t total = 0;
2030 for (const auto& op : model.operators) {
2031 int64_t num_ops;
2032 if (!EstimateArithmeticOpsCount(model, *op, &num_ops)) {
2033 return false;
2034 }
2035 total += num_ops;
2036 }
2037 *result = total;
2038 return true;
2039}
2040
2041std::string FormattedNumber(int64_t x) {
2042 const int64_t million = 1000000;
2043 const int64_t billion = 1000000000;
2044 if (x < 10000) {
2045 return toco::port::StringF("%d ", x);
2046 } else if (x < billion) {
2047 return toco::port::StringF("%.3f M", static_cast<double>(x) / million);
2048 } else {
2049 return toco::port::StringF("%.3f G", static_cast<double>(x) / billion);
2050 }
2051}
2052
2053void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order,
2054 std::vector<int>* shuffle) {
2055 CHECK_EQ(AxesCount(input_axes_order), AxesCount(output_axes_order));
2056 shuffle->resize(4);
2057 for (int i = 0; i < 4; i++) {
2058 (*shuffle)[i] = i;
2059 }
2060 if (input_axes_order == output_axes_order) {
2061 // nothing to do
2062 } else if (AxesCount(input_axes_order) == 2) {
2063 shuffle->resize(2);
2064 (*shuffle)[0] = 1;
2065 (*shuffle)[1] = 0;
2066 } else if (input_axes_order == AxesOrder::kOHWI &&
2067 output_axes_order == AxesOrder::kHWIO) {
2068 // 3210 <- 3210
2069 // HWIO <- OHWI
2070 *shuffle = {1, 2, 3, 0};
2071 } else if (input_axes_order == AxesOrder::kHWIO &&
2072 output_axes_order == AxesOrder::kOHWI) {
2073 // 3210 <- 3210
2074 // OHWI <- HWIO
2075 *shuffle = {3, 0, 1, 2};
2076 } else if (input_axes_order == AxesOrder::kOHWI &&
2077 output_axes_order == AxesOrder::kHWOI) {
2078 *shuffle = {1, 2, 0, 3};
2079 } else {
2080 LOG(FATAL) << "Bad shuffle";
2081 }
2082}
2083
2084void ExtendShuffle(const std::vector<int>& input_shuffle, int newdim,
2085 std::vector<int>* extended_shuffle) {
2086 *extended_shuffle = input_shuffle;
2087 CHECK(newdim >= input_shuffle.size());
2088 const int pad_size = newdim - input_shuffle.size();
2089 extended_shuffle->resize(newdim);
2090 for (int i = 0; i < pad_size; i++) {
2091 (*extended_shuffle)[i] = i;
2092 }
2093 for (int i = pad_size; i < newdim; i++) {
2094 (*extended_shuffle)[i] = input_shuffle[i - pad_size] + pad_size;
2095 }
2096}
2097
2098void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order,
2099 AxesOrder output_axes_order, Shape* output_shape) {
2100 if (input_axes_order == AxesOrder::kHWIM &&
2101 output_axes_order == AxesOrder::k1HWO) {
2102 // This special case isn't just a permutation, the IM pair of dims get
2103 // merged into the 3 dim, so we have to special-case it.
2104 *output_shape = Shape({1, input_shape.dims(0), input_shape.dims(1),
2105 input_shape.dims(3) * input_shape.dims(2)});
2106 } else {
2107 std::vector<int> shuffle;
2108 GetShuffleShape(input_axes_order, output_axes_order, &shuffle);
2109 std::vector<int>* output_dims = output_shape->mutable_dims();
2110 output_dims->resize(input_shape.dimensions_count());
2111 for (int i = 0; i < input_shape.dimensions_count(); i++) {
2112 (*output_dims)[i] = input_shape.dims(shuffle[i]);
2113 }
2114 }
2115}
2116
2117template <typename T>
2118void ShuffleArrayTemplate(const Shape& input_shape, AxesOrder input_axes_order,
2119 AxesOrder output_axes_order,
2120 const Shape& output_shape, const T* input_data,
2121 T* output_data) {
2122 if (input_axes_order == AxesOrder::kHWIM &&
2123 output_axes_order == AxesOrder::k1HWO) {
2124 // This special case isn't just a permutation, the IM pair of dims get
2125 // merged into the O dim, so we have to special-case it. Fortunately,
2126 // as far as array shuffling is concerned, it's just the identity
2127 // transformation.
2128 memcpy(output_data, input_data,
2129 RequiredBufferSizeForShape(input_shape) * sizeof(output_data[0]));
2130 return;
2131 }
2132 CHECK(input_shape.dimensions_count() == output_shape.dimensions_count());
2133 const int dim = input_shape.dimensions_count();
2134 CHECK_LE(dim, 4);
2135 std::vector<int> shuffle;
2136 GetShuffleShape(input_axes_order, output_axes_order, &shuffle);
2137 CHECK(shuffle.size() >= dim);
2138 for (int i = 0; i < dim; i++) {
2139 CHECK(shuffle[i] >= 0 && shuffle[i] < dim);
2140 CHECK(input_shape.dims(shuffle[i]) == output_shape.dims(i));
2141 }
2142 Shape extended_input_shape = input_shape;
2143 ExtendShape(&extended_input_shape, 4);
2144 Shape extended_output_shape = output_shape;
2145 ExtendShape(&extended_output_shape, 4);
2146 std::vector<int> extended_shuffle;
2147 ExtendShuffle(shuffle, 4, &extended_shuffle);
2148
2149 const std::vector<int>& extended_input_dims = extended_input_shape.dims();
2150 const std::vector<int>& extended_output_dims = extended_output_shape.dims();
2151
2152 // TODO(starka): Rework to handle different numbers of dimensions.
2153 int input_strides[4];
2154 input_strides[3] = 1;
2155 input_strides[2] = extended_input_dims[3];
2156 input_strides[1] = input_strides[2] * extended_input_dims[2];
2157 input_strides[0] = input_strides[1] * extended_input_dims[1];
2158 const int input_stride_0 = input_strides[extended_shuffle[3]];
2159 const int input_stride_1 = input_strides[extended_shuffle[2]];
2160 const int input_stride_2 = input_strides[extended_shuffle[1]];
2161 const int input_stride_3 = input_strides[extended_shuffle[0]];
2162
2163 const int output_size_0 = extended_output_dims[3];
2164 const int output_size_1 = extended_output_dims[2];
2165 const int output_size_2 = extended_output_dims[1];
2166 const int output_size_3 = extended_output_dims[0];
2167 const int output_stride_0 = 1;
2168 const int output_stride_1 = output_size_0;
2169 const int output_stride_2 = output_stride_1 * output_size_1;
2170 const int output_stride_3 = output_stride_2 * output_size_2;
2171
2172 for (int i3 = 0; i3 < output_size_3; i3++) {
2173 const T* const input_ptr_3 = input_data + i3 * input_stride_3;
2174 T* const output_ptr_3 = output_data + i3 * output_stride_3;
2175 for (int i2 = 0; i2 < output_size_2; i2++) {
2176 const T* const input_ptr_2 = input_ptr_3 + i2 * input_stride_2;
2177 T* const output_ptr_2 = output_ptr_3 + i2 * output_stride_2;
2178 for (int i1 = 0; i1 < output_size_1; i1++) {
2179 const T* input_ptr = input_ptr_2 + i1 * input_stride_1;
2180 T* output_ptr = output_ptr_2 + i1 * output_stride_1;
2181 T* const output_ptr_end = output_ptr + output_size_0 * output_stride_0;
2182 while (output_ptr != output_ptr_end) {
2183 *output_ptr = *input_ptr;
2184 input_ptr += input_stride_0;
2185 output_ptr += output_stride_0;
2186 }
2187 }
2188 }
2189 }
2190}
2191
2192void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
2193 AxesOrder output_axes_order, const Shape& output_shape,
2194 const uint8* input_data, uint8* output_data) {
2195 ShuffleArrayTemplate<uint8>(input_shape, input_axes_order, output_axes_order,
2196 output_shape, input_data, output_data);
2197}
2198
2199void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
2200 AxesOrder output_axes_order, const Shape& output_shape,
2201 const float* input_data, float* output_data) {
2202 ShuffleArrayTemplate<float>(input_shape, input_axes_order, output_axes_order,
2203 output_shape, input_data, output_data);
2204}
2205
2206int AxesCount(AxesOrder axes_order) {
2207 switch (axes_order) {
2208 case AxesOrder::kOneAxis:
2209 return 1;
2210 case AxesOrder::kRC:
2211 return 2;
2212 case AxesOrder::kCR:
2213 return 2;
2214 case AxesOrder::kHWIO:
2215 return 4;
2216 case AxesOrder::kOHWI:
2217 return 4;
2218 case AxesOrder::kHWIM:
2219 return 4;
2220 case AxesOrder::k1HWO:
2221 return 4;
2222 case AxesOrder::kNHWC:
2223 return 4;
2224 case AxesOrder::kHWOI:
2225 return 4;
2226 default:
2227 LOG(FATAL) << "Bad AxesOrder";
2228 return 0;
2229 }
2230}
2231
2232bool IsDiscardableArray(const Model& model, const std::string& array_name) {
2233 if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) {
2234 return false;
2235 }
2236 for (const auto& rnn_state : model.flags.rnn_states()) {
2237 if (!rnn_state.discardable()) {
2238 if (array_name == rnn_state.state_array()) {
2239 return false;
2240 }
2241 if (array_name == rnn_state.back_edge_source_array()) {
2242 return false;
2243 }
2244 }
2245 }
2246 return true;
2247}
2248
2249bool ReshapeIsEquivalentToTranspose(const Model& model,
2250 const TensorFlowReshapeOperator* op,
2251 bool allow_extra_unary_dims) {
2252 CHECK(!op->shape.empty());
2253 CHECK(model.HasArray(op->inputs[0]));
2254 CHECK(model.HasArray(op->outputs[0]));
2255
2256 const auto& input_array = model.GetArray(op->inputs[0]);
2257 const auto& output_array = model.GetArray(op->outputs[0]);
2258
2259 CHECK(input_array.has_shape());
2260 CHECK(output_array.has_shape());
2261
2262 std::vector<int> in_shape = input_array.shape().dims();
2263 std::vector<int> out_shape = output_array.shape().dims();
2264
2265 // If the reshape changes the number of dimensions so it cannot be interpreted
2266 // as a transpose.
2267 if (!allow_extra_unary_dims && in_shape.size() != out_shape.size()) {
2268 return false;
2269 }
2270
2271 in_shape.erase(std::remove(in_shape.begin(), in_shape.end(), 1),
2272 in_shape.end());
2273 out_shape.erase(std::remove(out_shape.begin(), out_shape.end(), 1),
2274 out_shape.end());
2275 return in_shape == out_shape;
2276}
2277
2278void CheckFinalDataTypesSatisfied(const Model& model) {
2279 for (const auto& array_entry : model.GetArrayMap()) {
2280 const auto& array = *array_entry.second;
2281 if (array.data_type == ArrayDataType::kBool) {
2282 // Boolean values are never quantized.
2283 continue;
2284 }
2285
2286 // If the final data type is int16, the data type may be float, for example
2287 // after dequantization.
2288 if (array.final_data_type != ArrayDataType::kNone &&
2289 array.final_data_type != ArrayDataType::kInt16) {
2290 CHECK(array.data_type == array.final_data_type)
2291 << "Array \"" << array_entry.first
2292 << "\" has mis-matching actual and final data types (data_type="
2293 << ArrayDataTypeName(array.data_type)
2294 << ", final_data_type=" << ArrayDataTypeName(array.final_data_type)
2295 << ").";
2296 }
2297 }
2298}
2299
2300ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) {
2301 switch (type) {
2302 case FLOAT:
2303 return ArrayDataType::kFloat;
2304 case UINT8:
2305 case QUANTIZED_UINT8:
2306 return ArrayDataType::kUint8;
2307 case INT8:
2308 case QUANTIZED_INT8:
2309 return ArrayDataType::kInt8;
2310 case INT16:
2311 case QUANTIZED_INT16:
2312 return ArrayDataType::kInt16;
2313 case UINT16:
2314 return ArrayDataType::kUint16;
2315 case INT32:
2316 return ArrayDataType::kInt32;
2317 case UINT32:
2318 return ArrayDataType::kUint32;
2319 case INT64:
2320 return ArrayDataType::kInt64;
2321 case UINT64:
2322 return ArrayDataType::kUint64;
2323 case BOOL:
2324 return ArrayDataType::kBool;
2325 case STRING:
2326 return ArrayDataType::kString;
2327 case COMPLEX64:
2328 return ArrayDataType::kComplex64;
2329 case COMPLEX128:
2330 return ArrayDataType::kComplex128;
2331 case FLOAT16:
2332 return ArrayDataType::kFloat16;
2333 case FLOAT64:
2334 return ArrayDataType::kFloat64;
2335 case RESOURCE:
2336 case VARIANT:
2337 default:
2338 return ArrayDataType::kNone;
2339 }
2340}
2341
2342void FinishBuildingRNNStates(Model* model) {
2343 for (const auto& rnn_state : model->flags.rnn_states()) {
2344 if (!model->HasArray(rnn_state.back_edge_source_array()) ||
2345 !model->HasArray(rnn_state.state_array())) {
2346 CHECK(model->HasArray(rnn_state.back_edge_source_array()));
2347 CHECK(model->HasArray(rnn_state.state_array()));
2348 continue;
2349 }
2350 const auto& src_array = model->GetArray(rnn_state.back_edge_source_array());
2351 auto& dst_array = model->GetArray(rnn_state.state_array());
2352 if (src_array.data_type == ArrayDataType::kNone &&
2353 dst_array.data_type == ArrayDataType::kNone) {
2354 dst_array.data_type = ArrayDataType::kFloat;
2355 }
2356 }
2357}
2358
2359// Returns the array names that match the ArraysExtraInfo's name and
2360// name_regexp. The regexp match is for a full match.
2361std::unordered_set<std::string> ScanArrayNames(
2362 const Model& model, const toco::ArraysExtraInfo_Entry& entry) {
2363 std::unordered_set<std::string> matches;
2364 if (model.HasArray(entry.name())) {
2365 matches.insert(entry.name());
2366 }
2367 if (!entry.name_regexp().empty()) {
2368 const auto& arrays = model.GetArrayMap();
2369 const RE2 name_regexp = {entry.name_regexp()};
2370 for (auto it = arrays.begin(); it != arrays.end(); ++it) {
2371 if (RE2::FullMatch(it->first, name_regexp)) {
2372 matches.insert(it->first);
2373 }
2374 }
2375 }
2376 return matches;
2377}
2378
2379void UseArraysExtraInfo(Model* model, bool quantize_output) {
2380 for (const auto& entry : model->flags.arrays_extra_info().entries()) {
2381 const auto matches = ScanArrayNames(*model, entry);
2382 if (matches.empty()) {
2383 LOG(ERROR) << "arrays_extra_info_file: No matching arrays found for "
2384 << (entry.has_name() ? entry.name() : "")
2385 << (entry.has_name_regexp() ? entry.name_regexp() : "");
2386 continue;
2387 }
2388 for (const auto& matched_name : matches) {
2389 auto& array = model->GetArray(matched_name);
2390 if (entry.has_min() || entry.has_max()) {
2391 CHECK_EQ(entry.has_min(), entry.has_max());
2392 auto& minmax = array.GetOrCreateMinMax();
2393 minmax.min = entry.min();
2394 minmax.max = entry.max();
2395 }
2396 if (entry.has_data_type() && quantize_output) {
2397 array.final_data_type =
2398 ConvertIODataTypeToArrayDataType(entry.data_type());
2399 }
2400 if (entry.has_shape()) {
2401 array.clear_shape();
2402 // Make sure to create the shape even if there are no dims, to
2403 // correctly record 0-D shapes.
2404 array.mutable_shape();
2405 for (const auto& dim : entry.shape().dims()) {
2406 array.mutable_shape()->mutable_dims()->push_back(dim);
2407 }
2408 }
2409 if (entry.has_constant_float_value()) {
2410 CHECK(array.has_shape());
2411 if (array.data_type == ArrayDataType::kFloat) {
2412 auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data;
2413 data.resize(RequiredBufferSizeForShape(array.shape()));
2414 for (float& f : data) {
2415 f = entry.constant_float_value();
2416 }
2417 }
2418 }
2419 }
2420 }
2421}
2422
2423void UndoWeightsShuffling(Model* model) {
2424 for (const auto& op : model->operators) {
2425 if (op->type != toco::OperatorType::kFullyConnected) {
2426 continue;
2427 }
2428 const auto& fc_op = static_cast<toco::FullyConnectedOperator&>(*op);
2429 if (fc_op.weights_format == FullyConnectedWeightsFormat::kDefault) {
2430 continue;
2431 }
2432 const std::string& weights_name = fc_op.inputs[1];
2433 QCHECK_EQ(CountOpsWithInput(*model, weights_name), 1);
2434 auto& weights_array = model->GetArray(weights_name);
2435 QCHECK(weights_array.data_type == ArrayDataType::kUint8);
2436 auto& weights_data =
2437 weights_array.GetMutableBuffer<toco::ArrayDataType::kUint8>().data;
2438 const auto& weights_shape = weights_array.shape();
2439 QCHECK_EQ(weights_shape.dimensions_count(), 2);
2440 const int rows = weights_shape.dims(0);
2441 const int cols = weights_shape.dims(1);
2442 QCHECK_EQ(rows % 4, 0);
2443 QCHECK_EQ(cols % 16, 0);
2444 CHECK_EQ(rows * cols, weights_data.size());
2445 // Compute the de-shuffled weights
2446 std::vector<uint8> deshuffled_data(weights_data.size());
2447 uint8* shuffled_data_ptr = weights_data.data();
2448 for (int r = 0; r < rows; r += 4) {
2449 for (int c = 0; c < cols; c += 16) {
2450 for (int i = 0; i < 4; i++) {
2451 uint8* deshuffled_data_ptr =
2452 deshuffled_data.data() + (r + i) * cols + c;
2453 for (int j = 0; j < 16; j++) {
2454 uint8 shuffled_val = *shuffled_data_ptr++;
2455 // Deshuffling isn't only about deshuffling the storage layout,
2456 // it's also about undoing the flipping of the sign bit, which is
2457 // performed on the shuffled weights.
2458 uint8 deshuffled_val = shuffled_val ^ 0x80;
2459 *deshuffled_data_ptr++ = deshuffled_val;
2460 }
2461 }
2462 }
2463 }
2464 CHECK_EQ(shuffled_data_ptr, weights_data.data() + rows * cols);
2465 // Switch this FC op to using the deshuffled weights.
2466 weights_data = std::move(deshuffled_data);
2467 }
2468}
2469
2470void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst) {
2471 if (src.minmax) {
2472 dst->GetOrCreateMinMax() = src.GetMinMax();
2473 }
2474 if (src.quantization_params) {
2475 dst->GetOrCreateQuantizationParams() = src.GetQuantizationParams();
2476 }
2477 dst->narrow_range = src.narrow_range;
2478}
2479
2480} // namespace toco
2481