1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_
16#define TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_
17
18#include "tensorflow/core/framework/tensor.h"
19#include "tensorflow/core/framework/tensor_shape.h"
20#include "tensorflow/core/framework/types.h"
21#include "tensorflow/core/lib/core/status.h"
22#include "tensorflow/core/lib/gtl/inlined_vector.h"
23
24namespace tensorflow {
25
26struct StridedSliceShapeSpec {
27 // Begin mask canonlized in dense form.
28 int32_t begin_dense_mask;
29 // End mask canonlized in dense form.
30 int32_t end_dense_mask;
31 // Shrink axis mask canonlized in dense form.
32 int32_t shrink_axis_dense_mask;
33 // output_to_sparse_mapping[i] represents output[i]'s the corresponding dim
34 // index in the begin_tensor. If
35 // output_to_sparse_mapping[i] is -1, it means the dimension doesn't show up
36 // in sparse_mapping.
37 gtl::InlinedVector<int64_t, 4> output_to_sparse_mapping;
38 // output_to_processing_mapping is similar to output_to_sparse_mapping, but
39 // for processing shape.
40 gtl::InlinedVector<int64_t, 4> output_to_processing_mapping;
41 // processing_to_sparse_mapping[i] represents input_shape[i]'s corresponding
42 // dim index in the begin_tensor.
43 gtl::InlinedVector<int64_t, 4> processing_to_sparse_mapping;
44};
45
46// Runs validation on the strided slice op parameters.
47//
48// Is a separate translation unit from the kernel so that:
49// 1. The op's shape function can use it.
50// 2. The code size is reduced vs templating this on the kernel's type.
51//
52// Note that when input_shape is not fully specified, only <final_shape> and
53// <processing_shape> are valid; <is_identity>, <is_simple_slice> and other
54// output parameters will not be accurate.
55//
56// If the rank of <input_shape> is unknown (i.e., "input_shape.unknown_rank()"
57// is true)), the method returns an invalid status.
58//
59// If <begin_tensor> or <end_tensor> are nullptr, <begin> and <end> will not be
60// valid. In this case, <slice_dim0> and <is_identity> will be true only if a
61// determination can be made based on the information given. A best effort is
62// made to set <processing_shape> and <final_shape> based on <input_shape>, but
63// some dimensions of <processing_shape> and/or <final_shape> may be unknown
64// (-1). Any validation that can be done without complete information is
65// performed.
66//
67Status ValidateStridedSliceOp(
68 const Tensor* begin_tensor, const Tensor* end_tensor,
69 const Tensor& strides_tensor, const PartialTensorShape& input_shape,
70 int32_t begin_mask_spec, int32_t end_mask_spec, const int32_t ellipsis_mask,
71 int32_t new_axis_mask, int32_t shrink_axis_mask,
72 PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
73 bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
74 gtl::InlinedVector<int64_t, 4>* begin, gtl::InlinedVector<int64_t, 4>* end,
75 gtl::InlinedVector<int64_t, 4>* strides,
76 StridedSliceShapeSpec* shape_spec = nullptr);
77
78// Same as above, but the outputs are TensorShape, not PartialTensorShape
79Status ValidateStridedSliceOp(
80 const Tensor* begin_tensor, const Tensor* end_tensor,
81 const Tensor& strides_tensor, const PartialTensorShape& input_shape,
82 int32_t begin_mask_spec, int32_t end_mask_spec, const int32_t ellipsis_mask,
83 int32_t new_axis_mask, int32_t shrink_axis_mask,
84 TensorShape* processing_shape, TensorShape* final_shape, bool* is_identity,
85 bool* is_simple_slice, bool* slice_dim0,
86 gtl::InlinedVector<int64_t, 4>* begin, gtl::InlinedVector<int64_t, 4>* end,
87 gtl::InlinedVector<int64_t, 4>* strides,
88 StridedSliceShapeSpec* shape_spec = nullptr);
89
90// Simple class for determining if it is possible to broadcast a tensor to a
91// strided slice. Modelled after tensorflow::BCast, but with a few key
92// differences:
93// - the input_shape must be broadcastable to output_shape
94// (i.e. the slice shape does not grow).
95// - does not allow reducing or flattening dimensions, since we cannot apply
96// these simplications to the destination slice.
97// - allows for remapping dimensions, required in order to associate the input
98// with correct dimensions in the full (unsliced) destination tensor.
99class StridedSliceAssignBCast {
100 public:
101 using Vec = gtl::InlinedVector<int64_t, 4>;
102
103 StridedSliceAssignBCast(const Vec& input_shape, const Vec& output_shape);
104
105 // Remaps broadcast, resize, and output dimensions via the provided map.
106 // Negative values in the map correspond to dimensions being removed.
107 // Unmapped dimensions are set to 1.
108 //
109 // This is to support remapping slice -> processing dimensions. To relate
110 // the sliced output dimensions back to processing dimensions (i.e. those
111 // relative to the original unsliced input), we need to remove any axes
112 // that were added via the `new_axis_mask`, and add back any axes that were
113 // removed via the `shrink_axis_mask`. For example, an expression like
114 //
115 // >>> t = tf.zeros([3, 3])
116 // >>> t[2, tf.newaxis, 0:2, tf.newaxis] = tf.ones([1, 3, 1])
117 // ^ ^ ^ ^
118 // |__ shrink axis new axis __| | |__ new axis
119 // |_____ dim 1 of t
120 //
121 // would have `new_axis_mask = 0b1010` and `shrink_axis_mask = 0b0001`. The
122 // slice has shape [1, 3, 1], but the original input tensor `t` has shape
123 // [3, 3]. To remap the slice dimensions back to the input dimensions, the
124 // mapping would use `num_dims = 2`, `dimension_map = {-1, 1, -1}`. This
125 // removes the two new axes added for the slice, maps the middle slice
126 // dimension to input dimension 1, and leaves input dimension 0 to have a
127 // default size of 1 to add back the shrink axis.
128 //
129 // Returns false if the remapping fails.
130 bool RemapDimensions(int64_t num_dims, const Vec& dimension_map);
131
132 bool IsValid() const { return valid_; }
133
134 bool IsBroadcastingRequired() const { return broadcasting_required_; }
135
136 const Vec& reshape() const { return reshape_; }
137
138 const Vec& bcast() const { return bcast_; }
139
140 const Vec& result_shape() const { return result_shape_; }
141
142 private:
143 bool valid_ = true;
144 bool broadcasting_required_ = false;
145 Vec reshape_;
146 Vec bcast_;
147 Vec result_shape_;
148};
149
150} // namespace tensorflow
151
152#endif // TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_
153