1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/framework/common_shape_fns.h"
16
17#include "absl/container/flat_hash_map.h"
18#include "absl/container/flat_hash_set.h"
19#include "absl/strings/match.h"
20#include "absl/strings/str_split.h"
21#include "absl/strings/string_view.h"
22#include "tensorflow/core/framework/attr_value.pb.h"
23#include "tensorflow/core/framework/shape_inference.h"
24#include "tensorflow/core/lib/core/errors.h"
25#include "tensorflow/core/lib/gtl/inlined_vector.h"
26#include "tensorflow/core/util/einsum_op_util.h"
27
28namespace tensorflow {
29
30namespace shape_inference {
31
32// The V2 version computes windowed output size with arbitrary dilation_rate and
33// explicit padding, while the original version only handles the cases where
34// dilation_rates equal to 1 and the padding is SAME or VALID.
35Status GetWindowedOutputSizeFromDimsV2(
36 shape_inference::InferenceContext* c,
37 shape_inference::DimensionHandle input_size,
38 shape_inference::DimensionOrConstant filter_size, int64_t dilation_rate,
39 int64_t stride, Padding padding_type, int64_t padding_before,
40 int64_t padding_after, shape_inference::DimensionHandle* output_size) {
41 if (stride <= 0) {
42 return errors::InvalidArgument("Stride must be > 0, but got ", stride);
43 }
44
45 if (dilation_rate < 1) {
46 return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
47 dilation_rate);
48 }
49
50 // See also the parallel implementation in GetWindowedOutputSizeVerbose.
51 switch (padding_type) {
52 case Padding::VALID:
53 padding_before = padding_after = 0;
54 TF_FALLTHROUGH_INTENDED;
55 case Padding::EXPLICIT:
56 TF_RETURN_IF_ERROR(
57 c->Add(input_size, padding_before + padding_after, &input_size));
58 if (dilation_rate > 1) {
59 DimensionHandle window_size;
60 TF_RETURN_IF_ERROR(
61 c->Subtract(c->MakeDim(filter_size), 1, &window_size));
62 TF_RETURN_IF_ERROR(
63 c->Multiply(window_size, dilation_rate, &window_size));
64 TF_RETURN_IF_ERROR(c->Add(window_size, 1, &window_size));
65 TF_RETURN_IF_ERROR(c->Subtract(input_size, window_size, output_size));
66 } else {
67 TF_RETURN_IF_ERROR(c->Subtract(input_size, filter_size, output_size));
68 }
69 TF_RETURN_IF_ERROR(c->Add(*output_size, stride, output_size));
70 TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
71 /*evenly_divisible=*/false, output_size));
72 break;
73 case Padding::SAME:
74 TF_RETURN_IF_ERROR(c->Add(input_size, stride - 1, output_size));
75 TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
76 /*evenly_divisible=*/false, output_size));
77 break;
78 }
79 return OkStatus();
80}
81
82Status GetWindowedOutputSizeFromDims(
83 shape_inference::InferenceContext* c,
84 shape_inference::DimensionHandle input_size,
85 shape_inference::DimensionOrConstant filter_size, int64_t stride,
86 Padding padding_type, shape_inference::DimensionHandle* output_size) {
87 if (padding_type == Padding::EXPLICIT) {
88 return errors::Internal(
89 "GetWindowedOutputSizeFromDims does not handle EXPLICIT padding; call "
90 "GetWindowedOutputSizeFromDimsV2 instead");
91 }
92 return GetWindowedOutputSizeFromDimsV2(c, input_size, filter_size,
93 /*dilation_rate=*/1, stride,
94 padding_type,
95 // Give dummy values of -1 to
96 // padding_before and padding_after,
97 // since explicit padding is not used.
98 -1, -1, output_size);
99}
100
101Status UnchangedShape(shape_inference::InferenceContext* c) {
102 c->set_output(0, c->input(0));
103 auto* handle_data = c->input_handle_shapes_and_types(0);
104 if (handle_data != nullptr) {
105 c->set_output_handle_shapes_and_types(0, *handle_data);
106 }
107 return OkStatus();
108}
109
110Status MatMulShape(shape_inference::InferenceContext* c) {
111 ShapeHandle a;
112 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a));
113
114 ShapeHandle b;
115 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &b));
116
117 bool transpose_a, transpose_b;
118 TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a));
119 TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b));
120 DimensionHandle output_rows = transpose_a ? c->Dim(a, 1) : c->Dim(a, 0);
121 DimensionHandle output_cols = transpose_b ? c->Dim(b, 0) : c->Dim(b, 1);
122
123 // Validate that the inner shapes are compatible.
124 DimensionHandle inner_a = transpose_a ? c->Dim(a, 0) : c->Dim(a, 1);
125 DimensionHandle inner_b = transpose_b ? c->Dim(b, 1) : c->Dim(b, 0);
126 DimensionHandle merged;
127 TF_RETURN_IF_ERROR(c->Merge(inner_a, inner_b, &merged));
128
129 c->set_output(0, c->Matrix(output_rows, output_cols));
130 return OkStatus();
131}
132
133namespace {
134
135// Validate that an Einsum subscript contains exactly one or zero ellipsis; and
136// that periods (.) occur only within an ellipses (...).
137Status ValidateEinsumEllipsis(absl::string_view subscript,
138 bool* found_ellipsis) {
139 const int num_periods = absl::c_count(subscript, '.');
140 if (num_periods != 0 && num_periods != 3) {
141 return errors::InvalidArgument(
142 "Expected at most one ellipsis (...), but found ", num_periods,
143 " periods (.) in the input subscript: ", subscript);
144 }
145 if (num_periods == 3 && !absl::StrContains(subscript, "...")) {
146 return errors::InvalidArgument(
147 "Periods found outside of ellipsis in subscript: ", subscript);
148 }
149 *found_ellipsis = num_periods > 0;
150 return OkStatus();
151}
152
153} // namespace
154
155Status EinsumShape(shape_inference::InferenceContext* c) {
156 // We assume that the equation has a valid format. Either (x),(y)->(z)
157 // or (x)->(z), where each of (x), (y) and (z) are concatenation of zero or
158 // more latin alphabets and contains at most one ellipsis ('...').
159 string equation;
160 TF_RETURN_IF_ERROR(c->GetAttr("equation", &equation));
161 gtl::InlinedVector<string, 2> input_labels;
162 string output_labels;
163 TF_RETURN_IF_ERROR(
164 ValidateEinsumEquation(equation, &input_labels, &output_labels));
165
166 if (c->num_inputs() == 0 || c->num_inputs() > 2) {
167 return errors::InvalidArgument("Expected either 1 or 2 inputs but got: ",
168 c->num_inputs());
169 }
170 const int input_labels_size = input_labels.size();
171 if (c->num_inputs() != input_labels_size) {
172 return errors::InvalidArgument("Expected ", input_labels.size(),
173 " inputs for equation ", equation,
174 " but got: ", c->num_inputs());
175 }
176
177 // Validate input subscripts, build the label to dimension mapping and obtain
178 // the broadcast shapes that map to ellipsis.
179 absl::flat_hash_map<char, DimensionHandle> label_to_dimension;
180 gtl::InlinedVector<ShapeHandle, 2> input_bcast_shapes(c->num_inputs());
181 for (int i = 0, end = c->num_inputs(); i < end; ++i) {
182 bool has_ellipsis = false;
183 TF_RETURN_IF_ERROR(ValidateEinsumEllipsis(input_labels[i], &has_ellipsis));
184 ShapeHandle input_shape = c->input(i);
185 // Validate that the input rank is sufficient for the given number of named
186 // labels.
187 if (c->RankKnown(input_shape)) {
188 if (has_ellipsis) {
189 const int num_named_labels =
190 static_cast<int>(input_labels[i].size()) - 3;
191 TF_RETURN_WITH_CONTEXT_IF_ERROR(
192 c->WithRankAtLeast(input_shape, num_named_labels, &input_shape),
193 " for ", i, "th input and equation: ", equation);
194 } else {
195 const int num_named_labels = static_cast<int>(input_labels[i].size());
196 TF_RETURN_WITH_CONTEXT_IF_ERROR(
197 c->WithRank(input_shape, num_named_labels, &input_shape), " for ",
198 i, "th input and equation: ", equation);
199 }
200 }
201
202 bool seen_ellipsis = false;
203 input_bcast_shapes[i] = c->Scalar();
204 // Run through the input labels; populate label_to_dimension mapping and
205 // compute the broadcast shapes corresponding to the ellipsis (if present).
206 for (int label_idx = 0, end = input_labels[i].size(); label_idx < end;
207 ++label_idx) {
208 const char label = input_labels[i][label_idx];
209 // Calculate the input axis that the current label is referring to. After
210 // the ellipsis, the axis may be found by using negative indices; i.e the
211 // (rank - k)th dimension corresponds to the (num_labels - k)th label.
212 const int64_t axis_before_ellipsis = label_idx;
213 const int64_t axis_after_ellipsis =
214 c->RankKnown(input_shape)
215 ? label_idx + c->Rank(input_shape) - input_labels[i].size()
216 : -1;
217
218 // Populate the input broadcast shape when we encounter an ellipsis (...).
219 if (label == '.') {
220 if (!c->RankKnown(input_shape)) {
221 input_bcast_shapes[i] = c->UnknownShape();
222 } else {
223 // The broadcast shape runs till the named label right after the
224 // ellipsis, the label with index (label_idx + 3).
225 TF_RETURN_IF_ERROR(c->Subshape(input_shape, axis_before_ellipsis,
226 axis_after_ellipsis + 3,
227 &input_bcast_shapes[i]));
228 }
229 label_idx += 2; // Skip the rest of the ellipsis.
230 seen_ellipsis = true;
231 continue;
232 }
233 // Obtain the dimension that the current label corresponds to.
234 int64_t axis = seen_ellipsis ? axis_after_ellipsis : axis_before_ellipsis;
235 DimensionHandle new_dim = c->RankKnown(input_shape)
236 ? c->Dim(input_shape, axis)
237 : c->UnknownDim();
238 // If we've seen this label before, make sure previous and current
239 // dimensions are compatible.
240 if (label_to_dimension.contains(label)) {
241 DimensionHandle merged;
242 TF_RETURN_IF_ERROR(
243 c->Merge(label_to_dimension[label], new_dim, &merged));
244 label_to_dimension[label] = merged;
245 } else {
246 label_to_dimension[label] = new_dim;
247 }
248 }
249 }
250
251 // For two inputs, broadcast the two input broadcast shapes to create the
252 // output broadcast shape. For one input, just copy the single broadcast
253 // shape.
254 ShapeHandle output_bcast_shape;
255 if (input_bcast_shapes.size() == 1) {
256 output_bcast_shape = input_bcast_shapes[0];
257 } else if (input_bcast_shapes.size() == 2) {
258 TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
259 c, input_bcast_shapes[0], input_bcast_shapes[1], true,
260 &output_bcast_shape));
261 }
262
263 bool output_has_ellipsis = false;
264 TF_RETURN_IF_ERROR(
265 ValidateEinsumEllipsis(output_labels, &output_has_ellipsis));
266 if (output_has_ellipsis) {
267 // If the output subscript has ellipsis and the output broadcast rank is
268 // unknown, then the output shape should have unknown rank.
269 if (!c->RankKnown(output_bcast_shape)) {
270 c->set_output(0, c->UnknownShape());
271 return OkStatus();
272 }
273 } else {
274 // If the output subscripts don't have ellipsis then make sure the output
275 // broadcasting shape is empty.
276 TF_RETURN_WITH_CONTEXT_IF_ERROR(
277 c->WithRankAtMost(output_bcast_shape, 0, &output_bcast_shape),
278 " for einsum equation '", equation,
279 "' without ellipsis (...) in the output subscripts where input(s) have "
280 "non-empty broadcasting shape");
281 output_bcast_shape = c->Scalar();
282 }
283
284 // Create the output shape from output labels and label_to_dimension mapping.
285 std::vector<DimensionHandle> output_dims;
286 for (int label_idx = 0, end = output_labels.size(); label_idx < end;
287 ++label_idx) {
288 const char label = output_labels[label_idx];
289 // Append the output_bcast_shape when the ellipsis is encountered.
290 if (label == '.') {
291 for (int k = 0; k < c->Rank(output_bcast_shape); ++k) {
292 output_dims.push_back(c->Dim(output_bcast_shape, k));
293 }
294 label_idx += 2; // Skip the rest of the ellipsis.
295 continue;
296 }
297 auto dimension_it = label_to_dimension.find(label);
298 if (dimension_it == label_to_dimension.end()) {
299 return errors::InvalidArgument(
300 "Einsum output subscripts for equation '", equation, "' has label '",
301 label, "' which is not present in the input subscripts");
302 }
303 output_dims.push_back(dimension_it->second);
304 }
305 c->set_output(0, c->MakeShape(output_dims));
306 return OkStatus();
307}
308
309Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) {
310 ShapeHandle a_shape;
311 ShapeHandle b_shape;
312 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape));
313 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape));
314
315 // Determine output rows and columns.
316 bool adj_x;
317 bool adj_y;
318 TF_RETURN_IF_ERROR(c->GetAttr("adj_x", &adj_x));
319 TF_RETURN_IF_ERROR(c->GetAttr("adj_y", &adj_y));
320 DimensionHandle output_rows = c->Dim(a_shape, adj_x ? -1 : -2);
321 DimensionHandle output_cols = c->Dim(b_shape, adj_y ? -2 : -1);
322
323 // Inner dimensions should be compatible.
324 DimensionHandle inner_merged;
325 TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1),
326 c->Dim(b_shape, adj_y ? -1 : -2), &inner_merged));
327
328 // Batch dimensions should broadcast with each other.
329 ShapeHandle a_batch_shape;
330 ShapeHandle b_batch_shape;
331 ShapeHandle output_batch_shape;
332 TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_shape));
333 TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_shape));
334
335 TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
336 c, a_batch_shape, b_batch_shape, true, &output_batch_shape));
337
338 ShapeHandle output_shape;
339 TF_RETURN_IF_ERROR(c->Concatenate(
340 output_batch_shape, c->Matrix(output_rows, output_cols), &output_shape));
341
342 c->set_output(0, output_shape);
343 return OkStatus();
344}
345
346Status BatchMatMulShape(shape_inference::InferenceContext* c) {
347 ShapeHandle a_shape;
348 ShapeHandle b_shape;
349 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape));
350 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape));
351
352 // Determine output rows and cols.
353 bool adj_x;
354 bool adj_y;
355 TF_RETURN_IF_ERROR(c->GetAttr("adj_x", &adj_x));
356 TF_RETURN_IF_ERROR(c->GetAttr("adj_y", &adj_y));
357 DimensionHandle output_rows = c->Dim(a_shape, adj_x ? -1 : -2);
358 DimensionHandle output_cols = c->Dim(b_shape, adj_y ? -2 : -1);
359
360 // Batch dims match between inputs.
361 ShapeHandle a_batch_dims;
362 ShapeHandle b_batch_dims;
363 ShapeHandle batch_dims;
364 TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims));
365 TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims));
366 TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims));
367
368 // Assert inner dims match.
369 DimensionHandle unused;
370 TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1),
371 c->Dim(b_shape, adj_y ? -1 : -2), &unused));
372
373 ShapeHandle out;
374 TF_RETURN_IF_ERROR(
375 c->Concatenate(batch_dims, c->Matrix(output_rows, output_cols), &out));
376 c->set_output(0, out);
377 return OkStatus();
378}
379
380// --------------------------------------------------------------------------
381
382Status BiasAddShape(shape_inference::InferenceContext* c) {
383 ShapeHandle input_shape;
384
385 // Fetch the data_format attribute, which may not exist.
386 string data_format;
387 Status s = c->GetAttr("data_format", &data_format);
388
389 if (s.ok() && data_format == "NCHW") {
390 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
391 } else {
392 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
393 }
394
395 ShapeHandle bias_shape;
396 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &bias_shape));
397 DimensionHandle bias_dim = c->Dim(bias_shape, 0);
398
399 // If rank unknown, return unknown shape.
400 if (!c->RankKnown(input_shape)) {
401 c->set_output(0, c->UnknownShape());
402 return OkStatus();
403 }
404
405 // Output has the same shape as the input, and matches the length of
406 // the bias in its bias dimension.
407 ShapeHandle output_shape;
408 if (s.ok() && data_format == "NCHW") {
409 // Merge the length of bias_shape into the third to last dimension
410 ShapeHandle first;
411 TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, 1, &first));
412
413 ShapeHandle last;
414 TF_RETURN_IF_ERROR(c->Subshape(input_shape, 2, &last));
415
416 DimensionHandle input_bias_dim = c->Dim(input_shape, 1);
417 DimensionHandle merged_bias_dim;
418 TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
419 ShapeHandle merged_bias = c->Vector(merged_bias_dim);
420
421 ShapeHandle temp;
422 TF_RETURN_IF_ERROR(c->Concatenate(first, merged_bias, &temp));
423 TF_RETURN_IF_ERROR(c->Concatenate(temp, last, &output_shape));
424 } else {
425 ShapeHandle all_but_bias;
426 TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -1, &all_but_bias));
427
428 DimensionHandle input_bias_dim = c->Dim(input_shape, -1);
429 DimensionHandle merged_bias_dim;
430 TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
431
432 ShapeHandle merged_bias = c->Vector(merged_bias_dim);
433 TF_RETURN_IF_ERROR(
434 c->Concatenate(all_but_bias, merged_bias, &output_shape));
435 }
436
437 c->set_output(0, output_shape);
438 return OkStatus();
439}
440
441Status BiasAddGradShape(shape_inference::InferenceContext* c) {
442 ShapeHandle input_shape;
443 // Fetch the data_format attribute, which may not exist.
444 string data_format;
445 Status s = c->GetAttr("data_format", &data_format);
446
447 if (s.ok() && data_format == "NCHW") {
448 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
449 c->set_output(0, c->Vector(c->Dim(input_shape, 1)));
450 } else {
451 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
452 c->set_output(0, c->Vector(c->Dim(input_shape, -1)));
453 }
454
455 return OkStatus();
456}
457
458Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format,
459 const ShapeHandle shape_handle,
460 const string& tensor_name,
461 shape_inference::InferenceContext* c) {
462 if (tensor_format == FORMAT_NCHW_VECT_C) {
463 // Check that the vect dim has size 4 or 32.
464 const int num_dims = c->Rank(shape_handle);
465 DimensionHandle vect_dim = c->Dim(
466 shape_handle, GetTensorInnerFeatureDimIndex(num_dims, tensor_format));
467 int64_t vect_dim_val = c->Value(vect_dim);
468 if (vect_dim_val != 4 && vect_dim_val != 32) {
469 return errors::InvalidArgument(
470 "VECT_C dimension must be 4 or 32, but is ", vect_dim_val);
471 }
472 }
473
474 return OkStatus();
475}
476
477Status DatasetIteratorShape(shape_inference::InferenceContext* c) {
478 shape_inference::ShapeHandle unused;
479 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
480 std::vector<PartialTensorShape> output_shapes;
481 TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
482 const int output_shapes_size = output_shapes.size();
483 if (output_shapes_size != c->num_outputs()) {
484 return errors::InvalidArgument(
485 "`output_shapes` must be the same length as `output_types` (",
486 output_shapes.size(), " vs. ", c->num_outputs());
487 }
488 for (size_t i = 0; i < output_shapes.size(); ++i) {
489 shape_inference::ShapeHandle output_shape_handle;
490 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
491 output_shapes[i], &output_shape_handle));
492 c->set_output(static_cast<int>(i), output_shape_handle);
493 }
494 return OkStatus();
495}
496
497Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
498 const std::vector<DimensionOrConstant>& spatial,
499 DimensionOrConstant C, ShapeHandle* out,
500 shape_inference::InferenceContext* context) {
501 const int num_dims = GetTensorDimsFromSpatialDims(spatial.size(), format);
502 std::vector<DimensionHandle> dims_actual(num_dims);
503 dims_actual[GetTensorBatchDimIndex(num_dims, format)] = context->MakeDim(N);
504 int outer_c_index = GetTensorFeatureDimIndex(num_dims, format);
505 dims_actual[outer_c_index] = context->MakeDim(C);
506 if (format == FORMAT_NCHW_VECT_C) {
507 dims_actual[GetTensorInnerFeatureDimIndex(num_dims, format)] =
508 context->MakeDim(4);
509 } else if (format == FORMAT_NHWC_VECT_W) {
510 dims_actual[GetTensorInnerWidthDimIndex(num_dims, format)] =
511 context->MakeDim(4);
512 }
513 for (int spatial_dim = 0, end = spatial.size(); spatial_dim < end;
514 spatial_dim++) {
515 dims_actual[GetTensorSpatialDimIndex(num_dims, format, spatial_dim)] =
516 context->MakeDim(spatial[spatial_dim]);
517 }
518 *out = context->MakeShape(dims_actual);
519 return OkStatus();
520}
521
522Status DimensionsFromShape(ShapeHandle shape, TensorFormat format,
523 DimensionHandle* batch_dim,
524 gtl::MutableArraySlice<DimensionHandle> spatial_dims,
525 DimensionHandle* filter_dim,
526 InferenceContext* context) {
527 const int32_t rank =
528 GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
529 // Batch.
530 *batch_dim = context->Dim(shape, GetTensorBatchDimIndex(rank, format));
531 // Spatial.
532 for (int spatial_dim_index = 0, end = spatial_dims.size();
533 spatial_dim_index < end; ++spatial_dim_index) {
534 spatial_dims[spatial_dim_index] = context->Dim(
535 shape, GetTensorSpatialDimIndex(rank, format, spatial_dim_index));
536 }
537 // Channel.
538 *filter_dim = context->Dim(shape, GetTensorFeatureDimIndex(rank, format));
539 if (format == FORMAT_NCHW_VECT_C) {
540 TF_RETURN_IF_ERROR(context->Multiply(
541 *filter_dim,
542 context->Dim(shape, GetTensorInnerFeatureDimIndex(rank, format)),
543 filter_dim));
544 }
545 return OkStatus();
546}
547
548// vect_size must be provided if format is NCHW_VECT_C.
549Status ShapeFromDimensions(DimensionHandle batch_dim,
550 gtl::ArraySlice<DimensionHandle> spatial_dims,
551 DimensionHandle filter_dim, TensorFormat format,
552 absl::optional<DimensionHandle> vect_size,
553 InferenceContext* context, ShapeHandle* shape) {
554 const int32_t rank =
555 GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
556 std::vector<DimensionHandle> out_dims(rank);
557
558 // Batch.
559 out_dims[tensorflow::GetTensorBatchDimIndex(rank, format)] = batch_dim;
560 // Spatial.
561 for (int spatial_dim_index = 0, end = spatial_dims.size();
562 spatial_dim_index < end; ++spatial_dim_index) {
563 out_dims[tensorflow::GetTensorSpatialDimIndex(
564 rank, format, spatial_dim_index)] = spatial_dims[spatial_dim_index];
565 }
566 // Channel.
567 if (format == tensorflow::FORMAT_NCHW_VECT_C) {
568 // When format is NCHW_VECT_C, factor the feature map count into the outer
569 // feature count and the inner feature count (4 or 32).
570 CHECK(vect_size.has_value()); // Crash ok.
571 TF_RETURN_IF_ERROR(context->Divide(
572 filter_dim, *vect_size, /*evenly_divisible=*/true,
573 &out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)]));
574 out_dims[GetTensorInnerFeatureDimIndex(rank, format)] = *vect_size;
575 } else {
576 out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)] = filter_dim;
577 }
578
579 *shape = context->MakeShape(out_dims);
580 return OkStatus();
581}
582
583namespace {
584
585Status Conv2DShapeImpl(shape_inference::InferenceContext* c,
586 bool supports_explicit_padding) {
587 string data_format_str, filter_format_str;
588 if (!c->GetAttr("data_format", &data_format_str).ok()) {
589 data_format_str = "NHWC";
590 }
591 if (!c->GetAttr("filter_format", &filter_format_str).ok()) {
592 filter_format_str =
593 data_format_str == "NCHW_VECT_C" ? "OIHW_VECT_I" : "HWIO";
594 }
595
596 TensorFormat data_format;
597 if (!FormatFromString(data_format_str, &data_format)) {
598 return errors::InvalidArgument("Invalid data format string: ",
599 data_format_str);
600 }
601 FilterTensorFormat filter_format;
602 if (!FilterFormatFromString(filter_format_str, &filter_format)) {
603 return errors::InvalidArgument("Invalid filter format string: ",
604 filter_format_str);
605 }
606
607 constexpr int num_spatial_dims = 2;
608 const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
609 ShapeHandle conv_input_shape;
610 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &conv_input_shape));
611 TF_RETURN_IF_ERROR(CheckFormatConstraintsOnShape(
612 data_format, conv_input_shape, "conv_input", c));
613
614 // The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C).
615 ShapeHandle filter_shape;
616 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape));
617 TF_RETURN_IF_ERROR(
618 CheckFormatConstraintsOnShape(data_format, filter_shape, "filter", c));
619
620 std::vector<int32> dilations;
621 TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations));
622
623 if (dilations.size() != 4) {
624 return errors::InvalidArgument(
625 "Conv2D requires the dilation attribute to contain 4 values, but got: ",
626 dilations.size());
627 }
628
629 std::vector<int32> strides;
630 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
631
632 // strides.size() should be 4 (NCHW) even if the input is 5 (NCHW_VECT_C).
633 if (strides.size() != 4) {
634 return errors::InvalidArgument("Conv2D on data format ", data_format_str,
635 " requires the stride attribute to contain"
636 " 4 values, but got: ",
637 strides.size());
638 }
639
640 const int32_t stride_rows = GetTensorDim(strides, data_format, 'H');
641 const int32_t stride_cols = GetTensorDim(strides, data_format, 'W');
642 const int32_t dilation_rows = GetTensorDim(dilations, data_format, 'H');
643 const int32_t dilation_cols = GetTensorDim(dilations, data_format, 'W');
644
645 DimensionHandle batch_size_dim;
646 DimensionHandle input_depth_dim;
647 gtl::InlinedVector<DimensionHandle, 2> input_spatial_dims(2);
648 TF_RETURN_IF_ERROR(DimensionsFromShape(
649 conv_input_shape, data_format, &batch_size_dim,
650 absl::MakeSpan(input_spatial_dims), &input_depth_dim, c));
651
652 DimensionHandle output_depth_dim = c->Dim(
653 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
654 DimensionHandle filter_rows_dim = c->Dim(
655 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'H'));
656 DimensionHandle filter_cols_dim = c->Dim(
657 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'W'));
658 DimensionHandle filter_input_depth_dim;
659 if (filter_format == FORMAT_OIHW_VECT_I) {
660 TF_RETURN_IF_ERROR(c->Multiply(
661 c->Dim(filter_shape,
662 GetFilterDimIndex<num_spatial_dims>(filter_format, 'I')),
663 c->Dim(filter_shape,
664 GetFilterTensorInnerInputChannelsDimIndex(rank, filter_format)),
665 &filter_input_depth_dim));
666 } else {
667 filter_input_depth_dim = c->Dim(
668 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'I'));
669 }
670
671 // Check that the input tensor and the filter tensor agree on the channel
672 // count.
673 if (c->ValueKnown(input_depth_dim) && c->ValueKnown(filter_input_depth_dim)) {
674 int64_t input_depth_value = c->Value(input_depth_dim),
675 filter_input_depth_value = c->Value(filter_input_depth_dim);
676 if (filter_input_depth_value == 0)
677 return errors::InvalidArgument("Depth of filter must not be 0");
678 if (input_depth_value % filter_input_depth_value != 0)
679 return errors::InvalidArgument(
680 "Depth of input (", input_depth_value,
681 ") is not a multiple of input depth of filter (",
682 filter_input_depth_value, ")");
683 if (input_depth_value != filter_input_depth_value) {
684 int64_t num_groups = input_depth_value / filter_input_depth_value;
685 if (c->ValueKnown(output_depth_dim)) {
686 int64_t output_depth_value = c->Value(output_depth_dim);
687 if (num_groups == 0)
688 return errors::InvalidArgument("Number of groups must not be 0");
689 if (output_depth_value % num_groups != 0)
690 return errors::InvalidArgument(
691 "Depth of output (", output_depth_value,
692 ") is not a multiple of the number of groups (", num_groups, ")");
693 }
694 }
695 }
696
697 Padding padding;
698 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
699 std::vector<int64_t> explicit_paddings;
700 if (supports_explicit_padding) {
701 Status s = c->GetAttr("explicit_paddings", &explicit_paddings);
702 // Use the default value, which is an empty list, if the attribute is not
703 // found. Otherwise return the error to the caller.
704 if (!s.ok() && !errors::IsNotFound(s)) {
705 return s;
706 }
707 TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
708 /*num_dims=*/4, data_format));
709 } else {
710 if (padding == Padding::EXPLICIT) {
711 return errors::InvalidArgument(
712 "Expected non-explicit padding but got explicit padding");
713 }
714 std::vector<int64_t> p_list;
715 // `padding_list` attribute is used by Fused int8 convolutions to support
716 // explicit paddings.
717 Status s_p_list = c->GetAttr("padding_list", &p_list);
718 if (!s_p_list.ok() && !errors::IsNotFound(s_p_list)) {
719 return s_p_list;
720 }
721 if (s_p_list.ok() && !p_list.empty()) {
722 padding = Padding::EXPLICIT;
723 explicit_paddings = p_list;
724 TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
725 /*num_dims=*/4, data_format));
726 }
727 }
728
729 DimensionHandle output_rows, output_cols;
730 int64_t pad_rows_before = -1, pad_rows_after = -1;
731 int64_t pad_cols_before = -1, pad_cols_after = -1;
732 if (padding == Padding::EXPLICIT) {
733 GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
734 &pad_rows_before, &pad_rows_after);
735 GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
736 &pad_cols_before, &pad_cols_after);
737 }
738 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
739 c, input_spatial_dims[0], filter_rows_dim, dilation_rows, stride_rows,
740 padding, pad_rows_before, pad_rows_after, &output_rows));
741 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
742 c, input_spatial_dims[1], filter_cols_dim, dilation_cols, stride_cols,
743 padding, pad_cols_before, pad_cols_after, &output_cols));
744
745 absl::optional<DimensionHandle> vect_size;
746 if (data_format == FORMAT_NCHW_VECT_C) {
747 vect_size.emplace(c->Dim(conv_input_shape,
748 GetTensorInnerFeatureDimIndex(rank, data_format)));
749 }
750 ShapeHandle output_shape;
751 TF_RETURN_IF_ERROR(ShapeFromDimensions(
752 batch_size_dim, {output_rows, output_cols}, output_depth_dim, data_format,
753 vect_size, c, &output_shape));
754 c->set_output(0, output_shape);
755 return OkStatus();
756}
757
758} // namespace
759
760// Shape function for Conv2D-like operations that support explicit padding.
761Status Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext* c) {
762 return Conv2DShapeImpl(c, true);
763}
764
765// Shape function for Conv2D-like operations that do not support explicit
766// padding.
767Status Conv2DShape(shape_inference::InferenceContext* c) {
768 return Conv2DShapeImpl(c, false);
769}
770
771// TODO(mjanusz): Unify all conv/pooling shape functions.
772Status Conv3DShape(shape_inference::InferenceContext* c) {
773 ShapeHandle input_shape;
774 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
775 ShapeHandle filter_shape;
776 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &filter_shape));
777
778 string data_format;
779 Status s = c->GetAttr("data_format", &data_format);
780
781 std::vector<int32> dilations;
782 TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations));
783
784 if (dilations.size() != 5) {
785 return errors::InvalidArgument(
786 "Conv3D requires the dilation attribute to contain 5 values, but got: ",
787 dilations.size());
788 }
789
790 std::vector<int32> strides;
791 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
792 if (strides.size() != 5) {
793 return errors::InvalidArgument(
794 "Conv3D requires the stride attribute to contain 5 values, but got: ",
795 strides.size());
796 }
797
798 int32_t stride_planes, stride_rows, stride_cols;
799 int32_t dilation_planes, dilation_rows, dilation_cols;
800 if (s.ok() && data_format == "NCDHW") {
801 // Convert input_shape to NDHWC.
802 auto dim = [&](char dimension) {
803 return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
804 };
805 input_shape =
806 c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
807 stride_planes = strides[2];
808 stride_rows = strides[3];
809 stride_cols = strides[4];
810 dilation_planes = dilations[2];
811 dilation_cols = dilations[3];
812 dilation_rows = dilations[4];
813 } else {
814 stride_planes = strides[1];
815 stride_rows = strides[2];
816 stride_cols = strides[3];
817 dilation_planes = dilations[1];
818 dilation_cols = dilations[2];
819 dilation_rows = dilations[3];
820 }
821
822 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
823 DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
824 DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
825 DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
826 DimensionHandle input_depth_dim = c->Dim(input_shape, 4);
827
828 DimensionHandle filter_planes_dim = c->Dim(filter_shape, 0);
829 DimensionHandle filter_rows_dim = c->Dim(filter_shape, 1);
830 DimensionHandle filter_cols_dim = c->Dim(filter_shape, 2);
831 DimensionHandle filter_input_depth_dim = c->Dim(filter_shape, 3);
832 DimensionHandle output_depth_dim = c->Dim(filter_shape, 4);
833
834 // Check that the input tensor and the filter tensor agree on the channel
835 // count.
836 if (c->ValueKnown(input_depth_dim) && c->ValueKnown(filter_input_depth_dim)) {
837 int64_t input_depth_value = c->Value(input_depth_dim),
838 filter_input_depth_value = c->Value(filter_input_depth_dim);
839 if (filter_input_depth_value == 0)
840 return errors::InvalidArgument("Depth of filter must not be 0");
841 if (input_depth_value % filter_input_depth_value != 0)
842 return errors::InvalidArgument(
843 "Depth of input (", input_depth_value,
844 ") is not a multiple of input depth of filter (",
845 filter_input_depth_value, ")");
846 if (input_depth_value != filter_input_depth_value) {
847 int64_t num_groups = input_depth_value / filter_input_depth_value;
848 if (c->ValueKnown(output_depth_dim)) {
849 int64_t output_depth_value = c->Value(output_depth_dim);
850 if (num_groups == 0)
851 return errors::InvalidArgument("Number of groups must not be 0");
852 if (output_depth_value % num_groups != 0)
853 return errors::InvalidArgument(
854 "Depth of output (", output_depth_value,
855 ") is not a multiple of the number of groups (", num_groups, ")");
856 }
857 }
858 }
859
860 Padding padding;
861 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
862 DimensionHandle output_planes, output_rows, output_cols;
863
864 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
865 c, in_planes_dim, filter_planes_dim, dilation_planes, stride_planes,
866 padding, -1, -1, &output_planes));
867 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
868 c, in_rows_dim, filter_rows_dim, dilation_rows, stride_rows, padding, -1,
869 -1, &output_rows));
870 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
871 c, in_cols_dim, filter_cols_dim, dilation_cols, stride_cols, padding, -1,
872 -1, &output_cols));
873
874 ShapeHandle output_shape;
875 if (data_format == "NCDHW") {
876 output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
877 output_planes, output_rows, output_cols});
878 } else {
879 output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
880 output_cols, output_depth_dim});
881 }
882 c->set_output(0, output_shape);
883 return OkStatus();
884}
885
886Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c) {
887 string data_format_str;
888 if (!c->GetAttr("data_format", &data_format_str).ok()) {
889 data_format_str = "NHWC";
890 }
891 TensorFormat data_format;
892 if (!FormatFromString(data_format_str, &data_format)) {
893 return errors::InvalidArgument("Invalid data format string: ",
894 data_format_str);
895 }
896
897 // For the rest of this function, output_grad_* describes out_backprop and
898 // input_grad_* describes in_backprop.
899 ShapeHandle output_grad_shape = c->input(2);
900 TF_RETURN_IF_ERROR(c->WithRank(output_grad_shape, 4, &output_grad_shape));
901 ShapeHandle filter_shape = c->input(1);
902 TF_RETURN_IF_ERROR(c->WithRank(filter_shape, 4, &filter_shape));
903
904 DimensionHandle batch_size_dim;
905 DimensionHandle output_grad_depth_dim;
906 gtl::InlinedVector<DimensionHandle, 2> output_grad_spatial_dims(2);
907 TF_RETURN_IF_ERROR(DimensionsFromShape(
908 output_grad_shape, data_format, &batch_size_dim,
909 absl::MakeSpan(output_grad_spatial_dims), &output_grad_depth_dim, c));
910 DimensionHandle unused;
911 TF_RETURN_IF_ERROR(
912 c->Merge(output_grad_depth_dim, c->Dim(filter_shape, 3), &unused));
913
914 ShapeHandle specified_input_grad_shape;
915 TF_RETURN_IF_ERROR(
916 c->MakeShapeFromShapeTensor(0, &specified_input_grad_shape));
917 if (c->Rank(specified_input_grad_shape) == InferenceContext::kUnknownRank) {
918 TF_RETURN_IF_ERROR(c->WithRank(specified_input_grad_shape, 4,
919 &specified_input_grad_shape));
920 }
921
922 // input_grad_depth_dim doesn't equal c->Dim(filter_shape,2) when the number
923 // of groups is larger than 1. If input_sizes is a 4D shape, we collect
924 // input_grad_depth_dim from input_sizes; otherwise we compute it as
925 // c->Dim(filter_shape,2).
926 DimensionHandle input_grad_depth_dim;
927 gtl::InlinedVector<DimensionHandle, 2> specified_input_grad_spatial_dims(2);
928 int specified_input_grad_rank = c->Rank(specified_input_grad_shape);
929 if (specified_input_grad_rank == 4) {
930 DimensionHandle specified_batch_size_dim;
931 TF_RETURN_IF_ERROR(DimensionsFromShape(
932 specified_input_grad_shape, data_format, &specified_batch_size_dim,
933 absl::MakeSpan(specified_input_grad_spatial_dims),
934 &input_grad_depth_dim, c));
935 TF_RETURN_IF_ERROR(
936 c->Merge(specified_batch_size_dim, batch_size_dim, &unused));
937 } else if (specified_input_grad_rank == 2) {
938 specified_input_grad_spatial_dims[0] =
939 c->Dim(specified_input_grad_shape, 0);
940 specified_input_grad_spatial_dims[1] =
941 c->Dim(specified_input_grad_shape, 1);
942 input_grad_depth_dim = c->Dim(filter_shape, 2);
943 } else {
944 return errors::InvalidArgument(
945 "Conv2DBackpropInput requires input_sizes to contain 4 values or 2 "
946 "values, but got: ",
947 specified_input_grad_rank);
948 }
949
950 ShapeHandle input_grad_shape;
951 TF_RETURN_IF_ERROR(ShapeFromDimensions(
952 batch_size_dim, specified_input_grad_spatial_dims, input_grad_depth_dim,
953 data_format, /*vect_size=*/absl::nullopt, c, &input_grad_shape));
954 c->set_output(0, input_grad_shape);
955 return OkStatus();
956}
957
958Status Conv2DBackpropFilterWithBiasShape(shape_inference::InferenceContext* c) {
959 ShapeHandle input_shape;
960 // Fetch the data_format attribute, which may not exist.
961 string data_format;
962 Status s = c->GetAttr("data_format", &data_format);
963
964 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
965 if (s.ok() && data_format == "NCHW") {
966 c->set_output(1, c->Vector(c->Dim(input_shape, -3)));
967 } else {
968 c->set_output(1, c->Vector(c->Dim(input_shape, -1)));
969 }
970 ShapeHandle sh;
971 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &sh));
972 TF_RETURN_IF_ERROR(c->WithRank(sh, 4, &sh));
973 c->set_output(0, sh);
974 return OkStatus();
975}
976
977namespace {
978
979Status DepthwiseConv2DNativeShapeImpl(shape_inference::InferenceContext* c,
980 bool supports_explicit_padding) {
981 ShapeHandle input_shape;
982 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
983 ShapeHandle filter_shape;
984 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape));
985
986 std::vector<int32> strides;
987 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
988
989 if (strides.size() != 4) {
990 return errors::InvalidArgument(
991 "DepthwiseConv2D requires the stride attribute to contain 4 values, "
992 "but got: ",
993 strides.size());
994 }
995
996 std::vector<int32> dilations;
997 if (!c->GetAttr("dilations", &dilations).ok()) {
998 dilations.resize(4, 1);
999 }
1000
1001 if (dilations.size() != 4) {
1002 return errors::InvalidArgument(
1003 "DepthwiseConv2D requires the dilations attribute to contain 4 values, "
1004 "but got: ",
1005 dilations.size());
1006 }
1007
1008 string data_format_str;
1009 Status s = c->GetAttr("data_format", &data_format_str);
1010 TensorFormat data_format;
1011 if (!s.ok() || !FormatFromString(data_format_str, &data_format)) {
1012 data_format = FORMAT_NHWC;
1013 }
1014 int32_t stride_rows;
1015 int32_t stride_cols;
1016 int32_t dilation_rows;
1017 int32_t dilation_cols;
1018 if (data_format == FORMAT_NCHW) {
1019 // Canonicalize input shape to NHWC so the shape inference code below can
1020 // process it.
1021 input_shape =
1022 c->MakeShape({{c->Dim(input_shape, 0), c->Dim(input_shape, 2),
1023 c->Dim(input_shape, 3), c->Dim(input_shape, 1)}});
1024 stride_rows = strides[2];
1025 stride_cols = strides[3];
1026 dilation_rows = dilations[2];
1027 dilation_cols = dilations[3];
1028 } else {
1029 stride_rows = strides[1];
1030 stride_cols = strides[2];
1031 dilation_rows = dilations[1];
1032 dilation_cols = dilations[2];
1033 }
1034
1035 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
1036 DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
1037 DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
1038
1039 DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
1040 DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
1041 DimensionHandle input_depth = c->Dim(filter_shape, 2);
1042 DimensionHandle depth_multiplier = c->Dim(filter_shape, 3);
1043
1044 // Check that the input depths are compatible.
1045 TF_RETURN_IF_ERROR(
1046 c->Merge(c->Dim(input_shape, 3), input_depth, &input_depth));
1047
1048 DimensionHandle output_depth;
1049 TF_RETURN_IF_ERROR(c->Multiply(input_depth, depth_multiplier, &output_depth));
1050
1051 Padding padding;
1052 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1053
1054 std::vector<int64_t> explicit_paddings;
1055 if (supports_explicit_padding) {
1056 Status status = c->GetAttr("explicit_paddings", &explicit_paddings);
1057 // Use the default value, which is an empty list, if the attribute is not
1058 // found. Otherwise return the error to the caller.
1059 if (!status.ok() && !errors::IsNotFound(status)) {
1060 return status;
1061 }
1062 TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
1063 /*num_dims=*/4, data_format));
1064 } else {
1065 DCHECK(padding != Padding::EXPLICIT);
1066 }
1067
1068 // TODO(mrry,shlens): Raise an error if the stride would cause
1069 // information in the input to be ignored. This will require a change
1070 // in the kernel implementation.
1071 DimensionHandle output_rows, output_cols;
1072 int64_t pad_rows_before = -1, pad_rows_after = -1;
1073 int64_t pad_cols_before = -1, pad_cols_after = -1;
1074 if (padding == Padding::EXPLICIT) {
1075 GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
1076 &pad_rows_before, &pad_rows_after);
1077 GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
1078 &pad_cols_before, &pad_cols_after);
1079 }
1080 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1081 c, in_rows_dim, filter_rows_dim, dilation_rows, stride_rows, padding,
1082 pad_rows_before, pad_rows_after, &output_rows));
1083 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1084 c, in_cols_dim, filter_cols_dim, dilation_cols, stride_cols, padding,
1085 pad_cols_before, pad_cols_after, &output_cols));
1086
1087 ShapeHandle output_shape;
1088 if (data_format == FORMAT_NCHW) {
1089 output_shape =
1090 c->MakeShape({batch_size_dim, output_depth, output_rows, output_cols});
1091 } else {
1092 output_shape =
1093 c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth});
1094 }
1095 c->set_output(0, output_shape);
1096 return OkStatus();
1097}
1098
1099}; // namespace
1100
1101Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
1102 return DepthwiseConv2DNativeShapeImpl(c, false);
1103}
1104
1105Status DepthwiseConv2DNativeShapeWithExplicitPadding(
1106 shape_inference::InferenceContext* c) {
1107 return DepthwiseConv2DNativeShapeImpl(c, true);
1108}
1109
1110Status AvgPoolShape(shape_inference::InferenceContext* c) {
1111 string data_format_str;
1112 TensorFormat data_format;
1113 Status s = c->GetAttr("data_format", &data_format_str);
1114 if (s.ok()) {
1115 FormatFromString(data_format_str, &data_format);
1116 } else {
1117 data_format = FORMAT_NHWC;
1118 }
1119
1120 const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
1121 ShapeHandle input_shape;
1122 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
1123
1124 TF_RETURN_IF_ERROR(
1125 CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
1126
1127 std::vector<int32> strides;
1128 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1129 if (strides.size() != 4) {
1130 return errors::InvalidArgument(
1131 "AvgPool requires the stride attribute to contain 4 values, but got: ",
1132 strides.size());
1133 }
1134
1135 std::vector<int32> kernel_sizes;
1136 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1137 if (kernel_sizes.size() != 4) {
1138 return errors::InvalidArgument(
1139 "AvgPool requires the ksize attribute to contain 4 values, but got: ",
1140 kernel_sizes.size());
1141 }
1142
1143 int32_t stride_rows = GetTensorDim(strides, data_format, 'H');
1144 int32_t stride_cols = GetTensorDim(strides, data_format, 'W');
1145 int32_t kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
1146 int32_t kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
1147
1148 constexpr int num_spatial_dims = 2;
1149 DimensionHandle batch_size_dim = c->Dim(
1150 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
1151 DimensionHandle in_rows_dim = c->Dim(
1152 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
1153 DimensionHandle in_cols_dim = c->Dim(
1154 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
1155 DimensionHandle depth_dim = c->Dim(
1156 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
1157
1158 Padding padding;
1159 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1160
1161 // TODO(mrry,shlens): Raise an error if the stride would cause
1162 // information in the input to be ignored. This will require a change
1163 // in the kernel implementation.
1164
1165 DimensionHandle output_rows, output_cols;
1166 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1167 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1168 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1169 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1170
1171 ShapeHandle output_shape;
1172 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
1173 {output_rows, output_cols}, depth_dim,
1174 &output_shape, c));
1175 c->set_output(0, output_shape);
1176 return OkStatus();
1177}
1178
1179Status AvgPoolGradShape(shape_inference::InferenceContext* c) {
1180 ShapeHandle s;
1181 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
1182 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
1183 c->set_output(0, s);
1184 return OkStatus();
1185}
1186
1187Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
1188 string data_format_str;
1189 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
1190 TensorFormat data_format;
1191 if (!FormatFromString(data_format_str, &data_format)) {
1192 return errors::InvalidArgument("Invalid data format string: ",
1193 data_format_str);
1194 }
1195 const int rank =
1196 (data_format_str == "NDHWC" || data_format_str == "NCDHW") ? 5 : 4;
1197 ShapeHandle x;
1198 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &x));
1199
1200 bool is_training;
1201 TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
1202 float exponential_avg_factor;
1203 if (!c->GetAttr("exponential_avg_factor", &exponential_avg_factor).ok()) {
1204 exponential_avg_factor = 1.0f; // default value
1205 }
1206 int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5;
1207
1208 int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
1209 DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
1210
1211 // covers scale, offset, and if is_training is false, mean, variance
1212 for (int i = 1; i < number_inputs; ++i) {
1213 ShapeHandle vec;
1214 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
1215 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
1216 }
1217
1218 ShapeHandle y;
1219 TF_RETURN_IF_ERROR(c->ReplaceDim(x, channel_dim_index, channel_dim, &y));
1220 c->set_output(0, y);
1221 ShapeHandle vector_shape = c->Vector(channel_dim);
1222 c->set_output(1, vector_shape);
1223 c->set_output(2, vector_shape);
1224 c->set_output(3, vector_shape);
1225 c->set_output(4, vector_shape);
1226 return OkStatus();
1227}
1228
1229Status FusedBatchNormV3Shape(shape_inference::InferenceContext* c) {
1230 TF_RETURN_IF_ERROR(FusedBatchNormShape(c));
1231 c->set_output(5, c->UnknownShape());
1232 return OkStatus();
1233}
1234
1235Status FusedBatchNormExShape(shape_inference::InferenceContext* c) {
1236 TF_RETURN_IF_ERROR(FusedBatchNormV3Shape(c));
1237
1238 string data_format_str;
1239 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
1240 TensorFormat data_format;
1241 if (!FormatFromString(data_format_str, &data_format)) {
1242 return errors::InvalidArgument("Invalid data format string: ",
1243 data_format_str);
1244 }
1245 ShapeHandle x;
1246 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
1247
1248 int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
1249 DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
1250
1251 // This is a cuDNN implementation constraint.
1252 if (c->ValueKnown(channel_dim) && c->Value(channel_dim) % 4 != 0) {
1253 return errors::InvalidArgument(
1254 "_FusedBatchNormEx channel dimension must be divisible by 4.");
1255 }
1256
1257 return OkStatus();
1258}
1259
1260Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
1261 string data_format_str;
1262 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
1263 TensorFormat data_format;
1264 if (!FormatFromString(data_format_str, &data_format)) {
1265 return errors::InvalidArgument("Invalid data format string: ",
1266 data_format_str);
1267 }
1268 const int rank =
1269 (data_format_str == "NDHWC" || data_format_str == "NCDHW") ? 5 : 4;
1270 ShapeHandle y_backprop;
1271 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &y_backprop));
1272 ShapeHandle x;
1273 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &x));
1274
1275 bool is_training;
1276 TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
1277
1278 int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
1279 DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index);
1280 TF_RETURN_IF_ERROR(
1281 c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim));
1282
1283 // covers scale, mean (reserve_space_1), variance (reserve_space_2)
1284 for (int i = 2; i < 5; ++i) {
1285 ShapeHandle vec;
1286 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
1287 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
1288 }
1289
1290 ShapeHandle x_backprop;
1291 TF_RETURN_IF_ERROR(
1292 c->ReplaceDim(y_backprop, channel_dim_index, channel_dim, &x_backprop));
1293 c->set_output(0, x_backprop);
1294 c->set_output(1, c->Vector(channel_dim));
1295 c->set_output(2, c->Vector(channel_dim));
1296 c->set_output(3, c->Vector(0));
1297 c->set_output(4, c->Vector(0));
1298 return OkStatus();
1299}
1300
1301Status FusedBatchNormGradExShape(shape_inference::InferenceContext* c) {
1302 TF_RETURN_IF_ERROR(FusedBatchNormGradShape(c));
1303
1304 int num_side_inputs;
1305 TF_RETURN_IF_ERROR(c->GetAttr("num_side_inputs", &num_side_inputs));
1306 if (num_side_inputs == 0) {
1307 return OkStatus();
1308 }
1309
1310 string data_format_str;
1311 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
1312 TensorFormat data_format;
1313 if (!FormatFromString(data_format_str, &data_format)) {
1314 return errors::InvalidArgument("Invalid data format string: ",
1315 data_format_str);
1316 }
1317 const int rank =
1318 (data_format_str == "NDHWC" || data_format_str == "NCDHW") ? 5 : 4;
1319 ShapeHandle y_backprop;
1320 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &y_backprop));
1321 ShapeHandle x;
1322 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &x));
1323
1324 int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
1325 DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index);
1326 TF_RETURN_IF_ERROR(
1327 c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim));
1328
1329 ShapeHandle side_input_backprop;
1330 TF_RETURN_IF_ERROR(c->ReplaceDim(y_backprop, channel_dim_index, channel_dim,
1331 &side_input_backprop));
1332
1333 c->set_output(5, side_input_backprop);
1334 return OkStatus();
1335}
1336
1337Status ReadDiagIndex(InferenceContext* c, const Tensor* diag_index_tensor,
1338 int32* lower_diag_index, int32* upper_diag_index) {
1339 // This function assumes that the shape of diag_index_tensor is fully defined.
1340 if (diag_index_tensor->dims() == 0) {
1341 *lower_diag_index = diag_index_tensor->scalar<int32>()();
1342 *upper_diag_index = *lower_diag_index;
1343 } else {
1344 int32_t num_elements = diag_index_tensor->dim_size(0);
1345 if (num_elements == 1) {
1346 *lower_diag_index = diag_index_tensor->vec<int32>()(0);
1347 *upper_diag_index = *lower_diag_index;
1348 } else if (num_elements == 2) {
1349 *lower_diag_index = diag_index_tensor->vec<int32>()(0);
1350 *upper_diag_index = diag_index_tensor->vec<int32>()(1);
1351 } else {
1352 return errors::InvalidArgument(
1353 "diag_index must be a vector with one or two elements. It has ",
1354 num_elements, " elements.");
1355 }
1356 }
1357 return OkStatus();
1358}
1359
1360Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c) {
1361 ShapeHandle input_shape, diag_index_shape, unused_shape;
1362 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
1363 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape));
1364 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape));
1365
1366 const Tensor* diag_index_tensor = c->input_tensor(1);
1367 if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) ||
1368 diag_index_tensor == nullptr) {
1369 c->set_output(0, c->UnknownShape());
1370 return OkStatus();
1371 }
1372 int32_t lower_diag_index = 0;
1373 int32_t upper_diag_index = 0;
1374 TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
1375 &upper_diag_index));
1376 if (lower_diag_index > upper_diag_index) {
1377 return errors::InvalidArgument(
1378 "lower_diag_index is greater than upper_diag_index");
1379 }
1380
1381 // Validates lower_diag_index and upper_diag_index.
1382 const int32_t input_rank = c->Rank(input_shape);
1383 const int32_t num_rows = c->Value(c->Dim(input_shape, input_rank - 2));
1384 const int32_t num_cols = c->Value(c->Dim(input_shape, input_rank - 1));
1385 int32_t max_diag_len = InferenceContext::kUnknownDim;
1386 if (num_rows != InferenceContext::kUnknownDim &&
1387 num_cols != InferenceContext::kUnknownDim) {
1388 if (lower_diag_index != 0 && // For when num_rows or num_cols == 0.
1389 (-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) {
1390 return errors::InvalidArgument("lower_diag_index is out of bound.");
1391 }
1392 if (upper_diag_index != 0 && // For when num_rows or num_cols == 0.
1393 (-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) {
1394 return errors::InvalidArgument("upper_diag_index is out of bound.");
1395 }
1396 max_diag_len = std::min(num_rows + std::min(upper_diag_index, 0),
1397 num_cols - std::max(lower_diag_index, 0));
1398 }
1399
1400 std::vector<DimensionHandle> dims;
1401 dims.reserve(input_rank - 2);
1402 for (int i = 0; i < input_rank - 2; ++i) {
1403 dims.push_back(c->Dim(input_shape, i));
1404 }
1405 if (lower_diag_index < upper_diag_index) {
1406 dims.push_back(c->MakeDim(upper_diag_index - lower_diag_index + 1));
1407 }
1408 dims.push_back(c->MakeDim(max_diag_len));
1409 c->set_output(0, c->MakeShape(dims));
1410 return OkStatus();
1411}
1412
1413Status MatrixDiagV2Shape(shape_inference::InferenceContext* c) {
1414 // Checks input ranks.
1415 ShapeHandle input_shape, diag_index_shape, unused_shape;
1416 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input_shape));
1417 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape));
1418 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape));
1419 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
1420 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
1421
1422 // Reads the diagonal indices.
1423 const Tensor* diag_index_tensor = c->input_tensor(1);
1424 if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) ||
1425 diag_index_tensor == nullptr) {
1426 c->set_output(0, c->UnknownShape());
1427 return OkStatus();
1428 }
1429 int32_t lower_diag_index = 0;
1430 int32_t upper_diag_index = 0;
1431 TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
1432 &upper_diag_index));
1433 if (lower_diag_index > upper_diag_index) {
1434 return errors::InvalidArgument(
1435 "lower_diag_index is greater than upper_diag_index");
1436 }
1437
1438 // Checks if the number of diagonals provided matches what we imply from
1439 // lower_diag_index and upper_diag_index.
1440 const int32_t input_rank = c->Rank(input_shape);
1441 if (lower_diag_index < upper_diag_index) {
1442 const int32_t num_diags = c->Value(c->Dim(input_shape, input_rank - 2));
1443 const int32_t other_dim = c->Value(c->Dim(input_shape, input_rank - 1));
1444
1445 if (num_diags != (upper_diag_index - lower_diag_index + 1)) {
1446 return errors::InvalidArgument(
1447 "The number of rows of `diagonal` doesn't match the number of "
1448 "diagonals implied from `d_lower` and `d_upper`.\n",
1449 "num_diags = ", num_diags, ", d_lower = ", lower_diag_index,
1450 ", d_upper = ", upper_diag_index, " ", input_rank, " ", other_dim);
1451 }
1452 }
1453
1454 // Reads num_rows and num_cols.
1455 const Tensor* num_rows_tensor = c->input_tensor(2);
1456 const Tensor* num_cols_tensor = c->input_tensor(3);
1457 int64_t num_rows = -1;
1458 int64_t num_cols = -1;
1459 if (num_rows_tensor != nullptr) {
1460 TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_rows_tensor, &num_rows));
1461 }
1462 if (num_cols_tensor != nullptr) {
1463 TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_cols_tensor, &num_cols));
1464 }
1465
1466 // Infers the missing num_rows or num_cols: If both are missing, assume
1467 // output is square. Otherwise, use the smallest possible value. Also
1468 // validates the provided values.
1469 const int32_t max_diag_len = c->Value(c->Dim(input_shape, input_rank - 1));
1470 const int32_t min_num_rows = max_diag_len - std::min(upper_diag_index, 0);
1471 const int32_t min_num_cols = max_diag_len + std::max(lower_diag_index, 0);
1472 if (num_rows == -1 && num_cols == -1) { // Special case.
1473 num_rows = std::max(min_num_rows, min_num_cols);
1474 num_cols = num_rows;
1475 }
1476 if (num_rows == -1) {
1477 num_rows = min_num_rows;
1478 } else if (num_rows < min_num_rows) {
1479 return errors::InvalidArgument("num_rows is too small");
1480 }
1481 if (num_cols == -1) {
1482 num_cols = min_num_cols;
1483 } else if (num_cols < min_num_cols) {
1484 return errors::InvalidArgument("num_cols is too small.");
1485 }
1486 // At least one of them must match the minimum length.
1487 if (num_rows != min_num_rows && num_cols != min_num_cols) {
1488 return errors::InvalidArgument(
1489 "num_rows and num_cols are not consistent with lower_diag_index, "
1490 "upper_diag_index, and the length of the given diagonals.\n",
1491 "num_rows = ", num_rows, " != min_num_rows = ", min_num_rows,
1492 ", num_cols = ", num_cols, " != min_num_cols = ", min_num_cols);
1493 }
1494
1495 // Sets output shape.
1496 ShapeHandle output_shape;
1497 const DimensionHandle output_row_dim = c->MakeDim(num_rows);
1498 const DimensionHandle output_col_dim = c->MakeDim(num_cols);
1499 if (lower_diag_index == upper_diag_index) {
1500 TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 1,
1501 output_row_dim, &output_shape));
1502 TF_RETURN_IF_ERROR(
1503 c->Concatenate(output_shape, c->Vector(output_col_dim), &output_shape));
1504 } else {
1505 TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 2,
1506 output_row_dim, &output_shape));
1507 TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape, input_rank - 1,
1508 output_col_dim, &output_shape));
1509 }
1510 c->set_output(0, output_shape);
1511 return OkStatus();
1512}
1513
1514Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c) {
1515 ShapeHandle input_shape, diag_shape, diag_index_shape;
1516 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
1517 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag_shape));
1518 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &diag_index_shape));
1519
1520 int32_t lower_diag_index = 0;
1521 int32_t upper_diag_index = 0;
1522 bool diag_index_known = false;
1523 const Tensor* diag_index_tensor = c->input_tensor(2);
1524 if (diag_index_tensor != nullptr && c->FullyDefined(diag_index_shape)) {
1525 diag_index_known = true;
1526 TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
1527 &upper_diag_index));
1528 if (lower_diag_index > upper_diag_index) {
1529 return errors::InvalidArgument(
1530 "lower_diag_index is greater than upper_diag_index");
1531 }
1532 }
1533
1534 // Do more checks when input rank is known.
1535 if (c->RankKnown(input_shape)) {
1536 int32_t input_rank = c->Rank(input_shape);
1537
1538 // If diag_index is set, we know the exact rank of diagonal.
1539 if (diag_index_known) {
1540 TF_RETURN_IF_ERROR(c->WithRank(
1541 c->input(1),
1542 (lower_diag_index == upper_diag_index) ? input_rank - 1 : input_rank,
1543 &diag_shape));
1544 } else {
1545 TF_RETURN_IF_ERROR(
1546 c->WithRankAtLeast(c->input(1), input_rank - 1, &diag_shape));
1547 TF_RETURN_IF_ERROR(
1548 c->WithRankAtMost(c->input(1), input_rank, &diag_shape));
1549 }
1550
1551 // Validates lower_diag_index and upper_diag_index.
1552 const int32_t num_rows = c->Value(c->Dim(input_shape, input_rank - 2));
1553 const int32_t num_cols = c->Value(c->Dim(input_shape, input_rank - 1));
1554 if (num_rows != InferenceContext::kUnknownDim &&
1555 num_cols != InferenceContext::kUnknownDim) {
1556 if (lower_diag_index != 0 && // For when num_rows or num_cols == 0.
1557 (-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) {
1558 return errors::InvalidArgument("lower_diag_index is out of bound.");
1559 }
1560 if (upper_diag_index != 0 && // For when num_rows or num_cols == 0.
1561 (-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) {
1562 return errors::InvalidArgument("upper_diag_index is out of bound.");
1563 }
1564 }
1565 }
1566
1567 ShapeHandle output_shape = input_shape;
1568 if (c->RankKnown(diag_shape) && !c->FullyDefined(input_shape)) {
1569 // Try to infer parts of shape from diag.
1570 ShapeHandle diag_prefix;
1571 TF_RETURN_IF_ERROR(c->Subshape(
1572 diag_shape, 0, (lower_diag_index == upper_diag_index) ? -1 : -2,
1573 &diag_prefix));
1574
1575 // The inner matrices can be rectangular, so we can't pinpoint their
1576 // exact height and width by just lower_diag_index, upper_diag_index,
1577 // and the longest length of given diagonals.
1578 TF_RETURN_IF_ERROR(
1579 c->Concatenate(diag_prefix, c->UnknownShapeOfRank(2), &diag_shape));
1580 TF_RETURN_IF_ERROR(c->Merge(input_shape, diag_shape, &output_shape));
1581 }
1582 c->set_output(0, output_shape);
1583 return OkStatus();
1584}
1585
1586Status MaxPoolShapeImpl(shape_inference::InferenceContext* c,
1587 bool supports_explicit_padding) {
1588 string data_format_str;
1589 TensorFormat data_format;
1590 Status s = c->GetAttr("data_format", &data_format_str);
1591 if (s.ok()) {
1592 FormatFromString(data_format_str, &data_format);
1593 } else {
1594 data_format = FORMAT_NHWC;
1595 }
1596
1597 const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
1598 ShapeHandle input_shape;
1599 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
1600
1601 TF_RETURN_IF_ERROR(
1602 CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
1603
1604 std::vector<int32> strides;
1605 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1606 if (strides.size() != 4) {
1607 return errors::InvalidArgument(
1608 "MaxPool requires the stride attribute to contain 4 values, but got: ",
1609 strides.size());
1610 }
1611
1612 std::vector<int32> kernel_sizes;
1613 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1614 if (kernel_sizes.size() != 4) {
1615 return errors::InvalidArgument(
1616 "MaxPool requires the ksize attribute to contain 4 values, but got: ",
1617 kernel_sizes.size());
1618 }
1619
1620 int32_t stride_depth = GetTensorDim(strides, data_format, 'C');
1621 int32_t stride_rows = GetTensorDim(strides, data_format, 'H');
1622 int32_t stride_cols = GetTensorDim(strides, data_format, 'W');
1623 int32_t kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
1624 int32_t kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
1625 int32_t kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
1626
1627 constexpr int num_spatial_dims = 2;
1628 DimensionHandle batch_size_dim = c->Dim(
1629 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
1630 DimensionHandle in_rows_dim = c->Dim(
1631 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
1632 DimensionHandle in_cols_dim = c->Dim(
1633 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
1634 DimensionHandle in_depth_dim = c->Dim(
1635 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
1636
1637 Padding padding;
1638 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1639
1640 std::vector<int64_t> explicit_paddings;
1641 if (supports_explicit_padding) {
1642 Status status = c->GetAttr("explicit_paddings", &explicit_paddings);
1643 // Use the default value, which is an empty list, if the attribute is not
1644 // found. Otherwise return the error to the caller.
1645 if (!status.ok() && !errors::IsNotFound(status)) {
1646 return status;
1647 }
1648 TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
1649 /*num_dims=*/4, data_format));
1650 } else {
1651 DCHECK(padding != Padding::EXPLICIT);
1652 }
1653
1654 ShapeHandle output_shape;
1655 DimensionHandle output_rows, output_cols, output_depth;
1656 int64_t pad_rows_before = -1, pad_rows_after = -1;
1657 int64_t pad_cols_before = -1, pad_cols_after = -1;
1658 if (padding == Padding::EXPLICIT) {
1659 GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
1660 &pad_rows_before, &pad_rows_after);
1661 GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
1662 &pad_cols_before, &pad_cols_after);
1663 }
1664 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1665 c, in_rows_dim, kernel_rows, /*dilation_rate=*/1, stride_rows, padding,
1666 pad_rows_before, pad_rows_after, &output_rows));
1667 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1668 c, in_cols_dim, kernel_cols, /*dilation_rate=*/1, stride_cols, padding,
1669 pad_cols_before, pad_cols_after, &output_cols));
1670 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1671 c, in_depth_dim, kernel_depth, /*dilation_rate=*/1, stride_depth, padding,
1672 /*pad_before*/ 0, /*pad_after*/ 0, &output_depth));
1673
1674 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
1675 {output_rows, output_cols},
1676 output_depth, &output_shape, c));
1677
1678 c->set_output(0, output_shape);
1679 return OkStatus();
1680}
1681
1682Status MaxPoolShape(shape_inference::InferenceContext* c) {
1683 return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/false);
1684}
1685
1686Status MaxPoolGradShape(shape_inference::InferenceContext* c) {
1687 return UnchangedShapeWithRank(c, 4);
1688}
1689
1690Status MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext* c) {
1691 return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/true);
1692}
1693
1694Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) {
1695 string data_format_str;
1696 TensorFormat data_format;
1697 Status s = c->GetAttr("data_format", &data_format_str);
1698 if (s.ok()) {
1699 FormatFromString(data_format_str, &data_format);
1700 } else {
1701 data_format = FORMAT_NHWC;
1702 }
1703
1704 const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
1705 ShapeHandle input_shape;
1706 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
1707
1708 TF_RETURN_IF_ERROR(
1709 CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
1710
1711 std::vector<int32> kernel_sizes;
1712 std::vector<int32> strides;
1713
1714 if (c->num_inputs() + 2 == num_inputs) {
1715 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1716
1717 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1718 } else {
1719 // Verify shape of ksize and strides input.
1720 ShapeHandle size;
1721 DimensionHandle unused;
1722 TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 1, &size));
1723 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
1724 TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &size));
1725 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
1726
1727 const Tensor* kernel_sizes_tensor = c->input_tensor(c->num_inputs() - 2);
1728 if (kernel_sizes_tensor == nullptr) {
1729 c->set_output(0, c->UnknownShape());
1730 return OkStatus();
1731 }
1732 kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements());
1733 auto kernel_sizes_vec = kernel_sizes_tensor->flat<int32>();
1734 std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(),
1735 kernel_sizes.begin());
1736
1737 const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1);
1738 if (strides_tensor == nullptr) {
1739 c->set_output(0, c->UnknownShape());
1740 return OkStatus();
1741 }
1742 strides.resize(strides_tensor->shape().num_elements());
1743 auto strides_vec = strides_tensor->flat<int32>();
1744 std::copy_n(&strides_vec(0), strides.size(), strides.begin());
1745 }
1746
1747 if (strides.size() != 4) {
1748 return errors::InvalidArgument(
1749 "MaxPool requires the stride attribute to contain 4 values, but "
1750 "got: ",
1751 strides.size());
1752 }
1753 if (kernel_sizes.size() != 4) {
1754 return errors::InvalidArgument(
1755 "MaxPool requires the ksize attribute to contain 4 values, but got: ",
1756 kernel_sizes.size());
1757 }
1758
1759 int32_t stride_depth = GetTensorDim(strides, data_format, 'C');
1760 int32_t stride_rows = GetTensorDim(strides, data_format, 'H');
1761 int32_t stride_cols = GetTensorDim(strides, data_format, 'W');
1762 int32_t kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
1763 int32_t kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
1764 int32_t kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
1765
1766 constexpr int num_spatial_dims = 2;
1767 DimensionHandle batch_size_dim = c->Dim(
1768 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
1769 DimensionHandle in_rows_dim = c->Dim(
1770 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
1771 DimensionHandle in_cols_dim = c->Dim(
1772 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
1773 DimensionHandle in_depth_dim = c->Dim(
1774 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
1775
1776 Padding padding;
1777 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1778
1779 ShapeHandle output_shape;
1780 DimensionHandle output_rows, output_cols, output_depth;
1781 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1782 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1783 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1784 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1785 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1786 c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth));
1787
1788 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
1789 {output_rows, output_cols},
1790 output_depth, &output_shape, c));
1791
1792 c->set_output(0, output_shape);
1793 return OkStatus();
1794}
1795
1796Status Pool3DShape(shape_inference::InferenceContext* c) {
1797 ShapeHandle input_shape;
1798 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
1799
1800 string data_format;
1801 Status s = c->GetAttr("data_format", &data_format);
1802
1803 std::vector<int32> strides;
1804 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1805 if (strides.size() != 5) {
1806 return errors::InvalidArgument(
1807 "Pool3D ops require the stride attribute to contain 5 values, but "
1808 "got: ",
1809 strides.size());
1810 }
1811
1812 std::vector<int32> kernel_sizes;
1813 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1814 if (kernel_sizes.size() != 5) {
1815 return errors::InvalidArgument(
1816 "Pool3D requires the ksize attribute to contain 5 values, but got: ",
1817 kernel_sizes.size());
1818 }
1819
1820 int32_t stride_planes, stride_rows, stride_cols;
1821 int32_t kernel_planes, kernel_rows, kernel_cols;
1822
1823 if (s.ok() && data_format == "NCDHW") {
1824 // Convert input_shape to NDHWC.
1825 auto dim = [&](char dimension) {
1826 return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
1827 };
1828 input_shape =
1829 c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
1830 stride_planes = strides[2];
1831 stride_rows = strides[3];
1832 stride_cols = strides[4];
1833 kernel_planes = kernel_sizes[2];
1834 kernel_rows = kernel_sizes[3];
1835 kernel_cols = kernel_sizes[4];
1836 } else {
1837 stride_planes = strides[1];
1838 stride_rows = strides[2];
1839 stride_cols = strides[3];
1840 kernel_planes = kernel_sizes[1];
1841 kernel_rows = kernel_sizes[2];
1842 kernel_cols = kernel_sizes[3];
1843 }
1844
1845 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
1846 DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
1847 DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
1848 DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
1849 DimensionHandle output_depth_dim = c->Dim(input_shape, 4);
1850
1851 Padding padding;
1852 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1853
1854 // TODO(mrry,shlens): Raise an error if the stride would cause
1855 // information in the input to be ignored. This will require a change
1856 // in the kernel implementation.
1857 DimensionHandle output_planes, output_rows, output_cols;
1858 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1859 c, in_planes_dim, kernel_planes, stride_planes, padding, &output_planes));
1860 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1861 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1862 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1863 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1864
1865 ShapeHandle output_shape;
1866 if (data_format == "NCDHW") {
1867 output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
1868 output_planes, output_rows, output_cols});
1869 } else {
1870 output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
1871 output_cols, output_depth_dim});
1872 }
1873
1874 c->set_output(0, output_shape);
1875 return OkStatus();
1876}
1877
1878Status MaxPool3DGradShape(shape_inference::InferenceContext* c) {
1879 return UnchangedShapeWithRank(c, 5);
1880}
1881
1882Status AvgPool3DGradShape(shape_inference::InferenceContext* c) {
1883 ShapeHandle s;
1884 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
1885 TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
1886 c->set_output(0, s);
1887 return OkStatus();
1888}
1889
1890Status UnknownShape(shape_inference::InferenceContext* c) {
1891 for (int i = 0; i < c->num_outputs(); ++i) {
1892 c->set_output(i, c->UnknownShape());
1893 }
1894 return OkStatus();
1895}
1896
1897template <typename T>
1898Status ReductionShapeHelper(const Tensor* reduction_indices_t,
1899 const int32_t input_rank,
1900 std::set<int64_t>* true_indices) {
1901 auto reduction_indices = reduction_indices_t->flat<T>();
1902 for (int i = 0; i < reduction_indices_t->NumElements(); ++i) {
1903 const T reduction_index = reduction_indices(i);
1904 if (reduction_index < -input_rank || reduction_index >= input_rank) {
1905 return errors::InvalidArgument("Invalid reduction dimension ",
1906 reduction_index, " for input with ",
1907 input_rank, " dimensions.");
1908 }
1909
1910 auto wrapped_index = reduction_index;
1911 if (wrapped_index < 0) {
1912 wrapped_index += input_rank;
1913 }
1914
1915 true_indices->insert(wrapped_index);
1916 }
1917 return OkStatus();
1918}
1919
1920Status ReductionShape(InferenceContext* c) {
1921 ShapeHandle input = c->input(0);
1922
1923 ShapeHandle indices;
1924 // Older versions of TensorFlow accidentally allowed higher rank tensors like
1925 // [[1,2]] or [[1],[2]] to represent axis=[1,2].
1926 if (c->graph_def_version() < 21) {
1927 indices = c->input(1);
1928 } else {
1929 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices));
1930 }
1931
1932 bool keep_dims;
1933 TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims));
1934
1935 const Tensor* reduction_indices_t = c->input_tensor(1);
1936 if (reduction_indices_t == nullptr || !c->RankKnown(input)) {
1937 // If we do not have the reduction values at runtime, or the
1938 // rank of the input, we don't know the output shape.
1939
1940 if (keep_dims && c->RankKnown(input)) {
1941 // output rank matches input input if <keep_dims>.
1942 c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
1943 return OkStatus();
1944 } else {
1945 return shape_inference::UnknownShape(c);
1946 }
1947 }
1948
1949 const int32_t input_rank = c->Rank(input);
1950 std::set<int64_t> true_indices;
1951 if (reduction_indices_t->dtype() == DataType::DT_INT32) {
1952 TF_RETURN_IF_ERROR(ReductionShapeHelper<int32>(reduction_indices_t,
1953 input_rank, &true_indices));
1954 } else if (reduction_indices_t->dtype() == DataType::DT_INT64) {
1955 TF_RETURN_IF_ERROR(ReductionShapeHelper<int64_t>(
1956 reduction_indices_t, input_rank, &true_indices));
1957 } else {
1958 return errors::InvalidArgument(
1959 "reduction_indices can only be int32 or int64");
1960 }
1961
1962 std::vector<DimensionHandle> dims;
1963 for (int i = 0; i < input_rank; ++i) {
1964 if (true_indices.count(i) > 0) {
1965 if (keep_dims) {
1966 dims.emplace_back(c->MakeDim(1));
1967 }
1968 } else {
1969 dims.emplace_back(c->Dim(input, i));
1970 }
1971 }
1972
1973 c->set_output(0, c->MakeShape(dims));
1974 return OkStatus();
1975}
1976
1977Status ConcatShapeHelper(InferenceContext* c, int start_value_index,
1978 int end_value_index, int dim_index) {
1979 ShapeHandle unused;
1980 TF_RETURN_IF_ERROR(c->WithRank(c->input(dim_index), 0, &unused));
1981 const Tensor* concat_dim_t = c->input_tensor(dim_index);
1982 if (concat_dim_t == nullptr) {
1983 // Return an unknown shape with same rank as inputs, or an unknown rank
1984 // if no input's rank is known.
1985
1986 // Find rank.
1987 int32_t rank = InferenceContext::kUnknownRank;
1988 for (int i = start_value_index; i < end_value_index; ++i) {
1989 if (rank == InferenceContext::kUnknownRank) rank = c->Rank(c->input(i));
1990 if (rank != InferenceContext::kUnknownRank) {
1991 break;
1992 }
1993 }
1994 if (rank == InferenceContext::kUnknownRank) {
1995 c->set_output(0, c->UnknownShape());
1996 return OkStatus();
1997 } else if (rank == 0) {
1998 return errors::InvalidArgument(
1999 "Can't concatenate scalars (use tf.stack instead)");
2000 } else {
2001 for (int i = start_value_index; i < end_value_index; ++i) {
2002 // Check that all the inputs are of the correct rank.
2003 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), rank, &unused));
2004 }
2005 }
2006 // Build result of <rank> different unknown dims.
2007 std::vector<DimensionHandle> dims;
2008 dims.reserve(rank);
2009 for (int i = 0; i < rank; ++i) dims.push_back(c->UnknownDim());
2010 c->set_output(0, c->MakeShape(dims));
2011 return OkStatus();
2012 }
2013
2014 // Merge all the non-concat dims, and sum the concat dim to make an output
2015 // shape.
2016 int64_t concat_dim;
2017 if (concat_dim_t->dtype() == DT_INT32) {
2018 concat_dim = static_cast<int64_t>(concat_dim_t->flat<int32>()(0));
2019 } else {
2020 concat_dim = concat_dim_t->flat<int64_t>()(0);
2021 }
2022
2023 // Minimum required number of dimensions.
2024 const int64 min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1;
2025
2026 ShapeHandle output_before;
2027 ShapeHandle output_after;
2028
2029 ShapeHandle input = c->input(end_value_index - 1);
2030 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
2031 TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &output_before));
2032 DimensionHandle output_middle = c->Dim(input, concat_dim);
2033 if (concat_dim == -1) {
2034 output_after = c->Scalar(); // no dimensions.
2035 } else {
2036 TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &output_after));
2037 }
2038
2039 for (int i = end_value_index - 2; i >= start_value_index; --i) {
2040 ShapeHandle before;
2041 ShapeHandle after;
2042 input = c->input(i);
2043 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
2044 TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &before));
2045 DimensionHandle middle = c->Dim(input, concat_dim);
2046 if (concat_dim == -1) {
2047 after = c->Scalar();
2048 } else {
2049 TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &after));
2050 }
2051
2052 TF_RETURN_IF_ERROR(c->Merge(before, output_before, &output_before));
2053 TF_RETURN_IF_ERROR(c->Add(output_middle, middle, &output_middle));
2054 TF_RETURN_IF_ERROR(c->Merge(after, output_after, &output_after));
2055 }
2056
2057 ShapeHandle s;
2058 TF_RETURN_IF_ERROR(
2059 c->Concatenate(output_before, c->Vector(output_middle), &s));
2060 TF_RETURN_IF_ERROR(c->Concatenate(s, output_after, &s));
2061 c->set_output(0, s);
2062 return OkStatus();
2063}
2064
2065Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) {
2066 return ConcatShapeHelper(c, 1 /* start_value_index */,
2067 1 + num_inputs_to_concat /* end_value_index */,
2068 0 /* dim_index */);
2069}
2070
2071Status ConcatV2Shape(InferenceContext* c) {
2072 return ConcatShapeHelper(c, 0 /* start_value_index */,
2073 c->num_inputs() - 1 /* end_value_index */,
2074 c->num_inputs() - 1 /* dim_index */);
2075}
2076
2077Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat) {
2078 return ConcatShapeHelper(c, 0 /* start_value_index */,
2079 num_inputs_to_concat /* end_value_index */,
2080 num_inputs_to_concat /* dim_index */);
2081}
2082
2083Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
2084 ShapeHandle shape_x,
2085 ShapeHandle shape_y,
2086 bool incompatible_shape_error,
2087 ShapeHandle* out) {
2088 CHECK_NOTNULL(out);
2089 if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) {
2090 *out = c->UnknownShape();
2091 return OkStatus();
2092 }
2093 const int32_t rank_x = c->Rank(shape_x);
2094 const int32_t rank_y = c->Rank(shape_y);
2095 const int32_t rank_out = std::max(rank_x, rank_y);
2096
2097 // To compute the broadcast dimensions, we zip together shape_x and shape_y
2098 // and
2099 // pad with 1 to make them the same length.
2100 std::vector<DimensionHandle> dims;
2101 DimensionHandle dim_one;
2102 if (rank_x != rank_y) dim_one = c->MakeDim(1);
2103 for (int i = 0; i < rank_out; ++i) {
2104 const auto dim_x = i < (rank_out - rank_x)
2105 ? dim_one
2106 : c->Dim(shape_x, i - (rank_out - rank_x));
2107 const bool dim_y_is_one = (i < (rank_out - rank_y));
2108 const auto dim_y =
2109 dim_y_is_one ? dim_one : c->Dim(shape_y, i - (rank_out - rank_y));
2110 if (!c->ValueKnown(dim_x) || !c->ValueKnown(dim_y)) {
2111 // One or both dimensions is unknown.
2112 //
2113 // - If either dimension is greater than 1, we assume that the program is
2114 // correct, and the other dimension will be broadcast to match it.
2115 // TODO(cwhipkey): For shape inference, if we eliminate the shape checks
2116 // in C++ op code, we must still assert that the unknown dim is either 1
2117 // or the same as the known dim.
2118 // - If either dimension is 1, the other dimension is the output.
2119 // - If both are unknown then dimension is unknown
2120 if (c->Value(dim_x) > 1) {
2121 if (!incompatible_shape_error) {
2122 *out = c->UnknownShape();
2123 return OkStatus();
2124 }
2125 dims.push_back(dim_x);
2126 } else if (c->Value(dim_y) > 1) {
2127 if (!incompatible_shape_error) {
2128 *out = c->UnknownShape();
2129 return OkStatus();
2130 }
2131 dims.push_back(dim_y);
2132 } else if (c->Value(dim_x) == 1) {
2133 dims.push_back(dim_y);
2134 } else if (c->Value(dim_y) == 1) {
2135 dims.push_back(dim_x);
2136 } else if (dim_y.SameHandle(dim_x)) {
2137 dims.push_back(dim_x);
2138 } else if (!c->ValueKnown(dim_x) && !c->ValueKnown(dim_y)) {
2139 dims.push_back(c->UnknownDim());
2140 } else {
2141 if (!incompatible_shape_error) {
2142 *out = c->UnknownShape();
2143 return OkStatus();
2144 }
2145 dims.push_back(c->UnknownDim());
2146 }
2147 } else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) {
2148 if (c->Value(dim_x) == 1 && !dim_y_is_one) {
2149 // We will broadcast dim_x to dim_y.
2150 dims.push_back(dim_y);
2151 } else {
2152 DCHECK_EQ(c->Value(dim_y), 1);
2153 // We will broadcast dim_y to dim_x.
2154 dims.push_back(dim_x);
2155 }
2156 } else {
2157 DimensionHandle dim;
2158 Status s = c->Merge(dim_x, dim_y, &dim);
2159 if (!s.ok()) {
2160 if (!incompatible_shape_error) {
2161 *out = c->MakeShape({});
2162 return OkStatus();
2163 }
2164 return s;
2165 }
2166 dims.push_back(dim);
2167 }
2168 }
2169
2170 *out = c->MakeShape(dims);
2171 return OkStatus();
2172}
2173
2174Status RandomShape(shape_inference::InferenceContext* c) {
2175 shape_inference::ShapeHandle out;
2176 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
2177 c->set_output(0, out);
2178 return OkStatus();
2179}
2180
2181Status UnsortedSegmentReductionShapeFn(InferenceContext* c) {
2182 ShapeHandle s_data = c->input(0);
2183 ShapeHandle s_segment_ids = c->input(1);
2184 ShapeHandle s_num_segments = c->input(2);
2185 TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments));
2186
2187 ShapeHandle out;
2188
2189 // Leading dimensions of data must be compatible with dimensions of
2190 // <s_segment_ids>.
2191 if (c->RankKnown(s_segment_ids)) {
2192 TF_RETURN_IF_ERROR(
2193 c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids));
2194
2195 // Get the value of the num_segments input tensor.
2196 DimensionHandle num_segments_dim;
2197 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim));
2198
2199 // Output is {segment_id_rank} + s_data[segment_id_rank:].
2200 ShapeHandle s_data_suffix;
2201 TF_RETURN_IF_ERROR(
2202 c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix));
2203 TF_RETURN_IF_ERROR(
2204 c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out));
2205 } else {
2206 out = c->UnknownShape();
2207 }
2208 c->set_output(0, out);
2209 return OkStatus();
2210}
2211
2212namespace {
2213
2214// This SliceHelper processes the output shape of the `slice`
2215// when the tensor of `sizes` is available.
2216template <typename T>
2217Status SliceHelper(InferenceContext* c, ShapeHandle begin_value,
2218 const Tensor* sizes_value,
2219 std::vector<DimensionHandle>* dims) {
2220 auto sizes_vec = sizes_value->vec<T>();
2221 for (int i = 0; i < sizes_value->NumElements(); ++i) {
2222 DimensionHandle dim = c->Dim(c->input(0), i);
2223 if (sizes_vec(i) != -1) {
2224 auto dim_val = c->Value(dim);
2225 if (sizes_vec(i) < 0) {
2226 return errors::InvalidArgument(
2227 "Out of bounds slicing on dimension ", i, " of length ", dim_val,
2228 ": sizes vector cannot be < -1, but was ", sizes_vec(i));
2229 }
2230
2231 dims->emplace_back(c->MakeDim(sizes_vec(i)));
2232 } else {
2233 DimensionHandle result;
2234 TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result));
2235 dims->emplace_back(result);
2236 }
2237 }
2238
2239 return OkStatus();
2240}
2241} // namespace
2242
2243Status SliceShape(InferenceContext* c) {
2244 ShapeHandle input = c->input(0);
2245 ShapeHandle begin_shape;
2246 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
2247 ShapeHandle sizes_shape;
2248 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape));
2249
2250 // Merge to check compatibility of begin and sizes tensors.
2251 TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape));
2252
2253 DimensionHandle ndims = c->Dim(begin_shape, 0);
2254 if (c->ValueKnown(ndims)) {
2255 TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input));
2256 }
2257
2258 // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known
2259 // values, even though the `begin` value does not represent a shape.
2260 ShapeHandle begin_value;
2261 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value));
2262
2263 // We check the tensor value here and will only use
2264 // `MakeShapeFromShapeTensor` when `sizes_value` is null.
2265 // The reason is that `sizes` might contain -1, which can't
2266 // be represented (-1 in the ShapeHandle would mean "unknown").
2267 const Tensor* sizes_value = c->input_tensor(2);
2268
2269 if (sizes_value != nullptr) {
2270 TF_RETURN_IF_ERROR(
2271 c->WithRank(begin_value, sizes_value->NumElements(), &begin_value));
2272 std::vector<DimensionHandle> dims;
2273 // If the begin and sizes tensors are available, then
2274 // we can be precise about the shape of the output.
2275 if (sizes_value->dtype() == DT_INT64) {
2276 TF_RETURN_IF_ERROR(
2277 SliceHelper<int64_t>(c, begin_value, sizes_value, &dims));
2278 } else {
2279 TF_RETURN_IF_ERROR(
2280 SliceHelper<int32>(c, begin_value, sizes_value, &dims));
2281 }
2282 c->set_output(0, c->MakeShape(dims));
2283 return OkStatus();
2284 } else {
2285 // In case `sizes` is not available (`sizes_value` is null),
2286 // we could try to use `MakeShapeFromShapeTensor` here.
2287 // If sizes contain -1, we will simply consider it as `Unknown`.
2288 // This is less than ideal but still an improvement of shape inference.
2289 // The following is an example that returns [None, 1, None] with this
2290 // code path:
2291 // z = tf.zeros((1, 2, 3))
2292 // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1])
2293 // m.get_shape().as_list()
2294 ShapeHandle sizes_value;
2295 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value));
2296 if (c->RankKnown(sizes_value)) {
2297 TF_RETURN_IF_ERROR(
2298 c->WithRank(begin_value, c->Rank(sizes_value), &begin_value));
2299 std::vector<DimensionHandle> dims;
2300 dims.reserve(c->Rank(sizes_value));
2301 for (int i = 0; i < c->Rank(sizes_value); ++i) {
2302 dims.emplace_back(c->Dim(sizes_value, i));
2303 }
2304 c->set_output(0, c->MakeShape(dims));
2305 return OkStatus();
2306 }
2307 // We might know the rank of the input.
2308 if (c->RankKnown(input)) {
2309 c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
2310 return OkStatus();
2311 } else {
2312 return shape_inference::UnknownShape(c);
2313 }
2314 }
2315
2316 return OkStatus();
2317}
2318
2319Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
2320 ShapeHandle values_shape, ShapeHandle shape_shape) {
2321 // Validate ranks.
2322 ShapeHandle unused_shape;
2323 TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape));
2324 TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape));
2325 TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape));
2326
2327 // Number of elements in indices and values must match.
2328 DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0);
2329 if (c->ValueKnown(num_index_elements_dim)) {
2330 DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0);
2331 if (c->ValueKnown(num_values_elements_dim)) {
2332 int64_t num_index_elements = c->Value(num_index_elements_dim);
2333 int64_t num_values_elements = c->Value(num_values_elements_dim);
2334 if (num_index_elements != num_values_elements) {
2335 return errors::InvalidArgument("Number of elements in index (",
2336 num_index_elements, ") and values (",
2337 num_values_elements, ") do not match.");
2338 }
2339 }
2340 }
2341
2342 // Rank embedded in indices must match shape.
2343 DimensionHandle index_rank_dim = c->Dim(indices_shape, 1);
2344 if (c->ValueKnown(index_rank_dim)) {
2345 DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0);
2346 if (c->ValueKnown(shape_rank_dim)) {
2347 int64_t index_rank = c->Value(index_rank_dim);
2348 int32_t shape_rank = c->Value(shape_rank_dim);
2349 if (index_rank != shape_rank) {
2350 return errors::InvalidArgument("Index rank (", index_rank,
2351 ") and shape rank (", shape_rank,
2352 ") do not match.");
2353 }
2354 }
2355 }
2356
2357 return OkStatus();
2358}
2359
2360Status ValidateVariableResourceHandle(
2361 InferenceContext* c, std::vector<ShapeAndType>* shape_and_type) {
2362 auto* handle_data = c->input_handle_shapes_and_types(0);
2363 if (handle_data == nullptr || handle_data->empty()) {
2364 shape_and_type->emplace_back(c->UnknownShape(), DT_INVALID);
2365 } else {
2366 *shape_and_type = *handle_data;
2367 DataType value_dtype;
2368 TF_RETURN_IF_ERROR(c->GetAttr("dtype", &value_dtype));
2369 if (shape_and_type->at(0).dtype != value_dtype) {
2370 return errors::InvalidArgument(
2371 "Trying to read variable with wrong dtype. "
2372 "Expected ",
2373 DataTypeString(shape_and_type->at(0).dtype), " got ",
2374 DataTypeString(value_dtype));
2375 }
2376 }
2377 return OkStatus();
2378}
2379
2380Status GatherNdShape(InferenceContext* c) {
2381 ShapeHandle params;
2382 std::vector<ShapeAndType> handle_shape_and_type;
2383 if (c->input_handle_shapes_and_types(0) != nullptr) {
2384 TF_RETURN_IF_ERROR(
2385 ValidateVariableResourceHandle(c, &handle_shape_and_type));
2386 params = handle_shape_and_type[0].shape;
2387 } else {
2388 params = c->input(0);
2389 }
2390 ShapeHandle indices;
2391 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices));
2392 DimensionHandle r_dim = c->Dim(indices, -1);
2393
2394 if (!c->RankKnown(params) || !c->ValueKnown(r_dim)) {
2395 c->set_output(0, c->UnknownShape());
2396 return OkStatus();
2397 }
2398
2399 if (c->Value(r_dim) > c->Rank(params)) {
2400 return errors::InvalidArgument(
2401 "indices.shape[-1] must be <= params.rank, but saw indices shape: ",
2402 c->DebugString(indices), " and params shape: ", c->DebugString(params));
2403 }
2404
2405 // Remove r_dim from indices to get output.
2406 ShapeHandle indices_slice;
2407 ShapeHandle params_slice;
2408 TF_RETURN_IF_ERROR(c->Subshape(indices, 0, -1, &indices_slice));
2409 TF_RETURN_IF_ERROR(c->Subshape(params, c->Value(r_dim), &params_slice));
2410 ShapeHandle out;
2411 TF_RETURN_IF_ERROR(c->Concatenate(indices_slice, params_slice, &out));
2412 c->set_output(0, out);
2413 return OkStatus();
2414}
2415
2416Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape,
2417 ShapeHandle updates_shape,
2418 ShapeHandle input_shape) {
2419 if (c->Value(c->NumElements(input_shape)) == 0 &&
2420 (c->Value(c->NumElements(indices_shape)) > 0 ||
2421 c->Value(c->NumElements(updates_shape)) > 0)) {
2422 return errors::InvalidArgument(
2423 "Indices and updates specified for empty input");
2424 }
2425
2426 if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape) &&
2427 c->Rank(updates_shape) != 0) {
2428 const int64_t outer_dims = c->Rank(indices_shape) - 1;
2429 const DimensionHandle ixdim = c->Dim(indices_shape, -1);
2430
2431 // We can only do more validation if the last dimension of indices
2432 // is a known value.
2433 if (c->ValueKnown(ixdim)) {
2434 int64_t ix = c->Value(ixdim);
2435 ShapeHandle unused;
2436 ShapeHandle prefix_indices;
2437 TF_RETURN_IF_ERROR(
2438 c->Subshape(indices_shape, 0, outer_dims, &prefix_indices));
2439 ShapeHandle prefix_updates;
2440 TF_RETURN_IF_ERROR(
2441 c->Subshape(updates_shape, 0, outer_dims, &prefix_updates));
2442
2443 Status s = c->Merge(prefix_indices, prefix_updates, &unused);
2444 if (!s.ok()) {
2445 return errors::InvalidArgument(
2446 "Dimensions [0,", outer_dims,
2447 ") of indices[shape=", c->DebugString(indices_shape),
2448 "] = ", c->DebugString(prefix_indices),
2449 " must match dimensions [0,", outer_dims,
2450 ") of updates[shape=", c->DebugString(updates_shape),
2451 "] = ", c->DebugString(prefix_updates), ": ", s.error_message());
2452 }
2453
2454 ShapeHandle suffix_output;
2455 TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &suffix_output));
2456 ShapeHandle suffix_updates;
2457 TF_RETURN_IF_ERROR(
2458 c->Subshape(updates_shape, outer_dims, &suffix_updates));
2459 s = c->Merge(suffix_output, suffix_updates, &unused);
2460 if (!s.ok()) {
2461 return errors::InvalidArgument(
2462 "Dimensions [", ix, ",", c->Rank(input_shape),
2463 ") of input[shape=", c->DebugString(input_shape),
2464 "] = ", c->DebugString(suffix_output), " must match dimensions [",
2465 outer_dims, ",", c->Rank(updates_shape),
2466 ") of updates[shape=", c->DebugString(updates_shape),
2467 "] = ", c->DebugString(suffix_updates), ": ", s.error_message());
2468 }
2469 }
2470 }
2471
2472 if (c->input_handle_shapes_and_types(0) == nullptr && c->num_outputs() > 0) {
2473 // This is called for tf.scatter_nd; output is a tensor with this shape.
2474 c->set_output(0, input_shape);
2475 }
2476 return OkStatus();
2477}
2478
2479Status ExplicitShape(InferenceContext* c) {
2480 PartialTensorShape shape;
2481 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
2482 ShapeHandle output_shape;
2483 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &output_shape));
2484 c->set_output(0, output_shape);
2485 return OkStatus();
2486}
2487
2488Status ExplicitShapes(InferenceContext* c) {
2489 std::vector<PartialTensorShape> shapes;
2490 TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
2491 if (shapes.empty()) {
2492 return errors::Internal("shapes attribute is empty");
2493 }
2494 for (int i = 0, end = shapes.size(); i < end; ++i) {
2495 ShapeHandle output_shape;
2496 TF_RETURN_IF_ERROR(
2497 c->MakeShapeFromPartialTensorShape(shapes[i], &output_shape));
2498 c->set_output(i, output_shape);
2499 }
2500 return OkStatus();
2501}
2502
2503Status SparseReduceShapeFn(InferenceContext* c) {
2504 // Input 0: input_indices
2505 // Input 1: input_values
2506 // Input 2: input_shape
2507 // Input 3: reduction_axes
2508 // Attr: keep_dims
2509 bool keep_dims = false;
2510 TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims));
2511
2512 const Tensor* shape_tensor = c->input_tensor(2);
2513 const Tensor* axes_tensor = c->input_tensor(3);
2514 if (shape_tensor != nullptr && axes_tensor != nullptr) {
2515 auto shape_vec = shape_tensor->flat<int64_t>();
2516 auto axes_vec = axes_tensor->flat<int32>();
2517
2518 int64_t ndims = shape_vec.size();
2519 absl::flat_hash_set<int64_t> axes;
2520 if (ndims == 0)
2521 return errors::InvalidArgument(
2522 "Number of dims in shape tensor must not be 0");
2523 for (int i = 0; i < axes_vec.size(); i++) {
2524 axes.insert((axes_vec(i) + ndims) % ndims);
2525 }
2526
2527 std::vector<DimensionHandle> dims;
2528 if (keep_dims) {
2529 dims.reserve(ndims);
2530 for (int d = 0; d < ndims; ++d) {
2531 if (axes.find(d) == axes.end()) {
2532 dims.push_back(c->MakeDim(shape_vec(d)));
2533 } else {
2534 dims.push_back(c->MakeDim(1));
2535 }
2536 }
2537 } else {
2538 for (int d = 0; d < ndims; ++d) {
2539 if (axes.find(d) == axes.end()) {
2540 dims.push_back(c->MakeDim(shape_vec(d)));
2541 }
2542 }
2543 }
2544
2545 c->set_output(0, c->MakeShape(dims));
2546 return OkStatus();
2547 }
2548 return UnknownShape(c);
2549}
2550
2551Status QuantizedConv2DShape(InferenceContext* c) {
2552 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2553 ShapeHandle unused;
2554 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2555 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2556 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2557 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2558 c->set_output(1, c->Scalar());
2559 c->set_output(2, c->Scalar());
2560 return OkStatus();
2561}
2562
2563Status FusedQuantizedConvShape(InferenceContext* c, int num_dims) {
2564 std::vector<string> fused_ops;
2565 TF_RETURN_IF_ERROR(c->GetAttr("fused_ops", &fused_ops));
2566 ShapeHandle unused, channel;
2567 bool fused_sum, fused_bias, fused_requantize;
2568 fused_sum =
2569 std::find(fused_ops.begin(), fused_ops.end(), "Sum") != fused_ops.end();
2570 fused_bias = std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") !=
2571 fused_ops.end();
2572 fused_requantize = std::find(fused_ops.begin(), fused_ops.end(),
2573 "Requantize") != fused_ops.end();
2574 const int kMinInputBaseIdx = 2;
2575 const int kMinFilterBaseIdx = 4;
2576 int min_input_filter_offset = 0;
2577 if (fused_bias && !fused_sum) {
2578 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); // bias
2579 min_input_filter_offset = 1;
2580 } else if (fused_sum && !fused_bias) {
2581 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), num_dims, &unused)); // summand
2582 min_input_filter_offset = 1;
2583 } else if (fused_bias && fused_sum) {
2584 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); // bias
2585 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), num_dims, &unused)); // summand
2586 min_input_filter_offset = 2;
2587 }
2588 TF_RETURN_IF_ERROR(
2589 c->WithRank(c->input(kMinInputBaseIdx + min_input_filter_offset), 0,
2590 &unused)); // min_input
2591 TF_RETURN_IF_ERROR(
2592 c->WithRank(c->input(kMinInputBaseIdx + min_input_filter_offset + 1), 0,
2593 &unused)); // max_input
2594 TF_RETURN_IF_ERROR(
2595 c->WithRankAtMost(c->input(kMinFilterBaseIdx + min_input_filter_offset),
2596 1, &channel)); // min_filter
2597 TF_RETURN_IF_ERROR(c->WithRankAtMost(
2598 c->input(kMinFilterBaseIdx + min_input_filter_offset + 1), 1,
2599 &channel)); // max_filter
2600 if (fused_requantize) {
2601 c->set_output(1, c->Scalar());
2602 c->set_output(2, c->Scalar());
2603 } else {
2604 c->set_output(1, channel);
2605 c->set_output(2, channel);
2606 }
2607 return Status::OK();
2608}
2609
2610Status FusedQuantizedConv2DShape(InferenceContext* c) {
2611 TF_RETURN_IF_ERROR(shape_inference::Conv2DShapeImpl(c, true));
2612 TF_RETURN_IF_ERROR(FusedQuantizedConvShape(c, 4));
2613 return Status::OK();
2614}
2615
2616Status FusedQuantizedDepthwiseConv2D(InferenceContext* c) {
2617 TF_RETURN_IF_ERROR(DepthwiseConv2DNativeShapeImpl(c, true));
2618 TF_RETURN_IF_ERROR(FusedQuantizedConvShape(c, 4));
2619 return Status::OK();
2620}
2621
2622Status QuantizedAvgPoolShape(InferenceContext* c) {
2623 TF_RETURN_IF_ERROR(shape_inference::AvgPoolShape(c));
2624 ShapeHandle unused;
2625 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
2626 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2627 c->set_output(1, c->Scalar());
2628 c->set_output(2, c->Scalar());
2629 return OkStatus();
2630}
2631
2632Status QuantizeV2Shape(InferenceContext* c) {
2633 int axis = -1;
2634 Status s = c->GetAttr("axis", &axis);
2635 if (!s.ok() && s.code() != error::NOT_FOUND) {
2636 return s;
2637 }
2638 if (axis < -1) {
2639 return errors::InvalidArgument("axis should be at least -1, got ", axis);
2640 }
2641 const int minmax_rank = (axis == -1) ? 0 : 1;
2642 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
2643 ShapeHandle minmax;
2644 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
2645 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), minmax_rank, &minmax));
2646 if (axis != -1) {
2647 ShapeHandle input;
2648 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
2649 DimensionHandle depth;
2650 TF_RETURN_IF_ERROR(
2651 c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
2652 }
2653 c->set_output(1, minmax);
2654 c->set_output(2, minmax);
2655 return OkStatus();
2656}
2657
2658Status ReduceScatterShape(shape_inference::InferenceContext* c) {
2659 shape_inference::ShapeHandle in = c->input(0);
2660 if (!c->RankKnown(in)) {
2661 // Input shape unknown, so set unknown output shape.
2662 c->set_output(0, in);
2663 return OkStatus();
2664 }
2665
2666 shape_inference::ShapeHandle group_assignment_shape = c->input(1);
2667 if (c->Rank(group_assignment_shape) != 2)
2668 return errors::InvalidArgument(
2669 "ReduceScatter group_assignment should be rank 2");
2670
2671 const Tensor* scatter_dimension = c->input_tensor(2);
2672 if (!scatter_dimension) {
2673 c->set_output(0, c->UnknownShape());
2674 return OkStatus();
2675 }
2676 int64_t scatter_dim;
2677 TF_RETURN_IF_ERROR(c->GetScalarFromTensor(scatter_dimension, &scatter_dim));
2678
2679 std::vector<shape_inference::DimensionHandle> out_dims;
2680 out_dims.reserve(c->Rank(in));
2681 for (int i = 0; i < c->Rank(in); ++i) {
2682 // If the dimension is the scatter_dimension, then divide the dimension
2683 // by the partition size in the group_assignment.
2684 if (i == scatter_dim) {
2685 shape_inference::DimensionHandle dim = c->Dim(in, i);
2686 shape_inference::DimensionHandle out_dim;
2687 TF_RETURN_IF_ERROR(c->Divide(dim, c->Dim(group_assignment_shape, 1),
2688 /*evenly_divisible=*/true, &out_dim));
2689 out_dims.push_back(out_dim);
2690 } else {
2691 out_dims.emplace_back(c->Dim(in, i));
2692 }
2693 }
2694 c->set_output(0, c->MakeShape(out_dims));
2695 return OkStatus();
2696}
2697
2698} // namespace shape_inference
2699
2700} // namespace tensorflow
2701