1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <algorithm> |
17 | #include <cmath> |
18 | #include <string> |
19 | #include <tuple> |
20 | |
21 | #include "absl/container/btree_set.h" |
22 | #include "absl/container/flat_hash_set.h" |
23 | #include "absl/strings/str_cat.h" |
24 | #include "absl/strings/str_split.h" |
25 | #include "tensorflow/cc/framework/grad_op_registry.h" |
26 | #include "tensorflow/cc/framework/gradients.h" |
27 | #include "tensorflow/cc/gradients/grad_helper.h" |
28 | #include "tensorflow/cc/ops/array_ops_internal.h" |
29 | #include "tensorflow/cc/ops/math_ops_internal.h" |
30 | #include "tensorflow/cc/ops/standard_ops.h" |
31 | |
32 | namespace tensorflow { |
33 | namespace ops { |
34 | namespace { |
35 | |
36 | constexpr absl::string_view kEllipsis = "..." ; |
37 | |
38 | // Returns the axis (possibly negative) corresponding to a label. |
39 | // |
40 | // Returns the axis index of the axis label if it is before an ellipsis (or if |
41 | // the ellipsis is not present), and the negative index if it occurs after the |
42 | // ellipsis. E.g. index of `b` in `ab...cd`, is `1`, but that of `c` is `-2`. |
43 | // |
44 | // For multiple occurrences, returns the leftmost one. If not found, returns |
45 | // absl::nullopt. |
46 | // |
47 | // Parameters: |
48 | // subscripts: A string denoting the einsum subscript (e.g. `ab...cd`) |
49 | // label: The single character axis label. |
50 | absl::optional<int> EinsumGetAxisFromLabel(absl::string_view subscripts, |
51 | char label) { |
52 | std::vector<absl::string_view> splits = absl::StrSplit(subscripts, kEllipsis); |
53 | auto index = splits[0].find(label); |
54 | if (index != splits[0].npos) { |
55 | return index; |
56 | } |
57 | if (splits.size() < 2) { |
58 | return absl::nullopt; |
59 | } |
60 | index = splits[1].find(label); |
61 | if (index != splits[1].npos) { |
62 | return index - splits[1].length(); |
63 | } |
64 | return absl::nullopt; |
65 | } |
66 | |
67 | // Returns a tuple denoting the slice mapping to ellipsis. |
68 | // |
69 | // For a given subscript, returns a tuple (start, end) denoting the start |
70 | // axis index and the (negative) end axis index respectively. For any input |
71 | // Tensor `x` described by the subscript, `x[start:end]` would be the slice |
72 | // represented by the ellipsis. E.g. For `ab...cd` returns `[1, -2]`. |
73 | // |
74 | // If ellipsis is not present in `subscripts`, returns `(0, 0)`. |
75 | // |
76 | // Parameters: |
77 | // subscripts: A string denoting the einsum subscript. |
78 | // start: Output for the start index |
79 | // end: Output for the end index (or nullopt to go to the end). |
80 | std::tuple<int, absl::optional<int>> EinsumGetBcastSubshape( |
81 | absl::string_view subscripts) { |
82 | int start = subscripts.find(kEllipsis); |
83 | if (start == subscripts.npos) { |
84 | return std::make_tuple(0, 0); |
85 | } |
86 | int remaining = subscripts.length() - (start + kEllipsis.length()); |
87 | absl::optional<int> end; |
88 | if (remaining > 0) { |
89 | end = -remaining; |
90 | } else { |
91 | end = absl::nullopt; |
92 | } |
93 | return std::make_tuple(start, end); |
94 | } |
95 | |
96 | // Slices elements of a 1d tensor from [start,end]. |
97 | // If end is nullopt, it goes to the end of the tensor. |
98 | // Supports negative values for end. |
99 | // This attempts to give the same result as tenspr[start:end] would give in |
100 | // Python. |
101 | Output Slice1dHelper(const Scope& scope, Output tensor, int start, |
102 | absl::optional<int> end) { |
103 | if (end.has_value() && *end > 0) { |
104 | return Slice(scope, tensor, Const(scope, start, TensorShape({1})), |
105 | Const(scope, *end - start, TensorShape({1}))); |
106 | } else { |
107 | return Slice(scope, tensor, Const(scope, start, TensorShape({1})), |
108 | Add(scope, Shape(scope, tensor), end.value_or(0) - start)); |
109 | } |
110 | } |
111 | |
112 | // Returns reduced subscripts and their corresponding dimensions and axes. |
113 | // |
114 | // Given a set of axis labels, returns their concatenated subscript, their |
115 | // corresponding dimensions from input_shape, and their corresponding axes. |
116 | // Note that the concatenated subscript `reduced_subs` may have axis labels |
117 | // from `reduced_label_set` in any order. For example, for the reduced label |
118 | // set `{b, d}`, subscripts `aabbcd` and input shape `[2,2,5,5,3,4]`, returns |
119 | // subscripts `bd`, dimensions `[5,4]` and axes `[2,5]`. |
120 | // |
121 | // Args: |
122 | // reduced_label_set: Set of axis labels which appear in `subscripts`. |
123 | // input_shape: A `Tensor` representing the shape of the einsum operand |
124 | // corresponding to `subscripts`. |
125 | // subscripts: A string denoting the einsum subscript. |
126 | // |
127 | // Returns: |
128 | // reduced_subs: Subscripts formed by a concatenation of labels in |
129 | // `reduced_label_set`. |
130 | // reduced_dims: Dimensions from `input_shape` corresponding to each label |
131 | // in `reduced_subs`. |
132 | // reduced_axes: Axes described by `subscripts` corresponding to each label |
133 | // in `reduced_subs`. If there are multiple occurrences in `subscripts`, |
134 | // we consider only the leftmost one. |
135 | std::tuple<std::string, Output, Output> EinsumGetReducedSubscripts( |
136 | const Scope& scope, const absl::btree_set<char>& reduced_label_set, |
137 | Output input_shape, absl::string_view subscripts) { |
138 | // Concatenate the sequence of reduced axis labels. |
139 | const std::string reduced_subs = |
140 | std::string(reduced_label_set.begin(), reduced_label_set.end()); |
141 | // Get the axis (may be positive, negative or zero) for each of the reduced |
142 | // labels. If the same label appears multiple times, get the left-most axis. |
143 | std::vector<int> reduced_axes; |
144 | reduced_axes.reserve(reduced_subs.size()); |
145 | for (const char s : reduced_subs) { |
146 | auto axis = EinsumGetAxisFromLabel(subscripts, s); |
147 | if (!axis.has_value()) { |
148 | // Should never happen. |
149 | scope.UpdateStatus(errors::Internal( |
150 | absl::StrCat("Missing axis" , absl::string_view(&s, 1)))); |
151 | } else { |
152 | reduced_axes.push_back(*axis); |
153 | } |
154 | } |
155 | // Get the corresponding dimensions for each reduced axis. |
156 | std::vector<Output> reduced_dims_inputs; |
157 | reduced_dims_inputs.reserve(reduced_axes.size()); |
158 | for (const int i : reduced_axes) { |
159 | if (i < 0) { |
160 | reduced_dims_inputs.push_back( |
161 | Gather(scope, input_shape, Add(scope, Size(scope, input_shape), i))); |
162 | } else { |
163 | reduced_dims_inputs.push_back(Gather(scope, input_shape, i)); |
164 | } |
165 | } |
166 | const Output reduced_dims = Stack(scope, reduced_dims_inputs); |
167 | Tensor reduced_axes_tensor( |
168 | DataType::DT_INT32, TensorShape({static_cast<int>(reduced_axes.size())})); |
169 | std::copy_n(reduced_axes.begin(), reduced_axes.size(), |
170 | reduced_axes_tensor.flat<int>().data()); |
171 | return std::make_tuple(reduced_subs, reduced_dims, |
172 | Const(scope, reduced_axes_tensor)); |
173 | } |
174 | |
175 | // Returns the gradient wrt input for a unary einsum with reductions. |
176 | // |
177 | // scope: Scope for grad operations. |
178 | // output_grad: The gradient wrt the output of a unary einsum operation. |
179 | // output_subs: The output subscript. (E.g. `ac` for equation `abc->ac`). |
180 | // input_subs: The input subscript. (E.g. `abc` for equation `abc->ac`). |
181 | // input_shape: The shape of the input operand. |
182 | // reduced_label_set: The set of axis labels appearing in `input_subs` but |
183 | // not in `output_subs`. |
184 | Output EinsumGradReducedHelper(const Scope& scope, const Output& output_grad, |
185 | absl::string_view output_subs, |
186 | absl::string_view input_subs, |
187 | const Output& input_shape, |
188 | const absl::btree_set<char>& reduced_label_set) { |
189 | // Let's say the einsum operation was "aabbcd->ca", where axis labels 'b' and |
190 | // 'd' are reduced with input_shape [2,2,5,5,3,4]. Then obtain the reduced |
191 | // subscripts "bd", corresponding dimensions [5,4] and axes [2,5]. |
192 | std::string reduced_subs; |
193 | Output reduced_dims, reduced_axes; |
194 | std::tie(reduced_subs, reduced_dims, reduced_axes) = |
195 | EinsumGetReducedSubscripts(scope, reduced_label_set, input_shape, |
196 | input_subs); |
197 | // Whether either the input or the output subscripts have a repeated label. |
198 | // This is true for "aabbcd->ca" or "abd->cca" but false for "abcd->ca". |
199 | const int distinct_input_labels = |
200 | absl::flat_hash_set<char>(input_subs.begin(), input_subs.end()).size(); |
201 | const int distinct_output_labels = |
202 | absl::flat_hash_set<char>(output_subs.begin(), output_subs.end()).size(); |
203 | const bool has_repeated_labels = |
204 | (distinct_input_labels + distinct_output_labels) < |
205 | input_subs.length() + output_subs.length(); |
206 | // Compute the input subscripts without the reduced axis labels, e.g. "aac" |
207 | // for the equation "aabbcd->ca". |
208 | std::string input_subs_without_reduced_labels; |
209 | for (const char s : input_subs) { |
210 | if (!absl::c_linear_search(reduced_label_set, s)) { |
211 | input_subs_without_reduced_labels.push_back(s); |
212 | } |
213 | } |
214 | |
215 | // The gradient wrt the input for the equation "abc->ac" (or, equivalently |
216 | // reduce_sum(..., axis=1)) is just the gradient of the output tiled N times |
217 | // along axis 1, where label 'b' represents a dimension of size N. |
218 | // |
219 | // If we're not dealing with repeated labels, and the non-reduced labels |
220 | // doesn't need to be transposed, then just tiling is enough and there is no |
221 | // need to call another einsum. For example, tiling is sufficient for |
222 | // "abcd->ac". But for equations like "aabbcd->ac" (generalized traces) or |
223 | // "abc->ca" (transpose), we'd need another einsum operation after tiling. |
224 | if (!has_repeated_labels && |
225 | input_subs_without_reduced_labels == output_subs) { |
226 | // Obtain the shape of the output, as if keepdims=True on reduce sum. E.g. |
227 | // for the equation "abcd->ac" with input shape [2,5,3,4], we get the |
228 | // reduced shape [2,1,3,1]. |
229 | auto reduced_shape = ReducedShapeHelper(scope, input_shape, reduced_axes); |
230 | // Reshaping the gradient (wrt "ac") to [2,1,3,1] and broadcasting it to |
231 | // the shape [2,5,3,4] results in the gradient wrt "abcd". |
232 | return BroadcastTo(scope, Reshape(scope, output_grad, reduced_shape), |
233 | input_shape); |
234 | } |
235 | |
236 | // If we *do* have traces or transpose operations, then prepend the extra |
237 | // reduced dimensions to the front. E.g. Given the equation "aabbcd->ca" we'd |
238 | // first obtain the VJP for "bdca->ca", and then the VJP for "aabbcd->bdca". |
239 | // |
240 | // Obtain the input shape with reduced dimensions prepended, viz. [5,4,3,2]. |
241 | // This is the shape of the intermediate "bdca". |
242 | Output output_grad_shape = Shape(scope, output_grad); |
243 | auto grad_shape_with_reduced_labels = |
244 | Concat(scope, {reduced_dims, output_grad_shape}, /*axis=*/0); |
245 | |
246 | // Obtain the output shape of the reduction-only equation "bdca->ca" as if |
247 | // keepdims=True; viz. [1,1,3,2]. Since we prepended the reduced labels, |
248 | // we just have to prepend that many 1s to the output shape. |
249 | |
250 | auto reduced_shape = Concat( |
251 | scope, |
252 | {Const(scope, 1, TensorShape{static_cast<int>(reduced_label_set.size())}), |
253 | output_grad_shape}, |
254 | /*axis=*/0); |
255 | // Compute the VJP for the intermediate (viz. "bdca->ca") for which |
256 | // broadcasting is sufficient. |
257 | Output broadcasted_grad = |
258 | BroadcastTo(scope, Reshape(scope, output_grad, reduced_shape), |
259 | grad_shape_with_reduced_labels); |
260 | // Compute the VJP for the final step (viz. "aabbcd->bdca"). We can |
261 | // use einsum with the input and output subscripts reversed (viz. |
262 | // "bdca->aabbcd") since the output axis labels now appear in the |
263 | // input subscripts. |
264 | return Einsum(scope, {broadcasted_grad}, |
265 | absl::StrCat(reduced_subs, output_subs, "->" , input_subs)); |
266 | } |
267 | |
268 | // Returns the gradient wrt an input operand for a binary einsum. |
269 | // |
270 | // This function does not handle (un)broadcasting. This must be done separately |
271 | // on the returned gradient. |
272 | // |
273 | // Args: |
274 | // output_grad: The gradient wrt the output of a binary einsum operation. |
275 | // other_operand: The complementary `Tensor` operand i.e. which is not the |
276 | // input operand. |
277 | // input_shape: A `Tensor` representing the shape of input operand. |
278 | // input_subs: The subscripts of the input operand. |
279 | // other_subs: The subscripts of the complementary operand. |
280 | // output_subs: The output subscripts. |
281 | Output EinsumGradWrt(const Scope& scope, Output output_grad, |
282 | Output other_operand, Output input_shape, |
283 | absl::string_view input_subs, absl::string_view other_subs, |
284 | absl::string_view output_subs) { |
285 | // Claim: For the einsum operation z = einsum("{eq_x},{eq_y}->{eq_z}", x, y), |
286 | // where the equation involves only Tensor contractions, generalized traces |
287 | // and transposes, the input gradients are given by the vector-jacobian |
288 | // products (VJPs): |
289 | // |
290 | // grad_wrt_x = einsum("{eq_y},{eq_z}->{eq_x}", y, grad_wrt_z) |
291 | // grad_wrt_y = einsum("{eq_x},{eq_z}->{eq_y}", x, grad_wrt_z} |
292 | // |
293 | // where grad_wrt_x and grad_wrt_y are the gradients with respect to inputs |
294 | // x and y and grad_wrt_z is the given gradient with respect to output z. |
295 | // |
296 | // Proof: For unary einsum equations involving only transpose ("ij->ji") and |
297 | // traces ("ii->i"), the linear mapping's Jacobian at input x is given |
298 | // by the function itself. We can verify that the linear map given by the |
299 | // VJP are einsums with the equations "ji->ij" and "i->ii" respectively, |
300 | // where the latter represents 'un-tracing', or filling the diagonal with |
301 | // the input axis and non-diagonal entries are zeros. |
302 | // Furthermore, recall that matrix multiplication, which is |
303 | // represented by the equation "ab,bc->ac", has its VJPs given by the |
304 | // einsum equations "ac,bc->ab" and "ab,ac->bc" (see, for example |
305 | // https://math.stackexchange.com/a/2755680). Combined with transposes and |
306 | // traces we can rewrite Tensor contractions as regular matrix |
307 | // multiplication. Since each of these operations have their VJPs described |
308 | // by einsums of the required pattern, the result follows. |
309 | // |
310 | // Accordingly, einsum operations except for those with reductions, e.g. |
311 | // "abc,cd->ad" have their VJPs defined by: |
312 | // "{output_subs},{other_subs}->{input_subs}". |
313 | // |
314 | // But if there is a reduction, this would lead to the equation "ad,cd->abc" |
315 | // which is invalid because the reduced axis label 'b' is present in the |
316 | // output but not in any of the inputs. Therefore, we compute the VJP in two |
317 | // steps: first we obtain VJP for "ac,cd->ad" and then we compute the VJP of |
318 | // "abc->ac" or, equivalently, reduce_sum(..., axis=1). |
319 | // |
320 | // Compute the set of input axis labels which doesn't appear in either the |
321 | // output subscripts or the other operand's subscript. E.g. the set {'b'} for |
322 | // the equation "abc,cd->ad". |
323 | absl::btree_set<char> reduced_label_set(input_subs.begin(), input_subs.end()); |
324 | for (const char x : output_subs) { |
325 | reduced_label_set.erase(x); |
326 | } |
327 | for (const char x : other_subs) { |
328 | reduced_label_set.erase(x); |
329 | } |
330 | reduced_label_set.erase('.'); |
331 | |
332 | // Obtain the input subscripts with the reduced axis labels removed. E.g. |
333 | // "ac" in the above example. |
334 | std::string left_subs; |
335 | for (const char s : input_subs) { |
336 | if (!reduced_label_set.contains(s)) { |
337 | left_subs.push_back(s); |
338 | } |
339 | } |
340 | |
341 | // Compute the gradient wrt the input, without accounting for the operation |
342 | // "abc->ac". So, now we have the VJP of the operation "ac,cd->ad". |
343 | Output grad_reduced = |
344 | Einsum(scope, {output_grad, other_operand}, |
345 | absl::StrCat(output_subs, "," , other_subs, "->" , left_subs)); |
346 | |
347 | // If the reduced_label_set is empty, then we already have the gradient |
348 | // wrt the input. |
349 | if (reduced_label_set.empty()) { |
350 | return grad_reduced; |
351 | } |
352 | // Otherwise, we currently have the gradient wrt the output of the reduction |
353 | // operation "abc->ac". Invoke the subroutine for the gradient for unary |
354 | // einsum with reductions. |
355 | return EinsumGradReducedHelper(scope, grad_reduced, left_subs, input_subs, |
356 | input_shape, reduced_label_set); |
357 | } |
358 | |
359 | Status EinsumGrad(const Scope& scope, const Operation& op, |
360 | const std::vector<Output>& grad_inputs, |
361 | std::vector<Output>* grad_outputs) { |
362 | if (grad_inputs.size() != 1) { |
363 | return errors::InvalidArgument("Expect 1 grad input." ); |
364 | } |
365 | const Output& grad = grad_inputs[0]; |
366 | |
367 | std::string equation; |
368 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "equation" , &equation)); |
369 | std::vector<absl::string_view> equation_split = |
370 | absl::StrSplit(equation, "->" ); |
371 | if (equation_split.size() != 2) { |
372 | return errors::InvalidArgument("Equation must contain a single ->" ); |
373 | } |
374 | |
375 | const absl::string_view input_subs = equation_split[0]; |
376 | const absl::string_view output_subs = equation_split[1]; |
377 | if (op.num_inputs() == 1) { |
378 | // For the unary einsum z = einsum("{eq_x}->{eq_z}", x), the gradient wrt |
379 | // the input (VJP) is given by the reversed equation: |
380 | // grad_wrt_x = einsum("{eq_z}->{eq_x}", grad_wrt_z) |
381 | // (See the justification in _GetGradWrt). This is valid unless there are |
382 | // reduced axis labels; i.e. axis labels appearing in the input but not in |
383 | // the output subscripts. |
384 | auto input_shape = Shape(scope, op.input(0)); |
385 | // Find the axis labels which appear only in the input. |
386 | absl::btree_set<char> reduced_label_set(input_subs.begin(), |
387 | input_subs.end()); |
388 | for (const char x : output_subs) { |
389 | reduced_label_set.erase(x); |
390 | } |
391 | reduced_label_set.erase('.'); |
392 | if (reduced_label_set.empty()) { |
393 | grad_outputs->push_back(Einsum( |
394 | scope, grad_inputs, absl::StrCat(output_subs, "->" , input_subs))); |
395 | return scope.status(); |
396 | } |
397 | // We do have reduced axes, so we invoke the subroutine for reduced unary |
398 | // einsums. |
399 | grad_outputs->push_back(EinsumGradReducedHelper( |
400 | scope, grad, output_subs, input_subs, input_shape, reduced_label_set)); |
401 | return scope.status(); |
402 | } |
403 | |
404 | std::vector<absl::string_view> subs = absl::StrSplit(input_subs, ','); |
405 | if (subs.size() != 2) { |
406 | return errors::InvalidArgument("Only 2 inputs are supported" ); |
407 | } |
408 | std::string x_subs(subs[0]); |
409 | std::string y_subs(subs[1]); |
410 | // Add ellipsis for broadcasted dimensions if any operand does not have it. |
411 | // This is because the equation "...ij,jk->ik" may be valid if the 0th input's |
412 | // batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid |
413 | // because only the output subscripts contain ellipsis. |
414 | if (absl::StrContains(output_subs, kEllipsis)) { |
415 | if (!absl::StrContains(x_subs, kEllipsis)) { |
416 | absl::StrAppend(&x_subs, kEllipsis); |
417 | } |
418 | if (!absl::StrContains(y_subs, kEllipsis)) { |
419 | absl::StrAppend(&y_subs, kEllipsis); |
420 | } |
421 | } |
422 | |
423 | // Obtain the gradients wrt the inputs x and y, without taking into account |
424 | // the unbroadcasting. |
425 | tensorflow::Output x = op.input(0); |
426 | tensorflow::Output y = op.input(1); |
427 | if (DataTypeIsComplex(grad.type())) { |
428 | x = Conj(scope, x); |
429 | y = Conj(scope, y); |
430 | } |
431 | |
432 | const auto x_shape = Shape(scope, x); |
433 | const auto y_shape = Shape(scope, y); |
434 | Output grad_x = |
435 | EinsumGradWrt(scope, grad, y, x_shape, x_subs, y_subs, output_subs); |
436 | Output grad_y = |
437 | EinsumGradWrt(scope, grad, x, y_shape, y_subs, x_subs, output_subs); |
438 | |
439 | if (!absl::StrContains(output_subs, kEllipsis)) { |
440 | // If no ellipsis in the output; then no need to unbroadcast. |
441 | grad_outputs->push_back(grad_x); |
442 | grad_outputs->push_back(grad_y); |
443 | return scope.status(); |
444 | } |
445 | |
446 | // Below we handle the case that broadcasting between x and y was necessary, |
447 | // with x and y having possibly different batch shapes. |
448 | |
449 | // Obtain the range of axes which map to ellipsis. E.g. for subscripts |
450 | // 'ab...c' and shape of rank 10; the range [3:-1] denotes the broadcasted |
451 | // axes. |
452 | int bx_start, by_start; |
453 | absl::optional<int> bx_end, by_end; |
454 | std::tie(bx_start, bx_end) = EinsumGetBcastSubshape(x_subs); |
455 | std::tie(by_start, by_end) = EinsumGetBcastSubshape(y_subs); |
456 | |
457 | // Sum the gradient across the broadcasted axes. |
458 | auto args = internal::BroadcastGradientArgs( |
459 | scope, Slice1dHelper(scope, x_shape, bx_start, bx_end), |
460 | Slice1dHelper(scope, y_shape, by_start, by_end)); |
461 | grad_x = Reshape( |
462 | scope, ReduceSum(scope, grad_x, Add(scope, bx_start, args.r0)), x_shape); |
463 | grad_y = Reshape( |
464 | scope, ReduceSum(scope, grad_y, Add(scope, by_start, args.r1)), y_shape); |
465 | grad_outputs->push_back(grad_x); |
466 | grad_outputs->push_back(grad_y); |
467 | return scope.status(); |
468 | } |
469 | |
470 | REGISTER_GRADIENT_OP("Einsum" , EinsumGrad); |
471 | |
472 | } // namespace |
473 | } // namespace ops |
474 | } // namespace tensorflow |
475 | |