1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/cc/ops/data_flow_ops.h" |
17 | #include "tensorflow/cc/ops/data_flow_ops_internal.h" |
18 | #include "tensorflow/cc/ops/standard_ops.h" |
19 | |
20 | #include "tensorflow/cc/framework/grad_op_registry.h" |
21 | #include "tensorflow/cc/framework/gradients.h" |
22 | |
23 | namespace tensorflow { |
24 | namespace ops { |
25 | namespace { |
26 | |
27 | REGISTER_NO_GRADIENT_OP("Queue" ); |
28 | REGISTER_NO_GRADIENT_OP("QueueEnqueue" ); |
29 | REGISTER_NO_GRADIENT_OP("QueueEnqueueMany" ); |
30 | REGISTER_NO_GRADIENT_OP("QueueDequeue" ); |
31 | REGISTER_NO_GRADIENT_OP("QueueDequeueMany" ); |
32 | REGISTER_NO_GRADIENT_OP("QueueDequeueUpTo" ); |
33 | REGISTER_NO_GRADIENT_OP("QueueClose" ); |
34 | REGISTER_NO_GRADIENT_OP("QueueSize" ); |
35 | REGISTER_NO_GRADIENT_OP("Stack" ); |
36 | REGISTER_NO_GRADIENT_OP("StackPush" ); |
37 | REGISTER_NO_GRADIENT_OP("StackPop" ); |
38 | REGISTER_NO_GRADIENT_OP("StackClose" ); |
39 | REGISTER_NO_GRADIENT_OP("GetSessionHandle" ); |
40 | REGISTER_NO_GRADIENT_OP("GetSessionHandleV2" ); |
41 | REGISTER_NO_GRADIENT_OP("GetSessionTensor" ); |
42 | REGISTER_NO_GRADIENT_OP("DeleteSessionTensor" ); |
43 | |
44 | Status DynamicPartitionGrad(const Scope& scope, const Operation& op, |
45 | const std::vector<Output>& grad_inputs, |
46 | std::vector<Output>* grad_outputs) { |
47 | // DynamicPartition only moves input values into various positions |
48 | // in the output, so the gradient operation only has to map incoming |
49 | // gradients into their input source locations. |
50 | // running example: |
51 | // data = [10, 20, 30, 40, 50] |
52 | // partitions = [0, 0, 1, 1, 0] |
53 | // num_partitions = 2 |
54 | // dynamic_partition(data, partitions, num_partitions) = { |
55 | // [10, 20, 50], |
56 | // [30, 40] |
57 | // } |
58 | // grads = { |
59 | // [g1, g2, g3], |
60 | // [g4, g5] |
61 | // } |
62 | // The desired propagation of the gradients back to the data inputs is: |
63 | // [g1, g2, g4, g5, g3] |
64 | auto data = op.input(0); |
65 | auto partitions = op.input(1); |
66 | int32_t num_partitions; |
67 | TF_RETURN_IF_ERROR( |
68 | GetNodeAttr(op.node()->attrs(), "num_partitions" , &num_partitions)); |
69 | |
70 | // Note: the shape of the partitions is a prefix of the data shape. |
71 | // shape(partitions) = [5] |
72 | auto partitions_shape = Shape(scope, partitions); |
73 | // We now create a partitions-shaped tensor with integers from |
74 | // [0..size(partitions)) This will be dynamic_partitioned with the |
75 | // input parameters, providing the destination index for a given |
76 | // source item. |
77 | // partitions_size = prod([5]) = 5 |
78 | // reshape(range(partitions_size), [5]) = [0, 1, 2, 3, 4] |
79 | auto zero = Const(scope, 0); |
80 | auto one = Const(scope, 1); |
81 | auto original_indices = Reshape( |
82 | scope, Range(scope, zero, Prod(scope, partitions_shape, zero), one), |
83 | partitions_shape); |
84 | // dynamic_partition( |
85 | // [0, 1, 2, 3, 4], |
86 | // [0, 0, 1, 1, 0], 2) |
87 | // = { [0, 1, 4], |
88 | // [2, 3] } |
89 | auto partitioned_indices = |
90 | DynamicPartition(scope, original_indices, partitions, num_partitions); |
91 | |
92 | // Invert these indices with dynamic_stitch to map the incoming |
93 | // gradients to their source inputs. |
94 | // dynamic_stitch( |
95 | // { [0, 1, 4], [2, 3] }, |
96 | // { [g1, g2, g3], [g4, g5] }) |
97 | // = [g1, g2, g4, g5, g3] |
98 | auto reconstructed = |
99 | DynamicStitch(scope, partitioned_indices.outputs, grad_inputs); |
100 | // reshape back into a data-shaped tensor to propagate gradients for the data |
101 | // input. |
102 | grad_outputs->push_back(Reshape(scope, reconstructed, Shape(scope, data))); |
103 | // Stop propagation along the partitions input |
104 | grad_outputs->push_back(NoGradient()); |
105 | return scope.status(); |
106 | } |
107 | REGISTER_GRADIENT_OP("DynamicPartition" , DynamicPartitionGrad); |
108 | |
109 | Status DynamicStitchGrad(const Scope& scope, const Operation& op, |
110 | const std::vector<Output>& grad_inputs, |
111 | std::vector<Output>* grad_outputs) { |
112 | // Running example: |
113 | // indices = {2, [1, 0]} |
114 | // data = {[d_1, d_2], [[d_3, d_4], [d_5, d_6]]} |
115 | // out = [[d_5, d_6], [d_3, d_4], [d_1, d_2]] |
116 | // grad = [[g_1, g_2], [g_3, g_4], [g_5, g_6]] |
117 | |
118 | // indices and data are two equal-sized lists passed |
119 | // into DynamicStitch. |
120 | // num_values = 2 |
121 | int32_t num_values = op.num_inputs() / 2; |
122 | |
123 | // Stop propagation along the indices list |
124 | for (int32_t i = 0; i < num_values; i++) { |
125 | grad_outputs->push_back(NoGradient()); |
126 | } |
127 | |
128 | // DynamicStitch shuffles its data to the output (using items in |
129 | // indices) so the gradient propagated to a given data input simply |
130 | // selects the gradient for its output position. |
131 | for (int32_t i = 0; i < num_values; i++) { |
132 | // index has the destination positions for the i'th data |
133 | // element. We cast it into an int32 if necessary, so we can use |
134 | // it from a Gather op. |
135 | // i = 0: index = 2 |
136 | // i = 1: index = [1, 0] |
137 | auto index = op.input(i); |
138 | if (index.type() != DT_INT32) { |
139 | index = Cast(scope, index, DT_INT32); |
140 | } |
141 | // Gather the index specified locations in the gradient and |
142 | // propagate it as the gradient for the i'th data item. |
143 | // i = 0: gather(grad, 2) = [g_5, g_6] |
144 | // i = 1: gather(grad, [1, 0]) = [[g_3, g_4], [g_1, g_2]] |
145 | grad_outputs->push_back(Gather(scope, grad_inputs[0], index)); |
146 | } |
147 | |
148 | return scope.status(); |
149 | } |
150 | REGISTER_GRADIENT_OP("DynamicStitch" , DynamicStitchGrad); |
151 | REGISTER_GRADIENT_OP("ParallelDynamicStitch" , DynamicStitchGrad); |
152 | |
153 | } // anonymous namespace |
154 | } // namespace ops |
155 | } // namespace tensorflow |
156 | |