1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_ |
18 | |
19 | #include "tensorflow/core/kernels/eigen_convolution_helpers.h" |
20 | |
21 | // Note this header is used in both TF and TFLite. |
22 | namespace Eigen { |
23 | |
24 | namespace internal { |
25 | |
26 | #if !EIGEN_ALTIVEC_USE_CUSTOM_PACK |
27 | // WARNING: Most of the code here implicitly assumes that the matrix is in |
28 | // ColMajor layout. This is guaranteed by the tensor contraction (see |
29 | // TensorContraction.h). |
30 | // |
31 | // Inside Eigen a tensor contraction is represented by a matrix multiplication. |
32 | // We don't want to actually extract image patches and reshape the result into |
33 | // a matrix (this involves allocating huge extra memory), so the patch |
34 | // extraction and reshape operations are implicit. |
35 | // |
36 | // TensorContractionInputMapper takes a matrix index and returns the coefficient |
37 | // (or the packet) of the "virtual tensor", that would be at that index if we |
38 | // were to actually reshape the result of patch extraction. |
39 | // |
40 | // TensorContractionSubMapper provides a similar view into the "virtual matrix" |
41 | // at the given vertical and horizontal offsets. |
42 | // |
43 | // "Virtual matrix" dimensions: |
44 | // *0: kernelChannels * kernelRows * kernelCols; |
45 | // 1: out_height * out_width; * OTHERS (e.g batches, etc...) |
46 | // |
47 | // *) extracted patches are continuous in memory (innermost dimension assuming |
48 | // col major layout) |
49 | // |
50 | // With this dimensions: |
51 | // row - offset within a single patch (in code: patchId) |
52 | // col - index of the extracted patch (in code: patchIndex) |
53 | // patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions) |
54 | // |
55 | // TODO(ezhulenev): Consolidate this part of the code with the image patch |
56 | // extraction code since they are both very similar. |
57 | |
58 | template <typename NewDimension, Index Rows, Index Cols, typename ArgType, |
59 | typename Device, typename Scalar_, typename Index, |
60 | typename nocontract_t, typename contract_t, int Side, int packet_size, |
61 | bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> |
62 | class TensorContractionInputMapper< |
63 | Scalar_, Index, Side, |
64 | TensorEvaluator< |
65 | const TensorReshapingOp<NewDimension, |
66 | const TensorImagePatchOp<Rows, Cols, ArgType> >, |
67 | Device>, |
68 | nocontract_t, contract_t, packet_size, inner_dim_contiguous, |
69 | inner_dim_reordered, Alignment> { |
70 | public: |
71 | typedef Scalar_ Scalar; |
72 | |
73 | typedef TensorContractionInputMapper< |
74 | Scalar, Index, Side, |
75 | TensorEvaluator< |
76 | const TensorReshapingOp< |
77 | NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, |
78 | Device>, |
79 | nocontract_t, contract_t, packet_size, inner_dim_contiguous, |
80 | inner_dim_reordered, Alignment> |
81 | Self; |
82 | |
83 | typedef TensorContractionSubMapper< |
84 | Scalar, Index, Side, |
85 | TensorEvaluator< |
86 | const TensorReshapingOp< |
87 | NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, |
88 | Device>, |
89 | nocontract_t, contract_t, packet_size, inner_dim_contiguous, |
90 | inner_dim_reordered, Alignment> |
91 | SubMapper; |
92 | |
93 | typedef SubMapper VectorMapper; |
94 | typedef SubMapper LinearMapper; |
95 | typedef typename packet_traits<Scalar>::type Packet; |
96 | |
97 | typedef TensorEvaluator<ArgType, Device> TensorEvaluatorT; |
98 | |
99 | EIGEN_DEVICE_FUNC |
100 | TensorContractionInputMapper( |
101 | const TensorEvaluator< |
102 | const TensorReshapingOp< |
103 | NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, |
104 | Device>& tensor, |
105 | const nocontract_t&, const nocontract_t&, const contract_t&, |
106 | const contract_t&) |
107 | : m_impl(tensor.impl().impl()) { |
108 | Index patch_rows; |
109 | Index patch_depth; |
110 | if (internal::traits<ArgType>::Layout == ColMajor) { |
111 | patch_depth = tensor.impl().dimensions()[0]; |
112 | patch_rows = tensor.impl().dimensions()[1]; |
113 | m_patch_cols = tensor.impl().dimensions()[2]; |
114 | m_num_patches = tensor.impl().dimensions()[3]; |
115 | } else { |
116 | const size_t NumDims = tensor.impl().dimensions().size(); |
117 | patch_depth = tensor.impl().dimensions()[NumDims - 1]; |
118 | patch_rows = tensor.impl().dimensions()[NumDims - 2]; |
119 | m_patch_cols = tensor.impl().dimensions()[NumDims - 3]; |
120 | m_num_patches = tensor.impl().dimensions()[NumDims - 4]; |
121 | } |
122 | |
123 | // Strides for navigating through the single patch. |
124 | m_patch_row_stride = patch_depth; |
125 | m_patch_col_stride = patch_rows * m_patch_row_stride; |
126 | |
127 | m_patch_row_inflate_strides = tensor.impl().rowInflateStride(); |
128 | m_patch_col_inflate_strides = tensor.impl().colInflateStride(); |
129 | |
130 | m_colStride = patch_rows; |
131 | |
132 | m_outputRows = tensor.impl().outputRows(); |
133 | m_outputCols = tensor.impl().outputCols(); |
134 | m_row_strides = tensor.impl().userRowStride(); |
135 | m_col_strides = tensor.impl().userColStride(); |
136 | |
137 | m_in_row_strides = tensor.impl().userInRowStride(); |
138 | m_in_col_strides = tensor.impl().userInColStride(); |
139 | |
140 | if (internal::traits<ArgType>::Layout == ColMajor) { |
141 | m_inputRows = tensor.impl().impl().dimensions()[1]; |
142 | m_inputCols = tensor.impl().impl().dimensions()[2]; |
143 | } else { |
144 | const int NumDims = tensor.impl().impl().dimensions().size(); |
145 | m_inputRows = tensor.impl().impl().dimensions()[NumDims - 2]; |
146 | m_inputCols = tensor.impl().impl().dimensions()[NumDims - 3]; |
147 | } |
148 | |
149 | m_rowInputStride = patch_depth; |
150 | m_colInputStride = patch_depth * m_inputRows; |
151 | m_patchInputStride = patch_depth * m_inputRows * m_inputCols; |
152 | |
153 | m_rowPaddingTop = tensor.impl().rowPaddingTop(); |
154 | m_colPaddingLeft = tensor.impl().colPaddingLeft(); |
155 | |
156 | m_fastPatchRowStride = |
157 | internal::TensorIntDivisor<Index>(m_patch_row_stride); |
158 | m_fastPatchColStride = |
159 | internal::TensorIntDivisor<Index>(m_patch_col_stride); |
160 | m_fastInputRowStride = |
161 | internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides); |
162 | m_fastInputColStride = |
163 | internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides); |
164 | m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches); |
165 | m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride); |
166 | m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows); |
167 | m_fastDimZero = internal::TensorIntDivisor<Index>(patch_depth); |
168 | } |
169 | |
170 | EIGEN_DEVICE_FUNC |
171 | TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper) |
172 | : m_impl(base_mapper.m_impl) { |
173 | m_patch_cols = base_mapper.m_patch_cols; |
174 | m_num_patches = base_mapper.m_num_patches; |
175 | |
176 | m_patch_row_stride = base_mapper.m_patch_row_stride; |
177 | m_patch_col_stride = base_mapper.m_patch_col_stride; |
178 | |
179 | m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides; |
180 | m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides; |
181 | |
182 | m_colStride = base_mapper.m_colStride; |
183 | |
184 | m_rowInputStride = base_mapper.m_rowInputStride; |
185 | m_colInputStride = base_mapper.m_colInputStride; |
186 | m_patchInputStride = base_mapper.m_patchInputStride; |
187 | |
188 | m_inputRows = base_mapper.m_inputRows; |
189 | m_inputCols = base_mapper.m_inputCols; |
190 | |
191 | m_outputRows = base_mapper.m_outputRows; |
192 | m_outputCols = base_mapper.m_outputCols; |
193 | m_row_strides = base_mapper.m_row_strides; |
194 | m_col_strides = base_mapper.m_col_strides; |
195 | |
196 | m_in_row_strides = base_mapper.m_in_row_strides; |
197 | m_in_col_strides = base_mapper.m_in_col_strides; |
198 | |
199 | m_rowPaddingTop = base_mapper.m_rowPaddingTop; |
200 | m_colPaddingLeft = base_mapper.m_colPaddingLeft; |
201 | |
202 | m_fastPatchRowStride = base_mapper.m_fastPatchRowStride; |
203 | m_fastPatchColStride = base_mapper.m_fastPatchColStride; |
204 | m_fastInputRowStride = base_mapper.m_fastInputRowStride; |
205 | m_fastInputColStride = base_mapper.m_fastInputColStride; |
206 | m_fastNumPatches = base_mapper.m_fastNumPatches; |
207 | m_fastColStride = base_mapper.m_fastColStride; |
208 | m_fastOutputRows = base_mapper.m_fastOutputRows; |
209 | m_fastDimZero = base_mapper.m_fastDimZero; |
210 | } |
211 | |
212 | // If true, turns off some optimizations for loading packets since the image |
213 | // patches are "non-standard" such as there are non-trivial strides or |
214 | // inflations in the input. |
215 | EIGEN_DEVICE_FUNC |
216 | EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { |
217 | return m_in_row_strides != 1 || m_in_col_strides != 1 || |
218 | m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1; |
219 | } |
220 | |
221 | EIGEN_DEVICE_FUNC |
222 | EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const { |
223 | return SubMapper(*this, i, j); |
224 | } |
225 | |
226 | EIGEN_DEVICE_FUNC |
227 | EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const { |
228 | return LinearMapper(*this, i, j); |
229 | } |
230 | |
231 | EIGEN_DEVICE_FUNC |
232 | EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const { |
233 | Index rowIndex, colIndex, otherIndex; |
234 | computeBaseIndices(0, rowIndex, colIndex, otherIndex); |
235 | return loadCoeff(row, rowIndex, colIndex, otherIndex); |
236 | } |
237 | |
238 | // Load the coefficient at the patchIndex location instead of the usual |
239 | // m_rowIndex, |
240 | // m_colIndex, m_otherIndex. This is currently only used by the gpu code. |
241 | // EIGEN_DEVICE_FUNC |
242 | EIGEN_DEVICE_FUNC |
243 | EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const { |
244 | Index rowIndex, colIndex, otherIndex; |
245 | computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex); |
246 | return loadCoeff(row, rowIndex, colIndex, otherIndex); |
247 | } |
248 | |
249 | EIGEN_DEVICE_FUNC |
250 | EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const { |
251 | Index rowIndex, colIndex, otherIndex; |
252 | computeBaseIndices(0, rowIndex, colIndex, otherIndex); |
253 | return loadPacket(row, rowIndex, colIndex, otherIndex); |
254 | } |
255 | |
256 | // Load the packet at the patchIndex location instead of the usual m_rowIndex, |
257 | // m_colIndex, m_otherIndex. This is currently only used by the gpu code. |
258 | EIGEN_DEVICE_FUNC |
259 | EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const { |
260 | Index rowIndex, colIndex, otherIndex; |
261 | computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex); |
262 | return loadPacket(row, rowIndex, colIndex, otherIndex); |
263 | } |
264 | |
265 | EIGEN_DEVICE_FUNC |
266 | EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const { |
267 | return m_impl; |
268 | } |
269 | |
270 | EIGEN_DEVICE_FUNC |
271 | EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_rowInputStride; } |
272 | EIGEN_DEVICE_FUNC |
273 | EIGEN_ALWAYS_INLINE Index patchRows() const { return m_colStride; } |
274 | EIGEN_DEVICE_FUNC |
275 | EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; } |
276 | |
277 | private: |
278 | friend class TensorContractionSubMapper< |
279 | Scalar, Index, Side, |
280 | TensorEvaluator< |
281 | const TensorReshapingOp< |
282 | NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, |
283 | Device>, |
284 | nocontract_t, contract_t, packet_size, inner_dim_contiguous, |
285 | inner_dim_reordered, Alignment>; |
286 | |
287 | // Load coefficient from a patch specified by the "within patch offset" |
288 | // (patchId) and the precomputed indices of the first element of the patch. |
289 | EIGEN_DEVICE_FUNC |
290 | EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex, |
291 | Index colIndex, Index otherIndex) const { |
292 | // Find the offset of the element wrt the location of the first element. |
293 | const Index patchOffset = patchId / m_fastDimZero; |
294 | |
295 | const Index colOffset = patchOffset / m_fastColStride; |
296 | const Index inputCol = colIndex + colOffset * m_in_col_strides; |
297 | const Index origInputCol = |
298 | (m_patch_col_inflate_strides == 1) |
299 | ? inputCol |
300 | : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0); |
301 | |
302 | const Index rowOffset = patchOffset - colOffset * m_colStride; |
303 | const Index inputRow = rowIndex + rowOffset * m_in_row_strides; |
304 | const Index origInputRow = |
305 | (m_patch_row_inflate_strides == 1) |
306 | ? inputRow |
307 | : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0); |
308 | if (origInputCol < 0 || origInputRow < 0 || origInputCol >= m_inputCols || |
309 | origInputRow >= m_inputRows || |
310 | (inputCol != origInputCol * m_patch_col_inflate_strides) || |
311 | (inputRow != origInputRow * m_patch_row_inflate_strides)) { |
312 | return Scalar(0); |
313 | } |
314 | const Index depth = patchId - patchOffset * patchDepth(); |
315 | const Index inputIndex = depth + origInputRow * m_rowInputStride + |
316 | origInputCol * m_colInputStride + otherIndex; |
317 | return m_impl.coeff(inputIndex); |
318 | } |
319 | |
320 | // This is the same as loadCoeff(...), but optimized for all `inflate_strides` |
321 | // and `in_strides` equal to 1 (template specialization without templates). |
322 | EIGEN_DEVICE_FUNC |
323 | EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex, |
324 | Index colIndex, |
325 | Index otherIndex) const { |
326 | eigen_assert(!nonStandardPatches()); |
327 | |
328 | // Find the offset of the element wrt the location of the first element. |
329 | const Index patchOffset = patchId / m_fastDimZero; |
330 | const Index colOffset = patchOffset / m_fastColStride; |
331 | const Index rowOffset = patchOffset - colOffset * m_colStride; |
332 | const Index inputCol = colIndex + colOffset; |
333 | const Index inputRow = rowIndex + rowOffset; |
334 | if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 || |
335 | inputRow >= m_inputRows) { |
336 | return Scalar(0); |
337 | } |
338 | const Index depth = patchId - patchOffset * patchDepth(); |
339 | const Index inputIndex = depth + inputRow * m_rowInputStride + |
340 | inputCol * m_colInputStride + otherIndex; |
341 | return m_impl.coeff(inputIndex); |
342 | } |
343 | |
344 | // Load packet from a patch specified by the "within patch offset" |
345 | // (patchId) and the precomputed indices of the first element of the patch. |
346 | EIGEN_DEVICE_FUNC |
347 | EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex, |
348 | Index colIndex, |
349 | Index otherIndex) const { |
350 | const Index packetSize = internal::unpacket_traits<Packet>::size; |
351 | EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) |
352 | eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); |
353 | |
354 | if (nonStandardPatches()) { |
355 | return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); |
356 | } |
357 | typedef decltype(m_impl) TensorEvaluatorT; |
358 | return loadPacketStandard<Packet, TensorEvaluatorT>(patchId, rowIndex, |
359 | colIndex, otherIndex); |
360 | } |
361 | |
362 | // Helper function to load a 'partial' packet - this is the single column |
363 | // part of a packet that is split across two columns. In the 'partial' packet, |
364 | // the elements corresponding to the column (specified through colOffset) are |
365 | // loaded and the rest of the elements are zero-filled into the 'partial' |
366 | // packet. This function is called from loadPacketStandardFromTwoColumns(). |
367 | // This code path is exercised only when the packet type supports masked load |
368 | // and when the partial packet load is available in the TensorEvaluator. |
369 | EIGEN_DEVICE_FUNC |
370 | EIGEN_ALWAYS_INLINE Packet loadPartialPacketStandard( |
371 | Index rowIndex, Index colIndex, Index otherIndex, Index patchId, |
372 | const Index span[], const Index patchOffsets[], Index colOffset) const { |
373 | const Index inputCol = colIndex + colOffset; |
374 | const Index rowOffsets[2] = {patchOffsets[0] - colOffset * m_colStride, |
375 | patchOffsets[1] - colOffset * m_colStride}; |
376 | const Index inputRows[2] = {rowIndex + rowOffsets[0], |
377 | rowIndex + rowOffsets[1]}; |
378 | |
379 | if (inputRows[0] >= m_inputRows || inputRows[1] < 0 || |
380 | inputCol >= m_inputCols || inputCol < 0) { |
381 | // Partial packet is all zeros |
382 | return internal::pset1<Packet>(Scalar(0)); |
383 | } else if (inputRows[0] >= 0 && inputRows[1] < m_inputRows) { |
384 | // From inputIndex-span[0], we need to load elements starting from index |
385 | // span[0] all the way upto (and including) span[1]. |
386 | const Index depth = patchId - patchOffsets[0] * patchDepth(); |
387 | const Index inputIndex = depth + inputRows[0] * m_rowInputStride + |
388 | inputCol * m_colInputStride + otherIndex; |
389 | return m_impl.template partialPacket<Packet>( |
390 | inputIndex - span[0], mask<Packet>(span[0], span[1] + 1)); |
391 | } else { |
392 | // Using slow path for this partial packet. |
393 | // We need to load elements starting from index span[0] all the way upto |
394 | // (and including) span[1]. We split this load into 3 parts: |
395 | // 0 : span[0]-1 - Zeros will be loaded for these indices |
396 | // span[0] : span[1] - Elements will be loaded here for these indices |
397 | // span[1]+1 : packetSize-1 - Zeross will be loaded for these indices |
398 | const Index packetSize = internal::unpacket_traits<Packet>::size; |
399 | EIGEN_ALIGN_MAX |
400 | std::remove_const_t<Scalar> values[packetSize]; |
401 | for (int i = 0; i < span[0]; ++i) values[i] = Scalar(0); |
402 | for (int i = span[0]; i < span[1] + 1; ++i) |
403 | values[i] = |
404 | loadCoeff(patchId - span[0] + i, rowIndex, colIndex, otherIndex); |
405 | for (int i = span[1] + 1; i < packetSize; ++i) values[i] = Scalar(0); |
406 | return internal::pload<Packet>(values); |
407 | } |
408 | } |
409 | |
410 | // Helper function to load a packet that is split across two columns. |
411 | // If required, this function is called from loadPacketStandard() when the |
412 | // packet type supports masked load and when the partial packet load is |
413 | // available in the TensorEvaluator. |
414 | EIGEN_DEVICE_FUNC |
415 | EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromTwoColumns( |
416 | Index patchId, Index rowIndex, Index colIndex, Index otherIndex, |
417 | const Index patchOffsets[], const Index colOffsets[]) const { |
418 | eigen_assert(colOffsets[1] == colOffsets[0] + 1); |
419 | const Index packetSize = internal::unpacket_traits<Packet>::size; |
420 | |
421 | // Packet to load will be split into 2 parts where each part spans a single |
422 | // column. First determine where to split. |
423 | const Index patchIdSplit = |
424 | ((colOffsets[1] * m_colStride) * m_rowInputStride) - 1; |
425 | const Index patchOffsetSplit = patchIdSplit / m_fastDimZero; |
426 | |
427 | // patchIds[i]: patchId corresponding to partial packet i |
428 | // spans[i]: Start and end indices corresponding to the elements |
429 | // to be loaded for partial packet i |
430 | // patchOffsets2Cols[i]: patchOffsets corresponding to partial packet i |
431 | const Index patchIds[2] = {patchId, patchIdSplit + 1}; |
432 | const Index spans[2][2] = {{0, patchIdSplit - patchId}, |
433 | {patchIdSplit - patchId + 1, packetSize - 1}}; |
434 | const Index patchOffsets2Cols[2][2] = { |
435 | {patchOffsets[0], patchOffsetSplit}, |
436 | {patchOffsetSplit + 1, patchOffsets[1]}}; |
437 | |
438 | // Load partial packets and do bit-wise OR to generate required packet |
439 | return internal::por<Packet>( |
440 | loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[0], |
441 | spans[0], patchOffsets2Cols[0], |
442 | colOffsets[0]), |
443 | loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[1], |
444 | spans[1], patchOffsets2Cols[1], |
445 | colOffsets[1])); |
446 | } |
447 | |
448 | // Helper function to load a packet that is present in a single columns. |
449 | // If required, this function is called from loadPacketStandard(). |
450 | EIGEN_DEVICE_FUNC |
451 | EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumn( |
452 | Index patchId, Index rowIndex, Index colIndex, Index otherIndex, |
453 | const Index patchOffsets[], const Index colOffsets[], |
454 | const Index inputCols[]) const { |
455 | eigen_assert(colOffsets[0] == colOffsets[1]); |
456 | const Index rowOffsets[2] = {patchOffsets[0] - colOffsets[0] * m_colStride, |
457 | patchOffsets[1] - colOffsets[1] * m_colStride}; |
458 | eigen_assert(rowOffsets[0] <= rowOffsets[1]); |
459 | const Index inputRows[2] = {rowIndex + rowOffsets[0], |
460 | rowIndex + rowOffsets[1]}; |
461 | |
462 | if (inputRows[0] >= m_inputRows || inputRows[1] < 0) { |
463 | // all zeros |
464 | return internal::pset1<Packet>(Scalar(0)); // all zeros |
465 | } |
466 | |
467 | if (inputRows[0] >= 0 && inputRows[1] < m_inputRows) { |
468 | // no padding |
469 | const Index depth = patchId - patchOffsets[0] * patchDepth(); |
470 | const Index inputIndex = depth + inputRows[0] * m_rowInputStride + |
471 | inputCols[0] * m_colInputStride + otherIndex; |
472 | return m_impl.template packet<Unaligned>(inputIndex); |
473 | } |
474 | return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); |
475 | } |
476 | |
477 | // Load standard packet from a patch specified by the "within patch offset" |
478 | // (patchId) and the precomputed indices of the first element of the patch. |
479 | // This function will be called if partial packet loading is not available |
480 | // for the TensorEvaluator or if the packet type does not support masked |
481 | // load. |
482 | template <typename PacketT, typename TensorEvaluatorT> |
483 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if< |
484 | !TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, |
485 | PacketT>::type |
486 | loadPacketStandard(Index patchId, Index rowIndex, Index colIndex, |
487 | Index otherIndex) const { |
488 | const Index packetSize = internal::unpacket_traits<Packet>::size; |
489 | EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) |
490 | eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); |
491 | |
492 | eigen_assert(!nonStandardPatches()); |
493 | |
494 | if ((patchDepth() % packetSize) == 0) { |
495 | return loadPacketFast(patchId, rowIndex, colIndex, otherIndex); |
496 | } |
497 | |
498 | // Offsets and input calculation here are identical to |
499 | // loadCoeffStandard(...), but repeated twice. |
500 | const Index patchOffsets[2] = {patchId / m_fastDimZero, |
501 | (patchId + packetSize - 1) / m_fastDimZero}; |
502 | const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, |
503 | patchOffsets[1] / m_fastColStride}; |
504 | const Index inputCols[2] = {colIndex + colOffsets[0], |
505 | colIndex + colOffsets[1]}; |
506 | |
507 | if (inputCols[0] >= m_inputCols || inputCols[1] < 0) { |
508 | // all zeros |
509 | return internal::pset1<Packet>(Scalar(0)); |
510 | } |
511 | if (inputCols[0] == inputCols[1]) { |
512 | return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex, |
513 | otherIndex, patchOffsets, |
514 | colOffsets, inputCols); |
515 | } |
516 | return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); |
517 | } |
518 | |
519 | // Load standard packet from a patch specified by the "within patch offset" |
520 | // (patchId) and the precomputed indices of the first element of the patch. |
521 | // This function will be called if partial packet loading is available for |
522 | // the TensorEvaluator and if the packet type supports masked load. |
523 | // The only difference between this and the other case is that if the packet |
524 | // to load is split across two columns, then in this case instead of going to |
525 | // the slow (element-by-element) load, we load two packets - each containing |
526 | // elements from one of the columns (rest of the elements of the packets are |
527 | // zeroes), and then combine these two packets to generate the required |
528 | // packet. The idea is to enable fast load (if possible) of these 'partial' |
529 | // packets. |
530 | template <typename PacketT, typename TensorEvaluatorT> |
531 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if< |
532 | TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, |
533 | PacketT>::type |
534 | loadPacketStandard(Index patchId, Index rowIndex, Index colIndex, |
535 | Index otherIndex) const { |
536 | const Index packetSize = internal::unpacket_traits<PacketT>::size; |
537 | EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) |
538 | eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); |
539 | |
540 | eigen_assert(!nonStandardPatches()); |
541 | |
542 | if ((patchDepth() % packetSize) == 0) { |
543 | return loadPacketFast(patchId, rowIndex, colIndex, otherIndex); |
544 | } |
545 | |
546 | // Offsets and input calculation here are identical to |
547 | // loadCoeffStandard(...), but repeated twice. |
548 | const Index patchOffsets[2] = {patchId / m_fastDimZero, |
549 | (patchId + packetSize - 1) / m_fastDimZero}; |
550 | const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, |
551 | patchOffsets[1] / m_fastColStride}; |
552 | const Index inputCols[2] = {colIndex + colOffsets[0], |
553 | colIndex + colOffsets[1]}; |
554 | |
555 | if (inputCols[0] >= m_inputCols || inputCols[1] < 0) { |
556 | // all zeros |
557 | return internal::pset1<PacketT>(Scalar(0)); |
558 | } |
559 | if (inputCols[0] == inputCols[1]) { |
560 | return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex, |
561 | otherIndex, patchOffsets, |
562 | colOffsets, inputCols); |
563 | } |
564 | if (inputCols[1] == inputCols[0] + 1) { |
565 | return loadPacketStandardFromTwoColumns( |
566 | patchId, rowIndex, colIndex, otherIndex, patchOffsets, colOffsets); |
567 | } |
568 | return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); |
569 | } |
570 | |
571 | EIGEN_DEVICE_FUNC |
572 | EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index rowIndex, |
573 | Index colIndex, |
574 | Index otherIndex) const { |
575 | const Index packetSize = internal::unpacket_traits<Packet>::size; |
576 | EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) |
577 | eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); |
578 | |
579 | eigen_assert(!nonStandardPatches()); |
580 | eigen_assert((patchDepth() % packetSize) == 0); |
581 | // Find the offset of the element wrt the location of the first element. |
582 | const Index patchOffset = patchId / m_fastDimZero; |
583 | eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset); |
584 | |
585 | const Index colOffset = patchOffset / m_fastColStride; |
586 | const Index rowOffset = patchOffset - colOffset * m_colStride; |
587 | const Index inputCol = colIndex + colOffset; |
588 | const Index inputRow = rowIndex + rowOffset; |
589 | if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols || |
590 | inputRow >= m_inputRows) { |
591 | // all zeros |
592 | return internal::pset1<Packet>(Scalar(0)); |
593 | } |
594 | // no padding |
595 | const Index depth = patchId - patchOffset * patchDepth(); |
596 | const Index inputIndex = depth + inputRow * m_rowInputStride + |
597 | inputCol * m_colInputStride + otherIndex; |
598 | return m_impl.template packet<Unaligned>(inputIndex); |
599 | } |
600 | |
601 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetWithPossibleZero( |
602 | Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const { |
603 | const int packetSize = internal::unpacket_traits<Packet>::size; |
604 | EIGEN_ALIGN_MAX |
605 | std::remove_const_t<Scalar> values[packetSize]; |
606 | for (int i = 0; i < packetSize; ++i) { |
607 | values[i] = loadCoeff(patchId + i, rowIndex, colIndex, otherIndex); |
608 | } |
609 | Packet rslt = internal::pload<Packet>(values); |
610 | return rslt; |
611 | } |
612 | |
613 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices( |
614 | Index patchIndex, Index& rowIndex, Index& colIndex, |
615 | Index& otherIndex) const { |
616 | const size_t NumInputDims = array_size< |
617 | typename TensorEvaluator<ArgType, Device>::Dimensions>::value; |
618 | otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches; |
619 | const Index patch2DIndex = (NumInputDims == 3) |
620 | ? patchIndex |
621 | : (patchIndex - otherIndex * m_num_patches); |
622 | otherIndex *= m_patchInputStride; |
623 | colIndex = patch2DIndex / m_fastOutputRows; |
624 | rowIndex = patch2DIndex - colIndex * m_outputRows; |
625 | colIndex = colIndex * m_col_strides - m_colPaddingLeft; |
626 | rowIndex = rowIndex * m_row_strides - m_rowPaddingTop; |
627 | } |
628 | |
629 | Index m_patch_cols; // number of columns in the patch |
630 | Index m_num_patches; // number of patches to extract. |
631 | |
632 | // Strides for navigating through the single patch. |
633 | Index m_patch_row_stride; |
634 | Index m_patch_col_stride; |
635 | internal::TensorIntDivisor<Index> m_fastPatchRowStride; |
636 | internal::TensorIntDivisor<Index> m_fastPatchColStride; |
637 | |
638 | Index m_patch_row_inflate_strides; // the strides for row inflation in the |
639 | // image patch |
640 | Index m_patch_col_inflate_strides; // the strides for col inflation in the |
641 | // image patch |
642 | // Fast representation of inflation strides. |
643 | internal::TensorIntDivisor<Index> m_fastInputRowStride; |
644 | internal::TensorIntDivisor<Index> m_fastInputColStride; |
645 | |
646 | Index m_otherStride; |
647 | Index m_colStride; |
648 | internal::TensorIntDivisor<Index> m_fastNumPatches; |
649 | internal::TensorIntDivisor<Index> m_fastColStride; |
650 | |
651 | Index m_rowInputStride; // row stride in the input tensor |
652 | Index m_colInputStride; // col stride in the input tensor |
653 | Index m_patchInputStride; // patch stride in the input tensor |
654 | |
655 | Index m_inputRows; // Number of rows in the input tensor |
656 | Index m_inputCols; // Number of cols in the input tensor |
657 | |
658 | Index m_outputRows; // Number of convolution output rows |
659 | Index m_outputCols; // Number of convolution output column |
660 | |
661 | Index m_row_strides; // User specified row stride |
662 | Index m_col_strides; // User specified col stride |
663 | |
664 | Index m_in_row_strides; // User specified input row stride |
665 | Index m_in_col_strides; // User specified input col stride |
666 | |
667 | Index m_rowPaddingTop; // Row padding |
668 | Index m_colPaddingLeft; // Column padding |
669 | |
670 | internal::TensorIntDivisor<Index> m_fastOutputRows; |
671 | internal::TensorIntDivisor<Index> m_fastDimZero; |
672 | |
673 | const TensorEvaluator<ArgType, Device> m_impl; |
674 | }; |
675 | |
676 | template <typename NewDimension, Index Rows, Index Cols, typename ArgType, |
677 | typename Device, typename Scalar, typename Index, |
678 | typename nocontract_t, typename contract_t, int Side, int packet_size, |
679 | bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> |
680 | class TensorContractionSubMapper< |
681 | Scalar, Index, Side, |
682 | TensorEvaluator< |
683 | const TensorReshapingOp<NewDimension, |
684 | const TensorImagePatchOp<Rows, Cols, ArgType> >, |
685 | Device>, |
686 | nocontract_t, contract_t, packet_size, inner_dim_contiguous, |
687 | inner_dim_reordered, Alignment> { |
688 | public: |
689 | typedef typename packet_traits<Scalar>::type Packet; |
690 | typedef typename packet_traits<Scalar>::half HalfPacket; |
691 | |
692 | typedef TensorContractionInputMapper< |
693 | Scalar, Index, Side, |
694 | TensorEvaluator< |
695 | const TensorReshapingOp< |
696 | NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, |
697 | Device>, |
698 | nocontract_t, contract_t, packet_size, inner_dim_contiguous, |
699 | inner_dim_reordered, Alignment> |
700 | ParentMapper; |
701 | |
702 | typedef TensorContractionSubMapper< |
703 | Scalar, Index, Side, |
704 | TensorEvaluator< |
705 | const TensorReshapingOp< |
706 | NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, |
707 | Device>, |
708 | nocontract_t, contract_t, packet_size, inner_dim_contiguous, |
709 | inner_dim_reordered, Alignment> |
710 | Self; |
711 | |
712 | typedef Self LinearMapper; |
713 | |
714 | typedef typename ParentMapper::TensorEvaluatorT TensorEvaluatorT; |
715 | |
716 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( |
717 | const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset) |
718 | : m_depth_offset(vert_offset), |
719 | m_col_offset(horiz_offset), |
720 | m_base_mapper(base_mapper) { |
721 | m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, |
722 | m_otherIndex); |
723 | } |
724 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( |
725 | const Self& base_mapper, Index vert_offset, Index horiz_offset) |
726 | : m_depth_offset(vert_offset + base_mapper.m_depth_offset), |
727 | m_col_offset(horiz_offset + base_mapper.m_col_offset), |
728 | m_base_mapper(base_mapper.m_base_mapper) { |
729 | m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, |
730 | m_otherIndex); |
731 | } |
732 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { |
733 | return m_base_mapper.loadCoeff(i + m_depth_offset, m_rowIndex, m_colIndex, |
734 | m_otherIndex); |
735 | } |
736 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, |
737 | Index j) const { |
738 | return m_base_mapper(i + m_depth_offset, j + m_col_offset); |
739 | } |
740 | |
741 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const { |
742 | return m_base_mapper.loadPacket(i + m_depth_offset, m_rowIndex, m_colIndex, |
743 | m_otherIndex); |
744 | } |
745 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, |
746 | Index j) const { |
747 | return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset, |
748 | j + m_col_offset); |
749 | } |
750 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar |
751 | loadCoeffStandard(Index i) const { |
752 | return m_base_mapper.loadCoeffStandard(i + m_depth_offset, m_rowIndex, |
753 | m_colIndex, m_otherIndex); |
754 | } |
755 | |
756 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const { |
757 | return m_base_mapper.loadPacketFast(i + m_depth_offset, m_rowIndex, |
758 | m_colIndex, m_otherIndex); |
759 | } |
760 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet |
761 | loadPacketStandard(Index i) const { |
762 | typedef decltype(m_base_mapper.m_impl) TensorEvaluatorT; |
763 | return m_base_mapper.template loadPacketStandard<Packet, TensorEvaluatorT>( |
764 | i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex); |
765 | } |
766 | template <typename Packet> |
767 | EIGEN_DEVICE_FUNC bool aligned(Index) const { |
768 | return false; |
769 | } |
770 | |
771 | EIGEN_DEVICE_FUNC |
772 | EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { |
773 | return m_base_mapper.nonStandardPatches(); |
774 | } |
775 | |
776 | // Max(Col|Row|Depth): compute the upper limit for the column, row and depth |
777 | // index respectively that fits into the peeled_k elements starting at |
778 | // m_depth_offset. |
779 | |
780 | EIGEN_DEVICE_FUNC |
781 | EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const { |
782 | const Index max_col = |
783 | (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1)) / |
784 | fastPatchColStride(); |
785 | return std::min<Index>(1 + max_col, patchCols()); |
786 | } |
787 | |
788 | EIGEN_DEVICE_FUNC |
789 | EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k, |
790 | const Index col) const { |
791 | const Index max_row = (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1) - |
792 | col * patchColStride()) / |
793 | fastPatchRowStride(); |
794 | return std::min<Index>(1 + max_row, patchRows()); |
795 | } |
796 | |
797 | EIGEN_DEVICE_FUNC |
798 | EIGEN_ALWAYS_INLINE Index maxDepth(const Index peeled_k, const Index col, |
799 | Index row) const { |
800 | const Index max_depth = m_depth_offset + peeled_k - // |
801 | col * patchColStride() - // |
802 | row * patchRowStride(); |
803 | return std::min<Index>(max_depth, patchDepth()); |
804 | } |
805 | |
806 | // MaxDepth uses only the remaining number of elements in the peeled_k. |
807 | EIGEN_DEVICE_FUNC |
808 | EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements, |
809 | const Index start_depth) const { |
810 | return std::min<Index>(start_depth + num_elements, patchDepth()); |
811 | } |
812 | |
813 | // Every register matters in this code, so sometimes to prevent register |
814 | // spilling, instead of the variable that you would expect to see, we use |
815 | // another one, that is guaranteed to have the same value. E.g. patch depth is |
816 | // always the same as input depth, and it's also the same as input row stride. |
817 | // Bunch of other parameters have similar relations. |
818 | |
819 | typedef internal::TensorIntDivisor<Index> IndexDivisor; |
820 | |
821 | EIGEN_DEVICE_FUNC |
822 | EIGEN_ALWAYS_INLINE Index patchDepth() const { |
823 | return m_base_mapper.m_rowInputStride; |
824 | } |
825 | EIGEN_DEVICE_FUNC |
826 | EIGEN_ALWAYS_INLINE Index patchRows() const { |
827 | return m_base_mapper.m_colStride; |
828 | } |
829 | EIGEN_DEVICE_FUNC |
830 | EIGEN_ALWAYS_INLINE Index patchCols() const { |
831 | return m_base_mapper.m_patch_cols; |
832 | } |
833 | |
834 | EIGEN_DEVICE_FUNC |
835 | EIGEN_ALWAYS_INLINE Index patchRowStride() const { |
836 | eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride && |
837 | "Patch depth must be equal to patch row stride." ); |
838 | return patchDepth(); |
839 | } |
840 | EIGEN_DEVICE_FUNC |
841 | EIGEN_ALWAYS_INLINE Index patchColStride() const { |
842 | return m_base_mapper.m_patch_col_stride; |
843 | } |
844 | |
845 | EIGEN_DEVICE_FUNC |
846 | EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const { |
847 | eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride && |
848 | "Patch depth must be equal to patch row stride." ); |
849 | return m_base_mapper.m_fastDimZero; // patch_depth |
850 | } |
851 | EIGEN_DEVICE_FUNC |
852 | EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const { |
853 | return m_base_mapper.m_fastPatchColStride; |
854 | } |
855 | |
856 | EIGEN_DEVICE_FUNC |
857 | EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, |
858 | const Index baseIndex) const { |
859 | const Index inputIndex = depth + baseIndex; |
860 | return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex); |
861 | } |
862 | EIGEN_DEVICE_FUNC |
863 | EIGEN_ALWAYS_INLINE Scalar coeffNoPadding(const Index depth, |
864 | const Index baseIndex) const { |
865 | const Index inputIndex = depth + baseIndex; |
866 | return m_base_mapper.m_impl.coeff(inputIndex); |
867 | } |
868 | template <typename PacketT = Packet> |
869 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if< |
870 | TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, |
871 | PacketT>::type |
872 | partialPacketNoPadding(const Index depth, const Index baseIndex, |
873 | Index num_coeffs) const { |
874 | const Index inputIndex = depth + baseIndex; |
875 | return m_base_mapper.m_impl.template partialPacket<PacketT>( |
876 | inputIndex, mask<PacketT>(0, num_coeffs)); |
877 | } |
878 | EIGEN_DEVICE_FUNC |
879 | EIGEN_ALWAYS_INLINE bool hasPadding() const { |
880 | // TODO(ezhulenev): It does seems that for inflated filter it's still |
881 | // possible to guarantee "no padding or skipping" for non-standard packing. |
882 | if (nonStandardPatches()) return true; |
883 | |
884 | // Non zero padding before. |
885 | if (m_base_mapper.m_rowPaddingTop > 0) return true; |
886 | if (m_base_mapper.m_colPaddingLeft > 0) return true; |
887 | |
888 | // Non zero padding after in rows. |
889 | const Index last_row = |
890 | (m_base_mapper.m_outputRows - 1) * m_base_mapper.m_row_strides; |
891 | if (last_row + (patchRows() - 1) >= m_base_mapper.m_inputRows) return true; |
892 | |
893 | // Non zero padding after in cols. |
894 | const Index last_col = |
895 | (m_base_mapper.m_outputCols - 1) * m_base_mapper.m_col_strides; |
896 | if (last_col + (patchCols() - 1) >= m_base_mapper.m_inputCols) return true; |
897 | |
898 | return false; |
899 | } |
900 | EIGEN_DEVICE_FUNC |
901 | EIGEN_ALWAYS_INLINE bool padRow(const Index row) const { |
902 | const Index r = m_rowIndex + row; |
903 | return r < 0 || r >= m_base_mapper.m_inputRows; |
904 | } |
905 | EIGEN_DEVICE_FUNC |
906 | EIGEN_ALWAYS_INLINE bool padAnyRow(const Index first_row, |
907 | const Index last_row) const { |
908 | return m_rowIndex + first_row < 0 || |
909 | m_rowIndex + last_row >= m_base_mapper.m_inputRows; |
910 | } |
911 | EIGEN_DEVICE_FUNC |
912 | EIGEN_ALWAYS_INLINE bool padOrSkipRow(const Index row, |
913 | Index* orig_row) const { |
914 | eigen_assert(nonStandardPatches()); |
915 | |
916 | const Index input_row = m_rowIndex + row * m_base_mapper.m_in_row_strides; |
917 | *orig_row = (m_base_mapper.m_patch_row_inflate_strides == 1) |
918 | ? input_row |
919 | : ((input_row >= 0) |
920 | ? (input_row / m_base_mapper.m_fastInputRowStride) |
921 | : 0); |
922 | |
923 | return (*orig_row < 0 || *orig_row >= m_base_mapper.m_inputRows) || |
924 | (input_row != *orig_row * m_base_mapper.m_patch_row_inflate_strides); |
925 | } |
926 | EIGEN_DEVICE_FUNC |
927 | EIGEN_ALWAYS_INLINE bool padCol(const Index col) const { |
928 | const Index c = m_colIndex + col; |
929 | return c < 0 || c >= m_base_mapper.m_inputCols; |
930 | } |
931 | EIGEN_DEVICE_FUNC |
932 | EIGEN_ALWAYS_INLINE bool padOrSkipCol(const Index col, |
933 | Index* orig_col) const { |
934 | eigen_assert(nonStandardPatches()); |
935 | |
936 | const Index input_col = m_colIndex + col * m_base_mapper.m_in_col_strides; |
937 | *orig_col = (m_base_mapper.m_patch_col_inflate_strides == 1) |
938 | ? input_col |
939 | : ((input_col >= 0) |
940 | ? (input_col / m_base_mapper.m_fastInputColStride) |
941 | : 0); |
942 | |
943 | return (*orig_col < 0 || *orig_col >= m_base_mapper.m_inputCols) || |
944 | (input_col != *orig_col * m_base_mapper.m_patch_col_inflate_strides); |
945 | } |
946 | EIGEN_DEVICE_FUNC |
947 | EIGEN_ALWAYS_INLINE Index baseIndex(const Index row, const Index col) const { |
948 | const Index r = m_rowIndex + row; |
949 | const Index c = m_colIndex + col; |
950 | return r * m_base_mapper.m_rowInputStride + |
951 | c * m_base_mapper.m_colInputStride + m_otherIndex; |
952 | } |
953 | // Compute a base index when original input row and column were precomputed |
954 | // using padOrSkipRow and padOrSkipCol. Used only for non standard patches. |
955 | EIGEN_DEVICE_FUNC |
956 | EIGEN_ALWAYS_INLINE Index origBaseIndex(const Index orig_row, |
957 | const Index orig_col) const { |
958 | return orig_row * m_base_mapper.m_rowInputStride + |
959 | orig_col * m_base_mapper.m_colInputStride + m_otherIndex; |
960 | } |
961 | |
962 | EIGEN_DEVICE_FUNC |
963 | EIGEN_ALWAYS_INLINE Index rowStride() const { |
964 | return m_base_mapper.m_row_strides; |
965 | } |
966 | EIGEN_DEVICE_FUNC |
967 | EIGEN_ALWAYS_INLINE Index colStride() const { |
968 | return m_base_mapper.m_col_strides; |
969 | } |
970 | |
971 | EIGEN_DEVICE_FUNC |
972 | EIGEN_ALWAYS_INLINE Index rowOffset() const { |
973 | const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; |
974 | const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; |
975 | return patchOffset - colOffset * m_base_mapper.m_colStride; |
976 | } |
977 | |
978 | EIGEN_DEVICE_FUNC |
979 | EIGEN_ALWAYS_INLINE Index colOffset() const { |
980 | const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; |
981 | const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; |
982 | return colOffset; |
983 | } |
984 | |
985 | EIGEN_DEVICE_FUNC |
986 | EIGEN_ALWAYS_INLINE Index depthOffset() const { |
987 | return m_depth_offset % patchDepth(); |
988 | } |
989 | |
990 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper |
991 | getLinearMapper(Index i, Index j) const { |
992 | return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset); |
993 | } |
994 | |
995 | private: |
996 | Index m_depth_offset; // First row in the input matrix |
997 | Index m_col_offset; // First col in the input matrix |
998 | |
999 | // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base |
1000 | // indices for the first element in a patch specified by col_offset |
1001 | // (see computeBaseIndices(...) for details). |
1002 | Index m_rowIndex; |
1003 | Index m_colIndex; |
1004 | Index m_otherIndex; |
1005 | |
1006 | const ParentMapper m_base_mapper; // Keeping a copy instead of a reference |
1007 | // performs better in benchmarks. |
1008 | }; |
1009 | |
1010 | // Arrange a block of the right input matrix (in our case it's always a "virtual |
1011 | // matrix" constructed from extracted image patches) in contiguous memory. |
1012 | // |
1013 | // Given column major input (A0 beside A1 in memory): |
1014 | // A0 B0 C0 D0 E0 F0 G0 H0 ... Z0 |
1015 | // A1 B1 C1 D1 E1 F1 G1 H1 ... Z1 |
1016 | // A2 B2 C2 D2 E2 F2 G2 H2 ... Z2 |
1017 | // A3 B3 C3 D3 E3 F3 G3 H3 ... Z3 |
1018 | // A4 B4 C4 D4 E4 F4 G4 H4 ... Z4 |
1019 | // A5 B5 C5 D5 E5 F5 G5 H5 ... Z5 |
1020 | // A6 B6 C6 D6 E6 F6 G6 H6 ... Z6 |
1021 | // A7 B7 C7 D7 E7 F7 G7 H7 ... Z7 |
1022 | // A8 ... |
1023 | // ... |
1024 | // |
1025 | // *) A, B, C, ... - patches extracted from the original input. |
1026 | // *) A0, A1, A2 ... - values from the same patch at different offsets. |
1027 | // |
1028 | // The traversal (packed rhs memory) order (B0 besides A0 in memory): |
1029 | // A0 B0 C0 D0 A1 B1 C1 D1 ... |
1030 | // E0 F0 G0 H0 E1 F1 G1 H1 ... |
1031 | // ... |
1032 | // Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4) |
1033 | // |
1034 | // This traversal order must be the same as in default gemm_pack_rhs defined in |
1035 | // GeneralBlockPanelKernel.h. |
1036 | // |
1037 | // *) nr - number of registers along the 'n' dimension. |
1038 | // See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix |
1039 | // Multiplication" paper. |
1040 | template <typename NewDimension, Index Rows, Index Cols, typename ArgType, |
1041 | typename Device, typename Scalar, typename Index, |
1042 | typename nocontract_t, typename contract_t, int packet_size, |
1043 | bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, |
1044 | int nr> |
1045 | struct gemm_pack_rhs< |
1046 | Scalar, Index, |
1047 | TensorContractionSubMapper< |
1048 | Scalar, Index, Rhs, |
1049 | TensorEvaluator< |
1050 | const TensorReshapingOp< |
1051 | NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, |
1052 | Device>, |
1053 | nocontract_t, contract_t, packet_size, inner_dim_contiguous, |
1054 | inner_dim_reordered, Alignment>, |
1055 | nr, ColMajor, false, false> { |
1056 | typedef TensorContractionSubMapper< |
1057 | Scalar, Index, Rhs, |
1058 | TensorEvaluator< |
1059 | const TensorReshapingOp< |
1060 | NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, |
1061 | Device>, |
1062 | nocontract_t, contract_t, packet_size, inner_dim_contiguous, |
1063 | inner_dim_reordered, Alignment> |
1064 | SubMapper; |
1065 | typedef SubMapper DataMapper; |
1066 | typedef typename packet_traits<Scalar>::type Packet; |
1067 | |
1068 | EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE) |
1069 | |
1070 | EIGEN_DEVICE_FUNC |
1071 | EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, |
1072 | Index depth, Index cols, Index stride = 0, |
1073 | Index offset = 0) const { |
1074 | eigen_assert(stride == 0); |
1075 | eigen_assert(offset == 0); |
1076 | |
1077 | const Index packet_cols4 = (cols / 4) * 4; |
1078 | const Index peeled_k = (depth / packet_size) * packet_size; |
1079 | const bool non_standard_patches = rhs.nonStandardPatches(); |
1080 | |
1081 | for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { |
1082 | const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); |
1083 | const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); |
1084 | const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); |
1085 | const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); |
1086 | |
1087 | Index k = 0; |
1088 | if ((packet_size % 4) == 0 && !non_standard_patches) { |
1089 | // FAST PATH: |
1090 | // Iterate over patch columns and rows, if we know that a single |
1091 | // packet do not span across multiple rows or columns. |
1092 | if ((rhs.patchDepth() % packet_size) == 0) { |
1093 | const Index start_col = rhs.colOffset(); |
1094 | const Index max_col = rhs.maxCol(peeled_k); |
1095 | |
1096 | for (Index c = start_col; c < max_col; ++c) { |
1097 | eigen_assert(k <= peeled_k); |
1098 | |
1099 | const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; |
1100 | const Index max_row = rhs.maxRow(peeled_k, c); |
1101 | |
1102 | const bool pad_col0 = dm0.padCol(c); |
1103 | const bool pad_col1 = dm1.padCol(c); |
1104 | const bool pad_col2 = dm2.padCol(c); |
1105 | const bool pad_col3 = dm3.padCol(c); |
1106 | |
1107 | // Check if we can squeeze reads along the `row` and `depth` |
1108 | // dimensions (two innermost dimensions). |
1109 | if (!pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 && // |
1110 | !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && // |
1111 | !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && // |
1112 | !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && // |
1113 | !dm3.padRow(start_row) && !dm3.padRow(max_row - 1)) { |
1114 | // Compute how many elements we can squeeze read. |
1115 | const Index start_depth = |
1116 | (c == start_col) ? rhs.depthOffset() : 0; |
1117 | |
1118 | // Upper bound for the number of elements in the depth dimension |
1119 | // that we can squeeze read. |
1120 | const Index squeeze_length = |
1121 | (max_row - start_row) * rhs.patchDepth() - start_depth; |
1122 | |
1123 | // Do not overshoot beyond the block size. |
1124 | const Index max_depth = |
1125 | start_depth + std::min<Index>(peeled_k - k, squeeze_length); |
1126 | eigen_assert((max_depth - start_depth) % packet_size == 0); |
1127 | |
1128 | const Index idx0 = dm0.baseIndex(start_row, c); |
1129 | const Index idx1 = dm1.baseIndex(start_row, c); |
1130 | const Index idx2 = dm2.baseIndex(start_row, c); |
1131 | const Index idx3 = dm3.baseIndex(start_row, c); |
1132 | |
1133 | for (Index d = start_depth; d < max_depth; d += packet_size) { |
1134 | eigen_assert(k < peeled_k); |
1135 | PacketBlock<Packet, 4> kernel; |
1136 | kernel.packet[0] = rhs.packetNoPadding(d, idx0); |
1137 | kernel.packet[1] = rhs.packetNoPadding(d, idx1); |
1138 | kernel.packet[2] = rhs.packetNoPadding(d, idx2); |
1139 | kernel.packet[3] = rhs.packetNoPadding(d, idx3); |
1140 | ptranspose(kernel); |
1141 | pstoreu(block + 0 * packet_size, kernel.packet[0]); |
1142 | pstoreu(block + 1 * packet_size, kernel.packet[1]); |
1143 | pstoreu(block + 2 * packet_size, kernel.packet[2]); |
1144 | pstoreu(block + 3 * packet_size, kernel.packet[3]); |
1145 | block += 4 * packet_size; |
1146 | k += packet_size; |
1147 | } |
1148 | |
1149 | // Go to the next column. |
1150 | continue; |
1151 | } |
1152 | |
1153 | // If we can't squeeze reads, process rows one by one. |
1154 | for (Index r = start_row; r < max_row; ++r) { |
1155 | eigen_assert(k <= peeled_k); |
1156 | |
1157 | const bool pad0 = pad_col0 || dm0.padRow(r); |
1158 | const bool pad1 = pad_col1 || dm1.padRow(r); |
1159 | const bool pad2 = pad_col2 || dm2.padRow(r); |
1160 | const bool pad3 = pad_col3 || dm3.padRow(r); |
1161 | |
1162 | const Index idx0 = dm0.baseIndex(r, c); |
1163 | const Index idx1 = dm1.baseIndex(r, c); |
1164 | const Index idx2 = dm2.baseIndex(r, c); |
1165 | const Index idx3 = dm3.baseIndex(r, c); |
1166 | |
1167 | const Index start_depth = ((c == start_col) && (r == start_row)) |
1168 | ? rhs.depthOffset() |
1169 | : 0; |
1170 | const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); |
1171 | eigen_assert((max_depth - start_depth) % packet_size == 0); |
1172 | |
1173 | for (Index d = start_depth; d < max_depth; d += packet_size) { |
1174 | eigen_assert(k < peeled_k); |
1175 | PacketBlock<Packet, 4> kernel; |
1176 | kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) |
1177 | : rhs.packetNoPadding(d, idx0); |
1178 | kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) |
1179 | : rhs.packetNoPadding(d, idx1); |
1180 | kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0)) |
1181 | : rhs.packetNoPadding(d, idx2); |
1182 | kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0)) |
1183 | : rhs.packetNoPadding(d, idx3); |
1184 | ptranspose(kernel); |
1185 | pstoreu(block + 0 * packet_size, kernel.packet[0]); |
1186 | pstoreu(block + 1 * packet_size, kernel.packet[1]); |
1187 | pstoreu(block + 2 * packet_size, kernel.packet[2]); |
1188 | pstoreu(block + 3 * packet_size, kernel.packet[3]); |
1189 | block += 4 * packet_size; |
1190 | k += packet_size; |
1191 | } |
1192 | } |
1193 | } |
1194 | |
1195 | // The loop above should fill peeled_k elements. |
1196 | eigen_assert(peeled_k == k); |
1197 | |
1198 | } else { |
1199 | for (; k < peeled_k; k += packet_size) { |
1200 | PacketBlock<Packet, 4> kernel; |
1201 | kernel.packet[0] = dm0.loadPacketStandard(k); |
1202 | kernel.packet[1] = dm1.loadPacketStandard(k); |
1203 | kernel.packet[2] = dm2.loadPacketStandard(k); |
1204 | kernel.packet[3] = dm3.loadPacketStandard(k); |
1205 | ptranspose(kernel); |
1206 | pstoreu(block + 0 * packet_size, kernel.packet[0]); |
1207 | pstoreu(block + 1 * packet_size, kernel.packet[1]); |
1208 | pstoreu(block + 2 * packet_size, kernel.packet[2]); |
1209 | pstoreu(block + 3 * packet_size, kernel.packet[3]); |
1210 | block += 4 * packet_size; |
1211 | } |
1212 | } |
1213 | } |
1214 | |
1215 | // Copy the remaining coefficients of the column block after the peeled_k. |
1216 | if (!rhs.nonStandardPatches()) { |
1217 | for (; k < depth; k++) { |
1218 | block[0] = dm0.loadCoeffStandard(k); |
1219 | block[1] = dm1.loadCoeffStandard(k); |
1220 | block[2] = dm2.loadCoeffStandard(k); |
1221 | block[3] = dm3.loadCoeffStandard(k); |
1222 | block += 4; |
1223 | } |
1224 | } else { |
1225 | for (; k < depth; k++) { |
1226 | block[0] = dm0(k); |
1227 | block[1] = dm1(k); |
1228 | block[2] = dm2(k); |
1229 | block[3] = dm3(k); |
1230 | block += 4; |
1231 | } |
1232 | } |
1233 | } |
1234 | |
1235 | // copy the remaining columns one at a time (nr==1) |
1236 | for (Index j2 = packet_cols4; j2 < cols; ++j2) { |
1237 | const SubMapper dm0 = rhs.getLinearMapper(0, j2); |
1238 | for (Index k = 0; k < depth; k++) { |
1239 | *block = dm0(k); |
1240 | block += 1; |
1241 | } |
1242 | } |
1243 | } |
1244 | }; |
1245 | |
1246 | // Template specialization for packet_size = 2. We must special-case packet |
1247 | // blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>. |
1248 | template <typename NewDimension, Index Rows, Index Cols, typename ArgType, |
1249 | typename Device, typename Scalar, typename Index, |
1250 | typename nocontract_t, typename contract_t, bool inner_dim_contiguous, |
1251 | bool inner_dim_reordered, int Alignment, int nr> |
1252 | struct gemm_pack_rhs< |
1253 | Scalar, Index, |
1254 | TensorContractionSubMapper< |
1255 | Scalar, Index, Rhs, |
1256 | TensorEvaluator< |
1257 | const TensorReshapingOp< |
1258 | NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, |
1259 | Device>, |
1260 | nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, |
1261 | Alignment>, |
1262 | nr, ColMajor, false, false> { |
1263 | typedef TensorContractionSubMapper< |
1264 | Scalar, Index, Rhs, |
1265 | TensorEvaluator< |
1266 | const TensorReshapingOp< |
1267 | NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, |
1268 | Device>, |
1269 | nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, |
1270 | Alignment> |
1271 | SubMapper; |
1272 | typedef SubMapper DataMapper; |
1273 | typedef typename packet_traits<Scalar>::type Packet; |
1274 | |
1275 | EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE) |
1276 | |
1277 | EIGEN_DEVICE_FUNC |
1278 | EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, |
1279 | Index depth, Index cols, Index stride = 0, |
1280 | Index offset = 0) const { |
1281 | eigen_assert(stride == 0); |
1282 | eigen_assert(offset == 0); |
1283 | |
1284 | const int packet_size = 2; |
1285 | const Index packet_cols4 = (cols / 4) * 4; |
1286 | const Index peeled_k = (depth / packet_size) * packet_size; |
1287 | const bool non_standard_patches = rhs.nonStandardPatches(); |
1288 | |
1289 | for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { |
1290 | const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); |
1291 | const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); |
1292 | const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); |
1293 | const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); |
1294 | |
1295 | Index k = 0; |
1296 | if (!non_standard_patches) { |
1297 | // FAST PATH: |
1298 | // Iterate over patch columns and rows if we know that a single |
1299 | // packet do not span across multiple rows or columns. |
1300 | if ((rhs.patchDepth() % packet_size) == 0) { |
1301 | const Index start_col = rhs.colOffset(); |
1302 | const Index max_col = rhs.maxCol(peeled_k); |
1303 | |
1304 | for (Index c = start_col; c < max_col; ++c) { |
1305 | eigen_assert(k <= peeled_k); |
1306 | |
1307 | const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; |
1308 | const Index max_row = rhs.maxRow(peeled_k, c); |
1309 | |
1310 | const bool pad_col0 = dm0.padCol(c); |
1311 | const bool pad_col1 = dm1.padCol(c); |
1312 | const bool pad_col2 = dm2.padCol(c); |
1313 | const bool pad_col3 = dm3.padCol(c); |
1314 | |
1315 | // We can squeeze reads along the `row` and `depth` dimensions if |
1316 | // the row stride is `1`, which means that `row` and `depth` |
1317 | // dimensions are contiguous (two innermost dimensions). |
1318 | if (rhs.rowStride() == 1 && // |
1319 | !pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 && // |
1320 | !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && // |
1321 | !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && // |
1322 | !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && // |
1323 | !dm3.padRow(start_row) && !dm3.padRow(max_row - 1)) { |
1324 | // Compute how many elements we can squeeze read. |
1325 | const Index start_depth = |
1326 | (c == start_col) ? rhs.depthOffset() : 0; |
1327 | |
1328 | // Upper bound for the number of elements in the depth dimension |
1329 | // that we can squeeze read. |
1330 | const Index squeeze_length = |
1331 | (max_row - start_row) * rhs.patchDepth() - start_depth; |
1332 | |
1333 | // Do not overshoot beyond the block size. |
1334 | const Index max_depth = |
1335 | start_depth + std::min<Index>(peeled_k - k, squeeze_length); |
1336 | eigen_assert((max_depth - start_depth) % packet_size == 0); |
1337 | |
1338 | const Index idx0 = dm0.baseIndex(start_row, c); |
1339 | const Index idx1 = dm1.baseIndex(start_row, c); |
1340 | const Index idx2 = dm2.baseIndex(start_row, c); |
1341 | const Index idx3 = dm3.baseIndex(start_row, c); |
1342 | |
1343 | for (Index d = start_depth; d < max_depth; d += packet_size) { |
1344 | PacketBlock<Packet, 2> kernel0; |
1345 | PacketBlock<Packet, 2> kernel1; |
1346 | kernel0.packet[0] = rhs.packetNoPadding(d, idx0); |
1347 | kernel0.packet[1] = rhs.packetNoPadding(d, idx1); |
1348 | kernel1.packet[0] = rhs.packetNoPadding(d, idx2); |
1349 | kernel1.packet[1] = rhs.packetNoPadding(d, idx3); |
1350 | ptranspose(kernel0); |
1351 | ptranspose(kernel1); |
1352 | pstoreu(block + 0 * packet_size, kernel0.packet[0]); |
1353 | pstoreu(block + 1 * packet_size, kernel1.packet[0]); |
1354 | pstoreu(block + 2 * packet_size, kernel0.packet[1]); |
1355 | pstoreu(block + 3 * packet_size, kernel1.packet[1]); |
1356 | block += 4 * packet_size; |
1357 | k += packet_size; |
1358 | } |
1359 | |
1360 | // Go to the next column. |
1361 | continue; |
1362 | } |
1363 | |
1364 | // If we can't squeeze reads, process rows one by one. |
1365 | for (Index r = start_row; r < max_row; ++r) { |
1366 | eigen_assert(k <= peeled_k); |
1367 | |
1368 | const bool pad0 = pad_col0 || dm0.padRow(r); |
1369 | const bool pad1 = pad_col1 || dm1.padRow(r); |
1370 | const bool pad2 = pad_col2 || dm2.padRow(r); |
1371 | const bool pad3 = pad_col3 || dm3.padRow(r); |
1372 | |
1373 | const Index idx0 = dm0.baseIndex(r, c); |
1374 | const Index idx1 = dm1.baseIndex(r, c); |
1375 | const Index idx2 = dm2.baseIndex(r, c); |
1376 | const Index idx3 = dm3.baseIndex(r, c); |
1377 | |
1378 | const Index start_depth = ((c == start_col) && (r == start_row)) |
1379 | ? rhs.depthOffset() |
1380 | : 0; |
1381 | const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); |
1382 | eigen_assert((max_depth - start_depth) % packet_size == 0); |
1383 | |
1384 | for (Index d = start_depth; d < max_depth; d += packet_size) { |
1385 | eigen_assert(k < peeled_k); |
1386 | PacketBlock<Packet, 2> kernel0; |
1387 | PacketBlock<Packet, 2> kernel1; |
1388 | kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) |
1389 | : rhs.packetNoPadding(d, idx0); |
1390 | kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) |
1391 | : rhs.packetNoPadding(d, idx1); |
1392 | kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0)) |
1393 | : rhs.packetNoPadding(d, idx2); |
1394 | kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0)) |
1395 | : rhs.packetNoPadding(d, idx3); |
1396 | ptranspose(kernel0); |
1397 | ptranspose(kernel1); |
1398 | pstoreu(block + 0 * packet_size, kernel0.packet[0]); |
1399 | pstoreu(block + 1 * packet_size, kernel1.packet[0]); |
1400 | pstoreu(block + 2 * packet_size, kernel0.packet[1]); |
1401 | pstoreu(block + 3 * packet_size, kernel1.packet[1]); |
1402 | block += 4 * packet_size; |
1403 | k += packet_size; |
1404 | } |
1405 | } |
1406 | } |
1407 | |
1408 | // The loop above should fill peeled_k elements. |
1409 | eigen_assert(peeled_k == k); |
1410 | |
1411 | } else { |
1412 | // Packet can span multiple rows or columns, so we have to go |
1413 | // though the slower "standard" path. |
1414 | for (; k < peeled_k; k += packet_size) { |
1415 | PacketBlock<Packet, 2> kernel0; |
1416 | PacketBlock<Packet, 2> kernel1; |
1417 | kernel0.packet[0] = dm0.loadPacketStandard(k); |
1418 | kernel0.packet[1] = dm1.loadPacketStandard(k); |
1419 | kernel1.packet[0] = dm2.loadPacketStandard(k); |
1420 | kernel1.packet[1] = dm3.loadPacketStandard(k); |
1421 | ptranspose(kernel0); |
1422 | ptranspose(kernel1); |
1423 | pstoreu(block + 0 * packet_size, kernel0.packet[0]); |
1424 | pstoreu(block + 1 * packet_size, kernel1.packet[0]); |
1425 | pstoreu(block + 2 * packet_size, kernel0.packet[1]); |
1426 | pstoreu(block + 3 * packet_size, kernel1.packet[1]); |
1427 | block += 4 * packet_size; |
1428 | } |
1429 | } |
1430 | } |
1431 | |
1432 | // Copy the remaining coefficients of the column block after the peeled_k. |
1433 | if (!non_standard_patches) { |
1434 | for (; k < depth; k++) { |
1435 | block[0] = dm0.loadCoeffStandard(k); |
1436 | block[1] = dm1.loadCoeffStandard(k); |
1437 | block[2] = dm2.loadCoeffStandard(k); |
1438 | block[3] = dm3.loadCoeffStandard(k); |
1439 | block += 4; |
1440 | } |
1441 | } else { |
1442 | for (; k < depth; k++) { |
1443 | block[0] = dm0(k); |
1444 | block[1] = dm1(k); |
1445 | block[2] = dm2(k); |
1446 | block[3] = dm3(k); |
1447 | block += 4; |
1448 | } |
1449 | } |
1450 | } |
1451 | |
1452 | // Copy the remaining columns one at a time (nr==1). |
1453 | for (Index j2 = packet_cols4; j2 < cols; ++j2) { |
1454 | const SubMapper dm0 = rhs.getLinearMapper(0, j2); |
1455 | for (Index k = 0; k < depth; k++) { |
1456 | *block = dm0(k); |
1457 | block += 1; |
1458 | } |
1459 | } |
1460 | } |
1461 | }; |
1462 | |
1463 | // Special case for non-vectorized types such as float16. |
1464 | template <typename NewDimension, Index Rows, Index Cols, typename ArgType, |
1465 | typename Device, typename Scalar, typename Index, |
1466 | typename nocontract_t, typename contract_t, bool inner_dim_contiguous, |
1467 | bool inner_dim_reordered, int Alignment, int nr> |
1468 | struct gemm_pack_rhs< |
1469 | Scalar, Index, |
1470 | TensorContractionSubMapper< |
1471 | Scalar, Index, Rhs, |
1472 | TensorEvaluator< |
1473 | const TensorReshapingOp< |
1474 | NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, |
1475 | Device>, |
1476 | nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, |
1477 | Alignment>, |
1478 | nr, ColMajor, false, false> { |
1479 | typedef TensorContractionSubMapper< |
1480 | Scalar, Index, Rhs, |
1481 | TensorEvaluator< |
1482 | const TensorReshapingOp< |
1483 | NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, |
1484 | Device>, |
1485 | nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, |
1486 | Alignment> |
1487 | SubMapper; |
1488 | typedef SubMapper DataMapper; |
1489 | |
1490 | EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE) |
1491 | |
1492 | EIGEN_DEVICE_FUNC |
1493 | EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, |
1494 | Index depth, Index cols, Index stride = 0, |
1495 | Index offset = 0) const { |
1496 | eigen_assert(stride == 0); |
1497 | eigen_assert(offset == 0); |
1498 | |
1499 | const Index packet_cols4 = (cols / 4) * 4; |
1500 | |
1501 | for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { |
1502 | const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); |
1503 | const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); |
1504 | const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); |
1505 | const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); |
1506 | |
1507 | if (!rhs.nonStandardPatches()) { |
1508 | for (Index k = 0; k < depth; k++) { |
1509 | block[0] = dm0.loadCoeffStandard(k); |
1510 | block[1] = dm1.loadCoeffStandard(k); |
1511 | block[2] = dm2.loadCoeffStandard(k); |
1512 | block[3] = dm3.loadCoeffStandard(k); |
1513 | block += 4; |
1514 | } |
1515 | } else { |
1516 | for (Index k = 0; k < depth; k++) { |
1517 | block[0] = dm0(k); |
1518 | block[1] = dm1(k); |
1519 | block[2] = dm2(k); |
1520 | block[3] = dm3(k); |
1521 | block += 4; |
1522 | } |
1523 | } |
1524 | } |
1525 | |
1526 | // Copy the remaining columns one at a time (nr==1). |
1527 | for (Index j2 = packet_cols4; j2 < cols; ++j2) { |
1528 | const SubMapper dm0 = rhs.getLinearMapper(0, j2); |
1529 | for (Index k = 0; k < depth; k++) { |
1530 | *block = dm0(k); |
1531 | block += 1; |
1532 | } |
1533 | } |
1534 | } |
1535 | }; |
1536 | #endif |
1537 | } // end namespace internal |
1538 | |
1539 | /** SpatialConvolution |
1540 | * \ingroup CXX11_NeuralNetworks_Module |
1541 | * |
1542 | * \brief Applies a 2D convolution over a multichannel input image. |
1543 | * |
1544 | * The input parameter is expected to be a tensor with a rank of 3 or more |
1545 | * (channels, height, width, and optionally others) |
1546 | * The kernel parameter is expected to be a 4D tensor (filters, channels, |
1547 | * kernel_height, kernel_width) |
1548 | * The input and the kernel must both be in col-major layout. The result will |
1549 | * also be in col-major layout. |
1550 | * |
1551 | * If col_in_stride, row_in_stride > 1, then applies convolution with holes |
1552 | * (aka atrous convolution), sampling every col_in_stride, row_in_stride input |
1553 | * pixels. |
1554 | * |
1555 | * If padding_top, padding_bottom, padding_left, or padding_right is specified, |
1556 | * then those paddings will be used to pad the input, and padding_type must be |
1557 | * PADDING_VALID. |
1558 | * |
1559 | * The result can be assigned to a tensor of rank equal to the rank of the |
1560 | * input. The dimensions of the result will be filters, height, width (and |
1561 | * others if applicable). |
1562 | * |
1563 | * It is possible to swap the order of the width and height dimensions provided |
1564 | * that the same order is used in the input, the kernel, and the output. |
1565 | * |
1566 | * It is also possible to add an output kernel to the contraction, output |
1567 | * kernel is called by Eigen when it "finalizes" the block of an output tensor. |
1568 | * |
1569 | */ |
1570 | template <typename Input, typename Kernel, |
1571 | typename OutputKernel = const NoOpOutputKernel> |
1572 | EIGEN_ALWAYS_INLINE static const std::conditional_t< |
1573 | internal::traits<Input>::Layout == ColMajor, |
1574 | TensorReshapingOp< |
1575 | const DSizes<typename internal::traits<Input>::Index, |
1576 | internal::traits<Input>::NumDimensions>, |
1577 | const TensorContractionOp< |
1578 | const array<IndexPair<typename internal::traits<Input>::Index>, 1>, |
1579 | const TensorReshapingOp< |
1580 | const DSizes<typename internal::traits<Input>::Index, 2>, |
1581 | const Kernel>, |
1582 | const TensorReshapingOp< |
1583 | const DSizes<typename internal::traits<Input>::Index, 2>, |
1584 | const TensorImagePatchOp<Dynamic, Dynamic, const Input> >, |
1585 | const OutputKernel> >, |
1586 | TensorReshapingOp< |
1587 | const DSizes<typename internal::traits<Input>::Index, |
1588 | internal::traits<Input>::NumDimensions>, |
1589 | const TensorContractionOp< |
1590 | const array<IndexPair<typename internal::traits<Input>::Index>, 1>, |
1591 | const TensorReshapingOp< |
1592 | const DSizes<typename internal::traits<Input>::Index, 2>, |
1593 | const TensorImagePatchOp<Dynamic, Dynamic, const Input> >, |
1594 | const TensorReshapingOp< |
1595 | const DSizes<typename internal::traits<Input>::Index, 2>, |
1596 | const Kernel>, |
1597 | const OutputKernel> > > |
1598 | SpatialConvolution(const Input& input, const Kernel& kernel, |
1599 | const Index row_stride = 1, const Index col_stride = 1, |
1600 | const PaddingType padding_type = PADDING_SAME, |
1601 | const Index row_in_stride = 1, const Index col_in_stride = 1, |
1602 | const OutputKernel& output_kernel = OutputKernel(), |
1603 | Index padding_top = 0, Index padding_bottom = 0, |
1604 | Index padding_left = 0, Index padding_right = 0) { |
1605 | typedef typename internal::traits<Input>::Index TensorIndex; |
1606 | typedef typename internal::traits<Input>::Scalar InputScalar; |
1607 | TensorRef<Tensor<InputScalar, internal::traits<Input>::NumDimensions, |
1608 | internal::traits<Input>::Layout, TensorIndex> > |
1609 | in(input); |
1610 | TensorRef<Tensor<typename internal::traits<Kernel>::Scalar, |
1611 | internal::traits<Kernel>::NumDimensions, |
1612 | internal::traits<Kernel>::Layout, TensorIndex> > |
1613 | kern(kernel); |
1614 | |
1615 | EIGEN_STATIC_ASSERT( |
1616 | internal::traits<Input>::Layout == internal::traits<Kernel>::Layout, |
1617 | YOU_MADE_A_PROGRAMMING_MISTAKE) |
1618 | const bool isColMajor = (internal::traits<Input>::Layout == ColMajor); |
1619 | |
1620 | const int NumDims = internal::traits<Input>::NumDimensions; |
1621 | |
1622 | // Number of filters to apply. This is the same as the output depth of the |
1623 | // result |
1624 | const TensorIndex kernelFilters = |
1625 | isColMajor ? kern.dimensions()[0] : kern.dimensions()[3]; |
1626 | // Number of channels. This is the same as the input depth. |
1627 | const TensorIndex kernelChannels = |
1628 | isColMajor ? kern.dimensions()[1] : kern.dimensions()[2]; |
1629 | const TensorIndex kernelRows = |
1630 | isColMajor ? kern.dimensions()[2] : kern.dimensions()[1]; |
1631 | const TensorIndex kernelCols = |
1632 | isColMajor ? kern.dimensions()[3] : kern.dimensions()[0]; |
1633 | |
1634 | const Index kernelRowsEff = |
1635 | kernelRows + (kernelRows - 1) * (row_in_stride - 1); |
1636 | const Index kernelColsEff = |
1637 | kernelCols + (kernelCols - 1) * (col_in_stride - 1); |
1638 | |
1639 | array<IndexPair<TensorIndex>, 1> contract_dims; |
1640 | contract_dims[0] = IndexPair<TensorIndex>(1, 0); |
1641 | |
1642 | const TensorIndex InputRows = |
1643 | isColMajor ? in.dimension(1) : in.dimension(NumDims - 2); |
1644 | const TensorIndex InputCols = |
1645 | isColMajor ? in.dimension(2) : in.dimension(NumDims - 3); |
1646 | const bool padding_explicit = |
1647 | (padding_top || padding_bottom || padding_left || padding_right); |
1648 | |
1649 | TensorIndex out_height; |
1650 | TensorIndex out_width; |
1651 | switch (padding_type) { |
1652 | case PADDING_VALID: { |
1653 | const TensorIndex InputRowsEff = InputRows + padding_top + padding_bottom; |
1654 | const TensorIndex InputColsEff = InputCols + padding_left + padding_right; |
1655 | out_height = divup(InputRowsEff - kernelRowsEff + 1, row_stride); |
1656 | out_width = divup(InputColsEff - kernelColsEff + 1, col_stride); |
1657 | break; |
1658 | } |
1659 | case PADDING_SAME: { |
1660 | eigen_assert(!padding_explicit); |
1661 | out_height = divup(InputRows, row_stride); |
1662 | out_width = divup(InputCols, col_stride); |
1663 | break; |
1664 | } |
1665 | default: { |
1666 | // Initialize unused variables to avoid a compiler warning |
1667 | out_height = 0; |
1668 | out_width = 0; |
1669 | eigen_assert(false && "unexpected padding" ); |
1670 | } |
1671 | } |
1672 | |
1673 | // Molds the output of the patch extraction code into a 2d tensor: |
1674 | // - the first dimension (dims[0]): the patch values to be multiplied with the |
1675 | // kernels |
1676 | // - the second dimension (dims[1]): everything else |
1677 | DSizes<TensorIndex, 2> pre_contract_dims; |
1678 | if (isColMajor) { |
1679 | pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols; |
1680 | pre_contract_dims[1] = out_height * out_width; |
1681 | for (int i = 3; i < NumDims; ++i) { |
1682 | pre_contract_dims[1] *= in.dimension(i); |
1683 | } |
1684 | } else { |
1685 | pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols; |
1686 | pre_contract_dims[0] = out_height * out_width; |
1687 | for (int i = 0; i < NumDims - 3; ++i) { |
1688 | pre_contract_dims[0] *= in.dimension(i); |
1689 | } |
1690 | } |
1691 | |
1692 | // Molds the output of the contraction into the shape expected by the used |
1693 | // (assuming this is ColMajor): |
1694 | // - 1st dim: kernel filters |
1695 | // - 2nd dim: output height |
1696 | // - 3rd dim: output width |
1697 | // - 4th dim and beyond: everything else including batch size |
1698 | DSizes<TensorIndex, NumDims> post_contract_dims; |
1699 | if (isColMajor) { |
1700 | post_contract_dims[0] = kernelFilters; |
1701 | post_contract_dims[1] = out_height; |
1702 | post_contract_dims[2] = out_width; |
1703 | for (int i = 3; i < NumDims; ++i) { |
1704 | post_contract_dims[i] = in.dimension(i); |
1705 | } |
1706 | } else { |
1707 | post_contract_dims[NumDims - 1] = kernelFilters; |
1708 | post_contract_dims[NumDims - 2] = out_height; |
1709 | post_contract_dims[NumDims - 3] = out_width; |
1710 | for (int i = 0; i < NumDims - 3; ++i) { |
1711 | post_contract_dims[i] = in.dimension(i); |
1712 | } |
1713 | } |
1714 | |
1715 | DSizes<TensorIndex, 2> kernel_dims; |
1716 | if (isColMajor) { |
1717 | kernel_dims[0] = kernelFilters; |
1718 | kernel_dims[1] = kernelChannels * kernelRows * kernelCols; |
1719 | } else { |
1720 | kernel_dims[0] = kernelChannels * kernelRows * kernelCols; |
1721 | kernel_dims[1] = kernelFilters; |
1722 | } |
1723 | if (padding_explicit) { |
1724 | return choose( |
1725 | Cond<internal::traits<Input>::Layout == ColMajor>(), |
1726 | kernel.reshape(kernel_dims) |
1727 | .contract(input |
1728 | .extract_image_patches( |
1729 | kernelRows, kernelCols, row_stride, col_stride, |
1730 | row_in_stride, col_in_stride, |
1731 | /*row_inflate_stride=*/1, |
1732 | /*col_inflate_stride=*/1, padding_top, |
1733 | padding_bottom, padding_left, padding_right, |
1734 | /*padding_value=*/static_cast<InputScalar>(0)) |
1735 | .reshape(pre_contract_dims), |
1736 | contract_dims, output_kernel) |
1737 | .reshape(post_contract_dims), |
1738 | input |
1739 | .extract_image_patches( |
1740 | kernelRows, kernelCols, row_stride, col_stride, row_in_stride, |
1741 | col_in_stride, |
1742 | /*row_inflate_stride=*/1, |
1743 | /*col_inflate_stride=*/1, padding_top, padding_bottom, |
1744 | padding_left, padding_right, |
1745 | /*padding_value=*/static_cast<InputScalar>(0)) |
1746 | .reshape(pre_contract_dims) |
1747 | .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel) |
1748 | .reshape(post_contract_dims)); |
1749 | } else { |
1750 | return choose( |
1751 | Cond<internal::traits<Input>::Layout == ColMajor>(), |
1752 | kernel.reshape(kernel_dims) |
1753 | .contract(input |
1754 | .extract_image_patches( |
1755 | kernelRows, kernelCols, row_stride, col_stride, |
1756 | row_in_stride, col_in_stride, padding_type) |
1757 | .reshape(pre_contract_dims), |
1758 | contract_dims, output_kernel) |
1759 | .reshape(post_contract_dims), |
1760 | input |
1761 | .extract_image_patches(kernelRows, kernelCols, row_stride, |
1762 | col_stride, row_in_stride, col_in_stride, |
1763 | padding_type) |
1764 | .reshape(pre_contract_dims) |
1765 | .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel) |
1766 | .reshape(post_contract_dims)); |
1767 | } |
1768 | } |
1769 | |
1770 | } // end namespace Eigen |
1771 | |
1772 | #endif // TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_ |
1773 | |