1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/framework/common_shape_fns.h" |
16 | |
17 | #include "absl/container/flat_hash_map.h" |
18 | #include "absl/container/flat_hash_set.h" |
19 | #include "absl/strings/match.h" |
20 | #include "absl/strings/str_split.h" |
21 | #include "absl/strings/string_view.h" |
22 | #include "tensorflow/core/framework/attr_value.pb.h" |
23 | #include "tensorflow/core/framework/shape_inference.h" |
24 | #include "tensorflow/core/lib/core/errors.h" |
25 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
26 | #include "tensorflow/core/util/einsum_op_util.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | namespace shape_inference { |
31 | |
32 | // The V2 version computes windowed output size with arbitrary dilation_rate and |
33 | // explicit padding, while the original version only handles the cases where |
34 | // dilation_rates equal to 1 and the padding is SAME or VALID. |
35 | Status GetWindowedOutputSizeFromDimsV2( |
36 | shape_inference::InferenceContext* c, |
37 | shape_inference::DimensionHandle input_size, |
38 | shape_inference::DimensionOrConstant filter_size, int64_t dilation_rate, |
39 | int64_t stride, Padding padding_type, int64_t padding_before, |
40 | int64_t padding_after, shape_inference::DimensionHandle* output_size) { |
41 | if (stride <= 0) { |
42 | return errors::InvalidArgument("Stride must be > 0, but got " , stride); |
43 | } |
44 | |
45 | if (dilation_rate < 1) { |
46 | return errors::InvalidArgument("Dilation rate must be >= 1, but got " , |
47 | dilation_rate); |
48 | } |
49 | |
50 | // See also the parallel implementation in GetWindowedOutputSizeVerbose. |
51 | switch (padding_type) { |
52 | case Padding::VALID: |
53 | padding_before = padding_after = 0; |
54 | TF_FALLTHROUGH_INTENDED; |
55 | case Padding::EXPLICIT: |
56 | TF_RETURN_IF_ERROR( |
57 | c->Add(input_size, padding_before + padding_after, &input_size)); |
58 | if (dilation_rate > 1) { |
59 | DimensionHandle window_size; |
60 | TF_RETURN_IF_ERROR( |
61 | c->Subtract(c->MakeDim(filter_size), 1, &window_size)); |
62 | TF_RETURN_IF_ERROR( |
63 | c->Multiply(window_size, dilation_rate, &window_size)); |
64 | TF_RETURN_IF_ERROR(c->Add(window_size, 1, &window_size)); |
65 | TF_RETURN_IF_ERROR(c->Subtract(input_size, window_size, output_size)); |
66 | } else { |
67 | TF_RETURN_IF_ERROR(c->Subtract(input_size, filter_size, output_size)); |
68 | } |
69 | TF_RETURN_IF_ERROR(c->Add(*output_size, stride, output_size)); |
70 | TF_RETURN_IF_ERROR(c->Divide(*output_size, stride, |
71 | /*evenly_divisible=*/false, output_size)); |
72 | break; |
73 | case Padding::SAME: |
74 | TF_RETURN_IF_ERROR(c->Add(input_size, stride - 1, output_size)); |
75 | TF_RETURN_IF_ERROR(c->Divide(*output_size, stride, |
76 | /*evenly_divisible=*/false, output_size)); |
77 | break; |
78 | } |
79 | return OkStatus(); |
80 | } |
81 | |
82 | Status GetWindowedOutputSizeFromDims( |
83 | shape_inference::InferenceContext* c, |
84 | shape_inference::DimensionHandle input_size, |
85 | shape_inference::DimensionOrConstant filter_size, int64_t stride, |
86 | Padding padding_type, shape_inference::DimensionHandle* output_size) { |
87 | if (padding_type == Padding::EXPLICIT) { |
88 | return errors::Internal( |
89 | "GetWindowedOutputSizeFromDims does not handle EXPLICIT padding; call " |
90 | "GetWindowedOutputSizeFromDimsV2 instead" ); |
91 | } |
92 | return GetWindowedOutputSizeFromDimsV2(c, input_size, filter_size, |
93 | /*dilation_rate=*/1, stride, |
94 | padding_type, |
95 | // Give dummy values of -1 to |
96 | // padding_before and padding_after, |
97 | // since explicit padding is not used. |
98 | -1, -1, output_size); |
99 | } |
100 | |
101 | Status UnchangedShape(shape_inference::InferenceContext* c) { |
102 | c->set_output(0, c->input(0)); |
103 | auto* handle_data = c->input_handle_shapes_and_types(0); |
104 | if (handle_data != nullptr) { |
105 | c->set_output_handle_shapes_and_types(0, *handle_data); |
106 | } |
107 | return OkStatus(); |
108 | } |
109 | |
110 | Status MatMulShape(shape_inference::InferenceContext* c) { |
111 | ShapeHandle a; |
112 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a)); |
113 | |
114 | ShapeHandle b; |
115 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &b)); |
116 | |
117 | bool transpose_a, transpose_b; |
118 | TF_RETURN_IF_ERROR(c->GetAttr("transpose_a" , &transpose_a)); |
119 | TF_RETURN_IF_ERROR(c->GetAttr("transpose_b" , &transpose_b)); |
120 | DimensionHandle output_rows = transpose_a ? c->Dim(a, 1) : c->Dim(a, 0); |
121 | DimensionHandle output_cols = transpose_b ? c->Dim(b, 0) : c->Dim(b, 1); |
122 | |
123 | // Validate that the inner shapes are compatible. |
124 | DimensionHandle inner_a = transpose_a ? c->Dim(a, 0) : c->Dim(a, 1); |
125 | DimensionHandle inner_b = transpose_b ? c->Dim(b, 1) : c->Dim(b, 0); |
126 | DimensionHandle merged; |
127 | TF_RETURN_IF_ERROR(c->Merge(inner_a, inner_b, &merged)); |
128 | |
129 | c->set_output(0, c->Matrix(output_rows, output_cols)); |
130 | return OkStatus(); |
131 | } |
132 | |
133 | namespace { |
134 | |
135 | // Validate that an Einsum subscript contains exactly one or zero ellipsis; and |
136 | // that periods (.) occur only within an ellipses (...). |
137 | Status ValidateEinsumEllipsis(absl::string_view subscript, |
138 | bool* found_ellipsis) { |
139 | const int num_periods = absl::c_count(subscript, '.'); |
140 | if (num_periods != 0 && num_periods != 3) { |
141 | return errors::InvalidArgument( |
142 | "Expected at most one ellipsis (...), but found " , num_periods, |
143 | " periods (.) in the input subscript: " , subscript); |
144 | } |
145 | if (num_periods == 3 && !absl::StrContains(subscript, "..." )) { |
146 | return errors::InvalidArgument( |
147 | "Periods found outside of ellipsis in subscript: " , subscript); |
148 | } |
149 | *found_ellipsis = num_periods > 0; |
150 | return OkStatus(); |
151 | } |
152 | |
153 | } // namespace |
154 | |
155 | Status EinsumShape(shape_inference::InferenceContext* c) { |
156 | // We assume that the equation has a valid format. Either (x),(y)->(z) |
157 | // or (x)->(z), where each of (x), (y) and (z) are concatenation of zero or |
158 | // more latin alphabets and contains at most one ellipsis ('...'). |
159 | string equation; |
160 | TF_RETURN_IF_ERROR(c->GetAttr("equation" , &equation)); |
161 | gtl::InlinedVector<string, 2> input_labels; |
162 | string output_labels; |
163 | TF_RETURN_IF_ERROR( |
164 | ValidateEinsumEquation(equation, &input_labels, &output_labels)); |
165 | |
166 | if (c->num_inputs() == 0 || c->num_inputs() > 2) { |
167 | return errors::InvalidArgument("Expected either 1 or 2 inputs but got: " , |
168 | c->num_inputs()); |
169 | } |
170 | const int input_labels_size = input_labels.size(); |
171 | if (c->num_inputs() != input_labels_size) { |
172 | return errors::InvalidArgument("Expected " , input_labels.size(), |
173 | " inputs for equation " , equation, |
174 | " but got: " , c->num_inputs()); |
175 | } |
176 | |
177 | // Validate input subscripts, build the label to dimension mapping and obtain |
178 | // the broadcast shapes that map to ellipsis. |
179 | absl::flat_hash_map<char, DimensionHandle> label_to_dimension; |
180 | gtl::InlinedVector<ShapeHandle, 2> input_bcast_shapes(c->num_inputs()); |
181 | for (int i = 0, end = c->num_inputs(); i < end; ++i) { |
182 | bool has_ellipsis = false; |
183 | TF_RETURN_IF_ERROR(ValidateEinsumEllipsis(input_labels[i], &has_ellipsis)); |
184 | ShapeHandle input_shape = c->input(i); |
185 | // Validate that the input rank is sufficient for the given number of named |
186 | // labels. |
187 | if (c->RankKnown(input_shape)) { |
188 | if (has_ellipsis) { |
189 | const int num_named_labels = |
190 | static_cast<int>(input_labels[i].size()) - 3; |
191 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
192 | c->WithRankAtLeast(input_shape, num_named_labels, &input_shape), |
193 | " for " , i, "th input and equation: " , equation); |
194 | } else { |
195 | const int num_named_labels = static_cast<int>(input_labels[i].size()); |
196 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
197 | c->WithRank(input_shape, num_named_labels, &input_shape), " for " , |
198 | i, "th input and equation: " , equation); |
199 | } |
200 | } |
201 | |
202 | bool seen_ellipsis = false; |
203 | input_bcast_shapes[i] = c->Scalar(); |
204 | // Run through the input labels; populate label_to_dimension mapping and |
205 | // compute the broadcast shapes corresponding to the ellipsis (if present). |
206 | for (int label_idx = 0, end = input_labels[i].size(); label_idx < end; |
207 | ++label_idx) { |
208 | const char label = input_labels[i][label_idx]; |
209 | // Calculate the input axis that the current label is referring to. After |
210 | // the ellipsis, the axis may be found by using negative indices; i.e the |
211 | // (rank - k)th dimension corresponds to the (num_labels - k)th label. |
212 | const int64_t axis_before_ellipsis = label_idx; |
213 | const int64_t axis_after_ellipsis = |
214 | c->RankKnown(input_shape) |
215 | ? label_idx + c->Rank(input_shape) - input_labels[i].size() |
216 | : -1; |
217 | |
218 | // Populate the input broadcast shape when we encounter an ellipsis (...). |
219 | if (label == '.') { |
220 | if (!c->RankKnown(input_shape)) { |
221 | input_bcast_shapes[i] = c->UnknownShape(); |
222 | } else { |
223 | // The broadcast shape runs till the named label right after the |
224 | // ellipsis, the label with index (label_idx + 3). |
225 | TF_RETURN_IF_ERROR(c->Subshape(input_shape, axis_before_ellipsis, |
226 | axis_after_ellipsis + 3, |
227 | &input_bcast_shapes[i])); |
228 | } |
229 | label_idx += 2; // Skip the rest of the ellipsis. |
230 | seen_ellipsis = true; |
231 | continue; |
232 | } |
233 | // Obtain the dimension that the current label corresponds to. |
234 | int64_t axis = seen_ellipsis ? axis_after_ellipsis : axis_before_ellipsis; |
235 | DimensionHandle new_dim = c->RankKnown(input_shape) |
236 | ? c->Dim(input_shape, axis) |
237 | : c->UnknownDim(); |
238 | // If we've seen this label before, make sure previous and current |
239 | // dimensions are compatible. |
240 | if (label_to_dimension.contains(label)) { |
241 | DimensionHandle merged; |
242 | TF_RETURN_IF_ERROR( |
243 | c->Merge(label_to_dimension[label], new_dim, &merged)); |
244 | label_to_dimension[label] = merged; |
245 | } else { |
246 | label_to_dimension[label] = new_dim; |
247 | } |
248 | } |
249 | } |
250 | |
251 | // For two inputs, broadcast the two input broadcast shapes to create the |
252 | // output broadcast shape. For one input, just copy the single broadcast |
253 | // shape. |
254 | ShapeHandle output_bcast_shape; |
255 | if (input_bcast_shapes.size() == 1) { |
256 | output_bcast_shape = input_bcast_shapes[0]; |
257 | } else if (input_bcast_shapes.size() == 2) { |
258 | TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( |
259 | c, input_bcast_shapes[0], input_bcast_shapes[1], true, |
260 | &output_bcast_shape)); |
261 | } |
262 | |
263 | bool output_has_ellipsis = false; |
264 | TF_RETURN_IF_ERROR( |
265 | ValidateEinsumEllipsis(output_labels, &output_has_ellipsis)); |
266 | if (output_has_ellipsis) { |
267 | // If the output subscript has ellipsis and the output broadcast rank is |
268 | // unknown, then the output shape should have unknown rank. |
269 | if (!c->RankKnown(output_bcast_shape)) { |
270 | c->set_output(0, c->UnknownShape()); |
271 | return OkStatus(); |
272 | } |
273 | } else { |
274 | // If the output subscripts don't have ellipsis then make sure the output |
275 | // broadcasting shape is empty. |
276 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
277 | c->WithRankAtMost(output_bcast_shape, 0, &output_bcast_shape), |
278 | " for einsum equation '" , equation, |
279 | "' without ellipsis (...) in the output subscripts where input(s) have " |
280 | "non-empty broadcasting shape" ); |
281 | output_bcast_shape = c->Scalar(); |
282 | } |
283 | |
284 | // Create the output shape from output labels and label_to_dimension mapping. |
285 | std::vector<DimensionHandle> output_dims; |
286 | for (int label_idx = 0, end = output_labels.size(); label_idx < end; |
287 | ++label_idx) { |
288 | const char label = output_labels[label_idx]; |
289 | // Append the output_bcast_shape when the ellipsis is encountered. |
290 | if (label == '.') { |
291 | for (int k = 0; k < c->Rank(output_bcast_shape); ++k) { |
292 | output_dims.push_back(c->Dim(output_bcast_shape, k)); |
293 | } |
294 | label_idx += 2; // Skip the rest of the ellipsis. |
295 | continue; |
296 | } |
297 | auto dimension_it = label_to_dimension.find(label); |
298 | if (dimension_it == label_to_dimension.end()) { |
299 | return errors::InvalidArgument( |
300 | "Einsum output subscripts for equation '" , equation, "' has label '" , |
301 | label, "' which is not present in the input subscripts" ); |
302 | } |
303 | output_dims.push_back(dimension_it->second); |
304 | } |
305 | c->set_output(0, c->MakeShape(output_dims)); |
306 | return OkStatus(); |
307 | } |
308 | |
309 | Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) { |
310 | ShapeHandle a_shape; |
311 | ShapeHandle b_shape; |
312 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape)); |
313 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape)); |
314 | |
315 | // Determine output rows and columns. |
316 | bool adj_x; |
317 | bool adj_y; |
318 | TF_RETURN_IF_ERROR(c->GetAttr("adj_x" , &adj_x)); |
319 | TF_RETURN_IF_ERROR(c->GetAttr("adj_y" , &adj_y)); |
320 | DimensionHandle output_rows = c->Dim(a_shape, adj_x ? -1 : -2); |
321 | DimensionHandle output_cols = c->Dim(b_shape, adj_y ? -2 : -1); |
322 | |
323 | // Inner dimensions should be compatible. |
324 | DimensionHandle inner_merged; |
325 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1), |
326 | c->Dim(b_shape, adj_y ? -1 : -2), &inner_merged)); |
327 | |
328 | // Batch dimensions should broadcast with each other. |
329 | ShapeHandle a_batch_shape; |
330 | ShapeHandle b_batch_shape; |
331 | ShapeHandle output_batch_shape; |
332 | TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_shape)); |
333 | TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_shape)); |
334 | |
335 | TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( |
336 | c, a_batch_shape, b_batch_shape, true, &output_batch_shape)); |
337 | |
338 | ShapeHandle output_shape; |
339 | TF_RETURN_IF_ERROR(c->Concatenate( |
340 | output_batch_shape, c->Matrix(output_rows, output_cols), &output_shape)); |
341 | |
342 | c->set_output(0, output_shape); |
343 | return OkStatus(); |
344 | } |
345 | |
346 | Status BatchMatMulShape(shape_inference::InferenceContext* c) { |
347 | ShapeHandle a_shape; |
348 | ShapeHandle b_shape; |
349 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape)); |
350 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape)); |
351 | |
352 | // Determine output rows and cols. |
353 | bool adj_x; |
354 | bool adj_y; |
355 | TF_RETURN_IF_ERROR(c->GetAttr("adj_x" , &adj_x)); |
356 | TF_RETURN_IF_ERROR(c->GetAttr("adj_y" , &adj_y)); |
357 | DimensionHandle output_rows = c->Dim(a_shape, adj_x ? -1 : -2); |
358 | DimensionHandle output_cols = c->Dim(b_shape, adj_y ? -2 : -1); |
359 | |
360 | // Batch dims match between inputs. |
361 | ShapeHandle a_batch_dims; |
362 | ShapeHandle b_batch_dims; |
363 | ShapeHandle batch_dims; |
364 | TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims)); |
365 | TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims)); |
366 | TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims)); |
367 | |
368 | // Assert inner dims match. |
369 | DimensionHandle unused; |
370 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1), |
371 | c->Dim(b_shape, adj_y ? -1 : -2), &unused)); |
372 | |
373 | ShapeHandle out; |
374 | TF_RETURN_IF_ERROR( |
375 | c->Concatenate(batch_dims, c->Matrix(output_rows, output_cols), &out)); |
376 | c->set_output(0, out); |
377 | return OkStatus(); |
378 | } |
379 | |
380 | // -------------------------------------------------------------------------- |
381 | |
382 | Status BiasAddShape(shape_inference::InferenceContext* c) { |
383 | ShapeHandle input_shape; |
384 | |
385 | // Fetch the data_format attribute, which may not exist. |
386 | string data_format; |
387 | Status s = c->GetAttr("data_format" , &data_format); |
388 | |
389 | if (s.ok() && data_format == "NCHW" ) { |
390 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape)); |
391 | } else { |
392 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape)); |
393 | } |
394 | |
395 | ShapeHandle bias_shape; |
396 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &bias_shape)); |
397 | DimensionHandle bias_dim = c->Dim(bias_shape, 0); |
398 | |
399 | // If rank unknown, return unknown shape. |
400 | if (!c->RankKnown(input_shape)) { |
401 | c->set_output(0, c->UnknownShape()); |
402 | return OkStatus(); |
403 | } |
404 | |
405 | // Output has the same shape as the input, and matches the length of |
406 | // the bias in its bias dimension. |
407 | ShapeHandle output_shape; |
408 | if (s.ok() && data_format == "NCHW" ) { |
409 | // Merge the length of bias_shape into the third to last dimension |
410 | ShapeHandle first; |
411 | TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, 1, &first)); |
412 | |
413 | ShapeHandle last; |
414 | TF_RETURN_IF_ERROR(c->Subshape(input_shape, 2, &last)); |
415 | |
416 | DimensionHandle input_bias_dim = c->Dim(input_shape, 1); |
417 | DimensionHandle merged_bias_dim; |
418 | TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim)); |
419 | ShapeHandle merged_bias = c->Vector(merged_bias_dim); |
420 | |
421 | ShapeHandle temp; |
422 | TF_RETURN_IF_ERROR(c->Concatenate(first, merged_bias, &temp)); |
423 | TF_RETURN_IF_ERROR(c->Concatenate(temp, last, &output_shape)); |
424 | } else { |
425 | ShapeHandle all_but_bias; |
426 | TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -1, &all_but_bias)); |
427 | |
428 | DimensionHandle input_bias_dim = c->Dim(input_shape, -1); |
429 | DimensionHandle merged_bias_dim; |
430 | TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim)); |
431 | |
432 | ShapeHandle merged_bias = c->Vector(merged_bias_dim); |
433 | TF_RETURN_IF_ERROR( |
434 | c->Concatenate(all_but_bias, merged_bias, &output_shape)); |
435 | } |
436 | |
437 | c->set_output(0, output_shape); |
438 | return OkStatus(); |
439 | } |
440 | |
441 | Status BiasAddGradShape(shape_inference::InferenceContext* c) { |
442 | ShapeHandle input_shape; |
443 | // Fetch the data_format attribute, which may not exist. |
444 | string data_format; |
445 | Status s = c->GetAttr("data_format" , &data_format); |
446 | |
447 | if (s.ok() && data_format == "NCHW" ) { |
448 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape)); |
449 | c->set_output(0, c->Vector(c->Dim(input_shape, 1))); |
450 | } else { |
451 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape)); |
452 | c->set_output(0, c->Vector(c->Dim(input_shape, -1))); |
453 | } |
454 | |
455 | return OkStatus(); |
456 | } |
457 | |
458 | Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format, |
459 | const ShapeHandle shape_handle, |
460 | const string& tensor_name, |
461 | shape_inference::InferenceContext* c) { |
462 | if (tensor_format == FORMAT_NCHW_VECT_C) { |
463 | // Check that the vect dim has size 4 or 32. |
464 | const int num_dims = c->Rank(shape_handle); |
465 | DimensionHandle vect_dim = c->Dim( |
466 | shape_handle, GetTensorInnerFeatureDimIndex(num_dims, tensor_format)); |
467 | int64_t vect_dim_val = c->Value(vect_dim); |
468 | if (vect_dim_val != 4 && vect_dim_val != 32) { |
469 | return errors::InvalidArgument( |
470 | "VECT_C dimension must be 4 or 32, but is " , vect_dim_val); |
471 | } |
472 | } |
473 | |
474 | return OkStatus(); |
475 | } |
476 | |
477 | Status DatasetIteratorShape(shape_inference::InferenceContext* c) { |
478 | shape_inference::ShapeHandle unused; |
479 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
480 | std::vector<PartialTensorShape> output_shapes; |
481 | TF_RETURN_IF_ERROR(c->GetAttr("output_shapes" , &output_shapes)); |
482 | const int output_shapes_size = output_shapes.size(); |
483 | if (output_shapes_size != c->num_outputs()) { |
484 | return errors::InvalidArgument( |
485 | "`output_shapes` must be the same length as `output_types` (" , |
486 | output_shapes.size(), " vs. " , c->num_outputs()); |
487 | } |
488 | for (size_t i = 0; i < output_shapes.size(); ++i) { |
489 | shape_inference::ShapeHandle output_shape_handle; |
490 | TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( |
491 | output_shapes[i], &output_shape_handle)); |
492 | c->set_output(static_cast<int>(i), output_shape_handle); |
493 | } |
494 | return OkStatus(); |
495 | } |
496 | |
497 | Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N, |
498 | const std::vector<DimensionOrConstant>& spatial, |
499 | DimensionOrConstant C, ShapeHandle* out, |
500 | shape_inference::InferenceContext* context) { |
501 | const int num_dims = GetTensorDimsFromSpatialDims(spatial.size(), format); |
502 | std::vector<DimensionHandle> dims_actual(num_dims); |
503 | dims_actual[GetTensorBatchDimIndex(num_dims, format)] = context->MakeDim(N); |
504 | int outer_c_index = GetTensorFeatureDimIndex(num_dims, format); |
505 | dims_actual[outer_c_index] = context->MakeDim(C); |
506 | if (format == FORMAT_NCHW_VECT_C) { |
507 | dims_actual[GetTensorInnerFeatureDimIndex(num_dims, format)] = |
508 | context->MakeDim(4); |
509 | } else if (format == FORMAT_NHWC_VECT_W) { |
510 | dims_actual[GetTensorInnerWidthDimIndex(num_dims, format)] = |
511 | context->MakeDim(4); |
512 | } |
513 | for (int spatial_dim = 0, end = spatial.size(); spatial_dim < end; |
514 | spatial_dim++) { |
515 | dims_actual[GetTensorSpatialDimIndex(num_dims, format, spatial_dim)] = |
516 | context->MakeDim(spatial[spatial_dim]); |
517 | } |
518 | *out = context->MakeShape(dims_actual); |
519 | return OkStatus(); |
520 | } |
521 | |
522 | Status DimensionsFromShape(ShapeHandle shape, TensorFormat format, |
523 | DimensionHandle* batch_dim, |
524 | gtl::MutableArraySlice<DimensionHandle> spatial_dims, |
525 | DimensionHandle* filter_dim, |
526 | InferenceContext* context) { |
527 | const int32_t rank = |
528 | GetTensorDimsFromSpatialDims(spatial_dims.size(), format); |
529 | // Batch. |
530 | *batch_dim = context->Dim(shape, GetTensorBatchDimIndex(rank, format)); |
531 | // Spatial. |
532 | for (int spatial_dim_index = 0, end = spatial_dims.size(); |
533 | spatial_dim_index < end; ++spatial_dim_index) { |
534 | spatial_dims[spatial_dim_index] = context->Dim( |
535 | shape, GetTensorSpatialDimIndex(rank, format, spatial_dim_index)); |
536 | } |
537 | // Channel. |
538 | *filter_dim = context->Dim(shape, GetTensorFeatureDimIndex(rank, format)); |
539 | if (format == FORMAT_NCHW_VECT_C) { |
540 | TF_RETURN_IF_ERROR(context->Multiply( |
541 | *filter_dim, |
542 | context->Dim(shape, GetTensorInnerFeatureDimIndex(rank, format)), |
543 | filter_dim)); |
544 | } |
545 | return OkStatus(); |
546 | } |
547 | |
548 | // vect_size must be provided if format is NCHW_VECT_C. |
549 | Status ShapeFromDimensions(DimensionHandle batch_dim, |
550 | gtl::ArraySlice<DimensionHandle> spatial_dims, |
551 | DimensionHandle filter_dim, TensorFormat format, |
552 | absl::optional<DimensionHandle> vect_size, |
553 | InferenceContext* context, ShapeHandle* shape) { |
554 | const int32_t rank = |
555 | GetTensorDimsFromSpatialDims(spatial_dims.size(), format); |
556 | std::vector<DimensionHandle> out_dims(rank); |
557 | |
558 | // Batch. |
559 | out_dims[tensorflow::GetTensorBatchDimIndex(rank, format)] = batch_dim; |
560 | // Spatial. |
561 | for (int spatial_dim_index = 0, end = spatial_dims.size(); |
562 | spatial_dim_index < end; ++spatial_dim_index) { |
563 | out_dims[tensorflow::GetTensorSpatialDimIndex( |
564 | rank, format, spatial_dim_index)] = spatial_dims[spatial_dim_index]; |
565 | } |
566 | // Channel. |
567 | if (format == tensorflow::FORMAT_NCHW_VECT_C) { |
568 | // When format is NCHW_VECT_C, factor the feature map count into the outer |
569 | // feature count and the inner feature count (4 or 32). |
570 | CHECK(vect_size.has_value()); // Crash ok. |
571 | TF_RETURN_IF_ERROR(context->Divide( |
572 | filter_dim, *vect_size, /*evenly_divisible=*/true, |
573 | &out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)])); |
574 | out_dims[GetTensorInnerFeatureDimIndex(rank, format)] = *vect_size; |
575 | } else { |
576 | out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)] = filter_dim; |
577 | } |
578 | |
579 | *shape = context->MakeShape(out_dims); |
580 | return OkStatus(); |
581 | } |
582 | |
583 | namespace { |
584 | |
585 | Status Conv2DShapeImpl(shape_inference::InferenceContext* c, |
586 | bool supports_explicit_padding) { |
587 | string data_format_str, filter_format_str; |
588 | if (!c->GetAttr("data_format" , &data_format_str).ok()) { |
589 | data_format_str = "NHWC" ; |
590 | } |
591 | if (!c->GetAttr("filter_format" , &filter_format_str).ok()) { |
592 | filter_format_str = |
593 | data_format_str == "NCHW_VECT_C" ? "OIHW_VECT_I" : "HWIO" ; |
594 | } |
595 | |
596 | TensorFormat data_format; |
597 | if (!FormatFromString(data_format_str, &data_format)) { |
598 | return errors::InvalidArgument("Invalid data format string: " , |
599 | data_format_str); |
600 | } |
601 | FilterTensorFormat filter_format; |
602 | if (!FilterFormatFromString(filter_format_str, &filter_format)) { |
603 | return errors::InvalidArgument("Invalid filter format string: " , |
604 | filter_format_str); |
605 | } |
606 | |
607 | constexpr int num_spatial_dims = 2; |
608 | const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format); |
609 | ShapeHandle conv_input_shape; |
610 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &conv_input_shape)); |
611 | TF_RETURN_IF_ERROR(CheckFormatConstraintsOnShape( |
612 | data_format, conv_input_shape, "conv_input" , c)); |
613 | |
614 | // The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C). |
615 | ShapeHandle filter_shape; |
616 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape)); |
617 | TF_RETURN_IF_ERROR( |
618 | CheckFormatConstraintsOnShape(data_format, filter_shape, "filter" , c)); |
619 | |
620 | std::vector<int32> dilations; |
621 | TF_RETURN_IF_ERROR(c->GetAttr("dilations" , &dilations)); |
622 | |
623 | if (dilations.size() != 4) { |
624 | return errors::InvalidArgument( |
625 | "Conv2D requires the dilation attribute to contain 4 values, but got: " , |
626 | dilations.size()); |
627 | } |
628 | |
629 | std::vector<int32> strides; |
630 | TF_RETURN_IF_ERROR(c->GetAttr("strides" , &strides)); |
631 | |
632 | // strides.size() should be 4 (NCHW) even if the input is 5 (NCHW_VECT_C). |
633 | if (strides.size() != 4) { |
634 | return errors::InvalidArgument("Conv2D on data format " , data_format_str, |
635 | " requires the stride attribute to contain" |
636 | " 4 values, but got: " , |
637 | strides.size()); |
638 | } |
639 | |
640 | const int32_t stride_rows = GetTensorDim(strides, data_format, 'H'); |
641 | const int32_t stride_cols = GetTensorDim(strides, data_format, 'W'); |
642 | const int32_t dilation_rows = GetTensorDim(dilations, data_format, 'H'); |
643 | const int32_t dilation_cols = GetTensorDim(dilations, data_format, 'W'); |
644 | |
645 | DimensionHandle batch_size_dim; |
646 | DimensionHandle input_depth_dim; |
647 | gtl::InlinedVector<DimensionHandle, 2> input_spatial_dims(2); |
648 | TF_RETURN_IF_ERROR(DimensionsFromShape( |
649 | conv_input_shape, data_format, &batch_size_dim, |
650 | absl::MakeSpan(input_spatial_dims), &input_depth_dim, c)); |
651 | |
652 | DimensionHandle output_depth_dim = c->Dim( |
653 | filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O')); |
654 | DimensionHandle filter_rows_dim = c->Dim( |
655 | filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'H')); |
656 | DimensionHandle filter_cols_dim = c->Dim( |
657 | filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'W')); |
658 | DimensionHandle filter_input_depth_dim; |
659 | if (filter_format == FORMAT_OIHW_VECT_I) { |
660 | TF_RETURN_IF_ERROR(c->Multiply( |
661 | c->Dim(filter_shape, |
662 | GetFilterDimIndex<num_spatial_dims>(filter_format, 'I')), |
663 | c->Dim(filter_shape, |
664 | GetFilterTensorInnerInputChannelsDimIndex(rank, filter_format)), |
665 | &filter_input_depth_dim)); |
666 | } else { |
667 | filter_input_depth_dim = c->Dim( |
668 | filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'I')); |
669 | } |
670 | |
671 | // Check that the input tensor and the filter tensor agree on the channel |
672 | // count. |
673 | if (c->ValueKnown(input_depth_dim) && c->ValueKnown(filter_input_depth_dim)) { |
674 | int64_t input_depth_value = c->Value(input_depth_dim), |
675 | filter_input_depth_value = c->Value(filter_input_depth_dim); |
676 | if (filter_input_depth_value == 0) |
677 | return errors::InvalidArgument("Depth of filter must not be 0" ); |
678 | if (input_depth_value % filter_input_depth_value != 0) |
679 | return errors::InvalidArgument( |
680 | "Depth of input (" , input_depth_value, |
681 | ") is not a multiple of input depth of filter (" , |
682 | filter_input_depth_value, ")" ); |
683 | if (input_depth_value != filter_input_depth_value) { |
684 | int64_t num_groups = input_depth_value / filter_input_depth_value; |
685 | if (c->ValueKnown(output_depth_dim)) { |
686 | int64_t output_depth_value = c->Value(output_depth_dim); |
687 | if (num_groups == 0) |
688 | return errors::InvalidArgument("Number of groups must not be 0" ); |
689 | if (output_depth_value % num_groups != 0) |
690 | return errors::InvalidArgument( |
691 | "Depth of output (" , output_depth_value, |
692 | ") is not a multiple of the number of groups (" , num_groups, ")" ); |
693 | } |
694 | } |
695 | } |
696 | |
697 | Padding padding; |
698 | TF_RETURN_IF_ERROR(c->GetAttr("padding" , &padding)); |
699 | std::vector<int64_t> explicit_paddings; |
700 | if (supports_explicit_padding) { |
701 | Status s = c->GetAttr("explicit_paddings" , &explicit_paddings); |
702 | // Use the default value, which is an empty list, if the attribute is not |
703 | // found. Otherwise return the error to the caller. |
704 | if (!s.ok() && !errors::IsNotFound(s)) { |
705 | return s; |
706 | } |
707 | TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings, |
708 | /*num_dims=*/4, data_format)); |
709 | } else { |
710 | if (padding == Padding::EXPLICIT) { |
711 | return errors::InvalidArgument( |
712 | "Expected non-explicit padding but got explicit padding" ); |
713 | } |
714 | std::vector<int64_t> p_list; |
715 | // `padding_list` attribute is used by Fused int8 convolutions to support |
716 | // explicit paddings. |
717 | Status s_p_list = c->GetAttr("padding_list" , &p_list); |
718 | if (!s_p_list.ok() && !errors::IsNotFound(s_p_list)) { |
719 | return s_p_list; |
720 | } |
721 | if (s_p_list.ok() && !p_list.empty()) { |
722 | padding = Padding::EXPLICIT; |
723 | explicit_paddings = p_list; |
724 | TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings, |
725 | /*num_dims=*/4, data_format)); |
726 | } |
727 | } |
728 | |
729 | DimensionHandle output_rows, output_cols; |
730 | int64_t pad_rows_before = -1, pad_rows_after = -1; |
731 | int64_t pad_cols_before = -1, pad_cols_after = -1; |
732 | if (padding == Padding::EXPLICIT) { |
733 | GetExplicitPaddingForDim(explicit_paddings, data_format, 'H', |
734 | &pad_rows_before, &pad_rows_after); |
735 | GetExplicitPaddingForDim(explicit_paddings, data_format, 'W', |
736 | &pad_cols_before, &pad_cols_after); |
737 | } |
738 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2( |
739 | c, input_spatial_dims[0], filter_rows_dim, dilation_rows, stride_rows, |
740 | padding, pad_rows_before, pad_rows_after, &output_rows)); |
741 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2( |
742 | c, input_spatial_dims[1], filter_cols_dim, dilation_cols, stride_cols, |
743 | padding, pad_cols_before, pad_cols_after, &output_cols)); |
744 | |
745 | absl::optional<DimensionHandle> vect_size; |
746 | if (data_format == FORMAT_NCHW_VECT_C) { |
747 | vect_size.emplace(c->Dim(conv_input_shape, |
748 | GetTensorInnerFeatureDimIndex(rank, data_format))); |
749 | } |
750 | ShapeHandle output_shape; |
751 | TF_RETURN_IF_ERROR(ShapeFromDimensions( |
752 | batch_size_dim, {output_rows, output_cols}, output_depth_dim, data_format, |
753 | vect_size, c, &output_shape)); |
754 | c->set_output(0, output_shape); |
755 | return OkStatus(); |
756 | } |
757 | |
758 | } // namespace |
759 | |
760 | // Shape function for Conv2D-like operations that support explicit padding. |
761 | Status Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext* c) { |
762 | return Conv2DShapeImpl(c, true); |
763 | } |
764 | |
765 | // Shape function for Conv2D-like operations that do not support explicit |
766 | // padding. |
767 | Status Conv2DShape(shape_inference::InferenceContext* c) { |
768 | return Conv2DShapeImpl(c, false); |
769 | } |
770 | |
771 | // TODO(mjanusz): Unify all conv/pooling shape functions. |
772 | Status Conv3DShape(shape_inference::InferenceContext* c) { |
773 | ShapeHandle input_shape; |
774 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape)); |
775 | ShapeHandle filter_shape; |
776 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &filter_shape)); |
777 | |
778 | string data_format; |
779 | Status s = c->GetAttr("data_format" , &data_format); |
780 | |
781 | std::vector<int32> dilations; |
782 | TF_RETURN_IF_ERROR(c->GetAttr("dilations" , &dilations)); |
783 | |
784 | if (dilations.size() != 5) { |
785 | return errors::InvalidArgument( |
786 | "Conv3D requires the dilation attribute to contain 5 values, but got: " , |
787 | dilations.size()); |
788 | } |
789 | |
790 | std::vector<int32> strides; |
791 | TF_RETURN_IF_ERROR(c->GetAttr("strides" , &strides)); |
792 | if (strides.size() != 5) { |
793 | return errors::InvalidArgument( |
794 | "Conv3D requires the stride attribute to contain 5 values, but got: " , |
795 | strides.size()); |
796 | } |
797 | |
798 | int32_t stride_planes, stride_rows, stride_cols; |
799 | int32_t dilation_planes, dilation_rows, dilation_cols; |
800 | if (s.ok() && data_format == "NCDHW" ) { |
801 | // Convert input_shape to NDHWC. |
802 | auto dim = [&](char dimension) { |
803 | return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension)); |
804 | }; |
805 | input_shape = |
806 | c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}}); |
807 | stride_planes = strides[2]; |
808 | stride_rows = strides[3]; |
809 | stride_cols = strides[4]; |
810 | dilation_planes = dilations[2]; |
811 | dilation_cols = dilations[3]; |
812 | dilation_rows = dilations[4]; |
813 | } else { |
814 | stride_planes = strides[1]; |
815 | stride_rows = strides[2]; |
816 | stride_cols = strides[3]; |
817 | dilation_planes = dilations[1]; |
818 | dilation_cols = dilations[2]; |
819 | dilation_rows = dilations[3]; |
820 | } |
821 | |
822 | DimensionHandle batch_size_dim = c->Dim(input_shape, 0); |
823 | DimensionHandle in_planes_dim = c->Dim(input_shape, 1); |
824 | DimensionHandle in_rows_dim = c->Dim(input_shape, 2); |
825 | DimensionHandle in_cols_dim = c->Dim(input_shape, 3); |
826 | DimensionHandle input_depth_dim = c->Dim(input_shape, 4); |
827 | |
828 | DimensionHandle filter_planes_dim = c->Dim(filter_shape, 0); |
829 | DimensionHandle filter_rows_dim = c->Dim(filter_shape, 1); |
830 | DimensionHandle filter_cols_dim = c->Dim(filter_shape, 2); |
831 | DimensionHandle filter_input_depth_dim = c->Dim(filter_shape, 3); |
832 | DimensionHandle output_depth_dim = c->Dim(filter_shape, 4); |
833 | |
834 | // Check that the input tensor and the filter tensor agree on the channel |
835 | // count. |
836 | if (c->ValueKnown(input_depth_dim) && c->ValueKnown(filter_input_depth_dim)) { |
837 | int64_t input_depth_value = c->Value(input_depth_dim), |
838 | filter_input_depth_value = c->Value(filter_input_depth_dim); |
839 | if (filter_input_depth_value == 0) |
840 | return errors::InvalidArgument("Depth of filter must not be 0" ); |
841 | if (input_depth_value % filter_input_depth_value != 0) |
842 | return errors::InvalidArgument( |
843 | "Depth of input (" , input_depth_value, |
844 | ") is not a multiple of input depth of filter (" , |
845 | filter_input_depth_value, ")" ); |
846 | if (input_depth_value != filter_input_depth_value) { |
847 | int64_t num_groups = input_depth_value / filter_input_depth_value; |
848 | if (c->ValueKnown(output_depth_dim)) { |
849 | int64_t output_depth_value = c->Value(output_depth_dim); |
850 | if (num_groups == 0) |
851 | return errors::InvalidArgument("Number of groups must not be 0" ); |
852 | if (output_depth_value % num_groups != 0) |
853 | return errors::InvalidArgument( |
854 | "Depth of output (" , output_depth_value, |
855 | ") is not a multiple of the number of groups (" , num_groups, ")" ); |
856 | } |
857 | } |
858 | } |
859 | |
860 | Padding padding; |
861 | TF_RETURN_IF_ERROR(c->GetAttr("padding" , &padding)); |
862 | DimensionHandle output_planes, output_rows, output_cols; |
863 | |
864 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2( |
865 | c, in_planes_dim, filter_planes_dim, dilation_planes, stride_planes, |
866 | padding, -1, -1, &output_planes)); |
867 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2( |
868 | c, in_rows_dim, filter_rows_dim, dilation_rows, stride_rows, padding, -1, |
869 | -1, &output_rows)); |
870 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2( |
871 | c, in_cols_dim, filter_cols_dim, dilation_cols, stride_cols, padding, -1, |
872 | -1, &output_cols)); |
873 | |
874 | ShapeHandle output_shape; |
875 | if (data_format == "NCDHW" ) { |
876 | output_shape = c->MakeShape({batch_size_dim, output_depth_dim, |
877 | output_planes, output_rows, output_cols}); |
878 | } else { |
879 | output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows, |
880 | output_cols, output_depth_dim}); |
881 | } |
882 | c->set_output(0, output_shape); |
883 | return OkStatus(); |
884 | } |
885 | |
886 | Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c) { |
887 | string data_format_str; |
888 | if (!c->GetAttr("data_format" , &data_format_str).ok()) { |
889 | data_format_str = "NHWC" ; |
890 | } |
891 | TensorFormat data_format; |
892 | if (!FormatFromString(data_format_str, &data_format)) { |
893 | return errors::InvalidArgument("Invalid data format string: " , |
894 | data_format_str); |
895 | } |
896 | |
897 | // For the rest of this function, output_grad_* describes out_backprop and |
898 | // input_grad_* describes in_backprop. |
899 | ShapeHandle output_grad_shape = c->input(2); |
900 | TF_RETURN_IF_ERROR(c->WithRank(output_grad_shape, 4, &output_grad_shape)); |
901 | ShapeHandle filter_shape = c->input(1); |
902 | TF_RETURN_IF_ERROR(c->WithRank(filter_shape, 4, &filter_shape)); |
903 | |
904 | DimensionHandle batch_size_dim; |
905 | DimensionHandle output_grad_depth_dim; |
906 | gtl::InlinedVector<DimensionHandle, 2> output_grad_spatial_dims(2); |
907 | TF_RETURN_IF_ERROR(DimensionsFromShape( |
908 | output_grad_shape, data_format, &batch_size_dim, |
909 | absl::MakeSpan(output_grad_spatial_dims), &output_grad_depth_dim, c)); |
910 | DimensionHandle unused; |
911 | TF_RETURN_IF_ERROR( |
912 | c->Merge(output_grad_depth_dim, c->Dim(filter_shape, 3), &unused)); |
913 | |
914 | ShapeHandle specified_input_grad_shape; |
915 | TF_RETURN_IF_ERROR( |
916 | c->MakeShapeFromShapeTensor(0, &specified_input_grad_shape)); |
917 | if (c->Rank(specified_input_grad_shape) == InferenceContext::kUnknownRank) { |
918 | TF_RETURN_IF_ERROR(c->WithRank(specified_input_grad_shape, 4, |
919 | &specified_input_grad_shape)); |
920 | } |
921 | |
922 | // input_grad_depth_dim doesn't equal c->Dim(filter_shape,2) when the number |
923 | // of groups is larger than 1. If input_sizes is a 4D shape, we collect |
924 | // input_grad_depth_dim from input_sizes; otherwise we compute it as |
925 | // c->Dim(filter_shape,2). |
926 | DimensionHandle input_grad_depth_dim; |
927 | gtl::InlinedVector<DimensionHandle, 2> specified_input_grad_spatial_dims(2); |
928 | int specified_input_grad_rank = c->Rank(specified_input_grad_shape); |
929 | if (specified_input_grad_rank == 4) { |
930 | DimensionHandle specified_batch_size_dim; |
931 | TF_RETURN_IF_ERROR(DimensionsFromShape( |
932 | specified_input_grad_shape, data_format, &specified_batch_size_dim, |
933 | absl::MakeSpan(specified_input_grad_spatial_dims), |
934 | &input_grad_depth_dim, c)); |
935 | TF_RETURN_IF_ERROR( |
936 | c->Merge(specified_batch_size_dim, batch_size_dim, &unused)); |
937 | } else if (specified_input_grad_rank == 2) { |
938 | specified_input_grad_spatial_dims[0] = |
939 | c->Dim(specified_input_grad_shape, 0); |
940 | specified_input_grad_spatial_dims[1] = |
941 | c->Dim(specified_input_grad_shape, 1); |
942 | input_grad_depth_dim = c->Dim(filter_shape, 2); |
943 | } else { |
944 | return errors::InvalidArgument( |
945 | "Conv2DBackpropInput requires input_sizes to contain 4 values or 2 " |
946 | "values, but got: " , |
947 | specified_input_grad_rank); |
948 | } |
949 | |
950 | ShapeHandle input_grad_shape; |
951 | TF_RETURN_IF_ERROR(ShapeFromDimensions( |
952 | batch_size_dim, specified_input_grad_spatial_dims, input_grad_depth_dim, |
953 | data_format, /*vect_size=*/absl::nullopt, c, &input_grad_shape)); |
954 | c->set_output(0, input_grad_shape); |
955 | return OkStatus(); |
956 | } |
957 | |
958 | Status Conv2DBackpropFilterWithBiasShape(shape_inference::InferenceContext* c) { |
959 | ShapeHandle input_shape; |
960 | // Fetch the data_format attribute, which may not exist. |
961 | string data_format; |
962 | Status s = c->GetAttr("data_format" , &data_format); |
963 | |
964 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); |
965 | if (s.ok() && data_format == "NCHW" ) { |
966 | c->set_output(1, c->Vector(c->Dim(input_shape, -3))); |
967 | } else { |
968 | c->set_output(1, c->Vector(c->Dim(input_shape, -1))); |
969 | } |
970 | ShapeHandle sh; |
971 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &sh)); |
972 | TF_RETURN_IF_ERROR(c->WithRank(sh, 4, &sh)); |
973 | c->set_output(0, sh); |
974 | return OkStatus(); |
975 | } |
976 | |
977 | namespace { |
978 | |
979 | Status DepthwiseConv2DNativeShapeImpl(shape_inference::InferenceContext* c, |
980 | bool supports_explicit_padding) { |
981 | ShapeHandle input_shape; |
982 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); |
983 | ShapeHandle filter_shape; |
984 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape)); |
985 | |
986 | std::vector<int32> strides; |
987 | TF_RETURN_IF_ERROR(c->GetAttr("strides" , &strides)); |
988 | |
989 | if (strides.size() != 4) { |
990 | return errors::InvalidArgument( |
991 | "DepthwiseConv2D requires the stride attribute to contain 4 values, " |
992 | "but got: " , |
993 | strides.size()); |
994 | } |
995 | |
996 | std::vector<int32> dilations; |
997 | if (!c->GetAttr("dilations" , &dilations).ok()) { |
998 | dilations.resize(4, 1); |
999 | } |
1000 | |
1001 | if (dilations.size() != 4) { |
1002 | return errors::InvalidArgument( |
1003 | "DepthwiseConv2D requires the dilations attribute to contain 4 values, " |
1004 | "but got: " , |
1005 | dilations.size()); |
1006 | } |
1007 | |
1008 | string data_format_str; |
1009 | Status s = c->GetAttr("data_format" , &data_format_str); |
1010 | TensorFormat data_format; |
1011 | if (!s.ok() || !FormatFromString(data_format_str, &data_format)) { |
1012 | data_format = FORMAT_NHWC; |
1013 | } |
1014 | int32_t stride_rows; |
1015 | int32_t stride_cols; |
1016 | int32_t dilation_rows; |
1017 | int32_t dilation_cols; |
1018 | if (data_format == FORMAT_NCHW) { |
1019 | // Canonicalize input shape to NHWC so the shape inference code below can |
1020 | // process it. |
1021 | input_shape = |
1022 | c->MakeShape({{c->Dim(input_shape, 0), c->Dim(input_shape, 2), |
1023 | c->Dim(input_shape, 3), c->Dim(input_shape, 1)}}); |
1024 | stride_rows = strides[2]; |
1025 | stride_cols = strides[3]; |
1026 | dilation_rows = dilations[2]; |
1027 | dilation_cols = dilations[3]; |
1028 | } else { |
1029 | stride_rows = strides[1]; |
1030 | stride_cols = strides[2]; |
1031 | dilation_rows = dilations[1]; |
1032 | dilation_cols = dilations[2]; |
1033 | } |
1034 | |
1035 | DimensionHandle batch_size_dim = c->Dim(input_shape, 0); |
1036 | DimensionHandle in_rows_dim = c->Dim(input_shape, 1); |
1037 | DimensionHandle in_cols_dim = c->Dim(input_shape, 2); |
1038 | |
1039 | DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0); |
1040 | DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1); |
1041 | DimensionHandle input_depth = c->Dim(filter_shape, 2); |
1042 | DimensionHandle depth_multiplier = c->Dim(filter_shape, 3); |
1043 | |
1044 | // Check that the input depths are compatible. |
1045 | TF_RETURN_IF_ERROR( |
1046 | c->Merge(c->Dim(input_shape, 3), input_depth, &input_depth)); |
1047 | |
1048 | DimensionHandle output_depth; |
1049 | TF_RETURN_IF_ERROR(c->Multiply(input_depth, depth_multiplier, &output_depth)); |
1050 | |
1051 | Padding padding; |
1052 | TF_RETURN_IF_ERROR(c->GetAttr("padding" , &padding)); |
1053 | |
1054 | std::vector<int64_t> explicit_paddings; |
1055 | if (supports_explicit_padding) { |
1056 | Status status = c->GetAttr("explicit_paddings" , &explicit_paddings); |
1057 | // Use the default value, which is an empty list, if the attribute is not |
1058 | // found. Otherwise return the error to the caller. |
1059 | if (!status.ok() && !errors::IsNotFound(status)) { |
1060 | return status; |
1061 | } |
1062 | TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings, |
1063 | /*num_dims=*/4, data_format)); |
1064 | } else { |
1065 | DCHECK(padding != Padding::EXPLICIT); |
1066 | } |
1067 | |
1068 | // TODO(mrry,shlens): Raise an error if the stride would cause |
1069 | // information in the input to be ignored. This will require a change |
1070 | // in the kernel implementation. |
1071 | DimensionHandle output_rows, output_cols; |
1072 | int64_t pad_rows_before = -1, pad_rows_after = -1; |
1073 | int64_t pad_cols_before = -1, pad_cols_after = -1; |
1074 | if (padding == Padding::EXPLICIT) { |
1075 | GetExplicitPaddingForDim(explicit_paddings, data_format, 'H', |
1076 | &pad_rows_before, &pad_rows_after); |
1077 | GetExplicitPaddingForDim(explicit_paddings, data_format, 'W', |
1078 | &pad_cols_before, &pad_cols_after); |
1079 | } |
1080 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2( |
1081 | c, in_rows_dim, filter_rows_dim, dilation_rows, stride_rows, padding, |
1082 | pad_rows_before, pad_rows_after, &output_rows)); |
1083 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2( |
1084 | c, in_cols_dim, filter_cols_dim, dilation_cols, stride_cols, padding, |
1085 | pad_cols_before, pad_cols_after, &output_cols)); |
1086 | |
1087 | ShapeHandle output_shape; |
1088 | if (data_format == FORMAT_NCHW) { |
1089 | output_shape = |
1090 | c->MakeShape({batch_size_dim, output_depth, output_rows, output_cols}); |
1091 | } else { |
1092 | output_shape = |
1093 | c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth}); |
1094 | } |
1095 | c->set_output(0, output_shape); |
1096 | return OkStatus(); |
1097 | } |
1098 | |
1099 | }; // namespace |
1100 | |
1101 | Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) { |
1102 | return DepthwiseConv2DNativeShapeImpl(c, false); |
1103 | } |
1104 | |
1105 | Status DepthwiseConv2DNativeShapeWithExplicitPadding( |
1106 | shape_inference::InferenceContext* c) { |
1107 | return DepthwiseConv2DNativeShapeImpl(c, true); |
1108 | } |
1109 | |
1110 | Status AvgPoolShape(shape_inference::InferenceContext* c) { |
1111 | string data_format_str; |
1112 | TensorFormat data_format; |
1113 | Status s = c->GetAttr("data_format" , &data_format_str); |
1114 | if (s.ok()) { |
1115 | FormatFromString(data_format_str, &data_format); |
1116 | } else { |
1117 | data_format = FORMAT_NHWC; |
1118 | } |
1119 | |
1120 | const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; |
1121 | ShapeHandle input_shape; |
1122 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); |
1123 | |
1124 | TF_RETURN_IF_ERROR( |
1125 | CheckFormatConstraintsOnShape(data_format, input_shape, "input" , c)); |
1126 | |
1127 | std::vector<int32> strides; |
1128 | TF_RETURN_IF_ERROR(c->GetAttr("strides" , &strides)); |
1129 | if (strides.size() != 4) { |
1130 | return errors::InvalidArgument( |
1131 | "AvgPool requires the stride attribute to contain 4 values, but got: " , |
1132 | strides.size()); |
1133 | } |
1134 | |
1135 | std::vector<int32> kernel_sizes; |
1136 | TF_RETURN_IF_ERROR(c->GetAttr("ksize" , &kernel_sizes)); |
1137 | if (kernel_sizes.size() != 4) { |
1138 | return errors::InvalidArgument( |
1139 | "AvgPool requires the ksize attribute to contain 4 values, but got: " , |
1140 | kernel_sizes.size()); |
1141 | } |
1142 | |
1143 | int32_t stride_rows = GetTensorDim(strides, data_format, 'H'); |
1144 | int32_t stride_cols = GetTensorDim(strides, data_format, 'W'); |
1145 | int32_t kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H'); |
1146 | int32_t kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W'); |
1147 | |
1148 | constexpr int num_spatial_dims = 2; |
1149 | DimensionHandle batch_size_dim = c->Dim( |
1150 | input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N')); |
1151 | DimensionHandle in_rows_dim = c->Dim( |
1152 | input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H')); |
1153 | DimensionHandle in_cols_dim = c->Dim( |
1154 | input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W')); |
1155 | DimensionHandle depth_dim = c->Dim( |
1156 | input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C')); |
1157 | |
1158 | Padding padding; |
1159 | TF_RETURN_IF_ERROR(c->GetAttr("padding" , &padding)); |
1160 | |
1161 | // TODO(mrry,shlens): Raise an error if the stride would cause |
1162 | // information in the input to be ignored. This will require a change |
1163 | // in the kernel implementation. |
1164 | |
1165 | DimensionHandle output_rows, output_cols; |
1166 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( |
1167 | c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows)); |
1168 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( |
1169 | c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols)); |
1170 | |
1171 | ShapeHandle output_shape; |
1172 | TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim, |
1173 | {output_rows, output_cols}, depth_dim, |
1174 | &output_shape, c)); |
1175 | c->set_output(0, output_shape); |
1176 | return OkStatus(); |
1177 | } |
1178 | |
1179 | Status AvgPoolGradShape(shape_inference::InferenceContext* c) { |
1180 | ShapeHandle s; |
1181 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); |
1182 | TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); |
1183 | c->set_output(0, s); |
1184 | return OkStatus(); |
1185 | } |
1186 | |
1187 | Status FusedBatchNormShape(shape_inference::InferenceContext* c) { |
1188 | string data_format_str; |
1189 | TF_RETURN_IF_ERROR(c->GetAttr("data_format" , &data_format_str)); |
1190 | TensorFormat data_format; |
1191 | if (!FormatFromString(data_format_str, &data_format)) { |
1192 | return errors::InvalidArgument("Invalid data format string: " , |
1193 | data_format_str); |
1194 | } |
1195 | const int rank = |
1196 | (data_format_str == "NDHWC" || data_format_str == "NCDHW" ) ? 5 : 4; |
1197 | ShapeHandle x; |
1198 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &x)); |
1199 | |
1200 | bool is_training; |
1201 | TF_RETURN_IF_ERROR(c->GetAttr("is_training" , &is_training)); |
1202 | float exponential_avg_factor; |
1203 | if (!c->GetAttr("exponential_avg_factor" , &exponential_avg_factor).ok()) { |
1204 | exponential_avg_factor = 1.0f; // default value |
1205 | } |
1206 | int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5; |
1207 | |
1208 | int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format); |
1209 | DimensionHandle channel_dim = c->Dim(x, channel_dim_index); |
1210 | |
1211 | // covers scale, offset, and if is_training is false, mean, variance |
1212 | for (int i = 1; i < number_inputs; ++i) { |
1213 | ShapeHandle vec; |
1214 | TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec)); |
1215 | TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim)); |
1216 | } |
1217 | |
1218 | ShapeHandle y; |
1219 | TF_RETURN_IF_ERROR(c->ReplaceDim(x, channel_dim_index, channel_dim, &y)); |
1220 | c->set_output(0, y); |
1221 | ShapeHandle vector_shape = c->Vector(channel_dim); |
1222 | c->set_output(1, vector_shape); |
1223 | c->set_output(2, vector_shape); |
1224 | c->set_output(3, vector_shape); |
1225 | c->set_output(4, vector_shape); |
1226 | return OkStatus(); |
1227 | } |
1228 | |
1229 | Status FusedBatchNormV3Shape(shape_inference::InferenceContext* c) { |
1230 | TF_RETURN_IF_ERROR(FusedBatchNormShape(c)); |
1231 | c->set_output(5, c->UnknownShape()); |
1232 | return OkStatus(); |
1233 | } |
1234 | |
1235 | Status FusedBatchNormExShape(shape_inference::InferenceContext* c) { |
1236 | TF_RETURN_IF_ERROR(FusedBatchNormV3Shape(c)); |
1237 | |
1238 | string data_format_str; |
1239 | TF_RETURN_IF_ERROR(c->GetAttr("data_format" , &data_format_str)); |
1240 | TensorFormat data_format; |
1241 | if (!FormatFromString(data_format_str, &data_format)) { |
1242 | return errors::InvalidArgument("Invalid data format string: " , |
1243 | data_format_str); |
1244 | } |
1245 | ShapeHandle x; |
1246 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x)); |
1247 | |
1248 | int channel_dim_index = GetTensorFeatureDimIndex(4, data_format); |
1249 | DimensionHandle channel_dim = c->Dim(x, channel_dim_index); |
1250 | |
1251 | // This is a cuDNN implementation constraint. |
1252 | if (c->ValueKnown(channel_dim) && c->Value(channel_dim) % 4 != 0) { |
1253 | return errors::InvalidArgument( |
1254 | "_FusedBatchNormEx channel dimension must be divisible by 4." ); |
1255 | } |
1256 | |
1257 | return OkStatus(); |
1258 | } |
1259 | |
1260 | Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) { |
1261 | string data_format_str; |
1262 | TF_RETURN_IF_ERROR(c->GetAttr("data_format" , &data_format_str)); |
1263 | TensorFormat data_format; |
1264 | if (!FormatFromString(data_format_str, &data_format)) { |
1265 | return errors::InvalidArgument("Invalid data format string: " , |
1266 | data_format_str); |
1267 | } |
1268 | const int rank = |
1269 | (data_format_str == "NDHWC" || data_format_str == "NCDHW" ) ? 5 : 4; |
1270 | ShapeHandle y_backprop; |
1271 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &y_backprop)); |
1272 | ShapeHandle x; |
1273 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &x)); |
1274 | |
1275 | bool is_training; |
1276 | TF_RETURN_IF_ERROR(c->GetAttr("is_training" , &is_training)); |
1277 | |
1278 | int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format); |
1279 | DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index); |
1280 | TF_RETURN_IF_ERROR( |
1281 | c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim)); |
1282 | |
1283 | // covers scale, mean (reserve_space_1), variance (reserve_space_2) |
1284 | for (int i = 2; i < 5; ++i) { |
1285 | ShapeHandle vec; |
1286 | TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec)); |
1287 | TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim)); |
1288 | } |
1289 | |
1290 | ShapeHandle x_backprop; |
1291 | TF_RETURN_IF_ERROR( |
1292 | c->ReplaceDim(y_backprop, channel_dim_index, channel_dim, &x_backprop)); |
1293 | c->set_output(0, x_backprop); |
1294 | c->set_output(1, c->Vector(channel_dim)); |
1295 | c->set_output(2, c->Vector(channel_dim)); |
1296 | c->set_output(3, c->Vector(0)); |
1297 | c->set_output(4, c->Vector(0)); |
1298 | return OkStatus(); |
1299 | } |
1300 | |
1301 | Status FusedBatchNormGradExShape(shape_inference::InferenceContext* c) { |
1302 | TF_RETURN_IF_ERROR(FusedBatchNormGradShape(c)); |
1303 | |
1304 | int num_side_inputs; |
1305 | TF_RETURN_IF_ERROR(c->GetAttr("num_side_inputs" , &num_side_inputs)); |
1306 | if (num_side_inputs == 0) { |
1307 | return OkStatus(); |
1308 | } |
1309 | |
1310 | string data_format_str; |
1311 | TF_RETURN_IF_ERROR(c->GetAttr("data_format" , &data_format_str)); |
1312 | TensorFormat data_format; |
1313 | if (!FormatFromString(data_format_str, &data_format)) { |
1314 | return errors::InvalidArgument("Invalid data format string: " , |
1315 | data_format_str); |
1316 | } |
1317 | const int rank = |
1318 | (data_format_str == "NDHWC" || data_format_str == "NCDHW" ) ? 5 : 4; |
1319 | ShapeHandle y_backprop; |
1320 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &y_backprop)); |
1321 | ShapeHandle x; |
1322 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &x)); |
1323 | |
1324 | int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format); |
1325 | DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index); |
1326 | TF_RETURN_IF_ERROR( |
1327 | c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim)); |
1328 | |
1329 | ShapeHandle side_input_backprop; |
1330 | TF_RETURN_IF_ERROR(c->ReplaceDim(y_backprop, channel_dim_index, channel_dim, |
1331 | &side_input_backprop)); |
1332 | |
1333 | c->set_output(5, side_input_backprop); |
1334 | return OkStatus(); |
1335 | } |
1336 | |
1337 | Status ReadDiagIndex(InferenceContext* c, const Tensor* diag_index_tensor, |
1338 | int32* lower_diag_index, int32* upper_diag_index) { |
1339 | // This function assumes that the shape of diag_index_tensor is fully defined. |
1340 | if (diag_index_tensor->dims() == 0) { |
1341 | *lower_diag_index = diag_index_tensor->scalar<int32>()(); |
1342 | *upper_diag_index = *lower_diag_index; |
1343 | } else { |
1344 | int32_t num_elements = diag_index_tensor->dim_size(0); |
1345 | if (num_elements == 1) { |
1346 | *lower_diag_index = diag_index_tensor->vec<int32>()(0); |
1347 | *upper_diag_index = *lower_diag_index; |
1348 | } else if (num_elements == 2) { |
1349 | *lower_diag_index = diag_index_tensor->vec<int32>()(0); |
1350 | *upper_diag_index = diag_index_tensor->vec<int32>()(1); |
1351 | } else { |
1352 | return errors::InvalidArgument( |
1353 | "diag_index must be a vector with one or two elements. It has " , |
1354 | num_elements, " elements." ); |
1355 | } |
1356 | } |
1357 | return OkStatus(); |
1358 | } |
1359 | |
1360 | Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c) { |
1361 | ShapeHandle input_shape, diag_index_shape, unused_shape; |
1362 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape)); |
1363 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape)); |
1364 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape)); |
1365 | |
1366 | const Tensor* diag_index_tensor = c->input_tensor(1); |
1367 | if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) || |
1368 | diag_index_tensor == nullptr) { |
1369 | c->set_output(0, c->UnknownShape()); |
1370 | return OkStatus(); |
1371 | } |
1372 | int32_t lower_diag_index = 0; |
1373 | int32_t upper_diag_index = 0; |
1374 | TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index, |
1375 | &upper_diag_index)); |
1376 | if (lower_diag_index > upper_diag_index) { |
1377 | return errors::InvalidArgument( |
1378 | "lower_diag_index is greater than upper_diag_index" ); |
1379 | } |
1380 | |
1381 | // Validates lower_diag_index and upper_diag_index. |
1382 | const int32_t input_rank = c->Rank(input_shape); |
1383 | const int32_t num_rows = c->Value(c->Dim(input_shape, input_rank - 2)); |
1384 | const int32_t num_cols = c->Value(c->Dim(input_shape, input_rank - 1)); |
1385 | int32_t max_diag_len = InferenceContext::kUnknownDim; |
1386 | if (num_rows != InferenceContext::kUnknownDim && |
1387 | num_cols != InferenceContext::kUnknownDim) { |
1388 | if (lower_diag_index != 0 && // For when num_rows or num_cols == 0. |
1389 | (-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) { |
1390 | return errors::InvalidArgument("lower_diag_index is out of bound." ); |
1391 | } |
1392 | if (upper_diag_index != 0 && // For when num_rows or num_cols == 0. |
1393 | (-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) { |
1394 | return errors::InvalidArgument("upper_diag_index is out of bound." ); |
1395 | } |
1396 | max_diag_len = std::min(num_rows + std::min(upper_diag_index, 0), |
1397 | num_cols - std::max(lower_diag_index, 0)); |
1398 | } |
1399 | |
1400 | std::vector<DimensionHandle> dims; |
1401 | dims.reserve(input_rank - 2); |
1402 | for (int i = 0; i < input_rank - 2; ++i) { |
1403 | dims.push_back(c->Dim(input_shape, i)); |
1404 | } |
1405 | if (lower_diag_index < upper_diag_index) { |
1406 | dims.push_back(c->MakeDim(upper_diag_index - lower_diag_index + 1)); |
1407 | } |
1408 | dims.push_back(c->MakeDim(max_diag_len)); |
1409 | c->set_output(0, c->MakeShape(dims)); |
1410 | return OkStatus(); |
1411 | } |
1412 | |
1413 | Status MatrixDiagV2Shape(shape_inference::InferenceContext* c) { |
1414 | // Checks input ranks. |
1415 | ShapeHandle input_shape, diag_index_shape, unused_shape; |
1416 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input_shape)); |
1417 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape)); |
1418 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape)); |
1419 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape)); |
1420 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape)); |
1421 | |
1422 | // Reads the diagonal indices. |
1423 | const Tensor* diag_index_tensor = c->input_tensor(1); |
1424 | if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) || |
1425 | diag_index_tensor == nullptr) { |
1426 | c->set_output(0, c->UnknownShape()); |
1427 | return OkStatus(); |
1428 | } |
1429 | int32_t lower_diag_index = 0; |
1430 | int32_t upper_diag_index = 0; |
1431 | TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index, |
1432 | &upper_diag_index)); |
1433 | if (lower_diag_index > upper_diag_index) { |
1434 | return errors::InvalidArgument( |
1435 | "lower_diag_index is greater than upper_diag_index" ); |
1436 | } |
1437 | |
1438 | // Checks if the number of diagonals provided matches what we imply from |
1439 | // lower_diag_index and upper_diag_index. |
1440 | const int32_t input_rank = c->Rank(input_shape); |
1441 | if (lower_diag_index < upper_diag_index) { |
1442 | const int32_t num_diags = c->Value(c->Dim(input_shape, input_rank - 2)); |
1443 | const int32_t other_dim = c->Value(c->Dim(input_shape, input_rank - 1)); |
1444 | |
1445 | if (num_diags != (upper_diag_index - lower_diag_index + 1)) { |
1446 | return errors::InvalidArgument( |
1447 | "The number of rows of `diagonal` doesn't match the number of " |
1448 | "diagonals implied from `d_lower` and `d_upper`.\n" , |
1449 | "num_diags = " , num_diags, ", d_lower = " , lower_diag_index, |
1450 | ", d_upper = " , upper_diag_index, " " , input_rank, " " , other_dim); |
1451 | } |
1452 | } |
1453 | |
1454 | // Reads num_rows and num_cols. |
1455 | const Tensor* num_rows_tensor = c->input_tensor(2); |
1456 | const Tensor* num_cols_tensor = c->input_tensor(3); |
1457 | int64_t num_rows = -1; |
1458 | int64_t num_cols = -1; |
1459 | if (num_rows_tensor != nullptr) { |
1460 | TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_rows_tensor, &num_rows)); |
1461 | } |
1462 | if (num_cols_tensor != nullptr) { |
1463 | TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_cols_tensor, &num_cols)); |
1464 | } |
1465 | |
1466 | // Infers the missing num_rows or num_cols: If both are missing, assume |
1467 | // output is square. Otherwise, use the smallest possible value. Also |
1468 | // validates the provided values. |
1469 | const int32_t max_diag_len = c->Value(c->Dim(input_shape, input_rank - 1)); |
1470 | const int32_t min_num_rows = max_diag_len - std::min(upper_diag_index, 0); |
1471 | const int32_t min_num_cols = max_diag_len + std::max(lower_diag_index, 0); |
1472 | if (num_rows == -1 && num_cols == -1) { // Special case. |
1473 | num_rows = std::max(min_num_rows, min_num_cols); |
1474 | num_cols = num_rows; |
1475 | } |
1476 | if (num_rows == -1) { |
1477 | num_rows = min_num_rows; |
1478 | } else if (num_rows < min_num_rows) { |
1479 | return errors::InvalidArgument("num_rows is too small" ); |
1480 | } |
1481 | if (num_cols == -1) { |
1482 | num_cols = min_num_cols; |
1483 | } else if (num_cols < min_num_cols) { |
1484 | return errors::InvalidArgument("num_cols is too small." ); |
1485 | } |
1486 | // At least one of them must match the minimum length. |
1487 | if (num_rows != min_num_rows && num_cols != min_num_cols) { |
1488 | return errors::InvalidArgument( |
1489 | "num_rows and num_cols are not consistent with lower_diag_index, " |
1490 | "upper_diag_index, and the length of the given diagonals.\n" , |
1491 | "num_rows = " , num_rows, " != min_num_rows = " , min_num_rows, |
1492 | ", num_cols = " , num_cols, " != min_num_cols = " , min_num_cols); |
1493 | } |
1494 | |
1495 | // Sets output shape. |
1496 | ShapeHandle output_shape; |
1497 | const DimensionHandle output_row_dim = c->MakeDim(num_rows); |
1498 | const DimensionHandle output_col_dim = c->MakeDim(num_cols); |
1499 | if (lower_diag_index == upper_diag_index) { |
1500 | TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 1, |
1501 | output_row_dim, &output_shape)); |
1502 | TF_RETURN_IF_ERROR( |
1503 | c->Concatenate(output_shape, c->Vector(output_col_dim), &output_shape)); |
1504 | } else { |
1505 | TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 2, |
1506 | output_row_dim, &output_shape)); |
1507 | TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape, input_rank - 1, |
1508 | output_col_dim, &output_shape)); |
1509 | } |
1510 | c->set_output(0, output_shape); |
1511 | return OkStatus(); |
1512 | } |
1513 | |
1514 | Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c) { |
1515 | ShapeHandle input_shape, diag_shape, diag_index_shape; |
1516 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape)); |
1517 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag_shape)); |
1518 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &diag_index_shape)); |
1519 | |
1520 | int32_t lower_diag_index = 0; |
1521 | int32_t upper_diag_index = 0; |
1522 | bool diag_index_known = false; |
1523 | const Tensor* diag_index_tensor = c->input_tensor(2); |
1524 | if (diag_index_tensor != nullptr && c->FullyDefined(diag_index_shape)) { |
1525 | diag_index_known = true; |
1526 | TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index, |
1527 | &upper_diag_index)); |
1528 | if (lower_diag_index > upper_diag_index) { |
1529 | return errors::InvalidArgument( |
1530 | "lower_diag_index is greater than upper_diag_index" ); |
1531 | } |
1532 | } |
1533 | |
1534 | // Do more checks when input rank is known. |
1535 | if (c->RankKnown(input_shape)) { |
1536 | int32_t input_rank = c->Rank(input_shape); |
1537 | |
1538 | // If diag_index is set, we know the exact rank of diagonal. |
1539 | if (diag_index_known) { |
1540 | TF_RETURN_IF_ERROR(c->WithRank( |
1541 | c->input(1), |
1542 | (lower_diag_index == upper_diag_index) ? input_rank - 1 : input_rank, |
1543 | &diag_shape)); |
1544 | } else { |
1545 | TF_RETURN_IF_ERROR( |
1546 | c->WithRankAtLeast(c->input(1), input_rank - 1, &diag_shape)); |
1547 | TF_RETURN_IF_ERROR( |
1548 | c->WithRankAtMost(c->input(1), input_rank, &diag_shape)); |
1549 | } |
1550 | |
1551 | // Validates lower_diag_index and upper_diag_index. |
1552 | const int32_t num_rows = c->Value(c->Dim(input_shape, input_rank - 2)); |
1553 | const int32_t num_cols = c->Value(c->Dim(input_shape, input_rank - 1)); |
1554 | if (num_rows != InferenceContext::kUnknownDim && |
1555 | num_cols != InferenceContext::kUnknownDim) { |
1556 | if (lower_diag_index != 0 && // For when num_rows or num_cols == 0. |
1557 | (-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) { |
1558 | return errors::InvalidArgument("lower_diag_index is out of bound." ); |
1559 | } |
1560 | if (upper_diag_index != 0 && // For when num_rows or num_cols == 0. |
1561 | (-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) { |
1562 | return errors::InvalidArgument("upper_diag_index is out of bound." ); |
1563 | } |
1564 | } |
1565 | } |
1566 | |
1567 | ShapeHandle output_shape = input_shape; |
1568 | if (c->RankKnown(diag_shape) && !c->FullyDefined(input_shape)) { |
1569 | // Try to infer parts of shape from diag. |
1570 | ShapeHandle diag_prefix; |
1571 | TF_RETURN_IF_ERROR(c->Subshape( |
1572 | diag_shape, 0, (lower_diag_index == upper_diag_index) ? -1 : -2, |
1573 | &diag_prefix)); |
1574 | |
1575 | // The inner matrices can be rectangular, so we can't pinpoint their |
1576 | // exact height and width by just lower_diag_index, upper_diag_index, |
1577 | // and the longest length of given diagonals. |
1578 | TF_RETURN_IF_ERROR( |
1579 | c->Concatenate(diag_prefix, c->UnknownShapeOfRank(2), &diag_shape)); |
1580 | TF_RETURN_IF_ERROR(c->Merge(input_shape, diag_shape, &output_shape)); |
1581 | } |
1582 | c->set_output(0, output_shape); |
1583 | return OkStatus(); |
1584 | } |
1585 | |
1586 | Status MaxPoolShapeImpl(shape_inference::InferenceContext* c, |
1587 | bool supports_explicit_padding) { |
1588 | string data_format_str; |
1589 | TensorFormat data_format; |
1590 | Status s = c->GetAttr("data_format" , &data_format_str); |
1591 | if (s.ok()) { |
1592 | FormatFromString(data_format_str, &data_format); |
1593 | } else { |
1594 | data_format = FORMAT_NHWC; |
1595 | } |
1596 | |
1597 | const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; |
1598 | ShapeHandle input_shape; |
1599 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); |
1600 | |
1601 | TF_RETURN_IF_ERROR( |
1602 | CheckFormatConstraintsOnShape(data_format, input_shape, "input" , c)); |
1603 | |
1604 | std::vector<int32> strides; |
1605 | TF_RETURN_IF_ERROR(c->GetAttr("strides" , &strides)); |
1606 | if (strides.size() != 4) { |
1607 | return errors::InvalidArgument( |
1608 | "MaxPool requires the stride attribute to contain 4 values, but got: " , |
1609 | strides.size()); |
1610 | } |
1611 | |
1612 | std::vector<int32> kernel_sizes; |
1613 | TF_RETURN_IF_ERROR(c->GetAttr("ksize" , &kernel_sizes)); |
1614 | if (kernel_sizes.size() != 4) { |
1615 | return errors::InvalidArgument( |
1616 | "MaxPool requires the ksize attribute to contain 4 values, but got: " , |
1617 | kernel_sizes.size()); |
1618 | } |
1619 | |
1620 | int32_t stride_depth = GetTensorDim(strides, data_format, 'C'); |
1621 | int32_t stride_rows = GetTensorDim(strides, data_format, 'H'); |
1622 | int32_t stride_cols = GetTensorDim(strides, data_format, 'W'); |
1623 | int32_t kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C'); |
1624 | int32_t kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H'); |
1625 | int32_t kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W'); |
1626 | |
1627 | constexpr int num_spatial_dims = 2; |
1628 | DimensionHandle batch_size_dim = c->Dim( |
1629 | input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N')); |
1630 | DimensionHandle in_rows_dim = c->Dim( |
1631 | input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H')); |
1632 | DimensionHandle in_cols_dim = c->Dim( |
1633 | input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W')); |
1634 | DimensionHandle in_depth_dim = c->Dim( |
1635 | input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C')); |
1636 | |
1637 | Padding padding; |
1638 | TF_RETURN_IF_ERROR(c->GetAttr("padding" , &padding)); |
1639 | |
1640 | std::vector<int64_t> explicit_paddings; |
1641 | if (supports_explicit_padding) { |
1642 | Status status = c->GetAttr("explicit_paddings" , &explicit_paddings); |
1643 | // Use the default value, which is an empty list, if the attribute is not |
1644 | // found. Otherwise return the error to the caller. |
1645 | if (!status.ok() && !errors::IsNotFound(status)) { |
1646 | return status; |
1647 | } |
1648 | TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings, |
1649 | /*num_dims=*/4, data_format)); |
1650 | } else { |
1651 | DCHECK(padding != Padding::EXPLICIT); |
1652 | } |
1653 | |
1654 | ShapeHandle output_shape; |
1655 | DimensionHandle output_rows, output_cols, output_depth; |
1656 | int64_t pad_rows_before = -1, pad_rows_after = -1; |
1657 | int64_t pad_cols_before = -1, pad_cols_after = -1; |
1658 | if (padding == Padding::EXPLICIT) { |
1659 | GetExplicitPaddingForDim(explicit_paddings, data_format, 'H', |
1660 | &pad_rows_before, &pad_rows_after); |
1661 | GetExplicitPaddingForDim(explicit_paddings, data_format, 'W', |
1662 | &pad_cols_before, &pad_cols_after); |
1663 | } |
1664 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2( |
1665 | c, in_rows_dim, kernel_rows, /*dilation_rate=*/1, stride_rows, padding, |
1666 | pad_rows_before, pad_rows_after, &output_rows)); |
1667 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2( |
1668 | c, in_cols_dim, kernel_cols, /*dilation_rate=*/1, stride_cols, padding, |
1669 | pad_cols_before, pad_cols_after, &output_cols)); |
1670 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2( |
1671 | c, in_depth_dim, kernel_depth, /*dilation_rate=*/1, stride_depth, padding, |
1672 | /*pad_before*/ 0, /*pad_after*/ 0, &output_depth)); |
1673 | |
1674 | TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim, |
1675 | {output_rows, output_cols}, |
1676 | output_depth, &output_shape, c)); |
1677 | |
1678 | c->set_output(0, output_shape); |
1679 | return OkStatus(); |
1680 | } |
1681 | |
1682 | Status MaxPoolShape(shape_inference::InferenceContext* c) { |
1683 | return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/false); |
1684 | } |
1685 | |
1686 | Status MaxPoolGradShape(shape_inference::InferenceContext* c) { |
1687 | return UnchangedShapeWithRank(c, 4); |
1688 | } |
1689 | |
1690 | Status MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext* c) { |
1691 | return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/true); |
1692 | } |
1693 | |
1694 | Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { |
1695 | string data_format_str; |
1696 | TensorFormat data_format; |
1697 | Status s = c->GetAttr("data_format" , &data_format_str); |
1698 | if (s.ok()) { |
1699 | FormatFromString(data_format_str, &data_format); |
1700 | } else { |
1701 | data_format = FORMAT_NHWC; |
1702 | } |
1703 | |
1704 | const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; |
1705 | ShapeHandle input_shape; |
1706 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); |
1707 | |
1708 | TF_RETURN_IF_ERROR( |
1709 | CheckFormatConstraintsOnShape(data_format, input_shape, "input" , c)); |
1710 | |
1711 | std::vector<int32> kernel_sizes; |
1712 | std::vector<int32> strides; |
1713 | |
1714 | if (c->num_inputs() + 2 == num_inputs) { |
1715 | TF_RETURN_IF_ERROR(c->GetAttr("ksize" , &kernel_sizes)); |
1716 | |
1717 | TF_RETURN_IF_ERROR(c->GetAttr("strides" , &strides)); |
1718 | } else { |
1719 | // Verify shape of ksize and strides input. |
1720 | ShapeHandle size; |
1721 | DimensionHandle unused; |
1722 | TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 1, &size)); |
1723 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused)); |
1724 | TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &size)); |
1725 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused)); |
1726 | |
1727 | const Tensor* kernel_sizes_tensor = c->input_tensor(c->num_inputs() - 2); |
1728 | if (kernel_sizes_tensor == nullptr) { |
1729 | c->set_output(0, c->UnknownShape()); |
1730 | return OkStatus(); |
1731 | } |
1732 | kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements()); |
1733 | auto kernel_sizes_vec = kernel_sizes_tensor->flat<int32>(); |
1734 | std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(), |
1735 | kernel_sizes.begin()); |
1736 | |
1737 | const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1); |
1738 | if (strides_tensor == nullptr) { |
1739 | c->set_output(0, c->UnknownShape()); |
1740 | return OkStatus(); |
1741 | } |
1742 | strides.resize(strides_tensor->shape().num_elements()); |
1743 | auto strides_vec = strides_tensor->flat<int32>(); |
1744 | std::copy_n(&strides_vec(0), strides.size(), strides.begin()); |
1745 | } |
1746 | |
1747 | if (strides.size() != 4) { |
1748 | return errors::InvalidArgument( |
1749 | "MaxPool requires the stride attribute to contain 4 values, but " |
1750 | "got: " , |
1751 | strides.size()); |
1752 | } |
1753 | if (kernel_sizes.size() != 4) { |
1754 | return errors::InvalidArgument( |
1755 | "MaxPool requires the ksize attribute to contain 4 values, but got: " , |
1756 | kernel_sizes.size()); |
1757 | } |
1758 | |
1759 | int32_t stride_depth = GetTensorDim(strides, data_format, 'C'); |
1760 | int32_t stride_rows = GetTensorDim(strides, data_format, 'H'); |
1761 | int32_t stride_cols = GetTensorDim(strides, data_format, 'W'); |
1762 | int32_t kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C'); |
1763 | int32_t kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H'); |
1764 | int32_t kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W'); |
1765 | |
1766 | constexpr int num_spatial_dims = 2; |
1767 | DimensionHandle batch_size_dim = c->Dim( |
1768 | input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N')); |
1769 | DimensionHandle in_rows_dim = c->Dim( |
1770 | input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H')); |
1771 | DimensionHandle in_cols_dim = c->Dim( |
1772 | input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W')); |
1773 | DimensionHandle in_depth_dim = c->Dim( |
1774 | input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C')); |
1775 | |
1776 | Padding padding; |
1777 | TF_RETURN_IF_ERROR(c->GetAttr("padding" , &padding)); |
1778 | |
1779 | ShapeHandle output_shape; |
1780 | DimensionHandle output_rows, output_cols, output_depth; |
1781 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( |
1782 | c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows)); |
1783 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( |
1784 | c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols)); |
1785 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( |
1786 | c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth)); |
1787 | |
1788 | TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim, |
1789 | {output_rows, output_cols}, |
1790 | output_depth, &output_shape, c)); |
1791 | |
1792 | c->set_output(0, output_shape); |
1793 | return OkStatus(); |
1794 | } |
1795 | |
1796 | Status Pool3DShape(shape_inference::InferenceContext* c) { |
1797 | ShapeHandle input_shape; |
1798 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape)); |
1799 | |
1800 | string data_format; |
1801 | Status s = c->GetAttr("data_format" , &data_format); |
1802 | |
1803 | std::vector<int32> strides; |
1804 | TF_RETURN_IF_ERROR(c->GetAttr("strides" , &strides)); |
1805 | if (strides.size() != 5) { |
1806 | return errors::InvalidArgument( |
1807 | "Pool3D ops require the stride attribute to contain 5 values, but " |
1808 | "got: " , |
1809 | strides.size()); |
1810 | } |
1811 | |
1812 | std::vector<int32> kernel_sizes; |
1813 | TF_RETURN_IF_ERROR(c->GetAttr("ksize" , &kernel_sizes)); |
1814 | if (kernel_sizes.size() != 5) { |
1815 | return errors::InvalidArgument( |
1816 | "Pool3D requires the ksize attribute to contain 5 values, but got: " , |
1817 | kernel_sizes.size()); |
1818 | } |
1819 | |
1820 | int32_t stride_planes, stride_rows, stride_cols; |
1821 | int32_t kernel_planes, kernel_rows, kernel_cols; |
1822 | |
1823 | if (s.ok() && data_format == "NCDHW" ) { |
1824 | // Convert input_shape to NDHWC. |
1825 | auto dim = [&](char dimension) { |
1826 | return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension)); |
1827 | }; |
1828 | input_shape = |
1829 | c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}}); |
1830 | stride_planes = strides[2]; |
1831 | stride_rows = strides[3]; |
1832 | stride_cols = strides[4]; |
1833 | kernel_planes = kernel_sizes[2]; |
1834 | kernel_rows = kernel_sizes[3]; |
1835 | kernel_cols = kernel_sizes[4]; |
1836 | } else { |
1837 | stride_planes = strides[1]; |
1838 | stride_rows = strides[2]; |
1839 | stride_cols = strides[3]; |
1840 | kernel_planes = kernel_sizes[1]; |
1841 | kernel_rows = kernel_sizes[2]; |
1842 | kernel_cols = kernel_sizes[3]; |
1843 | } |
1844 | |
1845 | DimensionHandle batch_size_dim = c->Dim(input_shape, 0); |
1846 | DimensionHandle in_planes_dim = c->Dim(input_shape, 1); |
1847 | DimensionHandle in_rows_dim = c->Dim(input_shape, 2); |
1848 | DimensionHandle in_cols_dim = c->Dim(input_shape, 3); |
1849 | DimensionHandle output_depth_dim = c->Dim(input_shape, 4); |
1850 | |
1851 | Padding padding; |
1852 | TF_RETURN_IF_ERROR(c->GetAttr("padding" , &padding)); |
1853 | |
1854 | // TODO(mrry,shlens): Raise an error if the stride would cause |
1855 | // information in the input to be ignored. This will require a change |
1856 | // in the kernel implementation. |
1857 | DimensionHandle output_planes, output_rows, output_cols; |
1858 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( |
1859 | c, in_planes_dim, kernel_planes, stride_planes, padding, &output_planes)); |
1860 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( |
1861 | c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows)); |
1862 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( |
1863 | c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols)); |
1864 | |
1865 | ShapeHandle output_shape; |
1866 | if (data_format == "NCDHW" ) { |
1867 | output_shape = c->MakeShape({batch_size_dim, output_depth_dim, |
1868 | output_planes, output_rows, output_cols}); |
1869 | } else { |
1870 | output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows, |
1871 | output_cols, output_depth_dim}); |
1872 | } |
1873 | |
1874 | c->set_output(0, output_shape); |
1875 | return OkStatus(); |
1876 | } |
1877 | |
1878 | Status MaxPool3DGradShape(shape_inference::InferenceContext* c) { |
1879 | return UnchangedShapeWithRank(c, 5); |
1880 | } |
1881 | |
1882 | Status AvgPool3DGradShape(shape_inference::InferenceContext* c) { |
1883 | ShapeHandle s; |
1884 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); |
1885 | TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); |
1886 | c->set_output(0, s); |
1887 | return OkStatus(); |
1888 | } |
1889 | |
1890 | Status UnknownShape(shape_inference::InferenceContext* c) { |
1891 | for (int i = 0; i < c->num_outputs(); ++i) { |
1892 | c->set_output(i, c->UnknownShape()); |
1893 | } |
1894 | return OkStatus(); |
1895 | } |
1896 | |
1897 | template <typename T> |
1898 | Status ReductionShapeHelper(const Tensor* reduction_indices_t, |
1899 | const int32_t input_rank, |
1900 | std::set<int64_t>* true_indices) { |
1901 | auto reduction_indices = reduction_indices_t->flat<T>(); |
1902 | for (int i = 0; i < reduction_indices_t->NumElements(); ++i) { |
1903 | const T reduction_index = reduction_indices(i); |
1904 | if (reduction_index < -input_rank || reduction_index >= input_rank) { |
1905 | return errors::InvalidArgument("Invalid reduction dimension " , |
1906 | reduction_index, " for input with " , |
1907 | input_rank, " dimensions." ); |
1908 | } |
1909 | |
1910 | auto wrapped_index = reduction_index; |
1911 | if (wrapped_index < 0) { |
1912 | wrapped_index += input_rank; |
1913 | } |
1914 | |
1915 | true_indices->insert(wrapped_index); |
1916 | } |
1917 | return OkStatus(); |
1918 | } |
1919 | |
1920 | Status ReductionShape(InferenceContext* c) { |
1921 | ShapeHandle input = c->input(0); |
1922 | |
1923 | ShapeHandle indices; |
1924 | // Older versions of TensorFlow accidentally allowed higher rank tensors like |
1925 | // [[1,2]] or [[1],[2]] to represent axis=[1,2]. |
1926 | if (c->graph_def_version() < 21) { |
1927 | indices = c->input(1); |
1928 | } else { |
1929 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices)); |
1930 | } |
1931 | |
1932 | bool keep_dims; |
1933 | TF_RETURN_IF_ERROR(c->GetAttr("keep_dims" , &keep_dims)); |
1934 | |
1935 | const Tensor* reduction_indices_t = c->input_tensor(1); |
1936 | if (reduction_indices_t == nullptr || !c->RankKnown(input)) { |
1937 | // If we do not have the reduction values at runtime, or the |
1938 | // rank of the input, we don't know the output shape. |
1939 | |
1940 | if (keep_dims && c->RankKnown(input)) { |
1941 | // output rank matches input input if <keep_dims>. |
1942 | c->set_output(0, c->UnknownShapeOfRank(c->Rank(input))); |
1943 | return OkStatus(); |
1944 | } else { |
1945 | return shape_inference::UnknownShape(c); |
1946 | } |
1947 | } |
1948 | |
1949 | const int32_t input_rank = c->Rank(input); |
1950 | std::set<int64_t> true_indices; |
1951 | if (reduction_indices_t->dtype() == DataType::DT_INT32) { |
1952 | TF_RETURN_IF_ERROR(ReductionShapeHelper<int32>(reduction_indices_t, |
1953 | input_rank, &true_indices)); |
1954 | } else if (reduction_indices_t->dtype() == DataType::DT_INT64) { |
1955 | TF_RETURN_IF_ERROR(ReductionShapeHelper<int64_t>( |
1956 | reduction_indices_t, input_rank, &true_indices)); |
1957 | } else { |
1958 | return errors::InvalidArgument( |
1959 | "reduction_indices can only be int32 or int64" ); |
1960 | } |
1961 | |
1962 | std::vector<DimensionHandle> dims; |
1963 | for (int i = 0; i < input_rank; ++i) { |
1964 | if (true_indices.count(i) > 0) { |
1965 | if (keep_dims) { |
1966 | dims.emplace_back(c->MakeDim(1)); |
1967 | } |
1968 | } else { |
1969 | dims.emplace_back(c->Dim(input, i)); |
1970 | } |
1971 | } |
1972 | |
1973 | c->set_output(0, c->MakeShape(dims)); |
1974 | return OkStatus(); |
1975 | } |
1976 | |
1977 | Status ConcatShapeHelper(InferenceContext* c, int start_value_index, |
1978 | int end_value_index, int dim_index) { |
1979 | ShapeHandle unused; |
1980 | TF_RETURN_IF_ERROR(c->WithRank(c->input(dim_index), 0, &unused)); |
1981 | const Tensor* concat_dim_t = c->input_tensor(dim_index); |
1982 | if (concat_dim_t == nullptr) { |
1983 | // Return an unknown shape with same rank as inputs, or an unknown rank |
1984 | // if no input's rank is known. |
1985 | |
1986 | // Find rank. |
1987 | int32_t rank = InferenceContext::kUnknownRank; |
1988 | for (int i = start_value_index; i < end_value_index; ++i) { |
1989 | if (rank == InferenceContext::kUnknownRank) rank = c->Rank(c->input(i)); |
1990 | if (rank != InferenceContext::kUnknownRank) { |
1991 | break; |
1992 | } |
1993 | } |
1994 | if (rank == InferenceContext::kUnknownRank) { |
1995 | c->set_output(0, c->UnknownShape()); |
1996 | return OkStatus(); |
1997 | } else if (rank == 0) { |
1998 | return errors::InvalidArgument( |
1999 | "Can't concatenate scalars (use tf.stack instead)" ); |
2000 | } else { |
2001 | for (int i = start_value_index; i < end_value_index; ++i) { |
2002 | // Check that all the inputs are of the correct rank. |
2003 | TF_RETURN_IF_ERROR(c->WithRank(c->input(i), rank, &unused)); |
2004 | } |
2005 | } |
2006 | // Build result of <rank> different unknown dims. |
2007 | std::vector<DimensionHandle> dims; |
2008 | dims.reserve(rank); |
2009 | for (int i = 0; i < rank; ++i) dims.push_back(c->UnknownDim()); |
2010 | c->set_output(0, c->MakeShape(dims)); |
2011 | return OkStatus(); |
2012 | } |
2013 | |
2014 | // Merge all the non-concat dims, and sum the concat dim to make an output |
2015 | // shape. |
2016 | int64_t concat_dim; |
2017 | if (concat_dim_t->dtype() == DT_INT32) { |
2018 | concat_dim = static_cast<int64_t>(concat_dim_t->flat<int32>()(0)); |
2019 | } else { |
2020 | concat_dim = concat_dim_t->flat<int64_t>()(0); |
2021 | } |
2022 | |
2023 | // Minimum required number of dimensions. |
2024 | const int64 min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1; |
2025 | |
2026 | ShapeHandle output_before; |
2027 | ShapeHandle output_after; |
2028 | |
2029 | ShapeHandle input = c->input(end_value_index - 1); |
2030 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input)); |
2031 | TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &output_before)); |
2032 | DimensionHandle output_middle = c->Dim(input, concat_dim); |
2033 | if (concat_dim == -1) { |
2034 | output_after = c->Scalar(); // no dimensions. |
2035 | } else { |
2036 | TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &output_after)); |
2037 | } |
2038 | |
2039 | for (int i = end_value_index - 2; i >= start_value_index; --i) { |
2040 | ShapeHandle before; |
2041 | ShapeHandle after; |
2042 | input = c->input(i); |
2043 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input)); |
2044 | TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &before)); |
2045 | DimensionHandle middle = c->Dim(input, concat_dim); |
2046 | if (concat_dim == -1) { |
2047 | after = c->Scalar(); |
2048 | } else { |
2049 | TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &after)); |
2050 | } |
2051 | |
2052 | TF_RETURN_IF_ERROR(c->Merge(before, output_before, &output_before)); |
2053 | TF_RETURN_IF_ERROR(c->Add(output_middle, middle, &output_middle)); |
2054 | TF_RETURN_IF_ERROR(c->Merge(after, output_after, &output_after)); |
2055 | } |
2056 | |
2057 | ShapeHandle s; |
2058 | TF_RETURN_IF_ERROR( |
2059 | c->Concatenate(output_before, c->Vector(output_middle), &s)); |
2060 | TF_RETURN_IF_ERROR(c->Concatenate(s, output_after, &s)); |
2061 | c->set_output(0, s); |
2062 | return OkStatus(); |
2063 | } |
2064 | |
2065 | Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) { |
2066 | return ConcatShapeHelper(c, 1 /* start_value_index */, |
2067 | 1 + num_inputs_to_concat /* end_value_index */, |
2068 | 0 /* dim_index */); |
2069 | } |
2070 | |
2071 | Status ConcatV2Shape(InferenceContext* c) { |
2072 | return ConcatShapeHelper(c, 0 /* start_value_index */, |
2073 | c->num_inputs() - 1 /* end_value_index */, |
2074 | c->num_inputs() - 1 /* dim_index */); |
2075 | } |
2076 | |
2077 | Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat) { |
2078 | return ConcatShapeHelper(c, 0 /* start_value_index */, |
2079 | num_inputs_to_concat /* end_value_index */, |
2080 | num_inputs_to_concat /* dim_index */); |
2081 | } |
2082 | |
2083 | Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, |
2084 | ShapeHandle shape_x, |
2085 | ShapeHandle shape_y, |
2086 | bool incompatible_shape_error, |
2087 | ShapeHandle* out) { |
2088 | CHECK_NOTNULL(out); |
2089 | if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) { |
2090 | *out = c->UnknownShape(); |
2091 | return OkStatus(); |
2092 | } |
2093 | const int32_t rank_x = c->Rank(shape_x); |
2094 | const int32_t rank_y = c->Rank(shape_y); |
2095 | const int32_t rank_out = std::max(rank_x, rank_y); |
2096 | |
2097 | // To compute the broadcast dimensions, we zip together shape_x and shape_y |
2098 | // and |
2099 | // pad with 1 to make them the same length. |
2100 | std::vector<DimensionHandle> dims; |
2101 | DimensionHandle dim_one; |
2102 | if (rank_x != rank_y) dim_one = c->MakeDim(1); |
2103 | for (int i = 0; i < rank_out; ++i) { |
2104 | const auto dim_x = i < (rank_out - rank_x) |
2105 | ? dim_one |
2106 | : c->Dim(shape_x, i - (rank_out - rank_x)); |
2107 | const bool dim_y_is_one = (i < (rank_out - rank_y)); |
2108 | const auto dim_y = |
2109 | dim_y_is_one ? dim_one : c->Dim(shape_y, i - (rank_out - rank_y)); |
2110 | if (!c->ValueKnown(dim_x) || !c->ValueKnown(dim_y)) { |
2111 | // One or both dimensions is unknown. |
2112 | // |
2113 | // - If either dimension is greater than 1, we assume that the program is |
2114 | // correct, and the other dimension will be broadcast to match it. |
2115 | // TODO(cwhipkey): For shape inference, if we eliminate the shape checks |
2116 | // in C++ op code, we must still assert that the unknown dim is either 1 |
2117 | // or the same as the known dim. |
2118 | // - If either dimension is 1, the other dimension is the output. |
2119 | // - If both are unknown then dimension is unknown |
2120 | if (c->Value(dim_x) > 1) { |
2121 | if (!incompatible_shape_error) { |
2122 | *out = c->UnknownShape(); |
2123 | return OkStatus(); |
2124 | } |
2125 | dims.push_back(dim_x); |
2126 | } else if (c->Value(dim_y) > 1) { |
2127 | if (!incompatible_shape_error) { |
2128 | *out = c->UnknownShape(); |
2129 | return OkStatus(); |
2130 | } |
2131 | dims.push_back(dim_y); |
2132 | } else if (c->Value(dim_x) == 1) { |
2133 | dims.push_back(dim_y); |
2134 | } else if (c->Value(dim_y) == 1) { |
2135 | dims.push_back(dim_x); |
2136 | } else if (dim_y.SameHandle(dim_x)) { |
2137 | dims.push_back(dim_x); |
2138 | } else if (!c->ValueKnown(dim_x) && !c->ValueKnown(dim_y)) { |
2139 | dims.push_back(c->UnknownDim()); |
2140 | } else { |
2141 | if (!incompatible_shape_error) { |
2142 | *out = c->UnknownShape(); |
2143 | return OkStatus(); |
2144 | } |
2145 | dims.push_back(c->UnknownDim()); |
2146 | } |
2147 | } else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) { |
2148 | if (c->Value(dim_x) == 1 && !dim_y_is_one) { |
2149 | // We will broadcast dim_x to dim_y. |
2150 | dims.push_back(dim_y); |
2151 | } else { |
2152 | DCHECK_EQ(c->Value(dim_y), 1); |
2153 | // We will broadcast dim_y to dim_x. |
2154 | dims.push_back(dim_x); |
2155 | } |
2156 | } else { |
2157 | DimensionHandle dim; |
2158 | Status s = c->Merge(dim_x, dim_y, &dim); |
2159 | if (!s.ok()) { |
2160 | if (!incompatible_shape_error) { |
2161 | *out = c->MakeShape({}); |
2162 | return OkStatus(); |
2163 | } |
2164 | return s; |
2165 | } |
2166 | dims.push_back(dim); |
2167 | } |
2168 | } |
2169 | |
2170 | *out = c->MakeShape(dims); |
2171 | return OkStatus(); |
2172 | } |
2173 | |
2174 | Status RandomShape(shape_inference::InferenceContext* c) { |
2175 | shape_inference::ShapeHandle out; |
2176 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); |
2177 | c->set_output(0, out); |
2178 | return OkStatus(); |
2179 | } |
2180 | |
2181 | Status UnsortedSegmentReductionShapeFn(InferenceContext* c) { |
2182 | ShapeHandle s_data = c->input(0); |
2183 | ShapeHandle s_segment_ids = c->input(1); |
2184 | ShapeHandle s_num_segments = c->input(2); |
2185 | TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments)); |
2186 | |
2187 | ShapeHandle out; |
2188 | |
2189 | // Leading dimensions of data must be compatible with dimensions of |
2190 | // <s_segment_ids>. |
2191 | if (c->RankKnown(s_segment_ids)) { |
2192 | TF_RETURN_IF_ERROR( |
2193 | c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids)); |
2194 | |
2195 | // Get the value of the num_segments input tensor. |
2196 | DimensionHandle num_segments_dim; |
2197 | TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim)); |
2198 | |
2199 | // Output is {segment_id_rank} + s_data[segment_id_rank:]. |
2200 | ShapeHandle s_data_suffix; |
2201 | TF_RETURN_IF_ERROR( |
2202 | c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix)); |
2203 | TF_RETURN_IF_ERROR( |
2204 | c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out)); |
2205 | } else { |
2206 | out = c->UnknownShape(); |
2207 | } |
2208 | c->set_output(0, out); |
2209 | return OkStatus(); |
2210 | } |
2211 | |
2212 | namespace { |
2213 | |
2214 | // This SliceHelper processes the output shape of the `slice` |
2215 | // when the tensor of `sizes` is available. |
2216 | template <typename T> |
2217 | Status SliceHelper(InferenceContext* c, ShapeHandle begin_value, |
2218 | const Tensor* sizes_value, |
2219 | std::vector<DimensionHandle>* dims) { |
2220 | auto sizes_vec = sizes_value->vec<T>(); |
2221 | for (int i = 0; i < sizes_value->NumElements(); ++i) { |
2222 | DimensionHandle dim = c->Dim(c->input(0), i); |
2223 | if (sizes_vec(i) != -1) { |
2224 | auto dim_val = c->Value(dim); |
2225 | if (sizes_vec(i) < 0) { |
2226 | return errors::InvalidArgument( |
2227 | "Out of bounds slicing on dimension " , i, " of length " , dim_val, |
2228 | ": sizes vector cannot be < -1, but was " , sizes_vec(i)); |
2229 | } |
2230 | |
2231 | dims->emplace_back(c->MakeDim(sizes_vec(i))); |
2232 | } else { |
2233 | DimensionHandle result; |
2234 | TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result)); |
2235 | dims->emplace_back(result); |
2236 | } |
2237 | } |
2238 | |
2239 | return OkStatus(); |
2240 | } |
2241 | } // namespace |
2242 | |
2243 | Status SliceShape(InferenceContext* c) { |
2244 | ShapeHandle input = c->input(0); |
2245 | ShapeHandle begin_shape; |
2246 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape)); |
2247 | ShapeHandle sizes_shape; |
2248 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape)); |
2249 | |
2250 | // Merge to check compatibility of begin and sizes tensors. |
2251 | TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape)); |
2252 | |
2253 | DimensionHandle ndims = c->Dim(begin_shape, 0); |
2254 | if (c->ValueKnown(ndims)) { |
2255 | TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input)); |
2256 | } |
2257 | |
2258 | // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known |
2259 | // values, even though the `begin` value does not represent a shape. |
2260 | ShapeHandle begin_value; |
2261 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value)); |
2262 | |
2263 | // We check the tensor value here and will only use |
2264 | // `MakeShapeFromShapeTensor` when `sizes_value` is null. |
2265 | // The reason is that `sizes` might contain -1, which can't |
2266 | // be represented (-1 in the ShapeHandle would mean "unknown"). |
2267 | const Tensor* sizes_value = c->input_tensor(2); |
2268 | |
2269 | if (sizes_value != nullptr) { |
2270 | TF_RETURN_IF_ERROR( |
2271 | c->WithRank(begin_value, sizes_value->NumElements(), &begin_value)); |
2272 | std::vector<DimensionHandle> dims; |
2273 | // If the begin and sizes tensors are available, then |
2274 | // we can be precise about the shape of the output. |
2275 | if (sizes_value->dtype() == DT_INT64) { |
2276 | TF_RETURN_IF_ERROR( |
2277 | SliceHelper<int64_t>(c, begin_value, sizes_value, &dims)); |
2278 | } else { |
2279 | TF_RETURN_IF_ERROR( |
2280 | SliceHelper<int32>(c, begin_value, sizes_value, &dims)); |
2281 | } |
2282 | c->set_output(0, c->MakeShape(dims)); |
2283 | return OkStatus(); |
2284 | } else { |
2285 | // In case `sizes` is not available (`sizes_value` is null), |
2286 | // we could try to use `MakeShapeFromShapeTensor` here. |
2287 | // If sizes contain -1, we will simply consider it as `Unknown`. |
2288 | // This is less than ideal but still an improvement of shape inference. |
2289 | // The following is an example that returns [None, 1, None] with this |
2290 | // code path: |
2291 | // z = tf.zeros((1, 2, 3)) |
2292 | // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1]) |
2293 | // m.get_shape().as_list() |
2294 | ShapeHandle sizes_value; |
2295 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value)); |
2296 | if (c->RankKnown(sizes_value)) { |
2297 | TF_RETURN_IF_ERROR( |
2298 | c->WithRank(begin_value, c->Rank(sizes_value), &begin_value)); |
2299 | std::vector<DimensionHandle> dims; |
2300 | dims.reserve(c->Rank(sizes_value)); |
2301 | for (int i = 0; i < c->Rank(sizes_value); ++i) { |
2302 | dims.emplace_back(c->Dim(sizes_value, i)); |
2303 | } |
2304 | c->set_output(0, c->MakeShape(dims)); |
2305 | return OkStatus(); |
2306 | } |
2307 | // We might know the rank of the input. |
2308 | if (c->RankKnown(input)) { |
2309 | c->set_output(0, c->UnknownShapeOfRank(c->Rank(input))); |
2310 | return OkStatus(); |
2311 | } else { |
2312 | return shape_inference::UnknownShape(c); |
2313 | } |
2314 | } |
2315 | |
2316 | return OkStatus(); |
2317 | } |
2318 | |
2319 | Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, |
2320 | ShapeHandle values_shape, ShapeHandle shape_shape) { |
2321 | // Validate ranks. |
2322 | ShapeHandle unused_shape; |
2323 | TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape)); |
2324 | TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape)); |
2325 | TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape)); |
2326 | |
2327 | // Number of elements in indices and values must match. |
2328 | DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0); |
2329 | if (c->ValueKnown(num_index_elements_dim)) { |
2330 | DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0); |
2331 | if (c->ValueKnown(num_values_elements_dim)) { |
2332 | int64_t num_index_elements = c->Value(num_index_elements_dim); |
2333 | int64_t num_values_elements = c->Value(num_values_elements_dim); |
2334 | if (num_index_elements != num_values_elements) { |
2335 | return errors::InvalidArgument("Number of elements in index (" , |
2336 | num_index_elements, ") and values (" , |
2337 | num_values_elements, ") do not match." ); |
2338 | } |
2339 | } |
2340 | } |
2341 | |
2342 | // Rank embedded in indices must match shape. |
2343 | DimensionHandle index_rank_dim = c->Dim(indices_shape, 1); |
2344 | if (c->ValueKnown(index_rank_dim)) { |
2345 | DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0); |
2346 | if (c->ValueKnown(shape_rank_dim)) { |
2347 | int64_t index_rank = c->Value(index_rank_dim); |
2348 | int32_t shape_rank = c->Value(shape_rank_dim); |
2349 | if (index_rank != shape_rank) { |
2350 | return errors::InvalidArgument("Index rank (" , index_rank, |
2351 | ") and shape rank (" , shape_rank, |
2352 | ") do not match." ); |
2353 | } |
2354 | } |
2355 | } |
2356 | |
2357 | return OkStatus(); |
2358 | } |
2359 | |
2360 | Status ValidateVariableResourceHandle( |
2361 | InferenceContext* c, std::vector<ShapeAndType>* shape_and_type) { |
2362 | auto* handle_data = c->input_handle_shapes_and_types(0); |
2363 | if (handle_data == nullptr || handle_data->empty()) { |
2364 | shape_and_type->emplace_back(c->UnknownShape(), DT_INVALID); |
2365 | } else { |
2366 | *shape_and_type = *handle_data; |
2367 | DataType value_dtype; |
2368 | TF_RETURN_IF_ERROR(c->GetAttr("dtype" , &value_dtype)); |
2369 | if (shape_and_type->at(0).dtype != value_dtype) { |
2370 | return errors::InvalidArgument( |
2371 | "Trying to read variable with wrong dtype. " |
2372 | "Expected " , |
2373 | DataTypeString(shape_and_type->at(0).dtype), " got " , |
2374 | DataTypeString(value_dtype)); |
2375 | } |
2376 | } |
2377 | return OkStatus(); |
2378 | } |
2379 | |
2380 | Status GatherNdShape(InferenceContext* c) { |
2381 | ShapeHandle params; |
2382 | std::vector<ShapeAndType> handle_shape_and_type; |
2383 | if (c->input_handle_shapes_and_types(0) != nullptr) { |
2384 | TF_RETURN_IF_ERROR( |
2385 | ValidateVariableResourceHandle(c, &handle_shape_and_type)); |
2386 | params = handle_shape_and_type[0].shape; |
2387 | } else { |
2388 | params = c->input(0); |
2389 | } |
2390 | ShapeHandle indices; |
2391 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices)); |
2392 | DimensionHandle r_dim = c->Dim(indices, -1); |
2393 | |
2394 | if (!c->RankKnown(params) || !c->ValueKnown(r_dim)) { |
2395 | c->set_output(0, c->UnknownShape()); |
2396 | return OkStatus(); |
2397 | } |
2398 | |
2399 | if (c->Value(r_dim) > c->Rank(params)) { |
2400 | return errors::InvalidArgument( |
2401 | "indices.shape[-1] must be <= params.rank, but saw indices shape: " , |
2402 | c->DebugString(indices), " and params shape: " , c->DebugString(params)); |
2403 | } |
2404 | |
2405 | // Remove r_dim from indices to get output. |
2406 | ShapeHandle indices_slice; |
2407 | ShapeHandle params_slice; |
2408 | TF_RETURN_IF_ERROR(c->Subshape(indices, 0, -1, &indices_slice)); |
2409 | TF_RETURN_IF_ERROR(c->Subshape(params, c->Value(r_dim), ¶ms_slice)); |
2410 | ShapeHandle out; |
2411 | TF_RETURN_IF_ERROR(c->Concatenate(indices_slice, params_slice, &out)); |
2412 | c->set_output(0, out); |
2413 | return OkStatus(); |
2414 | } |
2415 | |
2416 | Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape, |
2417 | ShapeHandle updates_shape, |
2418 | ShapeHandle input_shape) { |
2419 | if (c->Value(c->NumElements(input_shape)) == 0 && |
2420 | (c->Value(c->NumElements(indices_shape)) > 0 || |
2421 | c->Value(c->NumElements(updates_shape)) > 0)) { |
2422 | return errors::InvalidArgument( |
2423 | "Indices and updates specified for empty input" ); |
2424 | } |
2425 | |
2426 | if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape) && |
2427 | c->Rank(updates_shape) != 0) { |
2428 | const int64_t outer_dims = c->Rank(indices_shape) - 1; |
2429 | const DimensionHandle ixdim = c->Dim(indices_shape, -1); |
2430 | |
2431 | // We can only do more validation if the last dimension of indices |
2432 | // is a known value. |
2433 | if (c->ValueKnown(ixdim)) { |
2434 | int64_t ix = c->Value(ixdim); |
2435 | ShapeHandle unused; |
2436 | ShapeHandle prefix_indices; |
2437 | TF_RETURN_IF_ERROR( |
2438 | c->Subshape(indices_shape, 0, outer_dims, &prefix_indices)); |
2439 | ShapeHandle prefix_updates; |
2440 | TF_RETURN_IF_ERROR( |
2441 | c->Subshape(updates_shape, 0, outer_dims, &prefix_updates)); |
2442 | |
2443 | Status s = c->Merge(prefix_indices, prefix_updates, &unused); |
2444 | if (!s.ok()) { |
2445 | return errors::InvalidArgument( |
2446 | "Dimensions [0," , outer_dims, |
2447 | ") of indices[shape=" , c->DebugString(indices_shape), |
2448 | "] = " , c->DebugString(prefix_indices), |
2449 | " must match dimensions [0," , outer_dims, |
2450 | ") of updates[shape=" , c->DebugString(updates_shape), |
2451 | "] = " , c->DebugString(prefix_updates), ": " , s.error_message()); |
2452 | } |
2453 | |
2454 | ShapeHandle suffix_output; |
2455 | TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &suffix_output)); |
2456 | ShapeHandle suffix_updates; |
2457 | TF_RETURN_IF_ERROR( |
2458 | c->Subshape(updates_shape, outer_dims, &suffix_updates)); |
2459 | s = c->Merge(suffix_output, suffix_updates, &unused); |
2460 | if (!s.ok()) { |
2461 | return errors::InvalidArgument( |
2462 | "Dimensions [" , ix, "," , c->Rank(input_shape), |
2463 | ") of input[shape=" , c->DebugString(input_shape), |
2464 | "] = " , c->DebugString(suffix_output), " must match dimensions [" , |
2465 | outer_dims, "," , c->Rank(updates_shape), |
2466 | ") of updates[shape=" , c->DebugString(updates_shape), |
2467 | "] = " , c->DebugString(suffix_updates), ": " , s.error_message()); |
2468 | } |
2469 | } |
2470 | } |
2471 | |
2472 | if (c->input_handle_shapes_and_types(0) == nullptr && c->num_outputs() > 0) { |
2473 | // This is called for tf.scatter_nd; output is a tensor with this shape. |
2474 | c->set_output(0, input_shape); |
2475 | } |
2476 | return OkStatus(); |
2477 | } |
2478 | |
2479 | Status ExplicitShape(InferenceContext* c) { |
2480 | PartialTensorShape shape; |
2481 | TF_RETURN_IF_ERROR(c->GetAttr("shape" , &shape)); |
2482 | ShapeHandle output_shape; |
2483 | TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &output_shape)); |
2484 | c->set_output(0, output_shape); |
2485 | return OkStatus(); |
2486 | } |
2487 | |
2488 | Status ExplicitShapes(InferenceContext* c) { |
2489 | std::vector<PartialTensorShape> shapes; |
2490 | TF_RETURN_IF_ERROR(c->GetAttr("shapes" , &shapes)); |
2491 | if (shapes.empty()) { |
2492 | return errors::Internal("shapes attribute is empty" ); |
2493 | } |
2494 | for (int i = 0, end = shapes.size(); i < end; ++i) { |
2495 | ShapeHandle output_shape; |
2496 | TF_RETURN_IF_ERROR( |
2497 | c->MakeShapeFromPartialTensorShape(shapes[i], &output_shape)); |
2498 | c->set_output(i, output_shape); |
2499 | } |
2500 | return OkStatus(); |
2501 | } |
2502 | |
2503 | Status SparseReduceShapeFn(InferenceContext* c) { |
2504 | // Input 0: input_indices |
2505 | // Input 1: input_values |
2506 | // Input 2: input_shape |
2507 | // Input 3: reduction_axes |
2508 | // Attr: keep_dims |
2509 | bool keep_dims = false; |
2510 | TF_RETURN_IF_ERROR(c->GetAttr("keep_dims" , &keep_dims)); |
2511 | |
2512 | const Tensor* shape_tensor = c->input_tensor(2); |
2513 | const Tensor* axes_tensor = c->input_tensor(3); |
2514 | if (shape_tensor != nullptr && axes_tensor != nullptr) { |
2515 | auto shape_vec = shape_tensor->flat<int64_t>(); |
2516 | auto axes_vec = axes_tensor->flat<int32>(); |
2517 | |
2518 | int64_t ndims = shape_vec.size(); |
2519 | absl::flat_hash_set<int64_t> axes; |
2520 | if (ndims == 0) |
2521 | return errors::InvalidArgument( |
2522 | "Number of dims in shape tensor must not be 0" ); |
2523 | for (int i = 0; i < axes_vec.size(); i++) { |
2524 | axes.insert((axes_vec(i) + ndims) % ndims); |
2525 | } |
2526 | |
2527 | std::vector<DimensionHandle> dims; |
2528 | if (keep_dims) { |
2529 | dims.reserve(ndims); |
2530 | for (int d = 0; d < ndims; ++d) { |
2531 | if (axes.find(d) == axes.end()) { |
2532 | dims.push_back(c->MakeDim(shape_vec(d))); |
2533 | } else { |
2534 | dims.push_back(c->MakeDim(1)); |
2535 | } |
2536 | } |
2537 | } else { |
2538 | for (int d = 0; d < ndims; ++d) { |
2539 | if (axes.find(d) == axes.end()) { |
2540 | dims.push_back(c->MakeDim(shape_vec(d))); |
2541 | } |
2542 | } |
2543 | } |
2544 | |
2545 | c->set_output(0, c->MakeShape(dims)); |
2546 | return OkStatus(); |
2547 | } |
2548 | return UnknownShape(c); |
2549 | } |
2550 | |
2551 | Status QuantizedConv2DShape(InferenceContext* c) { |
2552 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
2553 | ShapeHandle unused; |
2554 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
2555 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
2556 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
2557 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); |
2558 | c->set_output(1, c->Scalar()); |
2559 | c->set_output(2, c->Scalar()); |
2560 | return OkStatus(); |
2561 | } |
2562 | |
2563 | Status FusedQuantizedConvShape(InferenceContext* c, int num_dims) { |
2564 | std::vector<string> fused_ops; |
2565 | TF_RETURN_IF_ERROR(c->GetAttr("fused_ops" , &fused_ops)); |
2566 | ShapeHandle unused, channel; |
2567 | bool fused_sum, fused_bias, fused_requantize; |
2568 | fused_sum = |
2569 | std::find(fused_ops.begin(), fused_ops.end(), "Sum" ) != fused_ops.end(); |
2570 | fused_bias = std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd" ) != |
2571 | fused_ops.end(); |
2572 | fused_requantize = std::find(fused_ops.begin(), fused_ops.end(), |
2573 | "Requantize" ) != fused_ops.end(); |
2574 | const int kMinInputBaseIdx = 2; |
2575 | const int kMinFilterBaseIdx = 4; |
2576 | int min_input_filter_offset = 0; |
2577 | if (fused_bias && !fused_sum) { |
2578 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); // bias |
2579 | min_input_filter_offset = 1; |
2580 | } else if (fused_sum && !fused_bias) { |
2581 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), num_dims, &unused)); // summand |
2582 | min_input_filter_offset = 1; |
2583 | } else if (fused_bias && fused_sum) { |
2584 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); // bias |
2585 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), num_dims, &unused)); // summand |
2586 | min_input_filter_offset = 2; |
2587 | } |
2588 | TF_RETURN_IF_ERROR( |
2589 | c->WithRank(c->input(kMinInputBaseIdx + min_input_filter_offset), 0, |
2590 | &unused)); // min_input |
2591 | TF_RETURN_IF_ERROR( |
2592 | c->WithRank(c->input(kMinInputBaseIdx + min_input_filter_offset + 1), 0, |
2593 | &unused)); // max_input |
2594 | TF_RETURN_IF_ERROR( |
2595 | c->WithRankAtMost(c->input(kMinFilterBaseIdx + min_input_filter_offset), |
2596 | 1, &channel)); // min_filter |
2597 | TF_RETURN_IF_ERROR(c->WithRankAtMost( |
2598 | c->input(kMinFilterBaseIdx + min_input_filter_offset + 1), 1, |
2599 | &channel)); // max_filter |
2600 | if (fused_requantize) { |
2601 | c->set_output(1, c->Scalar()); |
2602 | c->set_output(2, c->Scalar()); |
2603 | } else { |
2604 | c->set_output(1, channel); |
2605 | c->set_output(2, channel); |
2606 | } |
2607 | return Status::OK(); |
2608 | } |
2609 | |
2610 | Status FusedQuantizedConv2DShape(InferenceContext* c) { |
2611 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShapeImpl(c, true)); |
2612 | TF_RETURN_IF_ERROR(FusedQuantizedConvShape(c, 4)); |
2613 | return Status::OK(); |
2614 | } |
2615 | |
2616 | Status FusedQuantizedDepthwiseConv2D(InferenceContext* c) { |
2617 | TF_RETURN_IF_ERROR(DepthwiseConv2DNativeShapeImpl(c, true)); |
2618 | TF_RETURN_IF_ERROR(FusedQuantizedConvShape(c, 4)); |
2619 | return Status::OK(); |
2620 | } |
2621 | |
2622 | Status QuantizedAvgPoolShape(InferenceContext* c) { |
2623 | TF_RETURN_IF_ERROR(shape_inference::AvgPoolShape(c)); |
2624 | ShapeHandle unused; |
2625 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
2626 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
2627 | c->set_output(1, c->Scalar()); |
2628 | c->set_output(2, c->Scalar()); |
2629 | return OkStatus(); |
2630 | } |
2631 | |
2632 | Status QuantizeV2Shape(InferenceContext* c) { |
2633 | int axis = -1; |
2634 | Status s = c->GetAttr("axis" , &axis); |
2635 | if (!s.ok() && s.code() != error::NOT_FOUND) { |
2636 | return s; |
2637 | } |
2638 | if (axis < -1) { |
2639 | return errors::InvalidArgument("axis should be at least -1, got " , axis); |
2640 | } |
2641 | const int minmax_rank = (axis == -1) ? 0 : 1; |
2642 | TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); |
2643 | ShapeHandle minmax; |
2644 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax)); |
2645 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), minmax_rank, &minmax)); |
2646 | if (axis != -1) { |
2647 | ShapeHandle input; |
2648 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input)); |
2649 | DimensionHandle depth; |
2650 | TF_RETURN_IF_ERROR( |
2651 | c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth)); |
2652 | } |
2653 | c->set_output(1, minmax); |
2654 | c->set_output(2, minmax); |
2655 | return OkStatus(); |
2656 | } |
2657 | |
2658 | Status ReduceScatterShape(shape_inference::InferenceContext* c) { |
2659 | shape_inference::ShapeHandle in = c->input(0); |
2660 | if (!c->RankKnown(in)) { |
2661 | // Input shape unknown, so set unknown output shape. |
2662 | c->set_output(0, in); |
2663 | return OkStatus(); |
2664 | } |
2665 | |
2666 | shape_inference::ShapeHandle group_assignment_shape = c->input(1); |
2667 | if (c->Rank(group_assignment_shape) != 2) |
2668 | return errors::InvalidArgument( |
2669 | "ReduceScatter group_assignment should be rank 2" ); |
2670 | |
2671 | const Tensor* scatter_dimension = c->input_tensor(2); |
2672 | if (!scatter_dimension) { |
2673 | c->set_output(0, c->UnknownShape()); |
2674 | return OkStatus(); |
2675 | } |
2676 | int64_t scatter_dim; |
2677 | TF_RETURN_IF_ERROR(c->GetScalarFromTensor(scatter_dimension, &scatter_dim)); |
2678 | |
2679 | std::vector<shape_inference::DimensionHandle> out_dims; |
2680 | out_dims.reserve(c->Rank(in)); |
2681 | for (int i = 0; i < c->Rank(in); ++i) { |
2682 | // If the dimension is the scatter_dimension, then divide the dimension |
2683 | // by the partition size in the group_assignment. |
2684 | if (i == scatter_dim) { |
2685 | shape_inference::DimensionHandle dim = c->Dim(in, i); |
2686 | shape_inference::DimensionHandle out_dim; |
2687 | TF_RETURN_IF_ERROR(c->Divide(dim, c->Dim(group_assignment_shape, 1), |
2688 | /*evenly_divisible=*/true, &out_dim)); |
2689 | out_dims.push_back(out_dim); |
2690 | } else { |
2691 | out_dims.emplace_back(c->Dim(in, i)); |
2692 | } |
2693 | } |
2694 | c->set_output(0, c->MakeShape(out_dims)); |
2695 | return OkStatus(); |
2696 | } |
2697 | |
2698 | } // namespace shape_inference |
2699 | |
2700 | } // namespace tensorflow |
2701 | |