1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <algorithm>
17#include <cmath>
18#include <string>
19#include <tuple>
20
21#include "absl/container/btree_set.h"
22#include "absl/container/flat_hash_set.h"
23#include "absl/strings/str_cat.h"
24#include "absl/strings/str_split.h"
25#include "tensorflow/cc/framework/grad_op_registry.h"
26#include "tensorflow/cc/framework/gradients.h"
27#include "tensorflow/cc/gradients/grad_helper.h"
28#include "tensorflow/cc/ops/array_ops_internal.h"
29#include "tensorflow/cc/ops/math_ops_internal.h"
30#include "tensorflow/cc/ops/standard_ops.h"
31
32namespace tensorflow {
33namespace ops {
34namespace {
35
36constexpr absl::string_view kEllipsis = "...";
37
38// Returns the axis (possibly negative) corresponding to a label.
39//
40// Returns the axis index of the axis label if it is before an ellipsis (or if
41// the ellipsis is not present), and the negative index if it occurs after the
42// ellipsis. E.g. index of `b` in `ab...cd`, is `1`, but that of `c` is `-2`.
43//
44// For multiple occurrences, returns the leftmost one. If not found, returns
45// absl::nullopt.
46//
47// Parameters:
48// subscripts: A string denoting the einsum subscript (e.g. `ab...cd`)
49// label: The single character axis label.
50absl::optional<int> EinsumGetAxisFromLabel(absl::string_view subscripts,
51 char label) {
52 std::vector<absl::string_view> splits = absl::StrSplit(subscripts, kEllipsis);
53 auto index = splits[0].find(label);
54 if (index != splits[0].npos) {
55 return index;
56 }
57 if (splits.size() < 2) {
58 return absl::nullopt;
59 }
60 index = splits[1].find(label);
61 if (index != splits[1].npos) {
62 return index - splits[1].length();
63 }
64 return absl::nullopt;
65}
66
67// Returns a tuple denoting the slice mapping to ellipsis.
68//
69// For a given subscript, returns a tuple (start, end) denoting the start
70// axis index and the (negative) end axis index respectively. For any input
71// Tensor `x` described by the subscript, `x[start:end]` would be the slice
72// represented by the ellipsis. E.g. For `ab...cd` returns `[1, -2]`.
73//
74// If ellipsis is not present in `subscripts`, returns `(0, 0)`.
75//
76// Parameters:
77// subscripts: A string denoting the einsum subscript.
78// start: Output for the start index
79// end: Output for the end index (or nullopt to go to the end).
80std::tuple<int, absl::optional<int>> EinsumGetBcastSubshape(
81 absl::string_view subscripts) {
82 int start = subscripts.find(kEllipsis);
83 if (start == subscripts.npos) {
84 return std::make_tuple(0, 0);
85 }
86 int remaining = subscripts.length() - (start + kEllipsis.length());
87 absl::optional<int> end;
88 if (remaining > 0) {
89 end = -remaining;
90 } else {
91 end = absl::nullopt;
92 }
93 return std::make_tuple(start, end);
94}
95
96// Slices elements of a 1d tensor from [start,end].
97// If end is nullopt, it goes to the end of the tensor.
98// Supports negative values for end.
99// This attempts to give the same result as tenspr[start:end] would give in
100// Python.
101Output Slice1dHelper(const Scope& scope, Output tensor, int start,
102 absl::optional<int> end) {
103 if (end.has_value() && *end > 0) {
104 return Slice(scope, tensor, Const(scope, start, TensorShape({1})),
105 Const(scope, *end - start, TensorShape({1})));
106 } else {
107 return Slice(scope, tensor, Const(scope, start, TensorShape({1})),
108 Add(scope, Shape(scope, tensor), end.value_or(0) - start));
109 }
110}
111
112// Returns reduced subscripts and their corresponding dimensions and axes.
113//
114// Given a set of axis labels, returns their concatenated subscript, their
115// corresponding dimensions from input_shape, and their corresponding axes.
116// Note that the concatenated subscript `reduced_subs` may have axis labels
117// from `reduced_label_set` in any order. For example, for the reduced label
118// set `{b, d}`, subscripts `aabbcd` and input shape `[2,2,5,5,3,4]`, returns
119// subscripts `bd`, dimensions `[5,4]` and axes `[2,5]`.
120//
121// Args:
122// reduced_label_set: Set of axis labels which appear in `subscripts`.
123// input_shape: A `Tensor` representing the shape of the einsum operand
124// corresponding to `subscripts`.
125// subscripts: A string denoting the einsum subscript.
126//
127// Returns:
128// reduced_subs: Subscripts formed by a concatenation of labels in
129// `reduced_label_set`.
130// reduced_dims: Dimensions from `input_shape` corresponding to each label
131// in `reduced_subs`.
132// reduced_axes: Axes described by `subscripts` corresponding to each label
133// in `reduced_subs`. If there are multiple occurrences in `subscripts`,
134// we consider only the leftmost one.
135std::tuple<std::string, Output, Output> EinsumGetReducedSubscripts(
136 const Scope& scope, const absl::btree_set<char>& reduced_label_set,
137 Output input_shape, absl::string_view subscripts) {
138 // Concatenate the sequence of reduced axis labels.
139 const std::string reduced_subs =
140 std::string(reduced_label_set.begin(), reduced_label_set.end());
141 // Get the axis (may be positive, negative or zero) for each of the reduced
142 // labels. If the same label appears multiple times, get the left-most axis.
143 std::vector<int> reduced_axes;
144 reduced_axes.reserve(reduced_subs.size());
145 for (const char s : reduced_subs) {
146 auto axis = EinsumGetAxisFromLabel(subscripts, s);
147 if (!axis.has_value()) {
148 // Should never happen.
149 scope.UpdateStatus(errors::Internal(
150 absl::StrCat("Missing axis", absl::string_view(&s, 1))));
151 } else {
152 reduced_axes.push_back(*axis);
153 }
154 }
155 // Get the corresponding dimensions for each reduced axis.
156 std::vector<Output> reduced_dims_inputs;
157 reduced_dims_inputs.reserve(reduced_axes.size());
158 for (const int i : reduced_axes) {
159 if (i < 0) {
160 reduced_dims_inputs.push_back(
161 Gather(scope, input_shape, Add(scope, Size(scope, input_shape), i)));
162 } else {
163 reduced_dims_inputs.push_back(Gather(scope, input_shape, i));
164 }
165 }
166 const Output reduced_dims = Stack(scope, reduced_dims_inputs);
167 Tensor reduced_axes_tensor(
168 DataType::DT_INT32, TensorShape({static_cast<int>(reduced_axes.size())}));
169 std::copy_n(reduced_axes.begin(), reduced_axes.size(),
170 reduced_axes_tensor.flat<int>().data());
171 return std::make_tuple(reduced_subs, reduced_dims,
172 Const(scope, reduced_axes_tensor));
173}
174
175// Returns the gradient wrt input for a unary einsum with reductions.
176//
177// scope: Scope for grad operations.
178// output_grad: The gradient wrt the output of a unary einsum operation.
179// output_subs: The output subscript. (E.g. `ac` for equation `abc->ac`).
180// input_subs: The input subscript. (E.g. `abc` for equation `abc->ac`).
181// input_shape: The shape of the input operand.
182// reduced_label_set: The set of axis labels appearing in `input_subs` but
183// not in `output_subs`.
184Output EinsumGradReducedHelper(const Scope& scope, const Output& output_grad,
185 absl::string_view output_subs,
186 absl::string_view input_subs,
187 const Output& input_shape,
188 const absl::btree_set<char>& reduced_label_set) {
189 // Let's say the einsum operation was "aabbcd->ca", where axis labels 'b' and
190 // 'd' are reduced with input_shape [2,2,5,5,3,4]. Then obtain the reduced
191 // subscripts "bd", corresponding dimensions [5,4] and axes [2,5].
192 std::string reduced_subs;
193 Output reduced_dims, reduced_axes;
194 std::tie(reduced_subs, reduced_dims, reduced_axes) =
195 EinsumGetReducedSubscripts(scope, reduced_label_set, input_shape,
196 input_subs);
197 // Whether either the input or the output subscripts have a repeated label.
198 // This is true for "aabbcd->ca" or "abd->cca" but false for "abcd->ca".
199 const int distinct_input_labels =
200 absl::flat_hash_set<char>(input_subs.begin(), input_subs.end()).size();
201 const int distinct_output_labels =
202 absl::flat_hash_set<char>(output_subs.begin(), output_subs.end()).size();
203 const bool has_repeated_labels =
204 (distinct_input_labels + distinct_output_labels) <
205 input_subs.length() + output_subs.length();
206 // Compute the input subscripts without the reduced axis labels, e.g. "aac"
207 // for the equation "aabbcd->ca".
208 std::string input_subs_without_reduced_labels;
209 for (const char s : input_subs) {
210 if (!absl::c_linear_search(reduced_label_set, s)) {
211 input_subs_without_reduced_labels.push_back(s);
212 }
213 }
214
215 // The gradient wrt the input for the equation "abc->ac" (or, equivalently
216 // reduce_sum(..., axis=1)) is just the gradient of the output tiled N times
217 // along axis 1, where label 'b' represents a dimension of size N.
218 //
219 // If we're not dealing with repeated labels, and the non-reduced labels
220 // doesn't need to be transposed, then just tiling is enough and there is no
221 // need to call another einsum. For example, tiling is sufficient for
222 // "abcd->ac". But for equations like "aabbcd->ac" (generalized traces) or
223 // "abc->ca" (transpose), we'd need another einsum operation after tiling.
224 if (!has_repeated_labels &&
225 input_subs_without_reduced_labels == output_subs) {
226 // Obtain the shape of the output, as if keepdims=True on reduce sum. E.g.
227 // for the equation "abcd->ac" with input shape [2,5,3,4], we get the
228 // reduced shape [2,1,3,1].
229 auto reduced_shape = ReducedShapeHelper(scope, input_shape, reduced_axes);
230 // Reshaping the gradient (wrt "ac") to [2,1,3,1] and broadcasting it to
231 // the shape [2,5,3,4] results in the gradient wrt "abcd".
232 return BroadcastTo(scope, Reshape(scope, output_grad, reduced_shape),
233 input_shape);
234 }
235
236 // If we *do* have traces or transpose operations, then prepend the extra
237 // reduced dimensions to the front. E.g. Given the equation "aabbcd->ca" we'd
238 // first obtain the VJP for "bdca->ca", and then the VJP for "aabbcd->bdca".
239 //
240 // Obtain the input shape with reduced dimensions prepended, viz. [5,4,3,2].
241 // This is the shape of the intermediate "bdca".
242 Output output_grad_shape = Shape(scope, output_grad);
243 auto grad_shape_with_reduced_labels =
244 Concat(scope, {reduced_dims, output_grad_shape}, /*axis=*/0);
245
246 // Obtain the output shape of the reduction-only equation "bdca->ca" as if
247 // keepdims=True; viz. [1,1,3,2]. Since we prepended the reduced labels,
248 // we just have to prepend that many 1s to the output shape.
249
250 auto reduced_shape = Concat(
251 scope,
252 {Const(scope, 1, TensorShape{static_cast<int>(reduced_label_set.size())}),
253 output_grad_shape},
254 /*axis=*/0);
255 // Compute the VJP for the intermediate (viz. "bdca->ca") for which
256 // broadcasting is sufficient.
257 Output broadcasted_grad =
258 BroadcastTo(scope, Reshape(scope, output_grad, reduced_shape),
259 grad_shape_with_reduced_labels);
260 // Compute the VJP for the final step (viz. "aabbcd->bdca"). We can
261 // use einsum with the input and output subscripts reversed (viz.
262 // "bdca->aabbcd") since the output axis labels now appear in the
263 // input subscripts.
264 return Einsum(scope, {broadcasted_grad},
265 absl::StrCat(reduced_subs, output_subs, "->", input_subs));
266}
267
268// Returns the gradient wrt an input operand for a binary einsum.
269//
270// This function does not handle (un)broadcasting. This must be done separately
271// on the returned gradient.
272//
273// Args:
274// output_grad: The gradient wrt the output of a binary einsum operation.
275// other_operand: The complementary `Tensor` operand i.e. which is not the
276// input operand.
277// input_shape: A `Tensor` representing the shape of input operand.
278// input_subs: The subscripts of the input operand.
279// other_subs: The subscripts of the complementary operand.
280// output_subs: The output subscripts.
281Output EinsumGradWrt(const Scope& scope, Output output_grad,
282 Output other_operand, Output input_shape,
283 absl::string_view input_subs, absl::string_view other_subs,
284 absl::string_view output_subs) {
285 // Claim: For the einsum operation z = einsum("{eq_x},{eq_y}->{eq_z}", x, y),
286 // where the equation involves only Tensor contractions, generalized traces
287 // and transposes, the input gradients are given by the vector-jacobian
288 // products (VJPs):
289 //
290 // grad_wrt_x = einsum("{eq_y},{eq_z}->{eq_x}", y, grad_wrt_z)
291 // grad_wrt_y = einsum("{eq_x},{eq_z}->{eq_y}", x, grad_wrt_z}
292 //
293 // where grad_wrt_x and grad_wrt_y are the gradients with respect to inputs
294 // x and y and grad_wrt_z is the given gradient with respect to output z.
295 //
296 // Proof: For unary einsum equations involving only transpose ("ij->ji") and
297 // traces ("ii->i"), the linear mapping's Jacobian at input x is given
298 // by the function itself. We can verify that the linear map given by the
299 // VJP are einsums with the equations "ji->ij" and "i->ii" respectively,
300 // where the latter represents 'un-tracing', or filling the diagonal with
301 // the input axis and non-diagonal entries are zeros.
302 // Furthermore, recall that matrix multiplication, which is
303 // represented by the equation "ab,bc->ac", has its VJPs given by the
304 // einsum equations "ac,bc->ab" and "ab,ac->bc" (see, for example
305 // https://math.stackexchange.com/a/2755680). Combined with transposes and
306 // traces we can rewrite Tensor contractions as regular matrix
307 // multiplication. Since each of these operations have their VJPs described
308 // by einsums of the required pattern, the result follows.
309 //
310 // Accordingly, einsum operations except for those with reductions, e.g.
311 // "abc,cd->ad" have their VJPs defined by:
312 // "{output_subs},{other_subs}->{input_subs}".
313 //
314 // But if there is a reduction, this would lead to the equation "ad,cd->abc"
315 // which is invalid because the reduced axis label 'b' is present in the
316 // output but not in any of the inputs. Therefore, we compute the VJP in two
317 // steps: first we obtain VJP for "ac,cd->ad" and then we compute the VJP of
318 // "abc->ac" or, equivalently, reduce_sum(..., axis=1).
319 //
320 // Compute the set of input axis labels which doesn't appear in either the
321 // output subscripts or the other operand's subscript. E.g. the set {'b'} for
322 // the equation "abc,cd->ad".
323 absl::btree_set<char> reduced_label_set(input_subs.begin(), input_subs.end());
324 for (const char x : output_subs) {
325 reduced_label_set.erase(x);
326 }
327 for (const char x : other_subs) {
328 reduced_label_set.erase(x);
329 }
330 reduced_label_set.erase('.');
331
332 // Obtain the input subscripts with the reduced axis labels removed. E.g.
333 // "ac" in the above example.
334 std::string left_subs;
335 for (const char s : input_subs) {
336 if (!reduced_label_set.contains(s)) {
337 left_subs.push_back(s);
338 }
339 }
340
341 // Compute the gradient wrt the input, without accounting for the operation
342 // "abc->ac". So, now we have the VJP of the operation "ac,cd->ad".
343 Output grad_reduced =
344 Einsum(scope, {output_grad, other_operand},
345 absl::StrCat(output_subs, ",", other_subs, "->", left_subs));
346
347 // If the reduced_label_set is empty, then we already have the gradient
348 // wrt the input.
349 if (reduced_label_set.empty()) {
350 return grad_reduced;
351 }
352 // Otherwise, we currently have the gradient wrt the output of the reduction
353 // operation "abc->ac". Invoke the subroutine for the gradient for unary
354 // einsum with reductions.
355 return EinsumGradReducedHelper(scope, grad_reduced, left_subs, input_subs,
356 input_shape, reduced_label_set);
357}
358
359Status EinsumGrad(const Scope& scope, const Operation& op,
360 const std::vector<Output>& grad_inputs,
361 std::vector<Output>* grad_outputs) {
362 if (grad_inputs.size() != 1) {
363 return errors::InvalidArgument("Expect 1 grad input.");
364 }
365 const Output& grad = grad_inputs[0];
366
367 std::string equation;
368 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "equation", &equation));
369 std::vector<absl::string_view> equation_split =
370 absl::StrSplit(equation, "->");
371 if (equation_split.size() != 2) {
372 return errors::InvalidArgument("Equation must contain a single ->");
373 }
374
375 const absl::string_view input_subs = equation_split[0];
376 const absl::string_view output_subs = equation_split[1];
377 if (op.num_inputs() == 1) {
378 // For the unary einsum z = einsum("{eq_x}->{eq_z}", x), the gradient wrt
379 // the input (VJP) is given by the reversed equation:
380 // grad_wrt_x = einsum("{eq_z}->{eq_x}", grad_wrt_z)
381 // (See the justification in _GetGradWrt). This is valid unless there are
382 // reduced axis labels; i.e. axis labels appearing in the input but not in
383 // the output subscripts.
384 auto input_shape = Shape(scope, op.input(0));
385 // Find the axis labels which appear only in the input.
386 absl::btree_set<char> reduced_label_set(input_subs.begin(),
387 input_subs.end());
388 for (const char x : output_subs) {
389 reduced_label_set.erase(x);
390 }
391 reduced_label_set.erase('.');
392 if (reduced_label_set.empty()) {
393 grad_outputs->push_back(Einsum(
394 scope, grad_inputs, absl::StrCat(output_subs, "->", input_subs)));
395 return scope.status();
396 }
397 // We do have reduced axes, so we invoke the subroutine for reduced unary
398 // einsums.
399 grad_outputs->push_back(EinsumGradReducedHelper(
400 scope, grad, output_subs, input_subs, input_shape, reduced_label_set));
401 return scope.status();
402 }
403
404 std::vector<absl::string_view> subs = absl::StrSplit(input_subs, ',');
405 if (subs.size() != 2) {
406 return errors::InvalidArgument("Only 2 inputs are supported");
407 }
408 std::string x_subs(subs[0]);
409 std::string y_subs(subs[1]);
410 // Add ellipsis for broadcasted dimensions if any operand does not have it.
411 // This is because the equation "...ij,jk->ik" may be valid if the 0th input's
412 // batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid
413 // because only the output subscripts contain ellipsis.
414 if (absl::StrContains(output_subs, kEllipsis)) {
415 if (!absl::StrContains(x_subs, kEllipsis)) {
416 absl::StrAppend(&x_subs, kEllipsis);
417 }
418 if (!absl::StrContains(y_subs, kEllipsis)) {
419 absl::StrAppend(&y_subs, kEllipsis);
420 }
421 }
422
423 // Obtain the gradients wrt the inputs x and y, without taking into account
424 // the unbroadcasting.
425 tensorflow::Output x = op.input(0);
426 tensorflow::Output y = op.input(1);
427 if (DataTypeIsComplex(grad.type())) {
428 x = Conj(scope, x);
429 y = Conj(scope, y);
430 }
431
432 const auto x_shape = Shape(scope, x);
433 const auto y_shape = Shape(scope, y);
434 Output grad_x =
435 EinsumGradWrt(scope, grad, y, x_shape, x_subs, y_subs, output_subs);
436 Output grad_y =
437 EinsumGradWrt(scope, grad, x, y_shape, y_subs, x_subs, output_subs);
438
439 if (!absl::StrContains(output_subs, kEllipsis)) {
440 // If no ellipsis in the output; then no need to unbroadcast.
441 grad_outputs->push_back(grad_x);
442 grad_outputs->push_back(grad_y);
443 return scope.status();
444 }
445
446 // Below we handle the case that broadcasting between x and y was necessary,
447 // with x and y having possibly different batch shapes.
448
449 // Obtain the range of axes which map to ellipsis. E.g. for subscripts
450 // 'ab...c' and shape of rank 10; the range [3:-1] denotes the broadcasted
451 // axes.
452 int bx_start, by_start;
453 absl::optional<int> bx_end, by_end;
454 std::tie(bx_start, bx_end) = EinsumGetBcastSubshape(x_subs);
455 std::tie(by_start, by_end) = EinsumGetBcastSubshape(y_subs);
456
457 // Sum the gradient across the broadcasted axes.
458 auto args = internal::BroadcastGradientArgs(
459 scope, Slice1dHelper(scope, x_shape, bx_start, bx_end),
460 Slice1dHelper(scope, y_shape, by_start, by_end));
461 grad_x = Reshape(
462 scope, ReduceSum(scope, grad_x, Add(scope, bx_start, args.r0)), x_shape);
463 grad_y = Reshape(
464 scope, ReduceSum(scope, grad_y, Add(scope, by_start, args.r1)), y_shape);
465 grad_outputs->push_back(grad_x);
466 grad_outputs->push_back(grad_y);
467 return scope.status();
468}
469
470REGISTER_GRADIENT_OP("Einsum", EinsumGrad);
471
472} // namespace
473} // namespace ops
474} // namespace tensorflow
475