1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_ |
16 | #define TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_ |
17 | |
18 | #include "tensorflow/core/framework/tensor.h" |
19 | #include "tensorflow/core/framework/tensor_shape.h" |
20 | #include "tensorflow/core/framework/types.h" |
21 | #include "tensorflow/core/lib/core/status.h" |
22 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
23 | |
24 | namespace tensorflow { |
25 | |
26 | struct StridedSliceShapeSpec { |
27 | // Begin mask canonlized in dense form. |
28 | int32_t begin_dense_mask; |
29 | // End mask canonlized in dense form. |
30 | int32_t end_dense_mask; |
31 | // Shrink axis mask canonlized in dense form. |
32 | int32_t shrink_axis_dense_mask; |
33 | // output_to_sparse_mapping[i] represents output[i]'s the corresponding dim |
34 | // index in the begin_tensor. If |
35 | // output_to_sparse_mapping[i] is -1, it means the dimension doesn't show up |
36 | // in sparse_mapping. |
37 | gtl::InlinedVector<int64_t, 4> output_to_sparse_mapping; |
38 | // output_to_processing_mapping is similar to output_to_sparse_mapping, but |
39 | // for processing shape. |
40 | gtl::InlinedVector<int64_t, 4> output_to_processing_mapping; |
41 | // processing_to_sparse_mapping[i] represents input_shape[i]'s corresponding |
42 | // dim index in the begin_tensor. |
43 | gtl::InlinedVector<int64_t, 4> processing_to_sparse_mapping; |
44 | }; |
45 | |
46 | // Runs validation on the strided slice op parameters. |
47 | // |
48 | // Is a separate translation unit from the kernel so that: |
49 | // 1. The op's shape function can use it. |
50 | // 2. The code size is reduced vs templating this on the kernel's type. |
51 | // |
52 | // Note that when input_shape is not fully specified, only <final_shape> and |
53 | // <processing_shape> are valid; <is_identity>, <is_simple_slice> and other |
54 | // output parameters will not be accurate. |
55 | // |
56 | // If the rank of <input_shape> is unknown (i.e., "input_shape.unknown_rank()" |
57 | // is true)), the method returns an invalid status. |
58 | // |
59 | // If <begin_tensor> or <end_tensor> are nullptr, <begin> and <end> will not be |
60 | // valid. In this case, <slice_dim0> and <is_identity> will be true only if a |
61 | // determination can be made based on the information given. A best effort is |
62 | // made to set <processing_shape> and <final_shape> based on <input_shape>, but |
63 | // some dimensions of <processing_shape> and/or <final_shape> may be unknown |
64 | // (-1). Any validation that can be done without complete information is |
65 | // performed. |
66 | // |
67 | Status ValidateStridedSliceOp( |
68 | const Tensor* begin_tensor, const Tensor* end_tensor, |
69 | const Tensor& strides_tensor, const PartialTensorShape& input_shape, |
70 | int32_t begin_mask_spec, int32_t end_mask_spec, const int32_t ellipsis_mask, |
71 | int32_t new_axis_mask, int32_t shrink_axis_mask, |
72 | PartialTensorShape* processing_shape, PartialTensorShape* final_shape, |
73 | bool* is_identity, bool* is_simple_slice, bool* slice_dim0, |
74 | gtl::InlinedVector<int64_t, 4>* begin, gtl::InlinedVector<int64_t, 4>* end, |
75 | gtl::InlinedVector<int64_t, 4>* strides, |
76 | StridedSliceShapeSpec* shape_spec = nullptr); |
77 | |
78 | // Same as above, but the outputs are TensorShape, not PartialTensorShape |
79 | Status ValidateStridedSliceOp( |
80 | const Tensor* begin_tensor, const Tensor* end_tensor, |
81 | const Tensor& strides_tensor, const PartialTensorShape& input_shape, |
82 | int32_t begin_mask_spec, int32_t end_mask_spec, const int32_t ellipsis_mask, |
83 | int32_t new_axis_mask, int32_t shrink_axis_mask, |
84 | TensorShape* processing_shape, TensorShape* final_shape, bool* is_identity, |
85 | bool* is_simple_slice, bool* slice_dim0, |
86 | gtl::InlinedVector<int64_t, 4>* begin, gtl::InlinedVector<int64_t, 4>* end, |
87 | gtl::InlinedVector<int64_t, 4>* strides, |
88 | StridedSliceShapeSpec* shape_spec = nullptr); |
89 | |
90 | // Simple class for determining if it is possible to broadcast a tensor to a |
91 | // strided slice. Modelled after tensorflow::BCast, but with a few key |
92 | // differences: |
93 | // - the input_shape must be broadcastable to output_shape |
94 | // (i.e. the slice shape does not grow). |
95 | // - does not allow reducing or flattening dimensions, since we cannot apply |
96 | // these simplications to the destination slice. |
97 | // - allows for remapping dimensions, required in order to associate the input |
98 | // with correct dimensions in the full (unsliced) destination tensor. |
99 | class StridedSliceAssignBCast { |
100 | public: |
101 | using Vec = gtl::InlinedVector<int64_t, 4>; |
102 | |
103 | StridedSliceAssignBCast(const Vec& input_shape, const Vec& output_shape); |
104 | |
105 | // Remaps broadcast, resize, and output dimensions via the provided map. |
106 | // Negative values in the map correspond to dimensions being removed. |
107 | // Unmapped dimensions are set to 1. |
108 | // |
109 | // This is to support remapping slice -> processing dimensions. To relate |
110 | // the sliced output dimensions back to processing dimensions (i.e. those |
111 | // relative to the original unsliced input), we need to remove any axes |
112 | // that were added via the `new_axis_mask`, and add back any axes that were |
113 | // removed via the `shrink_axis_mask`. For example, an expression like |
114 | // |
115 | // >>> t = tf.zeros([3, 3]) |
116 | // >>> t[2, tf.newaxis, 0:2, tf.newaxis] = tf.ones([1, 3, 1]) |
117 | // ^ ^ ^ ^ |
118 | // |__ shrink axis new axis __| | |__ new axis |
119 | // |_____ dim 1 of t |
120 | // |
121 | // would have `new_axis_mask = 0b1010` and `shrink_axis_mask = 0b0001`. The |
122 | // slice has shape [1, 3, 1], but the original input tensor `t` has shape |
123 | // [3, 3]. To remap the slice dimensions back to the input dimensions, the |
124 | // mapping would use `num_dims = 2`, `dimension_map = {-1, 1, -1}`. This |
125 | // removes the two new axes added for the slice, maps the middle slice |
126 | // dimension to input dimension 1, and leaves input dimension 0 to have a |
127 | // default size of 1 to add back the shrink axis. |
128 | // |
129 | // Returns false if the remapping fails. |
130 | bool RemapDimensions(int64_t num_dims, const Vec& dimension_map); |
131 | |
132 | bool IsValid() const { return valid_; } |
133 | |
134 | bool IsBroadcastingRequired() const { return broadcasting_required_; } |
135 | |
136 | const Vec& reshape() const { return reshape_; } |
137 | |
138 | const Vec& bcast() const { return bcast_; } |
139 | |
140 | const Vec& result_shape() const { return result_shape_; } |
141 | |
142 | private: |
143 | bool valid_ = true; |
144 | bool broadcasting_required_ = false; |
145 | Vec reshape_; |
146 | Vec bcast_; |
147 | Vec result_shape_; |
148 | }; |
149 | |
150 | } // namespace tensorflow |
151 | |
152 | #endif // TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_ |
153 | |