1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <cmath>
17#include <memory>
18#include <unordered_map>
19
20#include "tensorflow/c/checkpoint_reader.h"
21#include "tensorflow/core/common_runtime/graph_constructor.h"
22#include "tensorflow/core/framework/tensor.h"
23#include "tensorflow/core/graph/node_builder.h"
24#include "tensorflow/core/graph/subgraph.h"
25#include "tensorflow/core/lib/strings/str_util.h"
26#include "tensorflow/core/platform/init_main.h"
27#include "tensorflow/core/public/session.h"
28#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
29#include "tensorflow/tools/graph_transforms/transform_utils.h"
30
31namespace tensorflow {
32using str_util::Split;
33using str_util::StringReplace;
34using strings::StrCat;
35
36namespace graph_transforms {
37
38// Sparsify Tensor of shape [N, 1]. Return the indices and values vectors for
39// non-zero tensor content.
40Status SparsifyWeights(const Tensor& tensor, Tensor* indices_tensor,
41 Tensor* values_tensor) {
42 if (tensor.dims() != 2 || tensor.dim_size(1) != 1) {
43 return tensorflow::errors::FailedPrecondition(
44 "Transform only applicable to subgraph with 'Const' with "
45 "tensor of shape [N, 1]. But instead get shape ",
46 tensor.shape().DebugString(), ".");
47 }
48
49 auto flat = tensor.flat<float>();
50 std::vector<int64_t> indices;
51 std::vector<float> values;
52
53 for (int64_t i = 0; i < flat.size(); i++) {
54 float val = flat(i);
55 if (std::abs(val) >= 1.0e-5) {
56 indices.push_back(i);
57 values.push_back(val);
58 }
59 }
60
61 // During model initialization, InitializeTableOp makes use of
62 // KeyValueTensorIterator, which does not accept empty keys or values.
63 // Consequently, adding a dummy pair of indices and values as a walkaround.
64 if (indices.empty() || values.empty()) {
65 indices.push_back(0);
66 values.push_back(0);
67 }
68 *indices_tensor = Tensor(DataTypeToEnum<int64_t>::value,
69 {static_cast<int64_t>(indices.size())});
70 std::copy_n(indices.begin(), indices.size(),
71 indices_tensor->flat<int64_t>().data());
72
73 *values_tensor = Tensor(DataTypeToEnum<float>::value,
74 {static_cast<int64_t>(values.size())});
75 std::copy_n(values.begin(), values.size(),
76 values_tensor->flat<float>().data());
77
78 return OkStatus();
79}
80
81void CreateConstNode(const Tensor& tensor, const string& name,
82 NodeDef* node_def) {
83 node_def->set_op("Const");
84 node_def->set_name(name);
85 SetNodeTensorAttr<float>("value", tensor, node_def);
86}
87
88string GetMonolithicTensorKey(const string& tensor_slice_name) {
89 std::vector<string> names = Split(tensor_slice_name, "/");
90 if (absl::StartsWith(names[names.size() - 1], "part_")) {
91 CHECK_GE(names.size(), 2);
92 names.pop_back();
93 }
94 return absl::StrJoin(names, "/");
95}
96
97Status ObtainTensorSlice(const GraphDef& input_graph_def,
98 const string& target_name,
99 string* shape_slice_string) {
100 string restore_node_name;
101 for (const auto& node : input_graph_def.node()) {
102 std::vector<string> node_name_parts = Split(node.name(), "/");
103 if (node_name_parts.size() == 2 &&
104 absl::StartsWith(node_name_parts[0], "save") &&
105 absl::StartsWith(node_name_parts[1], "Assign") &&
106 node.input(0) == target_name) {
107 restore_node_name = node.input(1);
108 break;
109 }
110 }
111
112 std::vector<string> restore_node_parts = Split(restore_node_name, ":");
113 CHECK_LE(restore_node_parts.size(), 2);
114 string tensor_names_node;
115 string shape_and_slices_node;
116 for (const auto& node : input_graph_def.node()) {
117 if ((node.name() == restore_node_parts[0]) && (node.op() == "RestoreV2")) {
118 tensor_names_node = node.input(1);
119 shape_and_slices_node = node.input(2);
120 break;
121 }
122 }
123
124 int offset = -1;
125 for (const auto& node : input_graph_def.node()) {
126 if (node.name() == tensor_names_node) {
127 Tensor tensor_names_tensor;
128 TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &tensor_names_tensor));
129 const auto& tensor_names_value = tensor_names_tensor.flat<tstring>();
130 for (int i = 0; i < tensor_names_value.size(); i++) {
131 if (tensor_names_value(i) == GetMonolithicTensorKey(target_name)) {
132 offset = i;
133 break;
134 }
135 }
136 }
137 }
138 if (offset == -1) {
139 return errors::Internal("Unable to find RestoreV2 entry for variable: ",
140 target_name);
141 }
142 for (const auto& node : input_graph_def.node()) {
143 if (node.name() == shape_and_slices_node) {
144 Tensor shape_and_slices_tensor;
145 TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &shape_and_slices_tensor));
146 const auto& shape_and_slices_value =
147 shape_and_slices_tensor.flat<tstring>();
148 *shape_slice_string = shape_and_slices_value(offset);
149 return OkStatus();
150 }
151 }
152 return errors::Internal("Unable to find slice for variable: ", target_name);
153}
154
155Status ReadTensorFromCheckpoint(
156 const string& tensor_name, const std::unique_ptr<BundleReader>& ckpt_reader,
157 const string& shape_and_slice, Tensor* tensor) {
158 if (ckpt_reader) {
159 TensorShape parsed_full_shape;
160 TensorSlice parsed_slice;
161 TensorShape parsed_slice_shape;
162
163 bool get_slice = false;
164 if (!shape_and_slice.empty()) {
165 TF_RETURN_IF_ERROR(
166 checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape,
167 &parsed_slice, &parsed_slice_shape));
168 get_slice = (parsed_full_shape != parsed_slice_shape);
169 }
170 if (get_slice) {
171 TF_RETURN_IF_ERROR(ckpt_reader->LookupSlice(
172 GetMonolithicTensorKey(tensor_name), parsed_slice, tensor));
173 } else {
174 TF_RETURN_IF_ERROR(
175 ckpt_reader->Lookup(GetMonolithicTensorKey(tensor_name), tensor));
176 }
177 return OkStatus();
178 }
179 return errors::Internal("Checkpoint reader was not initialized. ");
180}
181
182Status InitializeCheckpointReader(const TransformFuncContext& context,
183 std::unique_ptr<BundleReader>* ckpt_reader) {
184 if (context.params.count("input_checkpoint")) {
185 const string input_checkpoint = context.params.at("input_checkpoint")[0];
186 ckpt_reader->reset(new BundleReader(Env::Default(), input_checkpoint));
187 TF_RETURN_IF_ERROR((*ckpt_reader)->status());
188 }
189 return OkStatus();
190}
191
192Status ObtainVariableInfo(
193 const GraphDef& input_graph_def,
194 std::unique_ptr<std::unordered_map<string, string> >* shapes_and_slices) {
195 shapes_and_slices->reset(new std::unordered_map<string, string>());
196 for (const auto& node : input_graph_def.node()) {
197 if ((node.op() == "Variable") || (node.op() == "VariableV2")) {
198 string s;
199 TF_RETURN_IF_ERROR(ObtainTensorSlice(input_graph_def, node.name(), &s));
200 (**shapes_and_slices)[node.name()] = s;
201 }
202 }
203 return OkStatus();
204}
205
206Status RemoveInputAtIndex(NodeDef* n, int index) {
207 for (int i = index; i < n->input_size() - 1; i++) {
208 n->mutable_input()->SwapElements(i, i + 1);
209 }
210 n->mutable_input()->RemoveLast();
211 return OkStatus();
212}
213
214Status RemoveNodeAtIndex(GraphDef* g, int index) {
215 for (int i = index; i < g->node_size() - 1; i++) {
216 g->mutable_node()->SwapElements(i, i + 1);
217 }
218 g->mutable_node()->RemoveLast();
219 return OkStatus();
220}
221
222Status SparsifyGatherInternal(
223 const GraphDef& input_graph_def,
224 const std::unique_ptr<std::unordered_map<string, string> >&
225 shapes_and_slices,
226 const TransformFuncContext& context, const OpTypePattern& pattern,
227 const std::unique_ptr<BundleReader>& ckpt_reader,
228 GraphDef* output_graph_def) {
229 string group_init_node = "group_deps";
230 if (context.params.count("group_init_node")) {
231 group_init_node = context.params.at("group_init_node")[0];
232 }
233 GraphDef current_graph_def = input_graph_def;
234 bool any_match_found = false;
235
236 // Populate references.
237 std::unordered_map<string, int> refs;
238 for (const auto& node : current_graph_def.node()) {
239 for (const auto& input : node.input()) {
240 auto parsed_input = StringReplace(input, "^", "", true);
241 refs[parsed_input] += 1;
242 }
243 }
244
245 // The subgraphs may have overlapping components, therefore GraphMatcher
246 // doesn't return all subgraphs in one round -- this has to be multi-round
247 // update.
248 do {
249 any_match_found = false;
250 GraphDef replaced_graph_def = current_graph_def;
251 std::vector<string> init_table_node_names;
252 std::vector<string> removed_node_names;
253
254 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
255 current_graph_def, pattern,
256 [&ckpt_reader, &any_match_found, &init_table_node_names,
257 &shapes_and_slices, &removed_node_names,
258 &refs](const NodeMatch& match, const std::set<string>& input_nodes,
259 const std::set<string>& output_nodes,
260 std::vector<NodeDef>* new_nodes) {
261 any_match_found = true;
262
263 // The captured subgraph should be of the following pattern:
264 // Const --> Identity --> Gather --> ...
265 // ^
266 // |
267 // (ids)
268 //
269 // After transform, it becomes:
270 // --> NoOp(group_deps)
271 // |
272 // Const --> InitializeTable --> HashTable
273 // ^ |
274 // | |
275 // Const ------------- |
276 // v
277 // (ids) ---> LookupTableFind <--- Const(default)
278 // |
279 // v
280 // ...
281
282 // clang-format off
283 // For each subgraph, do the following
284 // 1. Sparsify the `Const`, creating two `Const`, for hashtable
285 // key/val.
286 // 2. Create a `InitializeTable` op connecting to the above 2 `Const`.
287 // 3. Create a `HashTable` op connecting to `InitializeTable` op.
288 // 4. Replace the `Gather` with a `LookupTableFind` op.
289 // 5. Connect the `LookupTableFind` with
290 // a. `HashTable`
291 // b. `Gather`'s ids input
292 // c. a `default_val` arg, valued at 0
293 // clang-format on
294 const NodeDef& gather_node = match.node;
295
296 // GatherV2 adds an "axis" parameter. sparsify_gather only supports
297 // axis 0 gathers.
298 if (gather_node.op() == "GatherV2") {
299 // Per the OpTypePattern, the 3rd input to Gather must be a Const.
300 const NodeDef& axis_node = match.inputs[2].node;
301
302 Tensor axis_t;
303 TF_RETURN_IF_ERROR(GetNodeAttr(axis_node, "value", &axis_t));
304 int64_t axis = 0;
305 if (axis_t.dtype() == DT_INT32) {
306 axis = axis_t.scalar<int32>()();
307 } else if (axis_t.dtype() == DT_INT64) {
308 axis = axis_t.scalar<int64_t>()();
309 } else {
310 return tensorflow::errors::FailedPrecondition(
311 "Gather axis was not int32 or int64.");
312 }
313
314 if (axis != 0) {
315 return tensorflow::errors::FailedPrecondition(
316 "Transform only applicable to subgraph with GatherV2 over "
317 "axis 0. Found axis ",
318 axis, ".");
319 }
320 }
321
322 const NodeDef& weights_node = match.inputs[0].inputs[0].node;
323
324 DataType data_type;
325 TF_RETURN_IF_ERROR(GetNodeAttr(weights_node, "dtype", &data_type));
326 if (data_type != DT_FLOAT) {
327 return tensorflow::errors::FailedPrecondition(
328 "Transform only applicable to subgraph with 'Const',"
329 "'Variable', or 'VariableV2' of dtype "
330 "'DT_FLOAT'. Found '" +
331 weights_node.op() + "' with name '",
332 weights_node.name(), "' and dtype '", data_type, "'.");
333 }
334
335 Tensor weight;
336 if (weights_node.op() == "Const") {
337 weight = GetNodeTensorAttr(weights_node, "value");
338 } else {
339 TF_RETURN_IF_ERROR(ReadTensorFromCheckpoint(
340 weights_node.name(), ckpt_reader,
341 (*shapes_and_slices)[weights_node.name()], &weight));
342 }
343 // Add both weight and identity node names.
344 removed_node_names.push_back(weights_node.name());
345 removed_node_names.push_back(match.inputs[0].node.name());
346 for (auto input_node : match.inputs[0].node.input()) {
347 auto parsed_input = StringReplace(input_node, "^", "", true);
348 refs[parsed_input]--;
349 }
350 Tensor indices_tensor;
351 Tensor values_tensor;
352 TF_RETURN_IF_ERROR(
353 SparsifyWeights(weight, &indices_tensor, &values_tensor));
354
355 // indices and values of sparsified `Const`
356 DataType key_dtype = DT_INT64;
357 NodeDef indices_node;
358 CreateConstNode(indices_tensor,
359 StrCat(weights_node.name(), "/indices"),
360 &indices_node);
361 SetNodeAttr("dtype", key_dtype, &indices_node);
362
363 NodeDef values_node;
364 CreateConstNode(values_tensor, StrCat(weights_node.name(), "/values"),
365 &values_node);
366 SetNodeAttr("dtype", data_type, &values_node);
367
368 // HashTable node
369 NodeDef hashtable_node;
370 hashtable_node.set_op("HashTable");
371 hashtable_node.set_name(StrCat(weights_node.name(), "/HashTable"));
372 SetNodeAttr("key_dtype", key_dtype, &hashtable_node);
373 SetNodeAttr("value_dtype", data_type, &hashtable_node);
374
375 // InitializeTable node
376 NodeDef init_table_node;
377 init_table_node.set_op("InitializeTable");
378 init_table_node.set_name(
379 StrCat(weights_node.name(), "/InitializeTable"));
380 SetNodeAttr("Tkey", key_dtype, &init_table_node);
381 SetNodeAttr("Tval", data_type, &init_table_node);
382 init_table_node_names.push_back(init_table_node.name());
383
384 // LookupTableFind node
385 NodeDef lookup_node;
386 lookup_node.set_op("LookupTableFind");
387 lookup_node.set_name(StrCat(gather_node.name(), "/LookupTableFind"));
388 SetNodeAttr("Tin", key_dtype, &lookup_node);
389 SetNodeAttr("Tout", data_type, &lookup_node);
390
391 // Default return value of hashtable lookup
392 Tensor zero_tensor(data_type, TensorShape({}));
393 zero_tensor.flat<float>()(0) = 0.0;
394 NodeDef default_value_node;
395 CreateConstNode(zero_tensor, StrCat(gather_node.name(), "/Const"),
396 &default_value_node);
397 SetNodeAttr("dtype", data_type, &default_value_node);
398
399 // ExpandDims argument
400 Tensor dim_idx(DT_INT32, TensorShape({}));
401 dim_idx.flat<int32>()(0) = -1;
402 NodeDef dim_idx_node;
403 dim_idx_node.set_op("Const");
404 dim_idx_node.set_name(
405 StrCat(gather_node.name(), "/ExpandDims/Const"));
406 SetNodeAttr("value", dim_idx, &dim_idx_node);
407 SetNodeAttr("dtype", DT_INT32, &dim_idx_node);
408
409 // ExpandDims node
410 NodeDef expand_dims_node;
411 expand_dims_node.set_op("ExpandDims");
412 // Reuse gather_node's name so not to change dependent's inputs
413 expand_dims_node.set_name(gather_node.name());
414 SetNodeAttr("T", data_type, &expand_dims_node);
415
416 // Connect nodes
417 AddNodeInput(hashtable_node.name(), &init_table_node);
418 refs[hashtable_node.name()]++;
419 AddNodeInput(indices_node.name(), &init_table_node);
420 refs[indices_node.name()]++;
421 AddNodeInput(values_node.name(), &init_table_node);
422 refs[values_node.name()]++;
423
424 AddNodeInput(hashtable_node.name(), &lookup_node);
425 refs[hashtable_node.name()]++;
426 AddNodeInput(gather_node.input(1), &lookup_node);
427 refs[gather_node.input(1)]++;
428 AddNodeInput(default_value_node.name(), &lookup_node);
429 refs[default_value_node.name()]++;
430
431 AddNodeInput(lookup_node.name(), &expand_dims_node);
432 refs[lookup_node.name()]++;
433 AddNodeInput(dim_idx_node.name(), &expand_dims_node);
434 refs[dim_idx_node.name()]++;
435
436 // Copy 'ids' input of original 'Gather'
437 new_nodes->push_back(match.inputs[1].node);
438 new_nodes->push_back(indices_node);
439 new_nodes->push_back(values_node);
440 new_nodes->push_back(hashtable_node);
441 new_nodes->push_back(init_table_node);
442 new_nodes->push_back(lookup_node);
443 new_nodes->push_back(default_value_node);
444 new_nodes->push_back(dim_idx_node);
445 new_nodes->push_back(expand_dims_node);
446
447 return OkStatus();
448 },
449 {true}, &replaced_graph_def));
450
451 NodeDef* init_op = nullptr;
452 for (int i = 0; i < replaced_graph_def.node_size(); i++) {
453 if (replaced_graph_def.node(i).name() == group_init_node &&
454 replaced_graph_def.node(i).op() == "NoOp") {
455 init_op = replaced_graph_def.mutable_node(i);
456 break;
457 }
458 }
459 if (!init_op) {
460 // Init node
461 init_op = replaced_graph_def.mutable_node()->Add();
462 init_op->set_op("NoOp");
463 init_op->set_name(group_init_node);
464 }
465 for (const string& name : init_table_node_names) {
466 // Add control dependence from init_table_node to group_deps_node
467 AddNodeInput(StrCat("^", name), init_op);
468 refs[name]++;
469 }
470
471 // Erase inputs and outputs as they are not considered for deletion.
472 for (const auto& output : context.output_names) {
473 refs.erase(output);
474 }
475
476 for (const auto& input : context.input_names) {
477 refs.erase(input);
478 }
479
480 // Add nodes with a reference count of 0 for deletion.
481 for (const auto& entry : refs) {
482 if (entry.second == 0) {
483 removed_node_names.push_back(entry.first);
484 }
485 }
486
487 while (!removed_node_names.empty()) {
488 auto name = removed_node_names.back();
489 removed_node_names.pop_back();
490
491 int i = 0;
492 while (i < replaced_graph_def.node_size()) {
493 // Revisit this to see if we can safely remove RestoreV2 nodes.
494 if ((replaced_graph_def.node(i).name() == name) &&
495 (replaced_graph_def.node(i).op() != "RestoreV2")) {
496 for (const auto& input : replaced_graph_def.node(i).input()) {
497 auto parsed_input = StringReplace(input, "^", "", true);
498 refs[parsed_input] -= 1;
499 if (refs[parsed_input] == 0) {
500 removed_node_names.push_back(parsed_input);
501 }
502 }
503 TF_RETURN_IF_ERROR(RemoveNodeAtIndex(&replaced_graph_def, i));
504 continue;
505 }
506 int j = 0;
507 bool deleted_inputs = false;
508 while (j < replaced_graph_def.node(i).input_size()) {
509 if (replaced_graph_def.node(i).input(j) == name ||
510 replaced_graph_def.node(i).input(j) == ("^" + name)) {
511 TF_RETURN_IF_ERROR(
512 RemoveInputAtIndex(replaced_graph_def.mutable_node(i), j));
513 deleted_inputs = true;
514 continue;
515 }
516 j++;
517 }
518 if (deleted_inputs) {
519 if (replaced_graph_def.node(i).op() == "ConcatV2") {
520 if (replaced_graph_def.node(i).input_size() > 2) {
521 SetNodeAttr("N", replaced_graph_def.node(i).input_size() - 1,
522 replaced_graph_def.mutable_node(i));
523 } else if (replaced_graph_def.node(i).input_size() == 2) {
524 if (refs[replaced_graph_def.node(i).input(1)] != 1) {
525 return errors::Internal(
526 "Expect axis tensor of ConcatV2 node to only be referenced "
527 "once.");
528 }
529 refs[replaced_graph_def.node(i).input(1)] -= 1;
530 removed_node_names.push_back(replaced_graph_def.node(i).input(1));
531 replaced_graph_def.mutable_node(i)->mutable_input()->RemoveLast();
532 replaced_graph_def.mutable_node(i)->mutable_attr()->erase("N");
533 replaced_graph_def.mutable_node(i)->set_op("Identity");
534 } else {
535 return errors::Internal(
536 "ConcatV2 should have at least two elements");
537 }
538 }
539 if ((replaced_graph_def.node(i).op() == "Assign" ||
540 replaced_graph_def.node(i).op() == "Reshape" ||
541 replaced_graph_def.node(i).op() == "Equal" ||
542 replaced_graph_def.node(i).op() == "Mean" ||
543 replaced_graph_def.node(i).op() == "ScalarSummary") &&
544 replaced_graph_def.node(i).input_size() == 1) {
545 removed_node_names.push_back(replaced_graph_def.node(i).name());
546 }
547 if (!replaced_graph_def.node(i).input_size()) {
548 removed_node_names.push_back(replaced_graph_def.node(i).name());
549 }
550 }
551 i++;
552 }
553 }
554 current_graph_def = replaced_graph_def;
555 } while (any_match_found);
556 *output_graph_def = current_graph_def;
557 return OkStatus();
558}
559
560Status SparsifyGather(const GraphDef& input_graph_def,
561 const TransformFuncContext& context,
562 GraphDef* output_graph_def) {
563 // clang-format off
564 const OpTypePattern gather_pattern =
565 {"Gather",
566 {
567 {"Identity",
568 {
569 {"Const|Variable|VariableV2"}
570 }
571 },
572 {"*"},
573 }
574 };
575 const OpTypePattern gather_v2_pattern =
576 {"GatherV2",
577 {
578 {"Identity",
579 {
580 {"Const|Variable|VariableV2"}
581 }
582 },
583 {"*"},
584 // GatherV2's axis must be constant.
585 {"Const"},
586 }
587 };
588 // clang-format on
589
590 GraphDef cleaned_input_graph_def;
591 RemoveAttributes(input_graph_def, {"_output_shapes"},
592 &cleaned_input_graph_def);
593
594 GraphDef temp_output;
595
596 std::unique_ptr<BundleReader> ckpt_reader;
597 TF_RETURN_IF_ERROR(InitializeCheckpointReader(context, &ckpt_reader));
598
599 std::unique_ptr<std::unordered_map<string, string> > shapes_and_slices;
600 TF_RETURN_IF_ERROR(
601 ObtainVariableInfo(cleaned_input_graph_def, &shapes_and_slices));
602
603 TF_RETURN_IF_ERROR(SparsifyGatherInternal(
604 cleaned_input_graph_def, shapes_and_slices, context, gather_pattern,
605 ckpt_reader, &temp_output));
606
607 TF_RETURN_IF_ERROR(SparsifyGatherInternal(temp_output, shapes_and_slices,
608 context, gather_v2_pattern,
609 ckpt_reader, output_graph_def));
610
611 return OkStatus();
612}
613
614REGISTER_GRAPH_TRANSFORM("sparsify_gather", SparsifyGather);
615
616} // namespace graph_transforms
617} // namespace tensorflow
618