1#pragma once
2
3#include <c10/core/Backend.h>
4#include <c10/util/ArrayRef.h>
5#include <c10/util/Exception.h>
6
7#include <ostream>
8
9// Memory format is not the property of a Tensor. It is the way to tell an
10// operator how the result should be organized in memory and nothing more. That
11// means memory format should never be used as return value for any tensor state
12// interrogation functions (internally and externally).
13//
14// Possible options are:
15// Preserve:
16// If any of the input tensors is in channels_last format, operator output
17// should be in channels_last format
18//
19// Contiguous:
20// Regardless of input tensors format, the output should be contiguous
21// Tensor.
22//
23// ChannelsLast:
24// Regardless of input tensors format, the output should be in channels_last
25// format.
26
27namespace c10 {
28enum class MemoryFormat : int8_t {
29 Contiguous,
30 Preserve,
31 ChannelsLast,
32 ChannelsLast3d,
33 NumOptions
34};
35
36// If you are seeing this, it means that this call site was not checked if
37// the memory format could be preserved, and it was switched to old default
38// behaviour of contiguous
39#define LEGACY_CONTIGUOUS_MEMORY_FORMAT c10::get_contiguous_memory_format()
40
41inline MemoryFormat get_contiguous_memory_format() {
42 return MemoryFormat::Contiguous;
43}
44
45inline std::ostream& operator<<(
46 std::ostream& stream,
47 at::MemoryFormat memory_format) {
48 switch (memory_format) {
49 case MemoryFormat::Preserve:
50 return stream << "Preserve";
51 case MemoryFormat::Contiguous:
52 return stream << "Contiguous";
53 case MemoryFormat::ChannelsLast:
54 return stream << "ChannelsLast";
55 case MemoryFormat::ChannelsLast3d:
56 return stream << "ChannelsLast3d";
57 default:
58 TORCH_CHECK(false, "Unknown memory format ", memory_format);
59 }
60}
61
62// Note: Hardcoded the channel last stride indices here to get better
63// performance
64template <typename T>
65inline std::vector<T> get_channels_last_strides_2d(ArrayRef<T> sizes) {
66 std::vector<T> strides(sizes.size());
67 switch (sizes.size()) {
68 case 4:
69 strides[1] = 1;
70 strides[3] = sizes[1];
71 strides[2] = strides[3] * sizes[3];
72 strides[0] = strides[2] * sizes[2];
73 return strides;
74 case 3:
75 strides[0] = 1;
76 strides[2] = sizes[0];
77 strides[1] = strides[2] * sizes[2];
78 return strides;
79 default:
80 TORCH_INTERNAL_ASSERT(
81 false, "ChannelsLast2d doesn't support size ", sizes.size());
82 }
83}
84
85inline std::vector<int64_t> get_channels_last_strides_2d(IntArrayRef sizes) {
86 return get_channels_last_strides_2d<int64_t>(sizes);
87}
88
89template <typename T>
90std::vector<T> get_channels_last_strides_3d(ArrayRef<T> sizes) {
91 std::vector<T> strides(sizes.size());
92 switch (sizes.size()) {
93 case 5:
94 strides[1] = 1;
95 strides[4] = sizes[1];
96 strides[3] = strides[4] * sizes[4];
97 strides[2] = strides[3] * sizes[3];
98 strides[0] = strides[2] * sizes[2];
99 return strides;
100 case 4:
101 strides[0] = 1;
102 strides[3] = sizes[0];
103 strides[2] = strides[3] * sizes[3];
104 strides[1] = strides[2] * sizes[2];
105 return strides;
106 default:
107 TORCH_INTERNAL_ASSERT(
108 false, "ChannelsLast3d doesn't support size ", sizes.size());
109 }
110}
111
112inline std::vector<int64_t> get_channels_last_strides_3d(IntArrayRef sizes) {
113 return get_channels_last_strides_3d<int64_t>(sizes);
114}
115
116// NOTE:
117// Below are Helper functions for is_channels_last_strides_xd.
118// 1. Please do not combine these helper functions, each helper function handles
119// exactly one case of sizes + memory_format, by doing this, the strides indices
120// will be a constant array and we can access it using constant index number,
121// the compiler will fully unroll the loop on strides indices to gain a better
122// performance.
123// 2. No error check in helper function, caller ensures the correctness of the
124// input
125// 3. All helper functions have similar comments, only 1st helper function is
126// commented here.
127template <typename T>
128inline bool is_channels_last_strides_2d_s4(
129 const ArrayRef<T> sizes,
130 const ArrayRef<T> strides) {
131 T min = 0;
132 // special case for trivial C dimension. default to NCHW
133 if (strides[1] == 0) {
134 return false;
135 }
136 // loop strides indices
137 for (auto& d : {1, 3, 2, 0}) {
138 if (sizes[d] == 0) {
139 return false;
140 }
141 if (strides[d] < min) {
142 return false;
143 }
144 // Fallback to NCHW as default layout for ambiguous cases
145 // This is the flaw of implicit memory_format from strides.
146 // N111 tensor with identical strides for size 1 dimension;
147 // Two cases could lead us here:
148 // a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1])
149 // b. N11W contiguous Tensor sliced on the W-dimension.
150 // ([N,1,1,1]@[W,W,W,W])
151 if (d == 0 && min == strides[1]) {
152 return false;
153 }
154 // This is necessary to:
155 // 1. distinguish the memory_format of N1H1;
156 // [H, 1, 1, 1] channels_last stride
157 // [H, H, 1, 1] contiguous stride
158 // 2. permutation of 1C1W:
159 // [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3)
160 // [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as channels_last
161 min = strides[d];
162 if (sizes[d] > 1) {
163 min *= sizes[d];
164 }
165 }
166 return true;
167}
168
169template <typename T>
170inline bool is_channels_last_strides_3d_s5(
171 const ArrayRef<T> sizes,
172 const ArrayRef<T> strides) {
173 T min = 0;
174 if (strides[1] == 0) {
175 return false;
176 }
177 for (auto& d : {1, 4, 3, 2, 0}) {
178 if (sizes[d] == 0) {
179 return false;
180 }
181 if (strides[d] < min) {
182 return false;
183 }
184 if (d == 0 && min == strides[1]) {
185 return false;
186 }
187 min = strides[d];
188 if (sizes[d] > 1) {
189 min *= sizes[d];
190 }
191 }
192 return true;
193}
194
195// Note [Ambiguous is_channels_last_strides_xd]
196// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
197// The flaw of carrying memory_format implicitly through strides is very hard
198// to WAR properly. issue #24090
199// Without the history of permutation, we can't infer the memory_format of a
200// tensor from the snapshot of its size & stride
201// e.g.
202//
203// 1. We can NOT specify the memory_format of N111 tensor through strides in a
204// meaningful way;
205//
206// 2. Two path that ended up with identical size/stride
207// N11W contiguous tensor sliced at w-dimension becomes [N,1,1,1]@[W,W,W,W]
208// NC11 channels_last tensor sliced at c-dimension becomes [N,1,1,1]@[C,C,C,C]
209// So if we see a tensor [N,1,1,1]@[X,X,X,X], there's no way for us to infer
210// the memory_format of the original tensor.
211//
212// Due to the limitations, our temporary WAR `is_channels_last_strides` does the
213// best effort to infer whether the original memory_format of a tensor is
214// at::MemoryFormat::ChannelsLast. The two objectives of this function (ordered
215// by their importance):
216// 1. Ensure that normal shape manipulation does not accidentally change the
217// MemoryFormat of an existing tensor.
218// 2. Allows user to mark MemoryFormat::ChannelsLast to tensors;
219//
220// The function does so via checking strides of the tensor, including strides of
221// size-1 dimensions. Although conventionally PyTorch implies no restriction on
222// trivial stride (stride for size-1 dimension).
223//
224// Note that this approach is a compromise. We did not solve the problem
225// completely. Many cases we will not be able to infer the correct memory
226// format.
227// The implementation of `is_channels_last_strides` is to serve the objectives:
228// MemoryFormat::ChannelsLast has to be explicitly opted-in (no accidental
229// conversion); Best effort to maintain the ChannelsLast flag.
230//
231// Due to the fact that this is not a bulletproof solution, through testing
232// (aten/src/ATen/test/memory_format_test.cpp)
233// a. we ensure that the common tasks are supported;
234// a. we identify corner cases where the implementation compromises on.
235//
236// By the time accumulated permutation is enabled to replace implicit
237// memory_format through strides, we should be updating our tests and fix the
238// issues in our tests.
239//
240// We use Channels Last 2d as an example above.
241// This is a general problem for all the is_channels_last_strides_xd
242// implementation. Please check the helper functions
243// (is_channels_last_strides_*d_s*) for more details.
244
245template <typename T>
246inline bool is_channels_last_strides_2d(
247 const ArrayRef<T> sizes,
248 const ArrayRef<T> strides) {
249 switch (sizes.size()) {
250 case 4:
251 return is_channels_last_strides_2d_s4(sizes, strides);
252 case 3:
253 // TODO dim == 3 case will be enabled once it is fully tested
254 return false;
255 default:
256 return false;
257 }
258}
259
260template <typename T>
261inline bool is_channels_last_strides_3d(
262 const ArrayRef<T> sizes,
263 const ArrayRef<T> strides) {
264 switch (sizes.size()) {
265 case 5:
266 return is_channels_last_strides_3d_s5(sizes, strides);
267 case 4:
268 // TODO dim == 4 case will be enabled once it is fully tested
269 return false;
270 default:
271 return false;
272 }
273}
274
275inline bool is_channels_last_strides_2d(
276 const IntArrayRef sizes,
277 const IntArrayRef strides) {
278 return is_channels_last_strides_2d<int64_t>(sizes, strides);
279}
280
281inline bool is_channels_last_strides_3d(
282 const IntArrayRef sizes,
283 const IntArrayRef strides) {
284 return is_channels_last_strides_3d<int64_t>(sizes, strides);
285}
286
287} // namespace c10
288