1 | #pragma once |
2 | |
3 | #include <c10/core/Backend.h> |
4 | #include <c10/util/ArrayRef.h> |
5 | #include <c10/util/Exception.h> |
6 | |
7 | #include <ostream> |
8 | |
9 | // Memory format is not the property of a Tensor. It is the way to tell an |
10 | // operator how the result should be organized in memory and nothing more. That |
11 | // means memory format should never be used as return value for any tensor state |
12 | // interrogation functions (internally and externally). |
13 | // |
14 | // Possible options are: |
15 | // Preserve: |
16 | // If any of the input tensors is in channels_last format, operator output |
17 | // should be in channels_last format |
18 | // |
19 | // Contiguous: |
20 | // Regardless of input tensors format, the output should be contiguous |
21 | // Tensor. |
22 | // |
23 | // ChannelsLast: |
24 | // Regardless of input tensors format, the output should be in channels_last |
25 | // format. |
26 | |
27 | namespace c10 { |
28 | enum class MemoryFormat : int8_t { |
29 | Contiguous, |
30 | Preserve, |
31 | ChannelsLast, |
32 | ChannelsLast3d, |
33 | NumOptions |
34 | }; |
35 | |
36 | // If you are seeing this, it means that this call site was not checked if |
37 | // the memory format could be preserved, and it was switched to old default |
38 | // behaviour of contiguous |
39 | #define LEGACY_CONTIGUOUS_MEMORY_FORMAT c10::get_contiguous_memory_format() |
40 | |
41 | inline MemoryFormat get_contiguous_memory_format() { |
42 | return MemoryFormat::Contiguous; |
43 | } |
44 | |
45 | inline std::ostream& operator<<( |
46 | std::ostream& stream, |
47 | at::MemoryFormat memory_format) { |
48 | switch (memory_format) { |
49 | case MemoryFormat::Preserve: |
50 | return stream << "Preserve" ; |
51 | case MemoryFormat::Contiguous: |
52 | return stream << "Contiguous" ; |
53 | case MemoryFormat::ChannelsLast: |
54 | return stream << "ChannelsLast" ; |
55 | case MemoryFormat::ChannelsLast3d: |
56 | return stream << "ChannelsLast3d" ; |
57 | default: |
58 | TORCH_CHECK(false, "Unknown memory format " , memory_format); |
59 | } |
60 | } |
61 | |
62 | // Note: Hardcoded the channel last stride indices here to get better |
63 | // performance |
64 | template <typename T> |
65 | inline std::vector<T> get_channels_last_strides_2d(ArrayRef<T> sizes) { |
66 | std::vector<T> strides(sizes.size()); |
67 | switch (sizes.size()) { |
68 | case 4: |
69 | strides[1] = 1; |
70 | strides[3] = sizes[1]; |
71 | strides[2] = strides[3] * sizes[3]; |
72 | strides[0] = strides[2] * sizes[2]; |
73 | return strides; |
74 | case 3: |
75 | strides[0] = 1; |
76 | strides[2] = sizes[0]; |
77 | strides[1] = strides[2] * sizes[2]; |
78 | return strides; |
79 | default: |
80 | TORCH_INTERNAL_ASSERT( |
81 | false, "ChannelsLast2d doesn't support size " , sizes.size()); |
82 | } |
83 | } |
84 | |
85 | inline std::vector<int64_t> get_channels_last_strides_2d(IntArrayRef sizes) { |
86 | return get_channels_last_strides_2d<int64_t>(sizes); |
87 | } |
88 | |
89 | template <typename T> |
90 | std::vector<T> get_channels_last_strides_3d(ArrayRef<T> sizes) { |
91 | std::vector<T> strides(sizes.size()); |
92 | switch (sizes.size()) { |
93 | case 5: |
94 | strides[1] = 1; |
95 | strides[4] = sizes[1]; |
96 | strides[3] = strides[4] * sizes[4]; |
97 | strides[2] = strides[3] * sizes[3]; |
98 | strides[0] = strides[2] * sizes[2]; |
99 | return strides; |
100 | case 4: |
101 | strides[0] = 1; |
102 | strides[3] = sizes[0]; |
103 | strides[2] = strides[3] * sizes[3]; |
104 | strides[1] = strides[2] * sizes[2]; |
105 | return strides; |
106 | default: |
107 | TORCH_INTERNAL_ASSERT( |
108 | false, "ChannelsLast3d doesn't support size " , sizes.size()); |
109 | } |
110 | } |
111 | |
112 | inline std::vector<int64_t> get_channels_last_strides_3d(IntArrayRef sizes) { |
113 | return get_channels_last_strides_3d<int64_t>(sizes); |
114 | } |
115 | |
116 | // NOTE: |
117 | // Below are Helper functions for is_channels_last_strides_xd. |
118 | // 1. Please do not combine these helper functions, each helper function handles |
119 | // exactly one case of sizes + memory_format, by doing this, the strides indices |
120 | // will be a constant array and we can access it using constant index number, |
121 | // the compiler will fully unroll the loop on strides indices to gain a better |
122 | // performance. |
123 | // 2. No error check in helper function, caller ensures the correctness of the |
124 | // input |
125 | // 3. All helper functions have similar comments, only 1st helper function is |
126 | // commented here. |
127 | template <typename T> |
128 | inline bool is_channels_last_strides_2d_s4( |
129 | const ArrayRef<T> sizes, |
130 | const ArrayRef<T> strides) { |
131 | T min = 0; |
132 | // special case for trivial C dimension. default to NCHW |
133 | if (strides[1] == 0) { |
134 | return false; |
135 | } |
136 | // loop strides indices |
137 | for (auto& d : {1, 3, 2, 0}) { |
138 | if (sizes[d] == 0) { |
139 | return false; |
140 | } |
141 | if (strides[d] < min) { |
142 | return false; |
143 | } |
144 | // Fallback to NCHW as default layout for ambiguous cases |
145 | // This is the flaw of implicit memory_format from strides. |
146 | // N111 tensor with identical strides for size 1 dimension; |
147 | // Two cases could lead us here: |
148 | // a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1]) |
149 | // b. N11W contiguous Tensor sliced on the W-dimension. |
150 | // ([N,1,1,1]@[W,W,W,W]) |
151 | if (d == 0 && min == strides[1]) { |
152 | return false; |
153 | } |
154 | // This is necessary to: |
155 | // 1. distinguish the memory_format of N1H1; |
156 | // [H, 1, 1, 1] channels_last stride |
157 | // [H, H, 1, 1] contiguous stride |
158 | // 2. permutation of 1C1W: |
159 | // [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3) |
160 | // [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as channels_last |
161 | min = strides[d]; |
162 | if (sizes[d] > 1) { |
163 | min *= sizes[d]; |
164 | } |
165 | } |
166 | return true; |
167 | } |
168 | |
169 | template <typename T> |
170 | inline bool is_channels_last_strides_3d_s5( |
171 | const ArrayRef<T> sizes, |
172 | const ArrayRef<T> strides) { |
173 | T min = 0; |
174 | if (strides[1] == 0) { |
175 | return false; |
176 | } |
177 | for (auto& d : {1, 4, 3, 2, 0}) { |
178 | if (sizes[d] == 0) { |
179 | return false; |
180 | } |
181 | if (strides[d] < min) { |
182 | return false; |
183 | } |
184 | if (d == 0 && min == strides[1]) { |
185 | return false; |
186 | } |
187 | min = strides[d]; |
188 | if (sizes[d] > 1) { |
189 | min *= sizes[d]; |
190 | } |
191 | } |
192 | return true; |
193 | } |
194 | |
195 | // Note [Ambiguous is_channels_last_strides_xd] |
196 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
197 | // The flaw of carrying memory_format implicitly through strides is very hard |
198 | // to WAR properly. issue #24090 |
199 | // Without the history of permutation, we can't infer the memory_format of a |
200 | // tensor from the snapshot of its size & stride |
201 | // e.g. |
202 | // |
203 | // 1. We can NOT specify the memory_format of N111 tensor through strides in a |
204 | // meaningful way; |
205 | // |
206 | // 2. Two path that ended up with identical size/stride |
207 | // N11W contiguous tensor sliced at w-dimension becomes [N,1,1,1]@[W,W,W,W] |
208 | // NC11 channels_last tensor sliced at c-dimension becomes [N,1,1,1]@[C,C,C,C] |
209 | // So if we see a tensor [N,1,1,1]@[X,X,X,X], there's no way for us to infer |
210 | // the memory_format of the original tensor. |
211 | // |
212 | // Due to the limitations, our temporary WAR `is_channels_last_strides` does the |
213 | // best effort to infer whether the original memory_format of a tensor is |
214 | // at::MemoryFormat::ChannelsLast. The two objectives of this function (ordered |
215 | // by their importance): |
216 | // 1. Ensure that normal shape manipulation does not accidentally change the |
217 | // MemoryFormat of an existing tensor. |
218 | // 2. Allows user to mark MemoryFormat::ChannelsLast to tensors; |
219 | // |
220 | // The function does so via checking strides of the tensor, including strides of |
221 | // size-1 dimensions. Although conventionally PyTorch implies no restriction on |
222 | // trivial stride (stride for size-1 dimension). |
223 | // |
224 | // Note that this approach is a compromise. We did not solve the problem |
225 | // completely. Many cases we will not be able to infer the correct memory |
226 | // format. |
227 | // The implementation of `is_channels_last_strides` is to serve the objectives: |
228 | // MemoryFormat::ChannelsLast has to be explicitly opted-in (no accidental |
229 | // conversion); Best effort to maintain the ChannelsLast flag. |
230 | // |
231 | // Due to the fact that this is not a bulletproof solution, through testing |
232 | // (aten/src/ATen/test/memory_format_test.cpp) |
233 | // a. we ensure that the common tasks are supported; |
234 | // a. we identify corner cases where the implementation compromises on. |
235 | // |
236 | // By the time accumulated permutation is enabled to replace implicit |
237 | // memory_format through strides, we should be updating our tests and fix the |
238 | // issues in our tests. |
239 | // |
240 | // We use Channels Last 2d as an example above. |
241 | // This is a general problem for all the is_channels_last_strides_xd |
242 | // implementation. Please check the helper functions |
243 | // (is_channels_last_strides_*d_s*) for more details. |
244 | |
245 | template <typename T> |
246 | inline bool is_channels_last_strides_2d( |
247 | const ArrayRef<T> sizes, |
248 | const ArrayRef<T> strides) { |
249 | switch (sizes.size()) { |
250 | case 4: |
251 | return is_channels_last_strides_2d_s4(sizes, strides); |
252 | case 3: |
253 | // TODO dim == 3 case will be enabled once it is fully tested |
254 | return false; |
255 | default: |
256 | return false; |
257 | } |
258 | } |
259 | |
260 | template <typename T> |
261 | inline bool is_channels_last_strides_3d( |
262 | const ArrayRef<T> sizes, |
263 | const ArrayRef<T> strides) { |
264 | switch (sizes.size()) { |
265 | case 5: |
266 | return is_channels_last_strides_3d_s5(sizes, strides); |
267 | case 4: |
268 | // TODO dim == 4 case will be enabled once it is fully tested |
269 | return false; |
270 | default: |
271 | return false; |
272 | } |
273 | } |
274 | |
275 | inline bool is_channels_last_strides_2d( |
276 | const IntArrayRef sizes, |
277 | const IntArrayRef strides) { |
278 | return is_channels_last_strides_2d<int64_t>(sizes, strides); |
279 | } |
280 | |
281 | inline bool is_channels_last_strides_3d( |
282 | const IntArrayRef sizes, |
283 | const IntArrayRef strides) { |
284 | return is_channels_last_strides_3d<int64_t>(sizes, strides); |
285 | } |
286 | |
287 | } // namespace c10 |
288 | |