1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/gradients/grad_helper.h"
17
18#include "tensorflow/cc/ops/array_ops.h"
19#include "tensorflow/cc/ops/data_flow_ops.h"
20#include "tensorflow/cc/ops/standard_ops.h"
21
22namespace tensorflow {
23
24using tensorflow::ops::Add;
25using tensorflow::ops::Const;
26using tensorflow::ops::DynamicStitch;
27using tensorflow::ops::Mod;
28using tensorflow::ops::OnesLike;
29using tensorflow::ops::Range;
30using tensorflow::ops::Size;
31
32Output ReducedShapeHelper(const Scope& scope, const Output& input_shape,
33 const Output& reduction_axes) {
34 auto zero = Const(scope, 0);
35 auto one = Const(scope, 1);
36
37 // Running example in comments
38 // input_shape = [2, 3, 5, 7]
39 // axes = [1, 2]
40 // The result (a shape after a reduction with keep_dims=True)
41 // [2, 1, 1, 7]
42 //
43 // We can treat each entry in axes as an index into input_shape that
44 // should be replaced by 1.
45 // We use DynamicStitch to do this.
46
47 // input_rank = 4
48 auto input_rank = Size(scope, input_shape);
49
50 // Normalize any negative indices in the reduction_axes to positive
51 // values.
52 auto axes = Mod(scope, Add(scope, reduction_axes, input_rank), input_rank);
53
54 // This [0..input_rank) range of integers is used in DynamicStitch to
55 // first copy input_shape to the result.
56 // input_rank_range = [0, 1, 2, 3]
57 auto input_rank_range = Range(scope, zero, input_rank, one);
58
59 // A 1-filled tensor with the same shape as axes. DynamicStitch will
60 // merge these 1s (using axes for indices) to the correct
61 // position in the result.
62 // axes_ones = [1, 1]
63 auto axes_ones = OnesLike(scope, axes);
64
65 // using DynamicStitch:
66 // indices = { input_rank_range, axes }
67 // = { [0, 1, 2, 3], [1, 2] }
68 // data = { input_shape, axes_ones }
69 // = { [2, 3, 5, 7], [1, 1] }
70 // The input_rank_range entry in indices first replicates the
71 // input_shape to the result.
72 // The axes entry in indices then moves a 1 to each of its entries,
73 // resulting in
74 // [2, 1, 1, 7]
75 std::vector<Output> indices = {input_rank_range, axes};
76 std::vector<Output> data = {input_shape, axes_ones};
77 return DynamicStitch(scope, indices, data);
78}
79
80} // namespace tensorflow
81