1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_
17#define TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_
18
19#include <array>
20#include <vector>
21
22#include "absl/strings/string_view.h"
23#include "tensorflow/core/framework/tensor.h"
24#include "tensorflow/core/lib/gtl/array_slice.h"
25#include "tensorflow/core/lib/gtl/inlined_vector.h"
26#include "tensorflow/core/platform/types.h"
27
28namespace tensorflow {
29
30// Tensor format for input/output activations used in convolution operations.
31// The mnemonics specify the meaning of each tensor dimension sorted from
32// largest to smallest memory stride.
33// N = Batch, H = Image Height, W = Image Width, C = Number of Channels.
34// TODO(pauldonnelly): It would probably be better to switch to a registration
35// process for tensor formats, so specialized formats could be defined more
36// locally to where they are used.
37enum TensorFormat {
38 // FORMAT_NHWC is the default format in TensorFlow.
39 FORMAT_NHWC = 0,
40
41 // FORMAT_NCHW often improves performance on GPUs.
42 FORMAT_NCHW = 1,
43
44 // NCHW_VECT_C is the most performant tensor format for cudnn6's quantized
45 // int8 convolution and fused convolution. It is laid out in the same order
46 // as NCHW, except that the size of the Channels dimension is divided by 4,
47 // and a new dimension of size 4 is appended, which packs 4 adjacent channel
48 // activations for the same pixel into an int32. Thus an NCHW format tensor
49 // with dimensions [N, C, H, W] would have dimensions [N, C/4, H, W, 4] in
50 // NCHW_VECT_C format.
51 // A pre-condition of this format is that C must be a multiple of 4.
52 FORMAT_NCHW_VECT_C = 2,
53
54 // Similar to NHWC, but the size of the W dimension is divided by 4, and a
55 // new dimension of size 4 is appended, which packs 4 adjacent activations
56 // in the width dimension.
57 FORMAT_NHWC_VECT_W = 3,
58
59 // Note: although the current code in this file assumes VECT_C and VECT_W
60 // enums imply int8x4 vectors, this should not be relied upon.
61 // In the future we may change the meaning of these enums to include vectors
62 // of other types such as int16x2, with op implementations automatically
63 // determining which format is implied based on the datatype.
64
65 // FORMAT_HWNC is for TPUs.
66 FORMAT_HWNC = 4,
67
68 // FORMAT_HWCN is for TPUs.
69 FORMAT_HWCN = 5,
70};
71
72// Tensor format for convolutional filters.
73// The mnemonics specify the meaning of each tensor dimension sorted
74// from largest to smallest memory stride.
75// H = Kernel Height, W = Kernel Width, I = Input Channels, O = Output Channels.
76// Note: In cudnnGetFilter4dDescriptor(), 'O' is called 'K', 'I' is called 'C'.
77enum FilterTensorFormat {
78 // FORMAT_HWIO is the default filter format in TensorFlow.
79 // Ops that do not have a 'filter_format' attribute will assume this format.
80 FORMAT_HWIO = 0,
81
82 // FORMAT_OIHW often improves performance on GPUs.
83 FORMAT_OIHW = 1,
84
85 // FORMAT_OHWI used by cuDNN for NHWC convolutions.
86 FORMAT_OHWI = 2,
87
88 // OIHW_VECT_I is the most performant tensor format for cudnn6's quantized
89 // int8 convolution and fused convolution. It is analogous to the NCHW_VECT_C
90 // data format. It is laid out in the same order as OIHW, except that the size
91 // of the Input Channels dimension is divided by 4, and a new dimension of
92 // size 4 is appended, which packs 4 adjacent input channel weights into an
93 // int32. Thus an OIHW format filter with dimensions [O, I, H, W] would have
94 // dimensions [O, I/4, H, W, 4] in OIHW_VECT_I format.
95 // A pre-condition of this format is that I must be a multiple of 4.
96 FORMAT_OIHW_VECT_I = 3,
97};
98
99// Parse tensor format from the given string.
100// Return true if the parsing succeeds, and false if it fails.
101bool FormatFromString(absl::string_view format_str, TensorFormat* format);
102
103// Parse tensor format from the given string.
104// Return true if the parsing succeeds, and false if it fails.
105bool FilterFormatFromString(absl::string_view format_str,
106 FilterTensorFormat* format);
107
108// Convert a tensor format into string.
109std::string ToString(TensorFormat format);
110
111// Convert a filter tensor format into string.
112std::string ToString(FilterTensorFormat format);
113
114// Returns the number of spatial dims of a tensor of rank 'num_dims' and tensor
115// format 'format'.
116inline int GetTensorSpatialDims(int num_dims, TensorFormat format) {
117 switch (format) {
118 case FORMAT_NHWC:
119 case FORMAT_NCHW:
120 case FORMAT_HWNC:
121 case FORMAT_HWCN:
122 return num_dims - 2; // Exclude N,C.
123 case FORMAT_NCHW_VECT_C:
124 case FORMAT_NHWC_VECT_W:
125 // Note: the VECT_W is not counted as an independent spatial dim here,
126 // since it just a component of the width dimension.
127 return num_dims - 3; // Exclude N,C,VectDim.
128 default:
129 LOG(FATAL) << "Unknown format " << format;
130 return -1; // Avoid compiler warning about missing return value
131 }
132}
133
134inline int GetFilterTensorSpatialDims(int num_dims, FilterTensorFormat format) {
135 if (format == FORMAT_OIHW_VECT_I) {
136 return num_dims - 3; // Exclude O,I,InnerI.
137 } else {
138 return num_dims - 2; // Exclude O,I.
139 }
140}
141
142// Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and
143// tensor format 'format'. This is the inverse of GetTensorSpatialDims.
144inline int GetTensorDimsFromSpatialDims(int num_spatial_dims,
145 TensorFormat format) {
146 switch (format) {
147 case FORMAT_NHWC:
148 case FORMAT_NCHW:
149 case FORMAT_HWNC:
150 case FORMAT_HWCN:
151 return num_spatial_dims + 2; // Include N,C.
152 case FORMAT_NCHW_VECT_C:
153 case FORMAT_NHWC_VECT_W:
154 return num_spatial_dims + 3; // Include N,C,VectDim.
155 default:
156 LOG(FATAL) << "Unknown format " << format;
157 return -1; // Avoid compiler warning about missing return value
158 }
159}
160
161// Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and
162// filter tensor format 'format'.
163inline int GetFilterTensorDimsFromSpatialDims(int num_spatial_dims,
164 FilterTensorFormat format) {
165 if (format == FORMAT_OIHW_VECT_I) {
166 return num_spatial_dims + 3; // Include O,I,InnerI.
167 } else {
168 return num_spatial_dims + 2; // Include O,I.
169 }
170}
171
172// Returns the index of the batch dimension.
173inline int GetTensorBatchDimIndex(int num_dims, TensorFormat format) {
174 switch (format) {
175 case FORMAT_NHWC:
176 case FORMAT_NCHW:
177 case FORMAT_NCHW_VECT_C:
178 case FORMAT_NHWC_VECT_W:
179 return 0;
180 case FORMAT_HWNC:
181 return num_dims - 2;
182 case FORMAT_HWCN:
183 return num_dims - 1;
184 default:
185 LOG(FATAL) << "Unknown format " << format;
186 return -1; // Avoid compiler warning about missing return value
187 }
188}
189
190// Returns the index of the feature dimension. If format is NCHW_VECT_C, returns
191// the index of the outer feature dimension (i.e. dimension 1, whose size would
192// be num_features / 4 in this case).
193inline int GetTensorFeatureDimIndex(int num_dims, TensorFormat format) {
194 switch (format) {
195 case FORMAT_NHWC:
196 case FORMAT_HWNC:
197 return num_dims - 1;
198 case FORMAT_NHWC_VECT_W:
199 case FORMAT_HWCN:
200 return num_dims - 2;
201 case FORMAT_NCHW:
202 case FORMAT_NCHW_VECT_C:
203 return 1;
204 default:
205 LOG(FATAL) << "Unknown format " << format;
206 return -1; // Avoid compiler warning about missing return value
207 }
208}
209
210// Returns the index of the inner feature dimension.
211inline int GetTensorInnerFeatureDimIndex(int num_dims, TensorFormat format) {
212 DCHECK_EQ(format, FORMAT_NCHW_VECT_C);
213 return num_dims - 1;
214}
215
216// Returns the index of the inner width dimension.
217inline int GetTensorInnerWidthDimIndex(int num_dims, TensorFormat format) {
218 DCHECK_EQ(format, FORMAT_NHWC_VECT_W);
219 return num_dims - 1;
220}
221
222// Returns the dimension index of the specified 'spatial_dim' within an
223// activation tensor. If format is NHWC_VECT_W and spatial_dim is 1, returns
224// the index of the outer width dimension (i.e. dimension 2, whose size would
225// be width / 4 in this case).
226inline int GetTensorSpatialDimIndex(int num_dims, TensorFormat format,
227 int spatial_dim) {
228 CHECK(spatial_dim >= 0 &&
229 spatial_dim < GetTensorSpatialDims(num_dims, format))
230 << spatial_dim << " " << num_dims << " " << ToString(format);
231 switch (format) {
232 case FORMAT_NHWC:
233 case FORMAT_NHWC_VECT_W:
234 return spatial_dim + 1;
235 case FORMAT_NCHW:
236 case FORMAT_NCHW_VECT_C:
237 return spatial_dim + 2;
238 case FORMAT_HWNC:
239 case FORMAT_HWCN:
240 return spatial_dim;
241 default:
242 LOG(FATAL) << "Unknown format " << format;
243 return -1; // Avoid compiler warning about missing return value
244 }
245}
246
247inline int GetFilterTensorSpatialDimIndex(int num_dims,
248 FilterTensorFormat format, int dim) {
249 CHECK(dim >= 0 && dim < GetFilterTensorSpatialDims(num_dims, format))
250 << dim << " " << num_dims << " " << ToString(format);
251 switch (format) {
252 case FORMAT_HWIO:
253 return dim;
254 case FORMAT_OIHW:
255 case FORMAT_OIHW_VECT_I:
256 return dim + 2;
257 default:
258 LOG(FATAL) << "Unknown format " << format;
259 return -1; // Avoid compiler warning about missing return value
260 }
261}
262
263// Returns the index of the inner input channels dimension.
264inline int GetFilterTensorInnerInputChannelsDimIndex(
265 int num_dims, FilterTensorFormat format) {
266 DCHECK_EQ(format, FORMAT_OIHW_VECT_I);
267 return num_dims - 1;
268}
269
270// Returns the index of the input channels dimension.
271// If 'format' is FORMAT_OIHW_VECT_I, returns the dimension index of the
272// outer input channel (i.e. 1), which holds num_input_channels / 4.
273inline int GetFilterTensorInputChannelsDimIndex(int num_dims,
274 FilterTensorFormat format) {
275 switch (format) {
276 case FORMAT_HWIO:
277 return num_dims - 2;
278 case FORMAT_OIHW:
279 case FORMAT_OIHW_VECT_I:
280 return 1;
281 default:
282 LOG(FATAL) << "Unknown format " << format;
283 return -1; // Avoid compiler warning about missing return value
284 }
285}
286
287// Returns the index of the output channels dimension.
288inline int GetFilterTensorOutputChannelsDimIndex(int num_dims,
289 FilterTensorFormat format) {
290 switch (format) {
291 case FORMAT_HWIO:
292 return num_dims - 1;
293 case FORMAT_OIHW:
294 case FORMAT_OIHW_VECT_I:
295 return 0;
296 default:
297 LOG(FATAL) << "Unknown format " << format;
298 return -1; // Avoid compiler warning about missing return value
299 }
300}
301
302// TODO(pauldonnelly): Replace these tensor dimension index functions with
303// constant structs to improve performance and reduce code size in Compute()
304// functions.
305
306// Return the dimension index for the specified 'dimension' of the specified
307// data 'tensor_format'. 'dimension' is a char that can be 'N' (batch size),
308// 'C' (channels), 'H' (height), 'W' (width), or a numbered spatial dimension:
309// '0', .. (NUM_SPATIAL_DIMS-1)..
310// If 'format' is NCHW_VECT_C and 'dimension' is 'C', returns the index of
311// the outer channel dimension (i.e. 1).
312template <int NUM_SPATIAL_DIMS>
313inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
314 if (format == FORMAT_NHWC || format == FORMAT_NHWC_VECT_W) {
315 // clang-format off
316 switch (dimension) {
317 case 'N': return 0;
318 case '0': return 1;
319 case '1': return 2;
320 case '2': return 3;
321 case 'H': return NUM_SPATIAL_DIMS - 1;
322 case 'W': return NUM_SPATIAL_DIMS;
323 case 'C': return NUM_SPATIAL_DIMS + 1;
324 default:
325 LOG(FATAL) << "Invalid dimension: " << dimension;
326 return -1; // Avoid compiler warning about missing return value
327 }
328 } else if (format == FORMAT_NCHW || format == FORMAT_NCHW_VECT_C) {
329 switch (dimension) {
330 case 'N': return 0;
331 case 'C': return 1;
332 case '0': return 2;
333 case '1': return 3;
334 case '2': return 4;
335 case 'H': return NUM_SPATIAL_DIMS;
336 case 'W': return NUM_SPATIAL_DIMS + 1;
337 default:
338 LOG(FATAL) << "Invalid dimension: " << dimension;
339 return -1; // Avoid compiler warning about missing return value
340 }
341 } else if (format == FORMAT_HWNC) {
342 switch (dimension) {
343 case '0': return 0;
344 case '1': return 1;
345 case '2': return 2;
346 case 'H': return NUM_SPATIAL_DIMS - 2;
347 case 'W': return NUM_SPATIAL_DIMS - 1;
348 case 'N': return NUM_SPATIAL_DIMS;
349 case 'C': return NUM_SPATIAL_DIMS + 1;
350 default:
351 LOG(FATAL) << "Invalid dimension: " << dimension;
352 return -1; // Avoid compiler warning about missing return value
353 }
354 } else if (format == FORMAT_HWCN) {
355 switch (dimension) {
356 case '0': return 0;
357 case '1': return 1;
358 case '2': return 2;
359 case 'H': return NUM_SPATIAL_DIMS - 2;
360 case 'W': return NUM_SPATIAL_DIMS - 1;
361 case 'C': return NUM_SPATIAL_DIMS;
362 case 'N': return NUM_SPATIAL_DIMS + 1;
363 default:
364 LOG(FATAL) << "Invalid dimension: " << dimension;
365 return -1; // Avoid compiler warning about missing return value
366 }
367 } else {
368 LOG(FATAL) << "Invalid format: " << static_cast<int>(format);
369 return -1; // Avoid compiler warning about missing return value
370 }
371 // clang-format on
372}
373
374// Return the dimension index for the specified 'dimension' of the specified
375// 'filter_tensor_format'. 'dimension' is a char that can be 'O' (num output
376// channels), 'I' (num input channels), 'H' (height), 'W' (width), or a
377// numbered spatial dimension: '0', .. (NUM_SPATIAL_DIMS-1).
378// If 'format' is OIHW_VECT_I and 'dimension' is 'I', returns the index of the
379// outer input channels dimension (i.e. 1).
380template <int NUM_SPATIAL_DIMS>
381inline int GetFilterDimIndex(FilterTensorFormat filter_tensor_format,
382 char dimension) {
383 // clang-format off
384 if (filter_tensor_format == FORMAT_HWIO) {
385 switch (dimension) {
386 case '0': return 0;
387 case '1': return 1;
388 case '2': return 2;
389 case 'H': return NUM_SPATIAL_DIMS - 2;
390 case 'W': return NUM_SPATIAL_DIMS - 1;
391 case 'I': return NUM_SPATIAL_DIMS;
392 case 'O': return NUM_SPATIAL_DIMS + 1;
393 default:
394 LOG(FATAL) << "Invalid dimension: " << dimension;
395 return -1; // Avoid compiler warning about missing return value
396 }
397 } else if (filter_tensor_format == FORMAT_OIHW ||
398 filter_tensor_format == FORMAT_OIHW_VECT_I) {
399 switch (dimension) {
400 case 'O': return 0;
401 case 'I': return 1;
402 case '0': return 2;
403 case '1': return 3;
404 case '2': return 4;
405 case 'H': return NUM_SPATIAL_DIMS;
406 case 'W': return NUM_SPATIAL_DIMS + 1;
407 default:
408 LOG(FATAL) << "Invalid dimension: " << dimension;
409 return -1; // Avoid compiler warning about missing return value
410 }
411 } else {
412 LOG(FATAL) << "Invalid format: " << static_cast<int>(filter_tensor_format);
413 return -1; // Avoid compiler warning about missing return value
414 }
415 // clang-format on
416}
417
418inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
419 return GetTensorDimIndex<2>(format, dimension);
420}
421
422inline int32 GetTensorDimIndex(TensorFormat format, char dimension,
423 int num_total_dims) {
424 int32_t index = (GetTensorSpatialDims(num_total_dims, format) == 3)
425 ? GetTensorDimIndex<3>(format, dimension)
426 : GetTensorDimIndex<2>(format, dimension);
427 CHECK(index >= 0 && index < num_total_dims) // Crash OK.
428 << "Invalid index from the dimension: " << index << ", " << format << ", "
429 << dimension;
430 return index;
431}
432
433// Return the element from 'dimension_attributes' that corresponds to the
434// specified 'dimension' according to 'tensor_format'.
435template <typename T>
436T GetTensorDim(gtl::ArraySlice<T> dimension_attributes,
437 TensorFormat tensor_format, char dimension) {
438 int index =
439 GetTensorDimIndex(tensor_format, dimension, dimension_attributes.size());
440 return dimension_attributes[index];
441}
442
443// Return the element from 'dimension_attribute' that corresponds to the
444// specified 'dimension' according to 'filter_tensor_format'.
445template <typename T>
446T GetFilterDim(gtl::ArraySlice<T> dimension_attribute,
447 FilterTensorFormat filter_tensor_format, char dimension) {
448 int index = (GetFilterTensorSpatialDims(dimension_attribute.size(),
449 filter_tensor_format) == 3)
450 ? GetFilterDimIndex<3>(filter_tensor_format, dimension)
451 : GetFilterDimIndex<2>(filter_tensor_format, dimension);
452 using size_type = typename gtl::ArraySlice<T>::size_type;
453 CHECK(index >= 0 &&
454 static_cast<size_type>(index) < dimension_attribute.size())
455 << "Invalid index from the dimension: " << index << ", "
456 << filter_tensor_format << ", " << dimension;
457 return dimension_attribute[index];
458}
459
460template <typename T>
461T GetTensorDim(const std::vector<T>& attributes, TensorFormat format,
462 char dimension) {
463 return GetTensorDim(gtl::ArraySlice<T>(attributes), format, dimension);
464}
465
466// Return the size of the specified 'dimension' within 'tensor_shape'
467// according to 'tensor_format'.
468inline int64_t GetTensorDim(const TensorShape& tensor_shape,
469 TensorFormat tensor_format, char dimension) {
470 return GetTensorDim(gtl::ArraySlice<int64_t>(tensor_shape.dim_sizes()),
471 tensor_format, dimension);
472}
473
474// Return the size of the specified 'dimension' within 'tensor_shape'
475// according to 'tensor_filter_format'.
476inline int64_t GetFilterDim(const TensorShape& tensor_shape,
477 FilterTensorFormat tensor_filter_format,
478 char dimension) {
479 return GetFilterDim(gtl::ArraySlice<int64_t>(tensor_shape.dim_sizes()),
480 tensor_filter_format, dimension);
481}
482
483// Return the size of the specified 'dimension' of 'tensor' according to
484// 'tensor_format'.
485inline int64_t GetTensorDim(const Tensor& tensor, TensorFormat tensor_format,
486 char dimension) {
487 return GetTensorDim(tensor.shape(), tensor_format, dimension);
488}
489
490// Return the size of the specified 'dimension' of 'tensor' according to
491// 'filter_tensor_format'.
492inline int64_t GetFilterDim(const Tensor& tensor,
493 FilterTensorFormat filter_tensor_format,
494 char dimension) {
495 return GetFilterDim(tensor.shape(), filter_tensor_format, dimension);
496}
497
498inline void GetExplicitPaddingForDim(
499 const std::vector<int64_t>& explicit_paddings, TensorFormat tensor_format,
500 char dimension, int64_t* padding_before, int64_t* padding_after) {
501 int index =
502 GetTensorDimIndex(tensor_format, dimension, explicit_paddings.size() / 2);
503 *padding_before = explicit_paddings[2 * index];
504 *padding_after = explicit_paddings[2 * index + 1];
505}
506
507// Return the string that specifies the data format for convnet operations.
508std::string GetConvnetDataFormatAttrString();
509std::string GetConvnet3dDataFormatAttrString();
510
511// Return the string that specifies the filter format for convnet operations.
512std::string GetConvnetFilterFormatAttrString();
513std::string GetConvnet3dFilterFormatAttrString();
514std::string GetConvnetDataFormat2D3DAttrString();
515
516// Returns a tensor shape for the specified format and dimension sizes.
517// Works for both 2D and 3D operations. The output shapes are as follows:
518// FORMAT_NHWC: (N, spatial, C); rank = spatial.size() + 2
519// FORMAT_NCHW: (N, C, spatial); rank = spatial.size() + 2
520// FORMAT_NCHW_VECT_C: (N, C, spatial, InnerC); rank = spatial.size() + 3
521// FORMAT_NHWC_VECT_W: (N, spatial, C, InnerW); rank = spatial.size() + 3
522inline TensorShape ShapeFromFormat(TensorFormat format, int64_t N,
523 gtl::ArraySlice<int64_t> spatial,
524 int64_t C) {
525 const int dims = GetTensorDimsFromSpatialDims(spatial.size(), format);
526 gtl::InlinedVector<int64_t, 6> dim_sizes(dims);
527 dim_sizes[GetTensorBatchDimIndex(dims, format)] = N;
528 for (int dim = 0; static_cast<size_t>(dim) < spatial.size(); dim++) {
529 auto dim_size = spatial[dim];
530 if (format == FORMAT_NHWC_VECT_W &&
531 static_cast<size_t>(dim) == spatial.size() - 1) {
532 CHECK_EQ(0, dim_size % 4)
533 << "FORMAT_NHWC_VECT_W requires W to be a multiple of 4, but W="
534 << dim_size;
535 dim_sizes[GetTensorInnerWidthDimIndex(dims, format)] = 4;
536 dim_size /= 4;
537 }
538 dim_sizes[GetTensorSpatialDimIndex(dims, format, dim)] = dim_size;
539 }
540
541 int feature_index = GetTensorFeatureDimIndex(dims, format);
542 if (format == FORMAT_NCHW_VECT_C) {
543 CHECK_EQ(0, C % 4) << "NCHW_VECT_C requires C to be a multiple of 4, but C="
544 << C;
545 C /= 4;
546 dim_sizes[GetTensorInnerFeatureDimIndex(dims, format)] = 4;
547 }
548 dim_sizes[feature_index] = C;
549 return TensorShape(dim_sizes);
550}
551
552// Return a tensor shape of the specified 'format', and dimensions.
553// Works for both 2D and 3D operations. If 'format' is OIHW_VECT_I,
554// the output TensorShape has spatial.size() + 3 dimensions, otherwise
555// it has spatial.size() + 2 dimensions.
556inline TensorShape ShapeFromFilterTensorFormat(FilterTensorFormat format,
557 gtl::ArraySlice<int64_t> spatial,
558 int64_t I, int64_t O) {
559 const int dims = GetFilterTensorDimsFromSpatialDims(spatial.size(), format);
560 gtl::InlinedVector<int64_t, 6> dim_sizes(dims);
561 dim_sizes[GetFilterTensorOutputChannelsDimIndex(dims, format)] = O;
562 for (int dim = 0; static_cast<size_t>(dim) < spatial.size(); dim++) {
563 dim_sizes[GetFilterTensorSpatialDimIndex(dims, format, dim)] = spatial[dim];
564 }
565
566 if (format == FORMAT_OIHW_VECT_I) {
567 CHECK_EQ(0, I % 4) << "OIHW_VECT_I requires I to be a multiple of 4, but I="
568 << I;
569 I /= 4;
570 dim_sizes[GetFilterTensorInnerInputChannelsDimIndex(dims, format)] = 4;
571 }
572 dim_sizes[GetFilterTensorInputChannelsDimIndex(dims, format)] = I;
573 return TensorShape(dim_sizes);
574}
575
576// Return a tensor shape of the specified 'format', and dimensions.
577inline TensorShape ShapeFromFormat(TensorFormat format, int64_t N, int64_t H,
578 int64_t W, int64_t C) {
579 return ShapeFromFormat(format, N, {H, W}, C);
580}
581
582// Return a filter tensor shape of the specified 'format', and dimensions.
583inline TensorShape ShapeFromFilterTensorFormat(FilterTensorFormat format,
584 int64_t H, int64_t W, int64_t I,
585 int64_t O) {
586 return ShapeFromFilterTensorFormat(format, {H, W}, I, O);
587}
588
589// Returns a copy of the specified tensor 'src_shape' converted from
590// 'src_format' to 'dst_format'.
591inline TensorShape ShapeFromFormat(TensorFormat dst_format,
592 const TensorShape& src_shape,
593 TensorFormat src_format) {
594 if (src_format == dst_format) {
595 return src_shape;
596 }
597
598 const int64_t batch = GetTensorDim(src_shape, src_format, 'N');
599 const int64_t channels = GetTensorDim(src_shape, src_format, 'C') *
600 (src_format == FORMAT_NCHW_VECT_C ? 4 : 1);
601 const int num_src_spatial_dims =
602 GetTensorSpatialDims(src_shape.dims(), src_format);
603 std::vector<int64_t> spatial_dims(num_src_spatial_dims);
604 for (int spatial_dim = 0; spatial_dim < num_src_spatial_dims; ++spatial_dim) {
605 spatial_dims[spatial_dim] = gtl::ArraySlice<int64_t>(
606 src_shape.dim_sizes())[GetTensorSpatialDimIndex(
607 src_shape.dims(), src_format, spatial_dim)];
608 }
609 if (src_format == FORMAT_NHWC_VECT_W) {
610 spatial_dims[num_src_spatial_dims - 1] *= 4;
611 }
612 return ShapeFromFormat(dst_format, batch, {spatial_dims}, channels);
613}
614
615// Returns a copy of the specified filter tensor 'src_shape' converted from
616// 'src_filter_format' to 'dst_filter_format'.
617inline TensorShape ShapeFromFilterFormat(FilterTensorFormat dst_filter_format,
618 const TensorShape& src_shape,
619 FilterTensorFormat src_filter_format) {
620 if (src_filter_format == dst_filter_format) {
621 return src_shape;
622 }
623
624 const int64_t output_channels =
625 GetFilterDim(src_shape, src_filter_format, 'O');
626 const int64_t input_channels =
627 GetFilterDim(src_shape, src_filter_format, 'I') *
628 (src_filter_format == FORMAT_OIHW_VECT_I ? 4 : 1);
629
630 if (GetFilterTensorSpatialDims(src_shape.dims(), src_filter_format) == 3) {
631 return ShapeFromFilterTensorFormat(
632 dst_filter_format,
633 {{GetFilterDim(src_shape, src_filter_format, '0'),
634 GetFilterDim(src_shape, src_filter_format, '1'),
635 GetFilterDim(src_shape, src_filter_format, '2')}},
636 input_channels, output_channels);
637 }
638
639 return ShapeFromFilterTensorFormat(
640 dst_filter_format,
641 {{GetFilterDim(src_shape, src_filter_format, 'H'),
642 GetFilterDim(src_shape, src_filter_format, 'W')}},
643 input_channels, output_channels);
644}
645
646} // namespace tensorflow
647
648#endif // TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_
649