1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/util/strided_slice_op.h" |
17 | |
18 | #include <algorithm> |
19 | #include <array> |
20 | #include <iterator> |
21 | |
22 | #include "tensorflow/core/framework/bounds_check.h" |
23 | #include "tensorflow/core/lib/core/status.h" |
24 | |
25 | namespace tensorflow { |
26 | namespace { |
27 | |
28 | /// Constants |
29 | constexpr int32_t kShrinkAxis = -1, kNewAxis = -2; |
30 | |
31 | // Sparse slicing specification |
32 | // if one does foo[3:5, ..., -3], this will have 3 length tensors |
33 | struct StridedSliceSparseSpec { |
34 | int64_t dims; |
35 | int32 num_add_axis_after_ellipsis; |
36 | const Tensor* begin_tensor; |
37 | const Tensor* end_tensor; |
38 | const Tensor& strides_tensor; |
39 | const int32 begin_mask, end_mask; |
40 | int32 ellipsis_mask; |
41 | const int32 new_axis_mask, shrink_axis_mask; |
42 | }; |
43 | |
44 | // Dense slicing specification |
45 | // all ellipses and newaxis' are expanded out. So if |
46 | // foo[3:5, ..., -3] where foo is 10 dimensional, |
47 | // each inlinedVector will have 10 entries whereas the |
48 | // sparse had 3 length tensors. |
49 | struct StridedSliceDenseSpec { |
50 | const int64_t dims; |
51 | int32 begin_mask; |
52 | int32 end_mask; |
53 | bool begin_valid; |
54 | bool end_valid; |
55 | gtl::InlinedVector<int64_t, 4>& begin; |
56 | gtl::InlinedVector<int64_t, 4>& end; |
57 | gtl::InlinedVector<int64_t, 4>& strides; |
58 | // This vector helps construct the final shape of the slice. |
59 | // The final tensor is reduced in rank whenever a single index e.g. foo[3] |
60 | // is called for. The final tensor increases in rank with tf.newaxis |
61 | // entries. If an index in this array is positive, the size of the dimension |
62 | // is obtained from canonical end-begin. Otherwise, if it is a kNewAxis, |
63 | // it will be 1. A shrunk dimension is skipped. |
64 | gtl::InlinedVector<int32, 4> final_shape_gather_indices; |
65 | // This vector has the same size as final_shape_gather_indices, but it |
66 | // remembers the sparse index that a dimension comes from, instead of dense |
67 | // index. A -1 in this vector means there the index is not from the sparse |
68 | // input. |
69 | gtl::InlinedVector<int32, 4> final_shape_gather_indices_sparse; |
70 | gtl::InlinedVector<int32, 4> input_shape_gather_indices_sparse; |
71 | // The dense indexed shrink mask is which processing dimensions |
72 | // should be shrunk. For example, if foo.shape = (10,10,10,10) |
73 | // foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and |
74 | // dense_shrink_axis_mask of 0x9, yielding a final shape (10,10). |
75 | int32 shrink_axis_mask; |
76 | }; |
77 | |
78 | } // namespace |
79 | |
80 | template <class T> |
81 | static Status TF_MUST_USE_RESULT BuildDenseSpec( |
82 | const StridedSliceSparseSpec& sparse, StridedSliceDenseSpec* dense) { |
83 | if (dense->dims < 0) { |
84 | return errors::InvalidArgument("Unexpected negative dense.dims" ); |
85 | } |
86 | |
87 | // Build expanded begin, end, strides, begin_mask, end_mask |
88 | // to remove any ellipsis |
89 | dense->begin.resize(dense->dims); |
90 | dense->end.resize(dense->dims); |
91 | dense->strides.resize(dense->dims); |
92 | dense->input_shape_gather_indices_sparse.resize(dense->dims); |
93 | // What indices to get the final shape from. |
94 | dense->begin_mask = 0; |
95 | dense->end_mask = 0; |
96 | dense->shrink_axis_mask = 0; |
97 | { |
98 | int full_index = 0; |
99 | |
100 | const T* const strides_flat = sparse.strides_tensor.vec<T>().data(); |
101 | dense->begin_valid = sparse.begin_tensor != nullptr; |
102 | dense->end_valid = sparse.end_tensor != nullptr; |
103 | |
104 | const T* const begin_flat = sparse.begin_tensor != nullptr |
105 | ? sparse.begin_tensor->vec<T>().data() |
106 | : nullptr; |
107 | const T* const end_flat = sparse.end_tensor != nullptr |
108 | ? sparse.end_tensor->vec<T>().data() |
109 | : nullptr; |
110 | |
111 | for (int i = 0; i < sparse.dims; i++) { |
112 | if ((1 << i) & sparse.ellipsis_mask) { |
113 | // Expand the ellipsis into the appropriate indices |
114 | // NOTE: this only works because we guaranteed one ellipsis |
115 | int32_t next_index = std::min(dense->dims - (sparse.dims - i) + 1 + |
116 | sparse.num_add_axis_after_ellipsis, |
117 | dense->dims); |
118 | for (; full_index < next_index; full_index++) { |
119 | // new_axis' aren't real axis so you have to skip |
120 | dense->begin[full_index] = dense->end[full_index] = 0; |
121 | dense->strides[full_index] = 1; |
122 | dense->begin_mask |= (1 << full_index); |
123 | dense->end_mask |= (1 << full_index); |
124 | dense->final_shape_gather_indices.push_back(full_index); |
125 | dense->final_shape_gather_indices_sparse.push_back(-1); |
126 | dense->input_shape_gather_indices_sparse[full_index] = i; |
127 | } |
128 | } else if ((1 << i) & sparse.new_axis_mask) { |
129 | dense->final_shape_gather_indices.push_back(kNewAxis); |
130 | dense->final_shape_gather_indices_sparse.push_back(-1); |
131 | } else { |
132 | if (full_index == dense->begin.size()) { |
133 | return errors::InvalidArgument("Index out of range using input dim " , |
134 | full_index, "; input has only " , |
135 | dense->dims, " dims" ); |
136 | } |
137 | |
138 | // Gather slicing spec into appropriate index |
139 | if (begin_flat != nullptr) { |
140 | dense->begin[full_index] = internal::SubtleMustCopy<T>(begin_flat[i]); |
141 | } |
142 | if (end_flat != nullptr) { |
143 | dense->end[full_index] = internal::SubtleMustCopy<T>(end_flat[i]); |
144 | } |
145 | dense->strides[full_index] = |
146 | internal::SubtleMustCopy<T>(strides_flat[i]); |
147 | if (sparse.begin_mask & (1 << i)) { |
148 | dense->begin_mask |= (1 << full_index); |
149 | } |
150 | if (sparse.end_mask & (1 << i)) { |
151 | dense->end_mask |= (1 << full_index); |
152 | } |
153 | // If shrink, record where to get the dimensionality from (i.e. |
154 | // new_axis creates a fake 1 size dimension. Also remember shrink |
155 | // axis (now in dense form) so we can ignore dense->end below. |
156 | if (sparse.shrink_axis_mask & (1 << i)) { |
157 | dense->final_shape_gather_indices.push_back(kShrinkAxis); |
158 | dense->final_shape_gather_indices_sparse.push_back(-1); |
159 | dense->shrink_axis_mask |= (1 << full_index); |
160 | } else { |
161 | dense->final_shape_gather_indices.push_back(full_index); |
162 | // Remember that where in the sparse shape the dense dim comes |
163 | // from. |
164 | dense->final_shape_gather_indices_sparse.push_back(i); |
165 | } |
166 | dense->input_shape_gather_indices_sparse[full_index] = i; |
167 | full_index++; |
168 | } |
169 | } |
170 | } |
171 | return OkStatus(); |
172 | } |
173 | |
174 | Status ValidateStridedSliceOp( |
175 | const Tensor* begin_tensor, const Tensor* end_tensor, |
176 | const Tensor& strides_tensor, const PartialTensorShape& input_shape, |
177 | int32_t begin_mask_spec, int32_t end_mask_spec, const int32_t ellipsis_mask, |
178 | int32_t new_axis_mask, int32_t shrink_axis_mask, |
179 | PartialTensorShape* processing_shape, PartialTensorShape* final_shape, |
180 | bool* is_identity, bool* is_simple_slice, bool* slice_dim0, |
181 | gtl::InlinedVector<int64_t, 4>* begin, gtl::InlinedVector<int64_t, 4>* end, |
182 | gtl::InlinedVector<int64_t, 4>* strides, |
183 | StridedSliceShapeSpec* shape_spec) { |
184 | if (input_shape.unknown_rank()) { |
185 | // Note: If the rank is unknown, "input_shape.dims()" is -1. |
186 | return errors::InvalidArgument("Unexpected input_shape with unknown rank" ); |
187 | } |
188 | |
189 | const bool begin_is_wrong = |
190 | begin_tensor != nullptr && |
191 | !(TensorShapeUtils::IsVector(begin_tensor->shape()) && |
192 | begin_tensor->NumElements() == strides_tensor.NumElements() && |
193 | begin_tensor->NumElements() < 32 /* using 32 bit masks */); |
194 | const bool end_is_wrong = |
195 | end_tensor != nullptr && |
196 | !(TensorShapeUtils::IsVector(end_tensor->shape()) && |
197 | end_tensor->NumElements() == strides_tensor.NumElements()); |
198 | if (begin_is_wrong || end_is_wrong || |
199 | !TensorShapeUtils::IsVector(strides_tensor.shape())) { |
200 | if (begin_tensor != nullptr && end_tensor != nullptr) { |
201 | return errors::InvalidArgument( |
202 | "Expected begin, end, and strides to be 1D equal size tensors, " , |
203 | "but got shapes " , begin_tensor->shape().DebugString(), ", " , |
204 | end_tensor->shape().DebugString(), ", and " , |
205 | strides_tensor.shape().DebugString(), " instead." ); |
206 | } else { |
207 | return errors::InvalidArgument( |
208 | "Expected begin, end, and strides to be 1D equal size tensors, " , |
209 | "but got shape " , strides_tensor.shape().DebugString(), |
210 | " for strides." ); |
211 | } |
212 | } |
213 | // Use bit compares to ensure ellipsis_mask is 0 or a power of 2 |
214 | // i.e. there exists only no more than one ellipsis |
215 | if (ellipsis_mask && ((ellipsis_mask & (ellipsis_mask - 1)) != 0)) { |
216 | return errors::InvalidArgument( |
217 | "Multiple ellipses in slice spec not allowed" ); |
218 | } |
219 | |
220 | // Step 1: Account for ellipsis and new axis |
221 | // |
222 | // Check for ellipses and count how many non-newaxis' there are after |
223 | // TODO(aselle): Convert this to do a fast log2 followed by iteration |
224 | // counting ones in next guys |
225 | bool ellipsis_seen = false; |
226 | |
227 | StridedSliceSparseSpec sparse_spec = {strides_tensor.NumElements(), |
228 | 0, |
229 | begin_tensor, |
230 | end_tensor, |
231 | strides_tensor, |
232 | begin_mask_spec, |
233 | end_mask_spec, |
234 | ellipsis_mask, |
235 | new_axis_mask, |
236 | shrink_axis_mask}; |
237 | |
238 | for (int32_t i = 0; i < sparse_spec.dims; i++) { |
239 | if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) { |
240 | sparse_spec.num_add_axis_after_ellipsis++; |
241 | } |
242 | if ((1 << i) & ellipsis_mask) { |
243 | ellipsis_seen = true; |
244 | } |
245 | } |
246 | // If no ellipsis insert one at the end |
247 | if (!ellipsis_seen) { |
248 | sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims); |
249 | sparse_spec.dims++; // this effects loop iteration below |
250 | } |
251 | |
252 | // Step 2: Make a sparse spec into a full index spec |
253 | // |
254 | // The sparse spec does not correspond to the number of dimensions |
255 | // Make a dense spec that corresponds to the number of dimensions |
256 | // |
257 | // For example suppose foo[...,3:] on foo.shape=(2,2,3) then |
258 | // we need to produce the missing begin_mask for the first two |
259 | // dimensions i.e. from begin_mask_spec=0, end_mask_spec=2 |
260 | // we achieve begin_mask=6, end_mask=7 |
261 | StridedSliceDenseSpec dense_spec = {input_shape.dims(), |
262 | 0 /* begin_mask */, |
263 | 0 /* end_mask */, |
264 | false /* begin_valid */, |
265 | false /* end_valid */, |
266 | *begin, |
267 | *end, |
268 | *strides}; |
269 | |
270 | if (strides_tensor.dtype() == DT_INT32) { |
271 | TF_RETURN_IF_ERROR(BuildDenseSpec<int32>(sparse_spec, &dense_spec)); |
272 | } else if (strides_tensor.dtype() == DT_INT64) { |
273 | TF_RETURN_IF_ERROR(BuildDenseSpec<int64_t>(sparse_spec, &dense_spec)); |
274 | } else if (strides_tensor.dtype() == DT_INT16) { |
275 | TF_RETURN_IF_ERROR(BuildDenseSpec<int16_t>(sparse_spec, &dense_spec)); |
276 | } else { |
277 | LOG(FATAL) << "begin must be either int16, int32 or int64" ; |
278 | } |
279 | |
280 | // Step 3: Make implicit ranges (non-zero begin_masks and end_masks) explicit |
281 | // and bounds check! |
282 | *is_identity = true; |
283 | *slice_dim0 = true; |
284 | *is_simple_slice = true; |
285 | processing_shape->Clear(); |
286 | for (int i = 0; i < input_shape.dims(); ++i) { |
287 | int64_t& begin_i = (*begin)[i]; |
288 | int64_t& end_i = (*end)[i]; |
289 | int64_t& stride_i = (*strides)[i]; |
290 | int64_t dim_i = input_shape.dim_size(i); |
291 | if (stride_i == 0) { |
292 | return errors::InvalidArgument("strides[" , i, "] must be non-zero" ); |
293 | } |
294 | bool shrink_i = (dense_spec.shrink_axis_mask & (1 << i)); |
295 | if (dim_i == -1) { |
296 | processing_shape->AddDim(shrink_i ? 1 : -1); |
297 | continue; |
298 | } |
299 | |
300 | const std::array<int64_t, 2> masks = { |
301 | {dense_spec.begin_mask & (1 << i), dense_spec.end_mask & (1 << i)}}; |
302 | const std::array<int64_t, 2> valid_range = { |
303 | {stride_i > 0 ? 0 : -1, stride_i > 0 ? dim_i : dim_i - 1}}; |
304 | |
305 | auto canonical = [stride_i, dim_i, masks, valid_range](int64_t x, int c) { |
306 | if (masks[c]) { |
307 | return stride_i > 0 ? valid_range[c] : valid_range[(c + 1) & 1]; |
308 | } else { |
309 | int64_t x_fwd = |
310 | x < 0 ? dim_i + x : x; // make negative indices positive |
311 | return x_fwd < valid_range[0] |
312 | ? valid_range[0] |
313 | : x_fwd > valid_range[1] ? valid_range[1] : x_fwd; |
314 | } |
315 | }; |
316 | if (shrink_i && stride_i <= 0) { |
317 | return errors::InvalidArgument( |
318 | "only stride 1 allowed on non-range indexing." ); |
319 | } |
320 | (*is_simple_slice) &= stride_i == 1; |
321 | |
322 | const bool begin_and_end_masked = |
323 | (dense_spec.begin_mask & (1 << i)) && (dense_spec.end_mask & (1 << i)); |
324 | if (dense_spec.begin_valid && dense_spec.end_valid) { |
325 | if (shrink_i) { |
326 | // If we are shrinking, the end index is now possibly incorrect. In |
327 | // particular foo[-1] produces sparse_begin = -1, sparse_end = 0. |
328 | // and canonical puts these to n-1 and 0, which implies a degenerate |
329 | // interval. Fortunately, it is now safe to re-create end as begin+1. |
330 | int64_t x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i; |
331 | begin_i = x_fwd; |
332 | end_i = begin_i + 1; |
333 | if (x_fwd < 0 || x_fwd >= dim_i) { |
334 | return errors::InvalidArgument( |
335 | "slice index " , begin_i, " of dimension " , i, " out of bounds." ); |
336 | } |
337 | } else { |
338 | begin_i = canonical(begin_i, 0); |
339 | end_i = canonical(end_i, 1); |
340 | } |
341 | // Update optimization values |
342 | bool take_all_in_dimension = |
343 | stride_i == 1 && begin_i == 0 && end_i == dim_i; |
344 | (*is_identity) &= take_all_in_dimension; |
345 | (*slice_dim0) &= (i == 0 && stride_i == 1) || take_all_in_dimension; |
346 | } else { |
347 | (*is_identity) &= stride_i == 1 && begin_and_end_masked; |
348 | (*slice_dim0) &= (i == 0 && stride_i == 1) || begin_and_end_masked; |
349 | } |
350 | // Compute the processing shape (the intermediate Eigen will produce) |
351 | int64_t interval_length; |
352 | bool known_interval = false; |
353 | if (dense_spec.begin_valid && dense_spec.end_valid) { |
354 | interval_length = end_i - begin_i; |
355 | known_interval = true; |
356 | } else if (shrink_i) { |
357 | // The dimension is still known as 1 for the processing_shape, but will be |
358 | // discarded for the final shape. |
359 | interval_length = 1; |
360 | known_interval = true; |
361 | } else if (begin_and_end_masked) { |
362 | // Even if we don't have values for begin or end, we do know that this |
363 | // dimension covers the whole interval. If we have shape information for |
364 | // this dimension, that tells us the interval length. |
365 | if (dim_i >= 0) { |
366 | if (stride_i < 0) { |
367 | interval_length = -dim_i; |
368 | } else { |
369 | interval_length = dim_i; |
370 | } |
371 | known_interval = true; |
372 | } |
373 | } |
374 | if (known_interval) { |
375 | int64_t size_i; |
376 | // Hold zero if the interval is degenerate, otherwise account for |
377 | // remainder |
378 | if (interval_length == 0 || ((interval_length < 0) != (stride_i < 0))) { |
379 | size_i = 0; |
380 | } else { |
381 | size_i = interval_length / stride_i + |
382 | (interval_length % stride_i != 0 ? 1 : 0); |
383 | } |
384 | processing_shape->AddDim(size_i); |
385 | } else { |
386 | processing_shape->AddDim(-1); |
387 | } |
388 | } |
389 | |
390 | // Step 4: Compute the final shape |
391 | // |
392 | // new_axis will increase dimension by 1 (with a one-size dimension) |
393 | // slices like foo[3,...] will reduce dimension by 1. |
394 | // This cannot be done earlier, because it depends on Step 3. |
395 | final_shape->Clear(); |
396 | if (shape_spec != nullptr) { |
397 | shape_spec->output_to_sparse_mapping.clear(); |
398 | shape_spec->output_to_processing_mapping.clear(); |
399 | shape_spec->processing_to_sparse_mapping.assign( |
400 | dense_spec.input_shape_gather_indices_sparse.begin(), |
401 | dense_spec.input_shape_gather_indices_sparse.end()); |
402 | |
403 | shape_spec->begin_dense_mask = dense_spec.begin_mask; |
404 | shape_spec->end_dense_mask = dense_spec.end_mask; |
405 | shape_spec->shrink_axis_dense_mask = dense_spec.shrink_axis_mask; |
406 | } |
407 | |
408 | for (int64_t dense_dim = 0; |
409 | dense_dim < dense_spec.final_shape_gather_indices.size(); ++dense_dim) { |
410 | int64_t gather_index = dense_spec.final_shape_gather_indices[dense_dim]; |
411 | int64_t sparse_index = |
412 | dense_spec.final_shape_gather_indices_sparse[dense_dim]; |
413 | if (gather_index >= 0) { |
414 | final_shape->AddDim(processing_shape->dim_size(gather_index)); |
415 | if (shape_spec != nullptr) { |
416 | shape_spec->output_to_sparse_mapping.push_back(sparse_index); |
417 | shape_spec->output_to_processing_mapping.push_back(gather_index); |
418 | } |
419 | } else if (gather_index == kNewAxis) { |
420 | final_shape->AddDim(1); |
421 | if (shape_spec != nullptr) { |
422 | shape_spec->output_to_sparse_mapping.push_back(-1); |
423 | shape_spec->output_to_processing_mapping.push_back(-1); |
424 | } |
425 | } |
426 | } |
427 | |
428 | return OkStatus(); |
429 | } |
430 | |
431 | Status ValidateStridedSliceOp( |
432 | const Tensor* begin_tensor, const Tensor* end_tensor, |
433 | const Tensor& strides_tensor, const PartialTensorShape& input_shape, |
434 | int32_t begin_mask_spec, int32_t end_mask_spec, const int32_t ellipsis_mask, |
435 | int32_t new_axis_mask, int32_t shrink_axis_mask, |
436 | TensorShape* processing_shape, TensorShape* final_shape, bool* is_identity, |
437 | bool* is_simple_slice, bool* slice_dim0, |
438 | gtl::InlinedVector<int64_t, 4>* begin, gtl::InlinedVector<int64_t, 4>* end, |
439 | gtl::InlinedVector<int64_t, 4>* strides, |
440 | StridedSliceShapeSpec* shape_spec) { |
441 | // Validate with PartialTensorShape output |
442 | PartialTensorShape partial_processing_shape, partial_final_shape; |
443 | TF_RETURN_IF_ERROR(ValidateStridedSliceOp( |
444 | begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec, |
445 | end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask, |
446 | &partial_processing_shape, &partial_final_shape, is_identity, |
447 | is_simple_slice, slice_dim0, begin, end, strides, shape_spec)); |
448 | |
449 | // Verify that the output shapes are fully known |
450 | if (!partial_processing_shape.AsTensorShape(processing_shape) || |
451 | !partial_final_shape.AsTensorShape(final_shape)) { |
452 | return errors::Internal("ValidateStridedSliceOp returned partial shapes " , |
453 | partial_processing_shape.DebugString(), " and " , |
454 | partial_final_shape.DebugString()); |
455 | } |
456 | return OkStatus(); |
457 | } |
458 | |
459 | StridedSliceAssignBCast::StridedSliceAssignBCast( |
460 | const StridedSliceAssignBCast::Vec& input_shape, |
461 | const StridedSliceAssignBCast::Vec& output_shape) |
462 | : valid_(true), |
463 | broadcasting_required_(false), |
464 | reshape_(output_shape.size()), |
465 | bcast_(output_shape.size()), |
466 | result_shape_(output_shape) { |
467 | // The input needs to be reshaped to have the same number of dimensions as |
468 | // the output. This is accomplished by either prepending with ones or removing |
469 | // leading, as necessary. |
470 | size_t input_start = 0; |
471 | size_t prepend_size = 0; |
472 | if (output_shape.size() < input_shape.size()) { |
473 | // Numpy allows assigning a larger rank array to smaller as long as |
474 | // broadcasting would otherwise work and the prefix dimensions are all 1. |
475 | // Though this behavior is undocumented, we allow it here for consistency. |
476 | // See https://github.com/numpy/numpy/issues/21744 for details. |
477 | input_start = input_shape.size() - output_shape.size(); |
478 | for (size_t i = 0; i < input_start; ++i) { |
479 | if (input_shape[i] != 1) { |
480 | valid_ = false; |
481 | return; |
482 | } |
483 | } |
484 | } else { |
485 | prepend_size = output_shape.size() - input_shape.size(); |
486 | } |
487 | std::fill_n(reshape_.begin(), prepend_size, 1); |
488 | std::copy(input_shape.begin() + input_start, input_shape.end(), |
489 | reshape_.begin() + prepend_size); |
490 | |
491 | // In order to broadcast, dimensions must either be equal or one. |
492 | for (size_t i = 0; i < output_shape.size(); ++i) { |
493 | if (reshape_[i] == output_shape[i]) { |
494 | bcast_[i] = 1; |
495 | } else if (reshape_[i] == 1) { |
496 | bcast_[i] = output_shape[i]; |
497 | broadcasting_required_ = true; |
498 | } else { |
499 | valid_ = false; |
500 | return; |
501 | } |
502 | } |
503 | } |
504 | |
505 | bool StridedSliceAssignBCast::RemapDimensions( |
506 | int64_t num_dims, const StridedSliceAssignBCast::Vec& dimension_map) { |
507 | // Each element in the map corresponds to the original result shape, so |
508 | // the sizes must be equal. |
509 | if (dimension_map.size() != result_shape_.size()) { |
510 | return false; |
511 | } |
512 | |
513 | // Ensure all indices are within-bounds before any modifications are made - |
514 | // otherwise we could be left in a corrupted state. |
515 | for (size_t i = 0; i < dimension_map.size(); ++i) { |
516 | int64_t dim = dimension_map[i]; |
517 | if (dim >= num_dims) { |
518 | return false; |
519 | } |
520 | } |
521 | |
522 | Vec old_reshape = std::move(reshape_); |
523 | Vec old_bcast = std::move(bcast_); |
524 | Vec old_result_shape = std::move(result_shape_); |
525 | reshape_ = Vec(num_dims); |
526 | bcast_ = Vec(num_dims); |
527 | result_shape_ = Vec(num_dims); |
528 | std::fill_n(reshape_.begin(), num_dims, 1); |
529 | std::fill_n(bcast_.begin(), num_dims, 1); |
530 | std::fill_n(result_shape_.begin(), num_dims, 1); |
531 | for (size_t i = 0; i < dimension_map.size(); ++i) { |
532 | int64_t dim = dimension_map[i]; |
533 | if (dim >= 0) { |
534 | reshape_[dim] = old_reshape[i]; |
535 | bcast_[dim] = old_bcast[i]; |
536 | result_shape_[dim] = old_result_shape[i]; |
537 | } |
538 | } |
539 | |
540 | return true; |
541 | } |
542 | |
543 | } // namespace tensorflow |
544 | |