1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/kernels/scatter_nd_util.h" |
17 | |
18 | #include "tensorflow/core/framework/tensor_shape.h" |
19 | |
20 | namespace tensorflow { |
21 | |
22 | Status ValidateScatterNdUpdateShape(const TensorShape& params_shape, |
23 | const TensorShape& indices_shape, |
24 | const TensorShape& updates_shape) { |
25 | const int64_t slice_dim = |
26 | (indices_shape.dims() > 1) |
27 | ? indices_shape.dim_size(indices_shape.dims() - 1) |
28 | : 1; |
29 | const int64_t batch_dim = |
30 | (indices_shape.dims() > 1) ? indices_shape.dims() - 1 : 1; |
31 | |
32 | auto shape_err_prefix = [&]() { |
33 | return errors::InvalidArgument( |
34 | "Dimensions [0," , batch_dim, |
35 | ") of indices[shape=" , indices_shape.DebugString(), |
36 | "] must match dimensions [0," , batch_dim, |
37 | ") of updates[shape=" , updates_shape.DebugString(), "]" ); |
38 | }; |
39 | auto shape_err_suffix = [&]() { |
40 | return errors::InvalidArgument( |
41 | "Dimensions [" , slice_dim, "," , params_shape.dims(), |
42 | ") of input[shape=" , params_shape.DebugString(), |
43 | "] must match dimensions [" , slice_dim, "," , updates_shape.dims(), |
44 | ") of updates[shape=" , updates_shape.DebugString(), "]" ); |
45 | }; |
46 | |
47 | if (updates_shape.dims() < batch_dim) return shape_err_prefix(); |
48 | if (params_shape.dims() < slice_dim + (updates_shape.dims() - batch_dim)) { |
49 | return shape_err_suffix(); |
50 | } |
51 | if (updates_shape.dims() != batch_dim + params_shape.dims() - slice_dim) { |
52 | return shape_err_suffix(); |
53 | } |
54 | for (int d = 0; d < batch_dim; ++d) { |
55 | if (updates_shape.dim_size(d) != indices_shape.dim_size(d)) |
56 | return shape_err_prefix(); |
57 | } |
58 | for (int d = 0; d < updates_shape.dims() - batch_dim; ++d) { |
59 | if (updates_shape.dim_size(d + batch_dim) != |
60 | params_shape.dim_size(d + slice_dim)) { |
61 | return shape_err_suffix(); |
62 | } |
63 | } |
64 | return OkStatus(); |
65 | } |
66 | |
67 | } // namespace tensorflow |
68 | |