1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/cc/gradients/grad_helper.h" |
17 | |
18 | #include "tensorflow/cc/ops/array_ops.h" |
19 | #include "tensorflow/cc/ops/data_flow_ops.h" |
20 | #include "tensorflow/cc/ops/standard_ops.h" |
21 | |
22 | namespace tensorflow { |
23 | |
24 | using tensorflow::ops::Add; |
25 | using tensorflow::ops::Const; |
26 | using tensorflow::ops::DynamicStitch; |
27 | using tensorflow::ops::Mod; |
28 | using tensorflow::ops::OnesLike; |
29 | using tensorflow::ops::Range; |
30 | using tensorflow::ops::Size; |
31 | |
32 | Output ReducedShapeHelper(const Scope& scope, const Output& input_shape, |
33 | const Output& reduction_axes) { |
34 | auto zero = Const(scope, 0); |
35 | auto one = Const(scope, 1); |
36 | |
37 | // Running example in comments |
38 | // input_shape = [2, 3, 5, 7] |
39 | // axes = [1, 2] |
40 | // The result (a shape after a reduction with keep_dims=True) |
41 | // [2, 1, 1, 7] |
42 | // |
43 | // We can treat each entry in axes as an index into input_shape that |
44 | // should be replaced by 1. |
45 | // We use DynamicStitch to do this. |
46 | |
47 | // input_rank = 4 |
48 | auto input_rank = Size(scope, input_shape); |
49 | |
50 | // Normalize any negative indices in the reduction_axes to positive |
51 | // values. |
52 | auto axes = Mod(scope, Add(scope, reduction_axes, input_rank), input_rank); |
53 | |
54 | // This [0..input_rank) range of integers is used in DynamicStitch to |
55 | // first copy input_shape to the result. |
56 | // input_rank_range = [0, 1, 2, 3] |
57 | auto input_rank_range = Range(scope, zero, input_rank, one); |
58 | |
59 | // A 1-filled tensor with the same shape as axes. DynamicStitch will |
60 | // merge these 1s (using axes for indices) to the correct |
61 | // position in the result. |
62 | // axes_ones = [1, 1] |
63 | auto axes_ones = OnesLike(scope, axes); |
64 | |
65 | // using DynamicStitch: |
66 | // indices = { input_rank_range, axes } |
67 | // = { [0, 1, 2, 3], [1, 2] } |
68 | // data = { input_shape, axes_ones } |
69 | // = { [2, 3, 5, 7], [1, 1] } |
70 | // The input_rank_range entry in indices first replicates the |
71 | // input_shape to the result. |
72 | // The axes entry in indices then moves a 1 to each of its entries, |
73 | // resulting in |
74 | // [2, 1, 1, 7] |
75 | std::vector<Output> indices = {input_rank_range, axes}; |
76 | std::vector<Output> data = {input_shape, axes_ones}; |
77 | return DynamicStitch(scope, indices, data); |
78 | } |
79 | |
80 | } // namespace tensorflow |
81 | |