1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_ |
16 | #define TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_ |
17 | |
18 | #include <array> |
19 | |
20 | #include "tensorflow/core/framework/shape_inference.h" |
21 | #include "tensorflow/core/util/padding.h" |
22 | #include "tensorflow/core/util/tensor_format.h" |
23 | |
24 | namespace tensorflow { |
25 | |
26 | namespace shape_inference { |
27 | |
28 | // Like GetWindowedOutputSize, but deals with DimensionHandles. Does not support |
29 | // EXPLICIT padding. |
30 | Status GetWindowedOutputSizeFromDims(InferenceContext* c, |
31 | DimensionHandle input_size, |
32 | DimensionOrConstant filter_size, |
33 | int64_t stride, Padding padding_type, |
34 | DimensionHandle* output_size); |
35 | |
36 | // The V2 version computes the same outputs with arbitrary dilation_rate, and |
37 | // supports EXPLICIT padding. For detailed equations, refer to the comments |
38 | // for GetWindowedOutputSizeV2(). The 'padding_before' and 'padding_after' |
39 | // parameters are only used if padding_type == EXPLICIT. |
40 | Status GetWindowedOutputSizeFromDimsV2( |
41 | InferenceContext* c, DimensionHandle input_size, |
42 | DimensionOrConstant filter_size, int64_t dilation_rate, int64_t stride, |
43 | Padding padding_type, int64_t padding_before, int64_t padding_after, |
44 | DimensionHandle* output_size); |
45 | |
46 | // Transfers shape of input(0) to output(0). |
47 | Status UnchangedShape(shape_inference::InferenceContext* c); |
48 | |
49 | // Transfers shape of input(0) to output(0), after asserting its rank is <rank>. |
50 | inline Status UnchangedShapeWithRank(shape_inference::InferenceContext* c, |
51 | int32_t rank) { |
52 | ShapeHandle out; |
53 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &out)); |
54 | c->set_output(0, out); |
55 | return OkStatus(); |
56 | } |
57 | |
58 | // Transfers shape of input(0) to output(0), after asserting its rank >= <rank>. |
59 | inline Status UnchangedShapeWithRankAtLeast( |
60 | shape_inference::InferenceContext* c, int32_t rank) { |
61 | ShapeHandle out; |
62 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), rank, &out)); |
63 | c->set_output(0, out); |
64 | return OkStatus(); |
65 | } |
66 | |
67 | // Transfers shape of input(0) to output(0), after asserting its rank <= <rank>. |
68 | inline Status UnchangedShapeWithRankAtMost(shape_inference::InferenceContext* c, |
69 | int32_t rank) { |
70 | ShapeHandle out; |
71 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), rank, &out)); |
72 | c->set_output(0, out); |
73 | return OkStatus(); |
74 | } |
75 | |
76 | // Shape function for use with ops no outputs. |
77 | inline Status NoOutputs(shape_inference::InferenceContext* c) { |
78 | return OkStatus(); |
79 | } |
80 | |
81 | // Shape function for ops that output a single scalar value. |
82 | inline Status ScalarShape(shape_inference::InferenceContext* c) { |
83 | c->set_output(0, c->Scalar()); |
84 | return OkStatus(); |
85 | } |
86 | |
87 | // Shape function for binary ops where both inputs and the output match. |
88 | inline Status MergeBothInputsShapeFn(InferenceContext* c) { |
89 | ShapeHandle out; |
90 | TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &out)); |
91 | c->set_output(0, out); |
92 | return OkStatus(); |
93 | } |
94 | |
95 | // Shape function for dataset iterators. |
96 | Status DatasetIteratorShape(shape_inference::InferenceContext* c); |
97 | |
98 | // Returns a new shape with the specified dims arranged in the specified |
99 | // format. The returned value is owned by this context. |
100 | // Note: if format = "FORMAT_NCHW_VECT_C" then C represents the outer_depth. |
101 | Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N, |
102 | const std::vector<DimensionOrConstant>& spatial, |
103 | DimensionOrConstant C, ShapeHandle* out, |
104 | shape_inference::InferenceContext* context); |
105 | |
106 | // Shape function for MatMul-like operations. |
107 | Status MatMulShape(shape_inference::InferenceContext* c); |
108 | |
109 | // Shape function for Batched MatMul-like operations with broadcasting across |
110 | // batch dimensions. |
111 | Status BatchMatMulV2Shape(shape_inference::InferenceContext* c); |
112 | |
113 | // Shape function for BatchMatMul-like operations |
114 | Status BatchMatMulShape(shape_inference::InferenceContext* c); |
115 | |
116 | // Shape function for Einsum. |
117 | Status EinsumShape(shape_inference::InferenceContext* c); |
118 | |
119 | // Shape function for BiasAdd-like operations. |
120 | Status BiasAddShape(shape_inference::InferenceContext* c); |
121 | |
122 | // Shape function for BiasAddGrad-like operations. |
123 | Status BiasAddGradShape(shape_inference::InferenceContext* c); |
124 | |
125 | // Shape function for Conv2D-like operations that support explicit padding. |
126 | Status Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext* c); |
127 | |
128 | // Shape function for Conv2D-like operations that do not support explicit |
129 | // padding. |
130 | Status Conv2DShape(shape_inference::InferenceContext* c); |
131 | |
132 | // Shape function for Conv3D-like operations. |
133 | Status Conv3DShape(shape_inference::InferenceContext* c); |
134 | |
135 | // Shape function for DepthwiseConv2D-like operations that support explicit |
136 | // padding. |
137 | Status DepthwiseConv2DNativeShapeWithExplicitPadding( |
138 | shape_inference::InferenceContext* c); |
139 | |
140 | // Shape function for DepthwiseConv2D-like operations that do not support |
141 | // explicit padding. |
142 | Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c); |
143 | |
144 | // Shape function for Conv2DBackpropInput. |
145 | Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c); |
146 | |
147 | // Shape function for Conv2DBackpropFilterWithBias. |
148 | Status Conv2DBackpropFilterWithBiasShape(shape_inference::InferenceContext* c); |
149 | |
150 | // Shape function for AvgPool-like operations. |
151 | Status AvgPoolShape(shape_inference::InferenceContext* c); |
152 | |
153 | // Shape function for AvgPoolGrad-like operations. |
154 | Status AvgPoolGradShape(shape_inference::InferenceContext* c); |
155 | |
156 | // Shape function for FusedBatchNorm and FusedBatchNormV2 operations. |
157 | Status FusedBatchNormShape(shape_inference::InferenceContext* c); |
158 | |
159 | // Shape function for FusedBatchNormV3 operations. |
160 | Status FusedBatchNormV3Shape(shape_inference::InferenceContext* c); |
161 | |
162 | // Shape function for _FusedBatchNormEx operations. |
163 | Status FusedBatchNormExShape(shape_inference::InferenceContext* c); |
164 | |
165 | // Shape function for FusedBatchNormGrad and FusedBatchNormGradV2 operations. |
166 | Status FusedBatchNormGradShape(shape_inference::InferenceContext* c); |
167 | |
168 | // Shape function for _FusedBatchNormGradEx operations. |
169 | Status FusedBatchNormGradExShape(shape_inference::InferenceContext* c); |
170 | |
171 | // Shape function for MatrixDiagPartV2 and MatrixDiagPartV3 operations. |
172 | Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c); |
173 | |
174 | // Shape function for MatrixDiagV2 and MatrixDiagV3 operations. |
175 | Status MatrixDiagV2Shape(shape_inference::InferenceContext* c); |
176 | |
177 | // Shape function for MatrixSetDiagV2 and MatrixSetDiagV3 operations. |
178 | Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c); |
179 | |
180 | // Shape function for MaxPool-like operations that support explicit padding. |
181 | Status MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext* c); |
182 | |
183 | // Shape function for MaxPool-like operations that do not support explicit |
184 | // padding. |
185 | Status MaxPoolShape(shape_inference::InferenceContext* c); |
186 | |
187 | // Shape function for MaxPoolV2-like operations. |
188 | Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs); |
189 | |
190 | // Shape function for MaxPoolGrad-like operations. |
191 | Status MaxPoolGradShape(shape_inference::InferenceContext* c); |
192 | |
193 | // Shape function for 3D Pooling operations. |
194 | Status Pool3DShape(shape_inference::InferenceContext* c); |
195 | |
196 | // Shape function for MaxPool3DGrad-like operations. |
197 | Status MaxPool3DGradShape(shape_inference::InferenceContext* c); |
198 | |
199 | // Shape function for AvgPool3DGrad-like operations. |
200 | Status AvgPool3DGradShape(shape_inference::InferenceContext* c); |
201 | |
202 | // Shape function for use with ops whose output shapes are unknown. |
203 | Status UnknownShape(shape_inference::InferenceContext* c); |
204 | |
205 | // Shape function for reduction operations. |
206 | Status ReductionShape(shape_inference::InferenceContext* c); |
207 | |
208 | // Shape function for unsorted segment operations. |
209 | Status UnsortedSegmentReductionShapeFn(InferenceContext* c); |
210 | |
211 | // Shape function for concat operations. |
212 | // <num_inputs_to_concat> is the number of inputs to concatenate and are taken |
213 | // from inputs |
214 | // [1,num_inputs_to_concat] of the op. Input 0 is the concat_dim input. |
215 | Status ConcatShape(shape_inference::InferenceContext* c, |
216 | int num_inputs_to_concat); |
217 | |
218 | // Shape function for concat operations. |
219 | Status ConcatV2Shape(shape_inference::InferenceContext* c); |
220 | |
221 | Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat); |
222 | |
223 | // Shape function for binary operators that broadcast their inputs |
224 | // and with output to output_index. |
225 | // Note: out cannot be NULL. |
226 | Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, |
227 | ShapeHandle shape_x, |
228 | ShapeHandle shape_y, |
229 | bool incompatible_shape_error, |
230 | ShapeHandle* out); |
231 | |
232 | // Shape function for binary operators that broadcast their inputs |
233 | // and with output to output_index. |
234 | inline Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c, |
235 | int output_index) { |
236 | ShapeHandle out; |
237 | TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( |
238 | c, c->input(0), c->input(1), true, &out)); |
239 | c->set_output(output_index, out); |
240 | return OkStatus(); |
241 | } |
242 | |
243 | // Shape function for binary operators that broadcast their inputs. |
244 | // Tested by ops/math_ops_test.cc. |
245 | inline Status BroadcastBinaryOpShapeFn(InferenceContext* c) { |
246 | return BroadcastBinaryOpOutputShapeFn(c, 0); |
247 | } |
248 | |
249 | // Shape function for random operations. |
250 | Status RandomShape(shape_inference::InferenceContext* c); |
251 | |
252 | // Shape function for Slice operations. |
253 | Status SliceShape(shape_inference::InferenceContext* c); |
254 | |
255 | // Validates the 3 component tensors of a sparse tensor have the proper |
256 | // shapes. This mimics SparseTensor.__init__ in python/framework/ops.py. |
257 | Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, |
258 | ShapeHandle values_shape, ShapeHandle shape_shape); |
259 | |
260 | Status ValidateVariableResourceHandle( |
261 | InferenceContext* c, std::vector<ShapeAndType>* shape_and_type); |
262 | |
263 | // Shape function for GatherNd operations. |
264 | Status GatherNdShape(InferenceContext* c); |
265 | |
266 | // Helper shape function for ScatterNd.../TensorScatter... operations. |
267 | Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape, |
268 | ShapeHandle updates_shape, ShapeHandle input_shape); |
269 | |
270 | // Shape function for ops with an explicit "shape" attribute. |
271 | Status ExplicitShape(InferenceContext* c); |
272 | |
273 | // Shape function for multiple-output ops with an explicit "shapes" attribute. |
274 | Status ExplicitShapes(InferenceContext* c); |
275 | |
276 | // Shape function for SparseReduceMax and SparseReduceSum. |
277 | Status SparseReduceShapeFn(InferenceContext* c); |
278 | |
279 | // Shape function for QuantizedConv2D op. |
280 | Status QuantizedConv2DShape(InferenceContext* c); |
281 | |
282 | // Shape function for _QuantizedConv2D op/fusion. |
283 | Status FusedQuantizedConv2DShape(InferenceContext* c); |
284 | |
285 | // Shape function for _QuantizedDepthwiseConv2D op/fusion. |
286 | Status FusedQuantizedDepthwiseConv2D(InferenceContext* c); |
287 | |
288 | // Shape function for QuantizedAvgPool op |
289 | Status QuantizedAvgPoolShape(InferenceContext* c); |
290 | |
291 | // Shape function for QuantizeV2 op |
292 | Status QuantizeV2Shape(InferenceContext* c); |
293 | |
294 | // Shape function for ReduceScatter ops |
295 | Status ReduceScatterShape(shape_inference::InferenceContext* c); |
296 | |
297 | } // namespace shape_inference |
298 | |
299 | } // namespace tensorflow |
300 | |
301 | #endif // TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_ |
302 | |