1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <cmath> |
17 | #include <memory> |
18 | #include <unordered_map> |
19 | |
20 | #include "tensorflow/c/checkpoint_reader.h" |
21 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
22 | #include "tensorflow/core/framework/tensor.h" |
23 | #include "tensorflow/core/graph/node_builder.h" |
24 | #include "tensorflow/core/graph/subgraph.h" |
25 | #include "tensorflow/core/lib/strings/str_util.h" |
26 | #include "tensorflow/core/platform/init_main.h" |
27 | #include "tensorflow/core/public/session.h" |
28 | #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" |
29 | #include "tensorflow/tools/graph_transforms/transform_utils.h" |
30 | |
31 | namespace tensorflow { |
32 | using str_util::Split; |
33 | using str_util::StringReplace; |
34 | using strings::StrCat; |
35 | |
36 | namespace graph_transforms { |
37 | |
38 | // Sparsify Tensor of shape [N, 1]. Return the indices and values vectors for |
39 | // non-zero tensor content. |
40 | Status SparsifyWeights(const Tensor& tensor, Tensor* indices_tensor, |
41 | Tensor* values_tensor) { |
42 | if (tensor.dims() != 2 || tensor.dim_size(1) != 1) { |
43 | return tensorflow::errors::FailedPrecondition( |
44 | "Transform only applicable to subgraph with 'Const' with " |
45 | "tensor of shape [N, 1]. But instead get shape " , |
46 | tensor.shape().DebugString(), "." ); |
47 | } |
48 | |
49 | auto flat = tensor.flat<float>(); |
50 | std::vector<int64_t> indices; |
51 | std::vector<float> values; |
52 | |
53 | for (int64_t i = 0; i < flat.size(); i++) { |
54 | float val = flat(i); |
55 | if (std::abs(val) >= 1.0e-5) { |
56 | indices.push_back(i); |
57 | values.push_back(val); |
58 | } |
59 | } |
60 | |
61 | // During model initialization, InitializeTableOp makes use of |
62 | // KeyValueTensorIterator, which does not accept empty keys or values. |
63 | // Consequently, adding a dummy pair of indices and values as a walkaround. |
64 | if (indices.empty() || values.empty()) { |
65 | indices.push_back(0); |
66 | values.push_back(0); |
67 | } |
68 | *indices_tensor = Tensor(DataTypeToEnum<int64_t>::value, |
69 | {static_cast<int64_t>(indices.size())}); |
70 | std::copy_n(indices.begin(), indices.size(), |
71 | indices_tensor->flat<int64_t>().data()); |
72 | |
73 | *values_tensor = Tensor(DataTypeToEnum<float>::value, |
74 | {static_cast<int64_t>(values.size())}); |
75 | std::copy_n(values.begin(), values.size(), |
76 | values_tensor->flat<float>().data()); |
77 | |
78 | return OkStatus(); |
79 | } |
80 | |
81 | void CreateConstNode(const Tensor& tensor, const string& name, |
82 | NodeDef* node_def) { |
83 | node_def->set_op("Const" ); |
84 | node_def->set_name(name); |
85 | SetNodeTensorAttr<float>("value" , tensor, node_def); |
86 | } |
87 | |
88 | string GetMonolithicTensorKey(const string& tensor_slice_name) { |
89 | std::vector<string> names = Split(tensor_slice_name, "/" ); |
90 | if (absl::StartsWith(names[names.size() - 1], "part_" )) { |
91 | CHECK_GE(names.size(), 2); |
92 | names.pop_back(); |
93 | } |
94 | return absl::StrJoin(names, "/" ); |
95 | } |
96 | |
97 | Status ObtainTensorSlice(const GraphDef& input_graph_def, |
98 | const string& target_name, |
99 | string* shape_slice_string) { |
100 | string restore_node_name; |
101 | for (const auto& node : input_graph_def.node()) { |
102 | std::vector<string> node_name_parts = Split(node.name(), "/" ); |
103 | if (node_name_parts.size() == 2 && |
104 | absl::StartsWith(node_name_parts[0], "save" ) && |
105 | absl::StartsWith(node_name_parts[1], "Assign" ) && |
106 | node.input(0) == target_name) { |
107 | restore_node_name = node.input(1); |
108 | break; |
109 | } |
110 | } |
111 | |
112 | std::vector<string> restore_node_parts = Split(restore_node_name, ":" ); |
113 | CHECK_LE(restore_node_parts.size(), 2); |
114 | string tensor_names_node; |
115 | string shape_and_slices_node; |
116 | for (const auto& node : input_graph_def.node()) { |
117 | if ((node.name() == restore_node_parts[0]) && (node.op() == "RestoreV2" )) { |
118 | tensor_names_node = node.input(1); |
119 | shape_and_slices_node = node.input(2); |
120 | break; |
121 | } |
122 | } |
123 | |
124 | int offset = -1; |
125 | for (const auto& node : input_graph_def.node()) { |
126 | if (node.name() == tensor_names_node) { |
127 | Tensor tensor_names_tensor; |
128 | TF_RETURN_IF_ERROR(GetNodeAttr(node, "value" , &tensor_names_tensor)); |
129 | const auto& tensor_names_value = tensor_names_tensor.flat<tstring>(); |
130 | for (int i = 0; i < tensor_names_value.size(); i++) { |
131 | if (tensor_names_value(i) == GetMonolithicTensorKey(target_name)) { |
132 | offset = i; |
133 | break; |
134 | } |
135 | } |
136 | } |
137 | } |
138 | if (offset == -1) { |
139 | return errors::Internal("Unable to find RestoreV2 entry for variable: " , |
140 | target_name); |
141 | } |
142 | for (const auto& node : input_graph_def.node()) { |
143 | if (node.name() == shape_and_slices_node) { |
144 | Tensor shape_and_slices_tensor; |
145 | TF_RETURN_IF_ERROR(GetNodeAttr(node, "value" , &shape_and_slices_tensor)); |
146 | const auto& shape_and_slices_value = |
147 | shape_and_slices_tensor.flat<tstring>(); |
148 | *shape_slice_string = shape_and_slices_value(offset); |
149 | return OkStatus(); |
150 | } |
151 | } |
152 | return errors::Internal("Unable to find slice for variable: " , target_name); |
153 | } |
154 | |
155 | Status ReadTensorFromCheckpoint( |
156 | const string& tensor_name, const std::unique_ptr<BundleReader>& ckpt_reader, |
157 | const string& shape_and_slice, Tensor* tensor) { |
158 | if (ckpt_reader) { |
159 | TensorShape parsed_full_shape; |
160 | TensorSlice parsed_slice; |
161 | TensorShape parsed_slice_shape; |
162 | |
163 | bool get_slice = false; |
164 | if (!shape_and_slice.empty()) { |
165 | TF_RETURN_IF_ERROR( |
166 | checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape, |
167 | &parsed_slice, &parsed_slice_shape)); |
168 | get_slice = (parsed_full_shape != parsed_slice_shape); |
169 | } |
170 | if (get_slice) { |
171 | TF_RETURN_IF_ERROR(ckpt_reader->LookupSlice( |
172 | GetMonolithicTensorKey(tensor_name), parsed_slice, tensor)); |
173 | } else { |
174 | TF_RETURN_IF_ERROR( |
175 | ckpt_reader->Lookup(GetMonolithicTensorKey(tensor_name), tensor)); |
176 | } |
177 | return OkStatus(); |
178 | } |
179 | return errors::Internal("Checkpoint reader was not initialized. " ); |
180 | } |
181 | |
182 | Status InitializeCheckpointReader(const TransformFuncContext& context, |
183 | std::unique_ptr<BundleReader>* ckpt_reader) { |
184 | if (context.params.count("input_checkpoint" )) { |
185 | const string input_checkpoint = context.params.at("input_checkpoint" )[0]; |
186 | ckpt_reader->reset(new BundleReader(Env::Default(), input_checkpoint)); |
187 | TF_RETURN_IF_ERROR((*ckpt_reader)->status()); |
188 | } |
189 | return OkStatus(); |
190 | } |
191 | |
192 | Status ObtainVariableInfo( |
193 | const GraphDef& input_graph_def, |
194 | std::unique_ptr<std::unordered_map<string, string> >* shapes_and_slices) { |
195 | shapes_and_slices->reset(new std::unordered_map<string, string>()); |
196 | for (const auto& node : input_graph_def.node()) { |
197 | if ((node.op() == "Variable" ) || (node.op() == "VariableV2" )) { |
198 | string s; |
199 | TF_RETURN_IF_ERROR(ObtainTensorSlice(input_graph_def, node.name(), &s)); |
200 | (**shapes_and_slices)[node.name()] = s; |
201 | } |
202 | } |
203 | return OkStatus(); |
204 | } |
205 | |
206 | Status RemoveInputAtIndex(NodeDef* n, int index) { |
207 | for (int i = index; i < n->input_size() - 1; i++) { |
208 | n->mutable_input()->SwapElements(i, i + 1); |
209 | } |
210 | n->mutable_input()->RemoveLast(); |
211 | return OkStatus(); |
212 | } |
213 | |
214 | Status RemoveNodeAtIndex(GraphDef* g, int index) { |
215 | for (int i = index; i < g->node_size() - 1; i++) { |
216 | g->mutable_node()->SwapElements(i, i + 1); |
217 | } |
218 | g->mutable_node()->RemoveLast(); |
219 | return OkStatus(); |
220 | } |
221 | |
222 | Status SparsifyGatherInternal( |
223 | const GraphDef& input_graph_def, |
224 | const std::unique_ptr<std::unordered_map<string, string> >& |
225 | shapes_and_slices, |
226 | const TransformFuncContext& context, const OpTypePattern& pattern, |
227 | const std::unique_ptr<BundleReader>& ckpt_reader, |
228 | GraphDef* output_graph_def) { |
229 | string group_init_node = "group_deps" ; |
230 | if (context.params.count("group_init_node" )) { |
231 | group_init_node = context.params.at("group_init_node" )[0]; |
232 | } |
233 | GraphDef current_graph_def = input_graph_def; |
234 | bool any_match_found = false; |
235 | |
236 | // Populate references. |
237 | std::unordered_map<string, int> refs; |
238 | for (const auto& node : current_graph_def.node()) { |
239 | for (const auto& input : node.input()) { |
240 | auto parsed_input = StringReplace(input, "^" , "" , true); |
241 | refs[parsed_input] += 1; |
242 | } |
243 | } |
244 | |
245 | // The subgraphs may have overlapping components, therefore GraphMatcher |
246 | // doesn't return all subgraphs in one round -- this has to be multi-round |
247 | // update. |
248 | do { |
249 | any_match_found = false; |
250 | GraphDef replaced_graph_def = current_graph_def; |
251 | std::vector<string> init_table_node_names; |
252 | std::vector<string> removed_node_names; |
253 | |
254 | TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( |
255 | current_graph_def, pattern, |
256 | [&ckpt_reader, &any_match_found, &init_table_node_names, |
257 | &shapes_and_slices, &removed_node_names, |
258 | &refs](const NodeMatch& match, const std::set<string>& input_nodes, |
259 | const std::set<string>& output_nodes, |
260 | std::vector<NodeDef>* new_nodes) { |
261 | any_match_found = true; |
262 | |
263 | // The captured subgraph should be of the following pattern: |
264 | // Const --> Identity --> Gather --> ... |
265 | // ^ |
266 | // | |
267 | // (ids) |
268 | // |
269 | // After transform, it becomes: |
270 | // --> NoOp(group_deps) |
271 | // | |
272 | // Const --> InitializeTable --> HashTable |
273 | // ^ | |
274 | // | | |
275 | // Const ------------- | |
276 | // v |
277 | // (ids) ---> LookupTableFind <--- Const(default) |
278 | // | |
279 | // v |
280 | // ... |
281 | |
282 | // clang-format off |
283 | // For each subgraph, do the following |
284 | // 1. Sparsify the `Const`, creating two `Const`, for hashtable |
285 | // key/val. |
286 | // 2. Create a `InitializeTable` op connecting to the above 2 `Const`. |
287 | // 3. Create a `HashTable` op connecting to `InitializeTable` op. |
288 | // 4. Replace the `Gather` with a `LookupTableFind` op. |
289 | // 5. Connect the `LookupTableFind` with |
290 | // a. `HashTable` |
291 | // b. `Gather`'s ids input |
292 | // c. a `default_val` arg, valued at 0 |
293 | // clang-format on |
294 | const NodeDef& gather_node = match.node; |
295 | |
296 | // GatherV2 adds an "axis" parameter. sparsify_gather only supports |
297 | // axis 0 gathers. |
298 | if (gather_node.op() == "GatherV2" ) { |
299 | // Per the OpTypePattern, the 3rd input to Gather must be a Const. |
300 | const NodeDef& axis_node = match.inputs[2].node; |
301 | |
302 | Tensor axis_t; |
303 | TF_RETURN_IF_ERROR(GetNodeAttr(axis_node, "value" , &axis_t)); |
304 | int64_t axis = 0; |
305 | if (axis_t.dtype() == DT_INT32) { |
306 | axis = axis_t.scalar<int32>()(); |
307 | } else if (axis_t.dtype() == DT_INT64) { |
308 | axis = axis_t.scalar<int64_t>()(); |
309 | } else { |
310 | return tensorflow::errors::FailedPrecondition( |
311 | "Gather axis was not int32 or int64." ); |
312 | } |
313 | |
314 | if (axis != 0) { |
315 | return tensorflow::errors::FailedPrecondition( |
316 | "Transform only applicable to subgraph with GatherV2 over " |
317 | "axis 0. Found axis " , |
318 | axis, "." ); |
319 | } |
320 | } |
321 | |
322 | const NodeDef& weights_node = match.inputs[0].inputs[0].node; |
323 | |
324 | DataType data_type; |
325 | TF_RETURN_IF_ERROR(GetNodeAttr(weights_node, "dtype" , &data_type)); |
326 | if (data_type != DT_FLOAT) { |
327 | return tensorflow::errors::FailedPrecondition( |
328 | "Transform only applicable to subgraph with 'Const'," |
329 | "'Variable', or 'VariableV2' of dtype " |
330 | "'DT_FLOAT'. Found '" + |
331 | weights_node.op() + "' with name '" , |
332 | weights_node.name(), "' and dtype '" , data_type, "'." ); |
333 | } |
334 | |
335 | Tensor weight; |
336 | if (weights_node.op() == "Const" ) { |
337 | weight = GetNodeTensorAttr(weights_node, "value" ); |
338 | } else { |
339 | TF_RETURN_IF_ERROR(ReadTensorFromCheckpoint( |
340 | weights_node.name(), ckpt_reader, |
341 | (*shapes_and_slices)[weights_node.name()], &weight)); |
342 | } |
343 | // Add both weight and identity node names. |
344 | removed_node_names.push_back(weights_node.name()); |
345 | removed_node_names.push_back(match.inputs[0].node.name()); |
346 | for (auto input_node : match.inputs[0].node.input()) { |
347 | auto parsed_input = StringReplace(input_node, "^" , "" , true); |
348 | refs[parsed_input]--; |
349 | } |
350 | Tensor indices_tensor; |
351 | Tensor values_tensor; |
352 | TF_RETURN_IF_ERROR( |
353 | SparsifyWeights(weight, &indices_tensor, &values_tensor)); |
354 | |
355 | // indices and values of sparsified `Const` |
356 | DataType key_dtype = DT_INT64; |
357 | NodeDef indices_node; |
358 | CreateConstNode(indices_tensor, |
359 | StrCat(weights_node.name(), "/indices" ), |
360 | &indices_node); |
361 | SetNodeAttr("dtype" , key_dtype, &indices_node); |
362 | |
363 | NodeDef values_node; |
364 | CreateConstNode(values_tensor, StrCat(weights_node.name(), "/values" ), |
365 | &values_node); |
366 | SetNodeAttr("dtype" , data_type, &values_node); |
367 | |
368 | // HashTable node |
369 | NodeDef hashtable_node; |
370 | hashtable_node.set_op("HashTable" ); |
371 | hashtable_node.set_name(StrCat(weights_node.name(), "/HashTable" )); |
372 | SetNodeAttr("key_dtype" , key_dtype, &hashtable_node); |
373 | SetNodeAttr("value_dtype" , data_type, &hashtable_node); |
374 | |
375 | // InitializeTable node |
376 | NodeDef init_table_node; |
377 | init_table_node.set_op("InitializeTable" ); |
378 | init_table_node.set_name( |
379 | StrCat(weights_node.name(), "/InitializeTable" )); |
380 | SetNodeAttr("Tkey" , key_dtype, &init_table_node); |
381 | SetNodeAttr("Tval" , data_type, &init_table_node); |
382 | init_table_node_names.push_back(init_table_node.name()); |
383 | |
384 | // LookupTableFind node |
385 | NodeDef lookup_node; |
386 | lookup_node.set_op("LookupTableFind" ); |
387 | lookup_node.set_name(StrCat(gather_node.name(), "/LookupTableFind" )); |
388 | SetNodeAttr("Tin" , key_dtype, &lookup_node); |
389 | SetNodeAttr("Tout" , data_type, &lookup_node); |
390 | |
391 | // Default return value of hashtable lookup |
392 | Tensor zero_tensor(data_type, TensorShape({})); |
393 | zero_tensor.flat<float>()(0) = 0.0; |
394 | NodeDef default_value_node; |
395 | CreateConstNode(zero_tensor, StrCat(gather_node.name(), "/Const" ), |
396 | &default_value_node); |
397 | SetNodeAttr("dtype" , data_type, &default_value_node); |
398 | |
399 | // ExpandDims argument |
400 | Tensor dim_idx(DT_INT32, TensorShape({})); |
401 | dim_idx.flat<int32>()(0) = -1; |
402 | NodeDef dim_idx_node; |
403 | dim_idx_node.set_op("Const" ); |
404 | dim_idx_node.set_name( |
405 | StrCat(gather_node.name(), "/ExpandDims/Const" )); |
406 | SetNodeAttr("value" , dim_idx, &dim_idx_node); |
407 | SetNodeAttr("dtype" , DT_INT32, &dim_idx_node); |
408 | |
409 | // ExpandDims node |
410 | NodeDef expand_dims_node; |
411 | expand_dims_node.set_op("ExpandDims" ); |
412 | // Reuse gather_node's name so not to change dependent's inputs |
413 | expand_dims_node.set_name(gather_node.name()); |
414 | SetNodeAttr("T" , data_type, &expand_dims_node); |
415 | |
416 | // Connect nodes |
417 | AddNodeInput(hashtable_node.name(), &init_table_node); |
418 | refs[hashtable_node.name()]++; |
419 | AddNodeInput(indices_node.name(), &init_table_node); |
420 | refs[indices_node.name()]++; |
421 | AddNodeInput(values_node.name(), &init_table_node); |
422 | refs[values_node.name()]++; |
423 | |
424 | AddNodeInput(hashtable_node.name(), &lookup_node); |
425 | refs[hashtable_node.name()]++; |
426 | AddNodeInput(gather_node.input(1), &lookup_node); |
427 | refs[gather_node.input(1)]++; |
428 | AddNodeInput(default_value_node.name(), &lookup_node); |
429 | refs[default_value_node.name()]++; |
430 | |
431 | AddNodeInput(lookup_node.name(), &expand_dims_node); |
432 | refs[lookup_node.name()]++; |
433 | AddNodeInput(dim_idx_node.name(), &expand_dims_node); |
434 | refs[dim_idx_node.name()]++; |
435 | |
436 | // Copy 'ids' input of original 'Gather' |
437 | new_nodes->push_back(match.inputs[1].node); |
438 | new_nodes->push_back(indices_node); |
439 | new_nodes->push_back(values_node); |
440 | new_nodes->push_back(hashtable_node); |
441 | new_nodes->push_back(init_table_node); |
442 | new_nodes->push_back(lookup_node); |
443 | new_nodes->push_back(default_value_node); |
444 | new_nodes->push_back(dim_idx_node); |
445 | new_nodes->push_back(expand_dims_node); |
446 | |
447 | return OkStatus(); |
448 | }, |
449 | {true}, &replaced_graph_def)); |
450 | |
451 | NodeDef* init_op = nullptr; |
452 | for (int i = 0; i < replaced_graph_def.node_size(); i++) { |
453 | if (replaced_graph_def.node(i).name() == group_init_node && |
454 | replaced_graph_def.node(i).op() == "NoOp" ) { |
455 | init_op = replaced_graph_def.mutable_node(i); |
456 | break; |
457 | } |
458 | } |
459 | if (!init_op) { |
460 | // Init node |
461 | init_op = replaced_graph_def.mutable_node()->Add(); |
462 | init_op->set_op("NoOp" ); |
463 | init_op->set_name(group_init_node); |
464 | } |
465 | for (const string& name : init_table_node_names) { |
466 | // Add control dependence from init_table_node to group_deps_node |
467 | AddNodeInput(StrCat("^" , name), init_op); |
468 | refs[name]++; |
469 | } |
470 | |
471 | // Erase inputs and outputs as they are not considered for deletion. |
472 | for (const auto& output : context.output_names) { |
473 | refs.erase(output); |
474 | } |
475 | |
476 | for (const auto& input : context.input_names) { |
477 | refs.erase(input); |
478 | } |
479 | |
480 | // Add nodes with a reference count of 0 for deletion. |
481 | for (const auto& entry : refs) { |
482 | if (entry.second == 0) { |
483 | removed_node_names.push_back(entry.first); |
484 | } |
485 | } |
486 | |
487 | while (!removed_node_names.empty()) { |
488 | auto name = removed_node_names.back(); |
489 | removed_node_names.pop_back(); |
490 | |
491 | int i = 0; |
492 | while (i < replaced_graph_def.node_size()) { |
493 | // Revisit this to see if we can safely remove RestoreV2 nodes. |
494 | if ((replaced_graph_def.node(i).name() == name) && |
495 | (replaced_graph_def.node(i).op() != "RestoreV2" )) { |
496 | for (const auto& input : replaced_graph_def.node(i).input()) { |
497 | auto parsed_input = StringReplace(input, "^" , "" , true); |
498 | refs[parsed_input] -= 1; |
499 | if (refs[parsed_input] == 0) { |
500 | removed_node_names.push_back(parsed_input); |
501 | } |
502 | } |
503 | TF_RETURN_IF_ERROR(RemoveNodeAtIndex(&replaced_graph_def, i)); |
504 | continue; |
505 | } |
506 | int j = 0; |
507 | bool deleted_inputs = false; |
508 | while (j < replaced_graph_def.node(i).input_size()) { |
509 | if (replaced_graph_def.node(i).input(j) == name || |
510 | replaced_graph_def.node(i).input(j) == ("^" + name)) { |
511 | TF_RETURN_IF_ERROR( |
512 | RemoveInputAtIndex(replaced_graph_def.mutable_node(i), j)); |
513 | deleted_inputs = true; |
514 | continue; |
515 | } |
516 | j++; |
517 | } |
518 | if (deleted_inputs) { |
519 | if (replaced_graph_def.node(i).op() == "ConcatV2" ) { |
520 | if (replaced_graph_def.node(i).input_size() > 2) { |
521 | SetNodeAttr("N" , replaced_graph_def.node(i).input_size() - 1, |
522 | replaced_graph_def.mutable_node(i)); |
523 | } else if (replaced_graph_def.node(i).input_size() == 2) { |
524 | if (refs[replaced_graph_def.node(i).input(1)] != 1) { |
525 | return errors::Internal( |
526 | "Expect axis tensor of ConcatV2 node to only be referenced " |
527 | "once." ); |
528 | } |
529 | refs[replaced_graph_def.node(i).input(1)] -= 1; |
530 | removed_node_names.push_back(replaced_graph_def.node(i).input(1)); |
531 | replaced_graph_def.mutable_node(i)->mutable_input()->RemoveLast(); |
532 | replaced_graph_def.mutable_node(i)->mutable_attr()->erase("N" ); |
533 | replaced_graph_def.mutable_node(i)->set_op("Identity" ); |
534 | } else { |
535 | return errors::Internal( |
536 | "ConcatV2 should have at least two elements" ); |
537 | } |
538 | } |
539 | if ((replaced_graph_def.node(i).op() == "Assign" || |
540 | replaced_graph_def.node(i).op() == "Reshape" || |
541 | replaced_graph_def.node(i).op() == "Equal" || |
542 | replaced_graph_def.node(i).op() == "Mean" || |
543 | replaced_graph_def.node(i).op() == "ScalarSummary" ) && |
544 | replaced_graph_def.node(i).input_size() == 1) { |
545 | removed_node_names.push_back(replaced_graph_def.node(i).name()); |
546 | } |
547 | if (!replaced_graph_def.node(i).input_size()) { |
548 | removed_node_names.push_back(replaced_graph_def.node(i).name()); |
549 | } |
550 | } |
551 | i++; |
552 | } |
553 | } |
554 | current_graph_def = replaced_graph_def; |
555 | } while (any_match_found); |
556 | *output_graph_def = current_graph_def; |
557 | return OkStatus(); |
558 | } |
559 | |
560 | Status SparsifyGather(const GraphDef& input_graph_def, |
561 | const TransformFuncContext& context, |
562 | GraphDef* output_graph_def) { |
563 | // clang-format off |
564 | const OpTypePattern gather_pattern = |
565 | {"Gather" , |
566 | { |
567 | {"Identity" , |
568 | { |
569 | {"Const|Variable|VariableV2" } |
570 | } |
571 | }, |
572 | {"*" }, |
573 | } |
574 | }; |
575 | const OpTypePattern gather_v2_pattern = |
576 | {"GatherV2" , |
577 | { |
578 | {"Identity" , |
579 | { |
580 | {"Const|Variable|VariableV2" } |
581 | } |
582 | }, |
583 | {"*" }, |
584 | // GatherV2's axis must be constant. |
585 | {"Const" }, |
586 | } |
587 | }; |
588 | // clang-format on |
589 | |
590 | GraphDef cleaned_input_graph_def; |
591 | RemoveAttributes(input_graph_def, {"_output_shapes" }, |
592 | &cleaned_input_graph_def); |
593 | |
594 | GraphDef temp_output; |
595 | |
596 | std::unique_ptr<BundleReader> ckpt_reader; |
597 | TF_RETURN_IF_ERROR(InitializeCheckpointReader(context, &ckpt_reader)); |
598 | |
599 | std::unique_ptr<std::unordered_map<string, string> > shapes_and_slices; |
600 | TF_RETURN_IF_ERROR( |
601 | ObtainVariableInfo(cleaned_input_graph_def, &shapes_and_slices)); |
602 | |
603 | TF_RETURN_IF_ERROR(SparsifyGatherInternal( |
604 | cleaned_input_graph_def, shapes_and_slices, context, gather_pattern, |
605 | ckpt_reader, &temp_output)); |
606 | |
607 | TF_RETURN_IF_ERROR(SparsifyGatherInternal(temp_output, shapes_and_slices, |
608 | context, gather_v2_pattern, |
609 | ckpt_reader, output_graph_def)); |
610 | |
611 | return OkStatus(); |
612 | } |
613 | |
614 | REGISTER_GRAPH_TRANSFORM("sparsify_gather" , SparsifyGather); |
615 | |
616 | } // namespace graph_transforms |
617 | } // namespace tensorflow |
618 | |