1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_ |
17 | #define TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_ |
18 | |
19 | #include <array> |
20 | #include <vector> |
21 | |
22 | #include "absl/strings/string_view.h" |
23 | #include "tensorflow/core/framework/tensor.h" |
24 | #include "tensorflow/core/lib/gtl/array_slice.h" |
25 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
26 | #include "tensorflow/core/platform/types.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | // Tensor format for input/output activations used in convolution operations. |
31 | // The mnemonics specify the meaning of each tensor dimension sorted from |
32 | // largest to smallest memory stride. |
33 | // N = Batch, H = Image Height, W = Image Width, C = Number of Channels. |
34 | // TODO(pauldonnelly): It would probably be better to switch to a registration |
35 | // process for tensor formats, so specialized formats could be defined more |
36 | // locally to where they are used. |
37 | enum TensorFormat { |
38 | // FORMAT_NHWC is the default format in TensorFlow. |
39 | FORMAT_NHWC = 0, |
40 | |
41 | // FORMAT_NCHW often improves performance on GPUs. |
42 | FORMAT_NCHW = 1, |
43 | |
44 | // NCHW_VECT_C is the most performant tensor format for cudnn6's quantized |
45 | // int8 convolution and fused convolution. It is laid out in the same order |
46 | // as NCHW, except that the size of the Channels dimension is divided by 4, |
47 | // and a new dimension of size 4 is appended, which packs 4 adjacent channel |
48 | // activations for the same pixel into an int32. Thus an NCHW format tensor |
49 | // with dimensions [N, C, H, W] would have dimensions [N, C/4, H, W, 4] in |
50 | // NCHW_VECT_C format. |
51 | // A pre-condition of this format is that C must be a multiple of 4. |
52 | FORMAT_NCHW_VECT_C = 2, |
53 | |
54 | // Similar to NHWC, but the size of the W dimension is divided by 4, and a |
55 | // new dimension of size 4 is appended, which packs 4 adjacent activations |
56 | // in the width dimension. |
57 | FORMAT_NHWC_VECT_W = 3, |
58 | |
59 | // Note: although the current code in this file assumes VECT_C and VECT_W |
60 | // enums imply int8x4 vectors, this should not be relied upon. |
61 | // In the future we may change the meaning of these enums to include vectors |
62 | // of other types such as int16x2, with op implementations automatically |
63 | // determining which format is implied based on the datatype. |
64 | |
65 | // FORMAT_HWNC is for TPUs. |
66 | FORMAT_HWNC = 4, |
67 | |
68 | // FORMAT_HWCN is for TPUs. |
69 | FORMAT_HWCN = 5, |
70 | }; |
71 | |
72 | // Tensor format for convolutional filters. |
73 | // The mnemonics specify the meaning of each tensor dimension sorted |
74 | // from largest to smallest memory stride. |
75 | // H = Kernel Height, W = Kernel Width, I = Input Channels, O = Output Channels. |
76 | // Note: In cudnnGetFilter4dDescriptor(), 'O' is called 'K', 'I' is called 'C'. |
77 | enum FilterTensorFormat { |
78 | // FORMAT_HWIO is the default filter format in TensorFlow. |
79 | // Ops that do not have a 'filter_format' attribute will assume this format. |
80 | FORMAT_HWIO = 0, |
81 | |
82 | // FORMAT_OIHW often improves performance on GPUs. |
83 | FORMAT_OIHW = 1, |
84 | |
85 | // FORMAT_OHWI used by cuDNN for NHWC convolutions. |
86 | FORMAT_OHWI = 2, |
87 | |
88 | // OIHW_VECT_I is the most performant tensor format for cudnn6's quantized |
89 | // int8 convolution and fused convolution. It is analogous to the NCHW_VECT_C |
90 | // data format. It is laid out in the same order as OIHW, except that the size |
91 | // of the Input Channels dimension is divided by 4, and a new dimension of |
92 | // size 4 is appended, which packs 4 adjacent input channel weights into an |
93 | // int32. Thus an OIHW format filter with dimensions [O, I, H, W] would have |
94 | // dimensions [O, I/4, H, W, 4] in OIHW_VECT_I format. |
95 | // A pre-condition of this format is that I must be a multiple of 4. |
96 | FORMAT_OIHW_VECT_I = 3, |
97 | }; |
98 | |
99 | // Parse tensor format from the given string. |
100 | // Return true if the parsing succeeds, and false if it fails. |
101 | bool FormatFromString(absl::string_view format_str, TensorFormat* format); |
102 | |
103 | // Parse tensor format from the given string. |
104 | // Return true if the parsing succeeds, and false if it fails. |
105 | bool FilterFormatFromString(absl::string_view format_str, |
106 | FilterTensorFormat* format); |
107 | |
108 | // Convert a tensor format into string. |
109 | std::string ToString(TensorFormat format); |
110 | |
111 | // Convert a filter tensor format into string. |
112 | std::string ToString(FilterTensorFormat format); |
113 | |
114 | // Returns the number of spatial dims of a tensor of rank 'num_dims' and tensor |
115 | // format 'format'. |
116 | inline int GetTensorSpatialDims(int num_dims, TensorFormat format) { |
117 | switch (format) { |
118 | case FORMAT_NHWC: |
119 | case FORMAT_NCHW: |
120 | case FORMAT_HWNC: |
121 | case FORMAT_HWCN: |
122 | return num_dims - 2; // Exclude N,C. |
123 | case FORMAT_NCHW_VECT_C: |
124 | case FORMAT_NHWC_VECT_W: |
125 | // Note: the VECT_W is not counted as an independent spatial dim here, |
126 | // since it just a component of the width dimension. |
127 | return num_dims - 3; // Exclude N,C,VectDim. |
128 | default: |
129 | LOG(FATAL) << "Unknown format " << format; |
130 | return -1; // Avoid compiler warning about missing return value |
131 | } |
132 | } |
133 | |
134 | inline int GetFilterTensorSpatialDims(int num_dims, FilterTensorFormat format) { |
135 | if (format == FORMAT_OIHW_VECT_I) { |
136 | return num_dims - 3; // Exclude O,I,InnerI. |
137 | } else { |
138 | return num_dims - 2; // Exclude O,I. |
139 | } |
140 | } |
141 | |
142 | // Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and |
143 | // tensor format 'format'. This is the inverse of GetTensorSpatialDims. |
144 | inline int GetTensorDimsFromSpatialDims(int num_spatial_dims, |
145 | TensorFormat format) { |
146 | switch (format) { |
147 | case FORMAT_NHWC: |
148 | case FORMAT_NCHW: |
149 | case FORMAT_HWNC: |
150 | case FORMAT_HWCN: |
151 | return num_spatial_dims + 2; // Include N,C. |
152 | case FORMAT_NCHW_VECT_C: |
153 | case FORMAT_NHWC_VECT_W: |
154 | return num_spatial_dims + 3; // Include N,C,VectDim. |
155 | default: |
156 | LOG(FATAL) << "Unknown format " << format; |
157 | return -1; // Avoid compiler warning about missing return value |
158 | } |
159 | } |
160 | |
161 | // Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and |
162 | // filter tensor format 'format'. |
163 | inline int GetFilterTensorDimsFromSpatialDims(int num_spatial_dims, |
164 | FilterTensorFormat format) { |
165 | if (format == FORMAT_OIHW_VECT_I) { |
166 | return num_spatial_dims + 3; // Include O,I,InnerI. |
167 | } else { |
168 | return num_spatial_dims + 2; // Include O,I. |
169 | } |
170 | } |
171 | |
172 | // Returns the index of the batch dimension. |
173 | inline int GetTensorBatchDimIndex(int num_dims, TensorFormat format) { |
174 | switch (format) { |
175 | case FORMAT_NHWC: |
176 | case FORMAT_NCHW: |
177 | case FORMAT_NCHW_VECT_C: |
178 | case FORMAT_NHWC_VECT_W: |
179 | return 0; |
180 | case FORMAT_HWNC: |
181 | return num_dims - 2; |
182 | case FORMAT_HWCN: |
183 | return num_dims - 1; |
184 | default: |
185 | LOG(FATAL) << "Unknown format " << format; |
186 | return -1; // Avoid compiler warning about missing return value |
187 | } |
188 | } |
189 | |
190 | // Returns the index of the feature dimension. If format is NCHW_VECT_C, returns |
191 | // the index of the outer feature dimension (i.e. dimension 1, whose size would |
192 | // be num_features / 4 in this case). |
193 | inline int GetTensorFeatureDimIndex(int num_dims, TensorFormat format) { |
194 | switch (format) { |
195 | case FORMAT_NHWC: |
196 | case FORMAT_HWNC: |
197 | return num_dims - 1; |
198 | case FORMAT_NHWC_VECT_W: |
199 | case FORMAT_HWCN: |
200 | return num_dims - 2; |
201 | case FORMAT_NCHW: |
202 | case FORMAT_NCHW_VECT_C: |
203 | return 1; |
204 | default: |
205 | LOG(FATAL) << "Unknown format " << format; |
206 | return -1; // Avoid compiler warning about missing return value |
207 | } |
208 | } |
209 | |
210 | // Returns the index of the inner feature dimension. |
211 | inline int GetTensorInnerFeatureDimIndex(int num_dims, TensorFormat format) { |
212 | DCHECK_EQ(format, FORMAT_NCHW_VECT_C); |
213 | return num_dims - 1; |
214 | } |
215 | |
216 | // Returns the index of the inner width dimension. |
217 | inline int GetTensorInnerWidthDimIndex(int num_dims, TensorFormat format) { |
218 | DCHECK_EQ(format, FORMAT_NHWC_VECT_W); |
219 | return num_dims - 1; |
220 | } |
221 | |
222 | // Returns the dimension index of the specified 'spatial_dim' within an |
223 | // activation tensor. If format is NHWC_VECT_W and spatial_dim is 1, returns |
224 | // the index of the outer width dimension (i.e. dimension 2, whose size would |
225 | // be width / 4 in this case). |
226 | inline int GetTensorSpatialDimIndex(int num_dims, TensorFormat format, |
227 | int spatial_dim) { |
228 | CHECK(spatial_dim >= 0 && |
229 | spatial_dim < GetTensorSpatialDims(num_dims, format)) |
230 | << spatial_dim << " " << num_dims << " " << ToString(format); |
231 | switch (format) { |
232 | case FORMAT_NHWC: |
233 | case FORMAT_NHWC_VECT_W: |
234 | return spatial_dim + 1; |
235 | case FORMAT_NCHW: |
236 | case FORMAT_NCHW_VECT_C: |
237 | return spatial_dim + 2; |
238 | case FORMAT_HWNC: |
239 | case FORMAT_HWCN: |
240 | return spatial_dim; |
241 | default: |
242 | LOG(FATAL) << "Unknown format " << format; |
243 | return -1; // Avoid compiler warning about missing return value |
244 | } |
245 | } |
246 | |
247 | inline int GetFilterTensorSpatialDimIndex(int num_dims, |
248 | FilterTensorFormat format, int dim) { |
249 | CHECK(dim >= 0 && dim < GetFilterTensorSpatialDims(num_dims, format)) |
250 | << dim << " " << num_dims << " " << ToString(format); |
251 | switch (format) { |
252 | case FORMAT_HWIO: |
253 | return dim; |
254 | case FORMAT_OIHW: |
255 | case FORMAT_OIHW_VECT_I: |
256 | return dim + 2; |
257 | default: |
258 | LOG(FATAL) << "Unknown format " << format; |
259 | return -1; // Avoid compiler warning about missing return value |
260 | } |
261 | } |
262 | |
263 | // Returns the index of the inner input channels dimension. |
264 | inline int GetFilterTensorInnerInputChannelsDimIndex( |
265 | int num_dims, FilterTensorFormat format) { |
266 | DCHECK_EQ(format, FORMAT_OIHW_VECT_I); |
267 | return num_dims - 1; |
268 | } |
269 | |
270 | // Returns the index of the input channels dimension. |
271 | // If 'format' is FORMAT_OIHW_VECT_I, returns the dimension index of the |
272 | // outer input channel (i.e. 1), which holds num_input_channels / 4. |
273 | inline int GetFilterTensorInputChannelsDimIndex(int num_dims, |
274 | FilterTensorFormat format) { |
275 | switch (format) { |
276 | case FORMAT_HWIO: |
277 | return num_dims - 2; |
278 | case FORMAT_OIHW: |
279 | case FORMAT_OIHW_VECT_I: |
280 | return 1; |
281 | default: |
282 | LOG(FATAL) << "Unknown format " << format; |
283 | return -1; // Avoid compiler warning about missing return value |
284 | } |
285 | } |
286 | |
287 | // Returns the index of the output channels dimension. |
288 | inline int GetFilterTensorOutputChannelsDimIndex(int num_dims, |
289 | FilterTensorFormat format) { |
290 | switch (format) { |
291 | case FORMAT_HWIO: |
292 | return num_dims - 1; |
293 | case FORMAT_OIHW: |
294 | case FORMAT_OIHW_VECT_I: |
295 | return 0; |
296 | default: |
297 | LOG(FATAL) << "Unknown format " << format; |
298 | return -1; // Avoid compiler warning about missing return value |
299 | } |
300 | } |
301 | |
302 | // TODO(pauldonnelly): Replace these tensor dimension index functions with |
303 | // constant structs to improve performance and reduce code size in Compute() |
304 | // functions. |
305 | |
306 | // Return the dimension index for the specified 'dimension' of the specified |
307 | // data 'tensor_format'. 'dimension' is a char that can be 'N' (batch size), |
308 | // 'C' (channels), 'H' (height), 'W' (width), or a numbered spatial dimension: |
309 | // '0', .. (NUM_SPATIAL_DIMS-1).. |
310 | // If 'format' is NCHW_VECT_C and 'dimension' is 'C', returns the index of |
311 | // the outer channel dimension (i.e. 1). |
312 | template <int NUM_SPATIAL_DIMS> |
313 | inline int32 GetTensorDimIndex(TensorFormat format, char dimension) { |
314 | if (format == FORMAT_NHWC || format == FORMAT_NHWC_VECT_W) { |
315 | // clang-format off |
316 | switch (dimension) { |
317 | case 'N': return 0; |
318 | case '0': return 1; |
319 | case '1': return 2; |
320 | case '2': return 3; |
321 | case 'H': return NUM_SPATIAL_DIMS - 1; |
322 | case 'W': return NUM_SPATIAL_DIMS; |
323 | case 'C': return NUM_SPATIAL_DIMS + 1; |
324 | default: |
325 | LOG(FATAL) << "Invalid dimension: " << dimension; |
326 | return -1; // Avoid compiler warning about missing return value |
327 | } |
328 | } else if (format == FORMAT_NCHW || format == FORMAT_NCHW_VECT_C) { |
329 | switch (dimension) { |
330 | case 'N': return 0; |
331 | case 'C': return 1; |
332 | case '0': return 2; |
333 | case '1': return 3; |
334 | case '2': return 4; |
335 | case 'H': return NUM_SPATIAL_DIMS; |
336 | case 'W': return NUM_SPATIAL_DIMS + 1; |
337 | default: |
338 | LOG(FATAL) << "Invalid dimension: " << dimension; |
339 | return -1; // Avoid compiler warning about missing return value |
340 | } |
341 | } else if (format == FORMAT_HWNC) { |
342 | switch (dimension) { |
343 | case '0': return 0; |
344 | case '1': return 1; |
345 | case '2': return 2; |
346 | case 'H': return NUM_SPATIAL_DIMS - 2; |
347 | case 'W': return NUM_SPATIAL_DIMS - 1; |
348 | case 'N': return NUM_SPATIAL_DIMS; |
349 | case 'C': return NUM_SPATIAL_DIMS + 1; |
350 | default: |
351 | LOG(FATAL) << "Invalid dimension: " << dimension; |
352 | return -1; // Avoid compiler warning about missing return value |
353 | } |
354 | } else if (format == FORMAT_HWCN) { |
355 | switch (dimension) { |
356 | case '0': return 0; |
357 | case '1': return 1; |
358 | case '2': return 2; |
359 | case 'H': return NUM_SPATIAL_DIMS - 2; |
360 | case 'W': return NUM_SPATIAL_DIMS - 1; |
361 | case 'C': return NUM_SPATIAL_DIMS; |
362 | case 'N': return NUM_SPATIAL_DIMS + 1; |
363 | default: |
364 | LOG(FATAL) << "Invalid dimension: " << dimension; |
365 | return -1; // Avoid compiler warning about missing return value |
366 | } |
367 | } else { |
368 | LOG(FATAL) << "Invalid format: " << static_cast<int>(format); |
369 | return -1; // Avoid compiler warning about missing return value |
370 | } |
371 | // clang-format on |
372 | } |
373 | |
374 | // Return the dimension index for the specified 'dimension' of the specified |
375 | // 'filter_tensor_format'. 'dimension' is a char that can be 'O' (num output |
376 | // channels), 'I' (num input channels), 'H' (height), 'W' (width), or a |
377 | // numbered spatial dimension: '0', .. (NUM_SPATIAL_DIMS-1). |
378 | // If 'format' is OIHW_VECT_I and 'dimension' is 'I', returns the index of the |
379 | // outer input channels dimension (i.e. 1). |
380 | template <int NUM_SPATIAL_DIMS> |
381 | inline int GetFilterDimIndex(FilterTensorFormat filter_tensor_format, |
382 | char dimension) { |
383 | // clang-format off |
384 | if (filter_tensor_format == FORMAT_HWIO) { |
385 | switch (dimension) { |
386 | case '0': return 0; |
387 | case '1': return 1; |
388 | case '2': return 2; |
389 | case 'H': return NUM_SPATIAL_DIMS - 2; |
390 | case 'W': return NUM_SPATIAL_DIMS - 1; |
391 | case 'I': return NUM_SPATIAL_DIMS; |
392 | case 'O': return NUM_SPATIAL_DIMS + 1; |
393 | default: |
394 | LOG(FATAL) << "Invalid dimension: " << dimension; |
395 | return -1; // Avoid compiler warning about missing return value |
396 | } |
397 | } else if (filter_tensor_format == FORMAT_OIHW || |
398 | filter_tensor_format == FORMAT_OIHW_VECT_I) { |
399 | switch (dimension) { |
400 | case 'O': return 0; |
401 | case 'I': return 1; |
402 | case '0': return 2; |
403 | case '1': return 3; |
404 | case '2': return 4; |
405 | case 'H': return NUM_SPATIAL_DIMS; |
406 | case 'W': return NUM_SPATIAL_DIMS + 1; |
407 | default: |
408 | LOG(FATAL) << "Invalid dimension: " << dimension; |
409 | return -1; // Avoid compiler warning about missing return value |
410 | } |
411 | } else { |
412 | LOG(FATAL) << "Invalid format: " << static_cast<int>(filter_tensor_format); |
413 | return -1; // Avoid compiler warning about missing return value |
414 | } |
415 | // clang-format on |
416 | } |
417 | |
418 | inline int32 GetTensorDimIndex(TensorFormat format, char dimension) { |
419 | return GetTensorDimIndex<2>(format, dimension); |
420 | } |
421 | |
422 | inline int32 GetTensorDimIndex(TensorFormat format, char dimension, |
423 | int num_total_dims) { |
424 | int32_t index = (GetTensorSpatialDims(num_total_dims, format) == 3) |
425 | ? GetTensorDimIndex<3>(format, dimension) |
426 | : GetTensorDimIndex<2>(format, dimension); |
427 | CHECK(index >= 0 && index < num_total_dims) // Crash OK. |
428 | << "Invalid index from the dimension: " << index << ", " << format << ", " |
429 | << dimension; |
430 | return index; |
431 | } |
432 | |
433 | // Return the element from 'dimension_attributes' that corresponds to the |
434 | // specified 'dimension' according to 'tensor_format'. |
435 | template <typename T> |
436 | T GetTensorDim(gtl::ArraySlice<T> dimension_attributes, |
437 | TensorFormat tensor_format, char dimension) { |
438 | int index = |
439 | GetTensorDimIndex(tensor_format, dimension, dimension_attributes.size()); |
440 | return dimension_attributes[index]; |
441 | } |
442 | |
443 | // Return the element from 'dimension_attribute' that corresponds to the |
444 | // specified 'dimension' according to 'filter_tensor_format'. |
445 | template <typename T> |
446 | T GetFilterDim(gtl::ArraySlice<T> dimension_attribute, |
447 | FilterTensorFormat filter_tensor_format, char dimension) { |
448 | int index = (GetFilterTensorSpatialDims(dimension_attribute.size(), |
449 | filter_tensor_format) == 3) |
450 | ? GetFilterDimIndex<3>(filter_tensor_format, dimension) |
451 | : GetFilterDimIndex<2>(filter_tensor_format, dimension); |
452 | using size_type = typename gtl::ArraySlice<T>::size_type; |
453 | CHECK(index >= 0 && |
454 | static_cast<size_type>(index) < dimension_attribute.size()) |
455 | << "Invalid index from the dimension: " << index << ", " |
456 | << filter_tensor_format << ", " << dimension; |
457 | return dimension_attribute[index]; |
458 | } |
459 | |
460 | template <typename T> |
461 | T GetTensorDim(const std::vector<T>& attributes, TensorFormat format, |
462 | char dimension) { |
463 | return GetTensorDim(gtl::ArraySlice<T>(attributes), format, dimension); |
464 | } |
465 | |
466 | // Return the size of the specified 'dimension' within 'tensor_shape' |
467 | // according to 'tensor_format'. |
468 | inline int64_t GetTensorDim(const TensorShape& tensor_shape, |
469 | TensorFormat tensor_format, char dimension) { |
470 | return GetTensorDim(gtl::ArraySlice<int64_t>(tensor_shape.dim_sizes()), |
471 | tensor_format, dimension); |
472 | } |
473 | |
474 | // Return the size of the specified 'dimension' within 'tensor_shape' |
475 | // according to 'tensor_filter_format'. |
476 | inline int64_t GetFilterDim(const TensorShape& tensor_shape, |
477 | FilterTensorFormat tensor_filter_format, |
478 | char dimension) { |
479 | return GetFilterDim(gtl::ArraySlice<int64_t>(tensor_shape.dim_sizes()), |
480 | tensor_filter_format, dimension); |
481 | } |
482 | |
483 | // Return the size of the specified 'dimension' of 'tensor' according to |
484 | // 'tensor_format'. |
485 | inline int64_t GetTensorDim(const Tensor& tensor, TensorFormat tensor_format, |
486 | char dimension) { |
487 | return GetTensorDim(tensor.shape(), tensor_format, dimension); |
488 | } |
489 | |
490 | // Return the size of the specified 'dimension' of 'tensor' according to |
491 | // 'filter_tensor_format'. |
492 | inline int64_t GetFilterDim(const Tensor& tensor, |
493 | FilterTensorFormat filter_tensor_format, |
494 | char dimension) { |
495 | return GetFilterDim(tensor.shape(), filter_tensor_format, dimension); |
496 | } |
497 | |
498 | inline void GetExplicitPaddingForDim( |
499 | const std::vector<int64_t>& explicit_paddings, TensorFormat tensor_format, |
500 | char dimension, int64_t* padding_before, int64_t* padding_after) { |
501 | int index = |
502 | GetTensorDimIndex(tensor_format, dimension, explicit_paddings.size() / 2); |
503 | *padding_before = explicit_paddings[2 * index]; |
504 | *padding_after = explicit_paddings[2 * index + 1]; |
505 | } |
506 | |
507 | // Return the string that specifies the data format for convnet operations. |
508 | std::string GetConvnetDataFormatAttrString(); |
509 | std::string GetConvnet3dDataFormatAttrString(); |
510 | |
511 | // Return the string that specifies the filter format for convnet operations. |
512 | std::string GetConvnetFilterFormatAttrString(); |
513 | std::string GetConvnet3dFilterFormatAttrString(); |
514 | std::string GetConvnetDataFormat2D3DAttrString(); |
515 | |
516 | // Returns a tensor shape for the specified format and dimension sizes. |
517 | // Works for both 2D and 3D operations. The output shapes are as follows: |
518 | // FORMAT_NHWC: (N, spatial, C); rank = spatial.size() + 2 |
519 | // FORMAT_NCHW: (N, C, spatial); rank = spatial.size() + 2 |
520 | // FORMAT_NCHW_VECT_C: (N, C, spatial, InnerC); rank = spatial.size() + 3 |
521 | // FORMAT_NHWC_VECT_W: (N, spatial, C, InnerW); rank = spatial.size() + 3 |
522 | inline TensorShape ShapeFromFormat(TensorFormat format, int64_t N, |
523 | gtl::ArraySlice<int64_t> spatial, |
524 | int64_t C) { |
525 | const int dims = GetTensorDimsFromSpatialDims(spatial.size(), format); |
526 | gtl::InlinedVector<int64_t, 6> dim_sizes(dims); |
527 | dim_sizes[GetTensorBatchDimIndex(dims, format)] = N; |
528 | for (int dim = 0; static_cast<size_t>(dim) < spatial.size(); dim++) { |
529 | auto dim_size = spatial[dim]; |
530 | if (format == FORMAT_NHWC_VECT_W && |
531 | static_cast<size_t>(dim) == spatial.size() - 1) { |
532 | CHECK_EQ(0, dim_size % 4) |
533 | << "FORMAT_NHWC_VECT_W requires W to be a multiple of 4, but W=" |
534 | << dim_size; |
535 | dim_sizes[GetTensorInnerWidthDimIndex(dims, format)] = 4; |
536 | dim_size /= 4; |
537 | } |
538 | dim_sizes[GetTensorSpatialDimIndex(dims, format, dim)] = dim_size; |
539 | } |
540 | |
541 | int feature_index = GetTensorFeatureDimIndex(dims, format); |
542 | if (format == FORMAT_NCHW_VECT_C) { |
543 | CHECK_EQ(0, C % 4) << "NCHW_VECT_C requires C to be a multiple of 4, but C=" |
544 | << C; |
545 | C /= 4; |
546 | dim_sizes[GetTensorInnerFeatureDimIndex(dims, format)] = 4; |
547 | } |
548 | dim_sizes[feature_index] = C; |
549 | return TensorShape(dim_sizes); |
550 | } |
551 | |
552 | // Return a tensor shape of the specified 'format', and dimensions. |
553 | // Works for both 2D and 3D operations. If 'format' is OIHW_VECT_I, |
554 | // the output TensorShape has spatial.size() + 3 dimensions, otherwise |
555 | // it has spatial.size() + 2 dimensions. |
556 | inline TensorShape ShapeFromFilterTensorFormat(FilterTensorFormat format, |
557 | gtl::ArraySlice<int64_t> spatial, |
558 | int64_t I, int64_t O) { |
559 | const int dims = GetFilterTensorDimsFromSpatialDims(spatial.size(), format); |
560 | gtl::InlinedVector<int64_t, 6> dim_sizes(dims); |
561 | dim_sizes[GetFilterTensorOutputChannelsDimIndex(dims, format)] = O; |
562 | for (int dim = 0; static_cast<size_t>(dim) < spatial.size(); dim++) { |
563 | dim_sizes[GetFilterTensorSpatialDimIndex(dims, format, dim)] = spatial[dim]; |
564 | } |
565 | |
566 | if (format == FORMAT_OIHW_VECT_I) { |
567 | CHECK_EQ(0, I % 4) << "OIHW_VECT_I requires I to be a multiple of 4, but I=" |
568 | << I; |
569 | I /= 4; |
570 | dim_sizes[GetFilterTensorInnerInputChannelsDimIndex(dims, format)] = 4; |
571 | } |
572 | dim_sizes[GetFilterTensorInputChannelsDimIndex(dims, format)] = I; |
573 | return TensorShape(dim_sizes); |
574 | } |
575 | |
576 | // Return a tensor shape of the specified 'format', and dimensions. |
577 | inline TensorShape ShapeFromFormat(TensorFormat format, int64_t N, int64_t H, |
578 | int64_t W, int64_t C) { |
579 | return ShapeFromFormat(format, N, {H, W}, C); |
580 | } |
581 | |
582 | // Return a filter tensor shape of the specified 'format', and dimensions. |
583 | inline TensorShape ShapeFromFilterTensorFormat(FilterTensorFormat format, |
584 | int64_t H, int64_t W, int64_t I, |
585 | int64_t O) { |
586 | return ShapeFromFilterTensorFormat(format, {H, W}, I, O); |
587 | } |
588 | |
589 | // Returns a copy of the specified tensor 'src_shape' converted from |
590 | // 'src_format' to 'dst_format'. |
591 | inline TensorShape ShapeFromFormat(TensorFormat dst_format, |
592 | const TensorShape& src_shape, |
593 | TensorFormat src_format) { |
594 | if (src_format == dst_format) { |
595 | return src_shape; |
596 | } |
597 | |
598 | const int64_t batch = GetTensorDim(src_shape, src_format, 'N'); |
599 | const int64_t channels = GetTensorDim(src_shape, src_format, 'C') * |
600 | (src_format == FORMAT_NCHW_VECT_C ? 4 : 1); |
601 | const int num_src_spatial_dims = |
602 | GetTensorSpatialDims(src_shape.dims(), src_format); |
603 | std::vector<int64_t> spatial_dims(num_src_spatial_dims); |
604 | for (int spatial_dim = 0; spatial_dim < num_src_spatial_dims; ++spatial_dim) { |
605 | spatial_dims[spatial_dim] = gtl::ArraySlice<int64_t>( |
606 | src_shape.dim_sizes())[GetTensorSpatialDimIndex( |
607 | src_shape.dims(), src_format, spatial_dim)]; |
608 | } |
609 | if (src_format == FORMAT_NHWC_VECT_W) { |
610 | spatial_dims[num_src_spatial_dims - 1] *= 4; |
611 | } |
612 | return ShapeFromFormat(dst_format, batch, {spatial_dims}, channels); |
613 | } |
614 | |
615 | // Returns a copy of the specified filter tensor 'src_shape' converted from |
616 | // 'src_filter_format' to 'dst_filter_format'. |
617 | inline TensorShape ShapeFromFilterFormat(FilterTensorFormat dst_filter_format, |
618 | const TensorShape& src_shape, |
619 | FilterTensorFormat src_filter_format) { |
620 | if (src_filter_format == dst_filter_format) { |
621 | return src_shape; |
622 | } |
623 | |
624 | const int64_t output_channels = |
625 | GetFilterDim(src_shape, src_filter_format, 'O'); |
626 | const int64_t input_channels = |
627 | GetFilterDim(src_shape, src_filter_format, 'I') * |
628 | (src_filter_format == FORMAT_OIHW_VECT_I ? 4 : 1); |
629 | |
630 | if (GetFilterTensorSpatialDims(src_shape.dims(), src_filter_format) == 3) { |
631 | return ShapeFromFilterTensorFormat( |
632 | dst_filter_format, |
633 | {{GetFilterDim(src_shape, src_filter_format, '0'), |
634 | GetFilterDim(src_shape, src_filter_format, '1'), |
635 | GetFilterDim(src_shape, src_filter_format, '2')}}, |
636 | input_channels, output_channels); |
637 | } |
638 | |
639 | return ShapeFromFilterTensorFormat( |
640 | dst_filter_format, |
641 | {{GetFilterDim(src_shape, src_filter_format, 'H'), |
642 | GetFilterDim(src_shape, src_filter_format, 'W')}}, |
643 | input_channels, output_channels); |
644 | } |
645 | |
646 | } // namespace tensorflow |
647 | |
648 | #endif // TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_ |
649 | |