1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
16#define TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
17
18#include <array>
19
20#include "tensorflow/core/framework/shape_inference.h"
21#include "tensorflow/core/util/padding.h"
22#include "tensorflow/core/util/tensor_format.h"
23
24namespace tensorflow {
25
26namespace shape_inference {
27
28// Like GetWindowedOutputSize, but deals with DimensionHandles. Does not support
29// EXPLICIT padding.
30Status GetWindowedOutputSizeFromDims(InferenceContext* c,
31 DimensionHandle input_size,
32 DimensionOrConstant filter_size,
33 int64_t stride, Padding padding_type,
34 DimensionHandle* output_size);
35
36// The V2 version computes the same outputs with arbitrary dilation_rate, and
37// supports EXPLICIT padding. For detailed equations, refer to the comments
38// for GetWindowedOutputSizeV2(). The 'padding_before' and 'padding_after'
39// parameters are only used if padding_type == EXPLICIT.
40Status GetWindowedOutputSizeFromDimsV2(
41 InferenceContext* c, DimensionHandle input_size,
42 DimensionOrConstant filter_size, int64_t dilation_rate, int64_t stride,
43 Padding padding_type, int64_t padding_before, int64_t padding_after,
44 DimensionHandle* output_size);
45
46// Transfers shape of input(0) to output(0).
47Status UnchangedShape(shape_inference::InferenceContext* c);
48
49// Transfers shape of input(0) to output(0), after asserting its rank is <rank>.
50inline Status UnchangedShapeWithRank(shape_inference::InferenceContext* c,
51 int32_t rank) {
52 ShapeHandle out;
53 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &out));
54 c->set_output(0, out);
55 return OkStatus();
56}
57
58// Transfers shape of input(0) to output(0), after asserting its rank >= <rank>.
59inline Status UnchangedShapeWithRankAtLeast(
60 shape_inference::InferenceContext* c, int32_t rank) {
61 ShapeHandle out;
62 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), rank, &out));
63 c->set_output(0, out);
64 return OkStatus();
65}
66
67// Transfers shape of input(0) to output(0), after asserting its rank <= <rank>.
68inline Status UnchangedShapeWithRankAtMost(shape_inference::InferenceContext* c,
69 int32_t rank) {
70 ShapeHandle out;
71 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), rank, &out));
72 c->set_output(0, out);
73 return OkStatus();
74}
75
76// Shape function for use with ops no outputs.
77inline Status NoOutputs(shape_inference::InferenceContext* c) {
78 return OkStatus();
79}
80
81// Shape function for ops that output a single scalar value.
82inline Status ScalarShape(shape_inference::InferenceContext* c) {
83 c->set_output(0, c->Scalar());
84 return OkStatus();
85}
86
87// Shape function for binary ops where both inputs and the output match.
88inline Status MergeBothInputsShapeFn(InferenceContext* c) {
89 ShapeHandle out;
90 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &out));
91 c->set_output(0, out);
92 return OkStatus();
93}
94
95// Shape function for dataset iterators.
96Status DatasetIteratorShape(shape_inference::InferenceContext* c);
97
98// Returns a new shape with the specified dims arranged in the specified
99// format. The returned value is owned by this context.
100// Note: if format = "FORMAT_NCHW_VECT_C" then C represents the outer_depth.
101Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
102 const std::vector<DimensionOrConstant>& spatial,
103 DimensionOrConstant C, ShapeHandle* out,
104 shape_inference::InferenceContext* context);
105
106// Shape function for MatMul-like operations.
107Status MatMulShape(shape_inference::InferenceContext* c);
108
109// Shape function for Batched MatMul-like operations with broadcasting across
110// batch dimensions.
111Status BatchMatMulV2Shape(shape_inference::InferenceContext* c);
112
113// Shape function for BatchMatMul-like operations
114Status BatchMatMulShape(shape_inference::InferenceContext* c);
115
116// Shape function for Einsum.
117Status EinsumShape(shape_inference::InferenceContext* c);
118
119// Shape function for BiasAdd-like operations.
120Status BiasAddShape(shape_inference::InferenceContext* c);
121
122// Shape function for BiasAddGrad-like operations.
123Status BiasAddGradShape(shape_inference::InferenceContext* c);
124
125// Shape function for Conv2D-like operations that support explicit padding.
126Status Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext* c);
127
128// Shape function for Conv2D-like operations that do not support explicit
129// padding.
130Status Conv2DShape(shape_inference::InferenceContext* c);
131
132// Shape function for Conv3D-like operations.
133Status Conv3DShape(shape_inference::InferenceContext* c);
134
135// Shape function for DepthwiseConv2D-like operations that support explicit
136// padding.
137Status DepthwiseConv2DNativeShapeWithExplicitPadding(
138 shape_inference::InferenceContext* c);
139
140// Shape function for DepthwiseConv2D-like operations that do not support
141// explicit padding.
142Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c);
143
144// Shape function for Conv2DBackpropInput.
145Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c);
146
147// Shape function for Conv2DBackpropFilterWithBias.
148Status Conv2DBackpropFilterWithBiasShape(shape_inference::InferenceContext* c);
149
150// Shape function for AvgPool-like operations.
151Status AvgPoolShape(shape_inference::InferenceContext* c);
152
153// Shape function for AvgPoolGrad-like operations.
154Status AvgPoolGradShape(shape_inference::InferenceContext* c);
155
156// Shape function for FusedBatchNorm and FusedBatchNormV2 operations.
157Status FusedBatchNormShape(shape_inference::InferenceContext* c);
158
159// Shape function for FusedBatchNormV3 operations.
160Status FusedBatchNormV3Shape(shape_inference::InferenceContext* c);
161
162// Shape function for _FusedBatchNormEx operations.
163Status FusedBatchNormExShape(shape_inference::InferenceContext* c);
164
165// Shape function for FusedBatchNormGrad and FusedBatchNormGradV2 operations.
166Status FusedBatchNormGradShape(shape_inference::InferenceContext* c);
167
168// Shape function for _FusedBatchNormGradEx operations.
169Status FusedBatchNormGradExShape(shape_inference::InferenceContext* c);
170
171// Shape function for MatrixDiagPartV2 and MatrixDiagPartV3 operations.
172Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c);
173
174// Shape function for MatrixDiagV2 and MatrixDiagV3 operations.
175Status MatrixDiagV2Shape(shape_inference::InferenceContext* c);
176
177// Shape function for MatrixSetDiagV2 and MatrixSetDiagV3 operations.
178Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c);
179
180// Shape function for MaxPool-like operations that support explicit padding.
181Status MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext* c);
182
183// Shape function for MaxPool-like operations that do not support explicit
184// padding.
185Status MaxPoolShape(shape_inference::InferenceContext* c);
186
187// Shape function for MaxPoolV2-like operations.
188Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs);
189
190// Shape function for MaxPoolGrad-like operations.
191Status MaxPoolGradShape(shape_inference::InferenceContext* c);
192
193// Shape function for 3D Pooling operations.
194Status Pool3DShape(shape_inference::InferenceContext* c);
195
196// Shape function for MaxPool3DGrad-like operations.
197Status MaxPool3DGradShape(shape_inference::InferenceContext* c);
198
199// Shape function for AvgPool3DGrad-like operations.
200Status AvgPool3DGradShape(shape_inference::InferenceContext* c);
201
202// Shape function for use with ops whose output shapes are unknown.
203Status UnknownShape(shape_inference::InferenceContext* c);
204
205// Shape function for reduction operations.
206Status ReductionShape(shape_inference::InferenceContext* c);
207
208// Shape function for unsorted segment operations.
209Status UnsortedSegmentReductionShapeFn(InferenceContext* c);
210
211// Shape function for concat operations.
212// <num_inputs_to_concat> is the number of inputs to concatenate and are taken
213// from inputs
214// [1,num_inputs_to_concat] of the op. Input 0 is the concat_dim input.
215Status ConcatShape(shape_inference::InferenceContext* c,
216 int num_inputs_to_concat);
217
218// Shape function for concat operations.
219Status ConcatV2Shape(shape_inference::InferenceContext* c);
220
221Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat);
222
223// Shape function for binary operators that broadcast their inputs
224// and with output to output_index.
225// Note: out cannot be NULL.
226Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
227 ShapeHandle shape_x,
228 ShapeHandle shape_y,
229 bool incompatible_shape_error,
230 ShapeHandle* out);
231
232// Shape function for binary operators that broadcast their inputs
233// and with output to output_index.
234inline Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c,
235 int output_index) {
236 ShapeHandle out;
237 TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
238 c, c->input(0), c->input(1), true, &out));
239 c->set_output(output_index, out);
240 return OkStatus();
241}
242
243// Shape function for binary operators that broadcast their inputs.
244// Tested by ops/math_ops_test.cc.
245inline Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
246 return BroadcastBinaryOpOutputShapeFn(c, 0);
247}
248
249// Shape function for random operations.
250Status RandomShape(shape_inference::InferenceContext* c);
251
252// Shape function for Slice operations.
253Status SliceShape(shape_inference::InferenceContext* c);
254
255// Validates the 3 component tensors of a sparse tensor have the proper
256// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
257Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
258 ShapeHandle values_shape, ShapeHandle shape_shape);
259
260Status ValidateVariableResourceHandle(
261 InferenceContext* c, std::vector<ShapeAndType>* shape_and_type);
262
263// Shape function for GatherNd operations.
264Status GatherNdShape(InferenceContext* c);
265
266// Helper shape function for ScatterNd.../TensorScatter... operations.
267Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape,
268 ShapeHandle updates_shape, ShapeHandle input_shape);
269
270// Shape function for ops with an explicit "shape" attribute.
271Status ExplicitShape(InferenceContext* c);
272
273// Shape function for multiple-output ops with an explicit "shapes" attribute.
274Status ExplicitShapes(InferenceContext* c);
275
276// Shape function for SparseReduceMax and SparseReduceSum.
277Status SparseReduceShapeFn(InferenceContext* c);
278
279// Shape function for QuantizedConv2D op.
280Status QuantizedConv2DShape(InferenceContext* c);
281
282// Shape function for _QuantizedConv2D op/fusion.
283Status FusedQuantizedConv2DShape(InferenceContext* c);
284
285// Shape function for _QuantizedDepthwiseConv2D op/fusion.
286Status FusedQuantizedDepthwiseConv2D(InferenceContext* c);
287
288// Shape function for QuantizedAvgPool op
289Status QuantizedAvgPoolShape(InferenceContext* c);
290
291// Shape function for QuantizeV2 op
292Status QuantizeV2Shape(InferenceContext* c);
293
294// Shape function for ReduceScatter ops
295Status ReduceScatterShape(shape_inference::InferenceContext* c);
296
297} // namespace shape_inference
298
299} // namespace tensorflow
300
301#endif // TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
302