1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_
17#define TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_
18
19#include "tensorflow/core/kernels/eigen_convolution_helpers.h"
20
21// Note this header is used in both TF and TFLite.
22namespace Eigen {
23
24namespace internal {
25
26#if !EIGEN_ALTIVEC_USE_CUSTOM_PACK
27// WARNING: Most of the code here implicitly assumes that the matrix is in
28// ColMajor layout. This is guaranteed by the tensor contraction (see
29// TensorContraction.h).
30//
31// Inside Eigen a tensor contraction is represented by a matrix multiplication.
32// We don't want to actually extract image patches and reshape the result into
33// a matrix (this involves allocating huge extra memory), so the patch
34// extraction and reshape operations are implicit.
35//
36// TensorContractionInputMapper takes a matrix index and returns the coefficient
37// (or the packet) of the "virtual tensor", that would be at that index if we
38// were to actually reshape the result of patch extraction.
39//
40// TensorContractionSubMapper provides a similar view into the "virtual matrix"
41// at the given vertical and horizontal offsets.
42//
43// "Virtual matrix" dimensions:
44// *0: kernelChannels * kernelRows * kernelCols;
45// 1: out_height * out_width; * OTHERS (e.g batches, etc...)
46//
47// *) extracted patches are continuous in memory (innermost dimension assuming
48// col major layout)
49//
50// With this dimensions:
51// row - offset within a single patch (in code: patchId)
52// col - index of the extracted patch (in code: patchIndex)
53// patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
54//
55// TODO(ezhulenev): Consolidate this part of the code with the image patch
56// extraction code since they are both very similar.
57
58template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
59 typename Device, typename Scalar_, typename Index,
60 typename nocontract_t, typename contract_t, int Side, int packet_size,
61 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
62class TensorContractionInputMapper<
63 Scalar_, Index, Side,
64 TensorEvaluator<
65 const TensorReshapingOp<NewDimension,
66 const TensorImagePatchOp<Rows, Cols, ArgType> >,
67 Device>,
68 nocontract_t, contract_t, packet_size, inner_dim_contiguous,
69 inner_dim_reordered, Alignment> {
70 public:
71 typedef Scalar_ Scalar;
72
73 typedef TensorContractionInputMapper<
74 Scalar, Index, Side,
75 TensorEvaluator<
76 const TensorReshapingOp<
77 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
78 Device>,
79 nocontract_t, contract_t, packet_size, inner_dim_contiguous,
80 inner_dim_reordered, Alignment>
81 Self;
82
83 typedef TensorContractionSubMapper<
84 Scalar, Index, Side,
85 TensorEvaluator<
86 const TensorReshapingOp<
87 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
88 Device>,
89 nocontract_t, contract_t, packet_size, inner_dim_contiguous,
90 inner_dim_reordered, Alignment>
91 SubMapper;
92
93 typedef SubMapper VectorMapper;
94 typedef SubMapper LinearMapper;
95 typedef typename packet_traits<Scalar>::type Packet;
96
97 typedef TensorEvaluator<ArgType, Device> TensorEvaluatorT;
98
99 EIGEN_DEVICE_FUNC
100 TensorContractionInputMapper(
101 const TensorEvaluator<
102 const TensorReshapingOp<
103 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
104 Device>& tensor,
105 const nocontract_t&, const nocontract_t&, const contract_t&,
106 const contract_t&)
107 : m_impl(tensor.impl().impl()) {
108 Index patch_rows;
109 Index patch_depth;
110 if (internal::traits<ArgType>::Layout == ColMajor) {
111 patch_depth = tensor.impl().dimensions()[0];
112 patch_rows = tensor.impl().dimensions()[1];
113 m_patch_cols = tensor.impl().dimensions()[2];
114 m_num_patches = tensor.impl().dimensions()[3];
115 } else {
116 const size_t NumDims = tensor.impl().dimensions().size();
117 patch_depth = tensor.impl().dimensions()[NumDims - 1];
118 patch_rows = tensor.impl().dimensions()[NumDims - 2];
119 m_patch_cols = tensor.impl().dimensions()[NumDims - 3];
120 m_num_patches = tensor.impl().dimensions()[NumDims - 4];
121 }
122
123 // Strides for navigating through the single patch.
124 m_patch_row_stride = patch_depth;
125 m_patch_col_stride = patch_rows * m_patch_row_stride;
126
127 m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
128 m_patch_col_inflate_strides = tensor.impl().colInflateStride();
129
130 m_colStride = patch_rows;
131
132 m_outputRows = tensor.impl().outputRows();
133 m_outputCols = tensor.impl().outputCols();
134 m_row_strides = tensor.impl().userRowStride();
135 m_col_strides = tensor.impl().userColStride();
136
137 m_in_row_strides = tensor.impl().userInRowStride();
138 m_in_col_strides = tensor.impl().userInColStride();
139
140 if (internal::traits<ArgType>::Layout == ColMajor) {
141 m_inputRows = tensor.impl().impl().dimensions()[1];
142 m_inputCols = tensor.impl().impl().dimensions()[2];
143 } else {
144 const int NumDims = tensor.impl().impl().dimensions().size();
145 m_inputRows = tensor.impl().impl().dimensions()[NumDims - 2];
146 m_inputCols = tensor.impl().impl().dimensions()[NumDims - 3];
147 }
148
149 m_rowInputStride = patch_depth;
150 m_colInputStride = patch_depth * m_inputRows;
151 m_patchInputStride = patch_depth * m_inputRows * m_inputCols;
152
153 m_rowPaddingTop = tensor.impl().rowPaddingTop();
154 m_colPaddingLeft = tensor.impl().colPaddingLeft();
155
156 m_fastPatchRowStride =
157 internal::TensorIntDivisor<Index>(m_patch_row_stride);
158 m_fastPatchColStride =
159 internal::TensorIntDivisor<Index>(m_patch_col_stride);
160 m_fastInputRowStride =
161 internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
162 m_fastInputColStride =
163 internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides);
164 m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
165 m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
166 m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
167 m_fastDimZero = internal::TensorIntDivisor<Index>(patch_depth);
168 }
169
170 EIGEN_DEVICE_FUNC
171 TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper)
172 : m_impl(base_mapper.m_impl) {
173 m_patch_cols = base_mapper.m_patch_cols;
174 m_num_patches = base_mapper.m_num_patches;
175
176 m_patch_row_stride = base_mapper.m_patch_row_stride;
177 m_patch_col_stride = base_mapper.m_patch_col_stride;
178
179 m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
180 m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
181
182 m_colStride = base_mapper.m_colStride;
183
184 m_rowInputStride = base_mapper.m_rowInputStride;
185 m_colInputStride = base_mapper.m_colInputStride;
186 m_patchInputStride = base_mapper.m_patchInputStride;
187
188 m_inputRows = base_mapper.m_inputRows;
189 m_inputCols = base_mapper.m_inputCols;
190
191 m_outputRows = base_mapper.m_outputRows;
192 m_outputCols = base_mapper.m_outputCols;
193 m_row_strides = base_mapper.m_row_strides;
194 m_col_strides = base_mapper.m_col_strides;
195
196 m_in_row_strides = base_mapper.m_in_row_strides;
197 m_in_col_strides = base_mapper.m_in_col_strides;
198
199 m_rowPaddingTop = base_mapper.m_rowPaddingTop;
200 m_colPaddingLeft = base_mapper.m_colPaddingLeft;
201
202 m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
203 m_fastPatchColStride = base_mapper.m_fastPatchColStride;
204 m_fastInputRowStride = base_mapper.m_fastInputRowStride;
205 m_fastInputColStride = base_mapper.m_fastInputColStride;
206 m_fastNumPatches = base_mapper.m_fastNumPatches;
207 m_fastColStride = base_mapper.m_fastColStride;
208 m_fastOutputRows = base_mapper.m_fastOutputRows;
209 m_fastDimZero = base_mapper.m_fastDimZero;
210 }
211
212 // If true, turns off some optimizations for loading packets since the image
213 // patches are "non-standard" such as there are non-trivial strides or
214 // inflations in the input.
215 EIGEN_DEVICE_FUNC
216 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
217 return m_in_row_strides != 1 || m_in_col_strides != 1 ||
218 m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1;
219 }
220
221 EIGEN_DEVICE_FUNC
222 EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
223 return SubMapper(*this, i, j);
224 }
225
226 EIGEN_DEVICE_FUNC
227 EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
228 return LinearMapper(*this, i, j);
229 }
230
231 EIGEN_DEVICE_FUNC
232 EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const {
233 Index rowIndex, colIndex, otherIndex;
234 computeBaseIndices(0, rowIndex, colIndex, otherIndex);
235 return loadCoeff(row, rowIndex, colIndex, otherIndex);
236 }
237
238 // Load the coefficient at the patchIndex location instead of the usual
239 // m_rowIndex,
240 // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
241 // EIGEN_DEVICE_FUNC
242 EIGEN_DEVICE_FUNC
243 EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const {
244 Index rowIndex, colIndex, otherIndex;
245 computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
246 return loadCoeff(row, rowIndex, colIndex, otherIndex);
247 }
248
249 EIGEN_DEVICE_FUNC
250 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const {
251 Index rowIndex, colIndex, otherIndex;
252 computeBaseIndices(0, rowIndex, colIndex, otherIndex);
253 return loadPacket(row, rowIndex, colIndex, otherIndex);
254 }
255
256 // Load the packet at the patchIndex location instead of the usual m_rowIndex,
257 // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
258 EIGEN_DEVICE_FUNC
259 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const {
260 Index rowIndex, colIndex, otherIndex;
261 computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
262 return loadPacket(row, rowIndex, colIndex, otherIndex);
263 }
264
265 EIGEN_DEVICE_FUNC
266 EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const {
267 return m_impl;
268 }
269
270 EIGEN_DEVICE_FUNC
271 EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_rowInputStride; }
272 EIGEN_DEVICE_FUNC
273 EIGEN_ALWAYS_INLINE Index patchRows() const { return m_colStride; }
274 EIGEN_DEVICE_FUNC
275 EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; }
276
277 private:
278 friend class TensorContractionSubMapper<
279 Scalar, Index, Side,
280 TensorEvaluator<
281 const TensorReshapingOp<
282 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
283 Device>,
284 nocontract_t, contract_t, packet_size, inner_dim_contiguous,
285 inner_dim_reordered, Alignment>;
286
287 // Load coefficient from a patch specified by the "within patch offset"
288 // (patchId) and the precomputed indices of the first element of the patch.
289 EIGEN_DEVICE_FUNC
290 EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex,
291 Index colIndex, Index otherIndex) const {
292 // Find the offset of the element wrt the location of the first element.
293 const Index patchOffset = patchId / m_fastDimZero;
294
295 const Index colOffset = patchOffset / m_fastColStride;
296 const Index inputCol = colIndex + colOffset * m_in_col_strides;
297 const Index origInputCol =
298 (m_patch_col_inflate_strides == 1)
299 ? inputCol
300 : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
301
302 const Index rowOffset = patchOffset - colOffset * m_colStride;
303 const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
304 const Index origInputRow =
305 (m_patch_row_inflate_strides == 1)
306 ? inputRow
307 : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
308 if (origInputCol < 0 || origInputRow < 0 || origInputCol >= m_inputCols ||
309 origInputRow >= m_inputRows ||
310 (inputCol != origInputCol * m_patch_col_inflate_strides) ||
311 (inputRow != origInputRow * m_patch_row_inflate_strides)) {
312 return Scalar(0);
313 }
314 const Index depth = patchId - patchOffset * patchDepth();
315 const Index inputIndex = depth + origInputRow * m_rowInputStride +
316 origInputCol * m_colInputStride + otherIndex;
317 return m_impl.coeff(inputIndex);
318 }
319
320 // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
321 // and `in_strides` equal to 1 (template specialization without templates).
322 EIGEN_DEVICE_FUNC
323 EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex,
324 Index colIndex,
325 Index otherIndex) const {
326 eigen_assert(!nonStandardPatches());
327
328 // Find the offset of the element wrt the location of the first element.
329 const Index patchOffset = patchId / m_fastDimZero;
330 const Index colOffset = patchOffset / m_fastColStride;
331 const Index rowOffset = patchOffset - colOffset * m_colStride;
332 const Index inputCol = colIndex + colOffset;
333 const Index inputRow = rowIndex + rowOffset;
334 if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 ||
335 inputRow >= m_inputRows) {
336 return Scalar(0);
337 }
338 const Index depth = patchId - patchOffset * patchDepth();
339 const Index inputIndex = depth + inputRow * m_rowInputStride +
340 inputCol * m_colInputStride + otherIndex;
341 return m_impl.coeff(inputIndex);
342 }
343
344 // Load packet from a patch specified by the "within patch offset"
345 // (patchId) and the precomputed indices of the first element of the patch.
346 EIGEN_DEVICE_FUNC
347 EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex,
348 Index colIndex,
349 Index otherIndex) const {
350 const Index packetSize = internal::unpacket_traits<Packet>::size;
351 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
352 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
353
354 if (nonStandardPatches()) {
355 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
356 }
357 typedef decltype(m_impl) TensorEvaluatorT;
358 return loadPacketStandard<Packet, TensorEvaluatorT>(patchId, rowIndex,
359 colIndex, otherIndex);
360 }
361
362 // Helper function to load a 'partial' packet - this is the single column
363 // part of a packet that is split across two columns. In the 'partial' packet,
364 // the elements corresponding to the column (specified through colOffset) are
365 // loaded and the rest of the elements are zero-filled into the 'partial'
366 // packet. This function is called from loadPacketStandardFromTwoColumns().
367 // This code path is exercised only when the packet type supports masked load
368 // and when the partial packet load is available in the TensorEvaluator.
369 EIGEN_DEVICE_FUNC
370 EIGEN_ALWAYS_INLINE Packet loadPartialPacketStandard(
371 Index rowIndex, Index colIndex, Index otherIndex, Index patchId,
372 const Index span[], const Index patchOffsets[], Index colOffset) const {
373 const Index inputCol = colIndex + colOffset;
374 const Index rowOffsets[2] = {patchOffsets[0] - colOffset * m_colStride,
375 patchOffsets[1] - colOffset * m_colStride};
376 const Index inputRows[2] = {rowIndex + rowOffsets[0],
377 rowIndex + rowOffsets[1]};
378
379 if (inputRows[0] >= m_inputRows || inputRows[1] < 0 ||
380 inputCol >= m_inputCols || inputCol < 0) {
381 // Partial packet is all zeros
382 return internal::pset1<Packet>(Scalar(0));
383 } else if (inputRows[0] >= 0 && inputRows[1] < m_inputRows) {
384 // From inputIndex-span[0], we need to load elements starting from index
385 // span[0] all the way upto (and including) span[1].
386 const Index depth = patchId - patchOffsets[0] * patchDepth();
387 const Index inputIndex = depth + inputRows[0] * m_rowInputStride +
388 inputCol * m_colInputStride + otherIndex;
389 return m_impl.template partialPacket<Packet>(
390 inputIndex - span[0], mask<Packet>(span[0], span[1] + 1));
391 } else {
392 // Using slow path for this partial packet.
393 // We need to load elements starting from index span[0] all the way upto
394 // (and including) span[1]. We split this load into 3 parts:
395 // 0 : span[0]-1 - Zeros will be loaded for these indices
396 // span[0] : span[1] - Elements will be loaded here for these indices
397 // span[1]+1 : packetSize-1 - Zeross will be loaded for these indices
398 const Index packetSize = internal::unpacket_traits<Packet>::size;
399 EIGEN_ALIGN_MAX
400 std::remove_const_t<Scalar> values[packetSize];
401 for (int i = 0; i < span[0]; ++i) values[i] = Scalar(0);
402 for (int i = span[0]; i < span[1] + 1; ++i)
403 values[i] =
404 loadCoeff(patchId - span[0] + i, rowIndex, colIndex, otherIndex);
405 for (int i = span[1] + 1; i < packetSize; ++i) values[i] = Scalar(0);
406 return internal::pload<Packet>(values);
407 }
408 }
409
410 // Helper function to load a packet that is split across two columns.
411 // If required, this function is called from loadPacketStandard() when the
412 // packet type supports masked load and when the partial packet load is
413 // available in the TensorEvaluator.
414 EIGEN_DEVICE_FUNC
415 EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromTwoColumns(
416 Index patchId, Index rowIndex, Index colIndex, Index otherIndex,
417 const Index patchOffsets[], const Index colOffsets[]) const {
418 eigen_assert(colOffsets[1] == colOffsets[0] + 1);
419 const Index packetSize = internal::unpacket_traits<Packet>::size;
420
421 // Packet to load will be split into 2 parts where each part spans a single
422 // column. First determine where to split.
423 const Index patchIdSplit =
424 ((colOffsets[1] * m_colStride) * m_rowInputStride) - 1;
425 const Index patchOffsetSplit = patchIdSplit / m_fastDimZero;
426
427 // patchIds[i]: patchId corresponding to partial packet i
428 // spans[i]: Start and end indices corresponding to the elements
429 // to be loaded for partial packet i
430 // patchOffsets2Cols[i]: patchOffsets corresponding to partial packet i
431 const Index patchIds[2] = {patchId, patchIdSplit + 1};
432 const Index spans[2][2] = {{0, patchIdSplit - patchId},
433 {patchIdSplit - patchId + 1, packetSize - 1}};
434 const Index patchOffsets2Cols[2][2] = {
435 {patchOffsets[0], patchOffsetSplit},
436 {patchOffsetSplit + 1, patchOffsets[1]}};
437
438 // Load partial packets and do bit-wise OR to generate required packet
439 return internal::por<Packet>(
440 loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[0],
441 spans[0], patchOffsets2Cols[0],
442 colOffsets[0]),
443 loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[1],
444 spans[1], patchOffsets2Cols[1],
445 colOffsets[1]));
446 }
447
448 // Helper function to load a packet that is present in a single columns.
449 // If required, this function is called from loadPacketStandard().
450 EIGEN_DEVICE_FUNC
451 EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumn(
452 Index patchId, Index rowIndex, Index colIndex, Index otherIndex,
453 const Index patchOffsets[], const Index colOffsets[],
454 const Index inputCols[]) const {
455 eigen_assert(colOffsets[0] == colOffsets[1]);
456 const Index rowOffsets[2] = {patchOffsets[0] - colOffsets[0] * m_colStride,
457 patchOffsets[1] - colOffsets[1] * m_colStride};
458 eigen_assert(rowOffsets[0] <= rowOffsets[1]);
459 const Index inputRows[2] = {rowIndex + rowOffsets[0],
460 rowIndex + rowOffsets[1]};
461
462 if (inputRows[0] >= m_inputRows || inputRows[1] < 0) {
463 // all zeros
464 return internal::pset1<Packet>(Scalar(0)); // all zeros
465 }
466
467 if (inputRows[0] >= 0 && inputRows[1] < m_inputRows) {
468 // no padding
469 const Index depth = patchId - patchOffsets[0] * patchDepth();
470 const Index inputIndex = depth + inputRows[0] * m_rowInputStride +
471 inputCols[0] * m_colInputStride + otherIndex;
472 return m_impl.template packet<Unaligned>(inputIndex);
473 }
474 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
475 }
476
477 // Load standard packet from a patch specified by the "within patch offset"
478 // (patchId) and the precomputed indices of the first element of the patch.
479 // This function will be called if partial packet loading is not available
480 // for the TensorEvaluator or if the packet type does not support masked
481 // load.
482 template <typename PacketT, typename TensorEvaluatorT>
483 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
484 !TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value,
485 PacketT>::type
486 loadPacketStandard(Index patchId, Index rowIndex, Index colIndex,
487 Index otherIndex) const {
488 const Index packetSize = internal::unpacket_traits<Packet>::size;
489 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
490 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
491
492 eigen_assert(!nonStandardPatches());
493
494 if ((patchDepth() % packetSize) == 0) {
495 return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
496 }
497
498 // Offsets and input calculation here are identical to
499 // loadCoeffStandard(...), but repeated twice.
500 const Index patchOffsets[2] = {patchId / m_fastDimZero,
501 (patchId + packetSize - 1) / m_fastDimZero};
502 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
503 patchOffsets[1] / m_fastColStride};
504 const Index inputCols[2] = {colIndex + colOffsets[0],
505 colIndex + colOffsets[1]};
506
507 if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
508 // all zeros
509 return internal::pset1<Packet>(Scalar(0));
510 }
511 if (inputCols[0] == inputCols[1]) {
512 return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex,
513 otherIndex, patchOffsets,
514 colOffsets, inputCols);
515 }
516 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
517 }
518
519 // Load standard packet from a patch specified by the "within patch offset"
520 // (patchId) and the precomputed indices of the first element of the patch.
521 // This function will be called if partial packet loading is available for
522 // the TensorEvaluator and if the packet type supports masked load.
523 // The only difference between this and the other case is that if the packet
524 // to load is split across two columns, then in this case instead of going to
525 // the slow (element-by-element) load, we load two packets - each containing
526 // elements from one of the columns (rest of the elements of the packets are
527 // zeroes), and then combine these two packets to generate the required
528 // packet. The idea is to enable fast load (if possible) of these 'partial'
529 // packets.
530 template <typename PacketT, typename TensorEvaluatorT>
531 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
532 TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value,
533 PacketT>::type
534 loadPacketStandard(Index patchId, Index rowIndex, Index colIndex,
535 Index otherIndex) const {
536 const Index packetSize = internal::unpacket_traits<PacketT>::size;
537 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
538 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
539
540 eigen_assert(!nonStandardPatches());
541
542 if ((patchDepth() % packetSize) == 0) {
543 return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
544 }
545
546 // Offsets and input calculation here are identical to
547 // loadCoeffStandard(...), but repeated twice.
548 const Index patchOffsets[2] = {patchId / m_fastDimZero,
549 (patchId + packetSize - 1) / m_fastDimZero};
550 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
551 patchOffsets[1] / m_fastColStride};
552 const Index inputCols[2] = {colIndex + colOffsets[0],
553 colIndex + colOffsets[1]};
554
555 if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
556 // all zeros
557 return internal::pset1<PacketT>(Scalar(0));
558 }
559 if (inputCols[0] == inputCols[1]) {
560 return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex,
561 otherIndex, patchOffsets,
562 colOffsets, inputCols);
563 }
564 if (inputCols[1] == inputCols[0] + 1) {
565 return loadPacketStandardFromTwoColumns(
566 patchId, rowIndex, colIndex, otherIndex, patchOffsets, colOffsets);
567 }
568 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
569 }
570
571 EIGEN_DEVICE_FUNC
572 EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index rowIndex,
573 Index colIndex,
574 Index otherIndex) const {
575 const Index packetSize = internal::unpacket_traits<Packet>::size;
576 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
577 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
578
579 eigen_assert(!nonStandardPatches());
580 eigen_assert((patchDepth() % packetSize) == 0);
581 // Find the offset of the element wrt the location of the first element.
582 const Index patchOffset = patchId / m_fastDimZero;
583 eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
584
585 const Index colOffset = patchOffset / m_fastColStride;
586 const Index rowOffset = patchOffset - colOffset * m_colStride;
587 const Index inputCol = colIndex + colOffset;
588 const Index inputRow = rowIndex + rowOffset;
589 if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols ||
590 inputRow >= m_inputRows) {
591 // all zeros
592 return internal::pset1<Packet>(Scalar(0));
593 }
594 // no padding
595 const Index depth = patchId - patchOffset * patchDepth();
596 const Index inputIndex = depth + inputRow * m_rowInputStride +
597 inputCol * m_colInputStride + otherIndex;
598 return m_impl.template packet<Unaligned>(inputIndex);
599 }
600
601 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetWithPossibleZero(
602 Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const {
603 const int packetSize = internal::unpacket_traits<Packet>::size;
604 EIGEN_ALIGN_MAX
605 std::remove_const_t<Scalar> values[packetSize];
606 for (int i = 0; i < packetSize; ++i) {
607 values[i] = loadCoeff(patchId + i, rowIndex, colIndex, otherIndex);
608 }
609 Packet rslt = internal::pload<Packet>(values);
610 return rslt;
611 }
612
613 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(
614 Index patchIndex, Index& rowIndex, Index& colIndex,
615 Index& otherIndex) const {
616 const size_t NumInputDims = array_size<
617 typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
618 otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches;
619 const Index patch2DIndex = (NumInputDims == 3)
620 ? patchIndex
621 : (patchIndex - otherIndex * m_num_patches);
622 otherIndex *= m_patchInputStride;
623 colIndex = patch2DIndex / m_fastOutputRows;
624 rowIndex = patch2DIndex - colIndex * m_outputRows;
625 colIndex = colIndex * m_col_strides - m_colPaddingLeft;
626 rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
627 }
628
629 Index m_patch_cols; // number of columns in the patch
630 Index m_num_patches; // number of patches to extract.
631
632 // Strides for navigating through the single patch.
633 Index m_patch_row_stride;
634 Index m_patch_col_stride;
635 internal::TensorIntDivisor<Index> m_fastPatchRowStride;
636 internal::TensorIntDivisor<Index> m_fastPatchColStride;
637
638 Index m_patch_row_inflate_strides; // the strides for row inflation in the
639 // image patch
640 Index m_patch_col_inflate_strides; // the strides for col inflation in the
641 // image patch
642 // Fast representation of inflation strides.
643 internal::TensorIntDivisor<Index> m_fastInputRowStride;
644 internal::TensorIntDivisor<Index> m_fastInputColStride;
645
646 Index m_otherStride;
647 Index m_colStride;
648 internal::TensorIntDivisor<Index> m_fastNumPatches;
649 internal::TensorIntDivisor<Index> m_fastColStride;
650
651 Index m_rowInputStride; // row stride in the input tensor
652 Index m_colInputStride; // col stride in the input tensor
653 Index m_patchInputStride; // patch stride in the input tensor
654
655 Index m_inputRows; // Number of rows in the input tensor
656 Index m_inputCols; // Number of cols in the input tensor
657
658 Index m_outputRows; // Number of convolution output rows
659 Index m_outputCols; // Number of convolution output column
660
661 Index m_row_strides; // User specified row stride
662 Index m_col_strides; // User specified col stride
663
664 Index m_in_row_strides; // User specified input row stride
665 Index m_in_col_strides; // User specified input col stride
666
667 Index m_rowPaddingTop; // Row padding
668 Index m_colPaddingLeft; // Column padding
669
670 internal::TensorIntDivisor<Index> m_fastOutputRows;
671 internal::TensorIntDivisor<Index> m_fastDimZero;
672
673 const TensorEvaluator<ArgType, Device> m_impl;
674};
675
676template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
677 typename Device, typename Scalar, typename Index,
678 typename nocontract_t, typename contract_t, int Side, int packet_size,
679 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
680class TensorContractionSubMapper<
681 Scalar, Index, Side,
682 TensorEvaluator<
683 const TensorReshapingOp<NewDimension,
684 const TensorImagePatchOp<Rows, Cols, ArgType> >,
685 Device>,
686 nocontract_t, contract_t, packet_size, inner_dim_contiguous,
687 inner_dim_reordered, Alignment> {
688 public:
689 typedef typename packet_traits<Scalar>::type Packet;
690 typedef typename packet_traits<Scalar>::half HalfPacket;
691
692 typedef TensorContractionInputMapper<
693 Scalar, Index, Side,
694 TensorEvaluator<
695 const TensorReshapingOp<
696 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
697 Device>,
698 nocontract_t, contract_t, packet_size, inner_dim_contiguous,
699 inner_dim_reordered, Alignment>
700 ParentMapper;
701
702 typedef TensorContractionSubMapper<
703 Scalar, Index, Side,
704 TensorEvaluator<
705 const TensorReshapingOp<
706 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
707 Device>,
708 nocontract_t, contract_t, packet_size, inner_dim_contiguous,
709 inner_dim_reordered, Alignment>
710 Self;
711
712 typedef Self LinearMapper;
713
714 typedef typename ParentMapper::TensorEvaluatorT TensorEvaluatorT;
715
716 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
717 const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
718 : m_depth_offset(vert_offset),
719 m_col_offset(horiz_offset),
720 m_base_mapper(base_mapper) {
721 m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex,
722 m_otherIndex);
723 }
724 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
725 const Self& base_mapper, Index vert_offset, Index horiz_offset)
726 : m_depth_offset(vert_offset + base_mapper.m_depth_offset),
727 m_col_offset(horiz_offset + base_mapper.m_col_offset),
728 m_base_mapper(base_mapper.m_base_mapper) {
729 m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex,
730 m_otherIndex);
731 }
732 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
733 return m_base_mapper.loadCoeff(i + m_depth_offset, m_rowIndex, m_colIndex,
734 m_otherIndex);
735 }
736 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i,
737 Index j) const {
738 return m_base_mapper(i + m_depth_offset, j + m_col_offset);
739 }
740
741 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
742 return m_base_mapper.loadPacket(i + m_depth_offset, m_rowIndex, m_colIndex,
743 m_otherIndex);
744 }
745 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i,
746 Index j) const {
747 return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset,
748 j + m_col_offset);
749 }
750 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar
751 loadCoeffStandard(Index i) const {
752 return m_base_mapper.loadCoeffStandard(i + m_depth_offset, m_rowIndex,
753 m_colIndex, m_otherIndex);
754 }
755
756 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const {
757 return m_base_mapper.loadPacketFast(i + m_depth_offset, m_rowIndex,
758 m_colIndex, m_otherIndex);
759 }
760 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
761 loadPacketStandard(Index i) const {
762 typedef decltype(m_base_mapper.m_impl) TensorEvaluatorT;
763 return m_base_mapper.template loadPacketStandard<Packet, TensorEvaluatorT>(
764 i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
765 }
766 template <typename Packet>
767 EIGEN_DEVICE_FUNC bool aligned(Index) const {
768 return false;
769 }
770
771 EIGEN_DEVICE_FUNC
772 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
773 return m_base_mapper.nonStandardPatches();
774 }
775
776 // Max(Col|Row|Depth): compute the upper limit for the column, row and depth
777 // index respectively that fits into the peeled_k elements starting at
778 // m_depth_offset.
779
780 EIGEN_DEVICE_FUNC
781 EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const {
782 const Index max_col =
783 (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1)) /
784 fastPatchColStride();
785 return std::min<Index>(1 + max_col, patchCols());
786 }
787
788 EIGEN_DEVICE_FUNC
789 EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k,
790 const Index col) const {
791 const Index max_row = (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1) -
792 col * patchColStride()) /
793 fastPatchRowStride();
794 return std::min<Index>(1 + max_row, patchRows());
795 }
796
797 EIGEN_DEVICE_FUNC
798 EIGEN_ALWAYS_INLINE Index maxDepth(const Index peeled_k, const Index col,
799 Index row) const {
800 const Index max_depth = m_depth_offset + peeled_k - //
801 col * patchColStride() - //
802 row * patchRowStride();
803 return std::min<Index>(max_depth, patchDepth());
804 }
805
806 // MaxDepth uses only the remaining number of elements in the peeled_k.
807 EIGEN_DEVICE_FUNC
808 EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements,
809 const Index start_depth) const {
810 return std::min<Index>(start_depth + num_elements, patchDepth());
811 }
812
813 // Every register matters in this code, so sometimes to prevent register
814 // spilling, instead of the variable that you would expect to see, we use
815 // another one, that is guaranteed to have the same value. E.g. patch depth is
816 // always the same as input depth, and it's also the same as input row stride.
817 // Bunch of other parameters have similar relations.
818
819 typedef internal::TensorIntDivisor<Index> IndexDivisor;
820
821 EIGEN_DEVICE_FUNC
822 EIGEN_ALWAYS_INLINE Index patchDepth() const {
823 return m_base_mapper.m_rowInputStride;
824 }
825 EIGEN_DEVICE_FUNC
826 EIGEN_ALWAYS_INLINE Index patchRows() const {
827 return m_base_mapper.m_colStride;
828 }
829 EIGEN_DEVICE_FUNC
830 EIGEN_ALWAYS_INLINE Index patchCols() const {
831 return m_base_mapper.m_patch_cols;
832 }
833
834 EIGEN_DEVICE_FUNC
835 EIGEN_ALWAYS_INLINE Index patchRowStride() const {
836 eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
837 "Patch depth must be equal to patch row stride.");
838 return patchDepth();
839 }
840 EIGEN_DEVICE_FUNC
841 EIGEN_ALWAYS_INLINE Index patchColStride() const {
842 return m_base_mapper.m_patch_col_stride;
843 }
844
845 EIGEN_DEVICE_FUNC
846 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const {
847 eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
848 "Patch depth must be equal to patch row stride.");
849 return m_base_mapper.m_fastDimZero; // patch_depth
850 }
851 EIGEN_DEVICE_FUNC
852 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const {
853 return m_base_mapper.m_fastPatchColStride;
854 }
855
856 EIGEN_DEVICE_FUNC
857 EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
858 const Index baseIndex) const {
859 const Index inputIndex = depth + baseIndex;
860 return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex);
861 }
862 EIGEN_DEVICE_FUNC
863 EIGEN_ALWAYS_INLINE Scalar coeffNoPadding(const Index depth,
864 const Index baseIndex) const {
865 const Index inputIndex = depth + baseIndex;
866 return m_base_mapper.m_impl.coeff(inputIndex);
867 }
868 template <typename PacketT = Packet>
869 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
870 TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value,
871 PacketT>::type
872 partialPacketNoPadding(const Index depth, const Index baseIndex,
873 Index num_coeffs) const {
874 const Index inputIndex = depth + baseIndex;
875 return m_base_mapper.m_impl.template partialPacket<PacketT>(
876 inputIndex, mask<PacketT>(0, num_coeffs));
877 }
878 EIGEN_DEVICE_FUNC
879 EIGEN_ALWAYS_INLINE bool hasPadding() const {
880 // TODO(ezhulenev): It does seems that for inflated filter it's still
881 // possible to guarantee "no padding or skipping" for non-standard packing.
882 if (nonStandardPatches()) return true;
883
884 // Non zero padding before.
885 if (m_base_mapper.m_rowPaddingTop > 0) return true;
886 if (m_base_mapper.m_colPaddingLeft > 0) return true;
887
888 // Non zero padding after in rows.
889 const Index last_row =
890 (m_base_mapper.m_outputRows - 1) * m_base_mapper.m_row_strides;
891 if (last_row + (patchRows() - 1) >= m_base_mapper.m_inputRows) return true;
892
893 // Non zero padding after in cols.
894 const Index last_col =
895 (m_base_mapper.m_outputCols - 1) * m_base_mapper.m_col_strides;
896 if (last_col + (patchCols() - 1) >= m_base_mapper.m_inputCols) return true;
897
898 return false;
899 }
900 EIGEN_DEVICE_FUNC
901 EIGEN_ALWAYS_INLINE bool padRow(const Index row) const {
902 const Index r = m_rowIndex + row;
903 return r < 0 || r >= m_base_mapper.m_inputRows;
904 }
905 EIGEN_DEVICE_FUNC
906 EIGEN_ALWAYS_INLINE bool padAnyRow(const Index first_row,
907 const Index last_row) const {
908 return m_rowIndex + first_row < 0 ||
909 m_rowIndex + last_row >= m_base_mapper.m_inputRows;
910 }
911 EIGEN_DEVICE_FUNC
912 EIGEN_ALWAYS_INLINE bool padOrSkipRow(const Index row,
913 Index* orig_row) const {
914 eigen_assert(nonStandardPatches());
915
916 const Index input_row = m_rowIndex + row * m_base_mapper.m_in_row_strides;
917 *orig_row = (m_base_mapper.m_patch_row_inflate_strides == 1)
918 ? input_row
919 : ((input_row >= 0)
920 ? (input_row / m_base_mapper.m_fastInputRowStride)
921 : 0);
922
923 return (*orig_row < 0 || *orig_row >= m_base_mapper.m_inputRows) ||
924 (input_row != *orig_row * m_base_mapper.m_patch_row_inflate_strides);
925 }
926 EIGEN_DEVICE_FUNC
927 EIGEN_ALWAYS_INLINE bool padCol(const Index col) const {
928 const Index c = m_colIndex + col;
929 return c < 0 || c >= m_base_mapper.m_inputCols;
930 }
931 EIGEN_DEVICE_FUNC
932 EIGEN_ALWAYS_INLINE bool padOrSkipCol(const Index col,
933 Index* orig_col) const {
934 eigen_assert(nonStandardPatches());
935
936 const Index input_col = m_colIndex + col * m_base_mapper.m_in_col_strides;
937 *orig_col = (m_base_mapper.m_patch_col_inflate_strides == 1)
938 ? input_col
939 : ((input_col >= 0)
940 ? (input_col / m_base_mapper.m_fastInputColStride)
941 : 0);
942
943 return (*orig_col < 0 || *orig_col >= m_base_mapper.m_inputCols) ||
944 (input_col != *orig_col * m_base_mapper.m_patch_col_inflate_strides);
945 }
946 EIGEN_DEVICE_FUNC
947 EIGEN_ALWAYS_INLINE Index baseIndex(const Index row, const Index col) const {
948 const Index r = m_rowIndex + row;
949 const Index c = m_colIndex + col;
950 return r * m_base_mapper.m_rowInputStride +
951 c * m_base_mapper.m_colInputStride + m_otherIndex;
952 }
953 // Compute a base index when original input row and column were precomputed
954 // using padOrSkipRow and padOrSkipCol. Used only for non standard patches.
955 EIGEN_DEVICE_FUNC
956 EIGEN_ALWAYS_INLINE Index origBaseIndex(const Index orig_row,
957 const Index orig_col) const {
958 return orig_row * m_base_mapper.m_rowInputStride +
959 orig_col * m_base_mapper.m_colInputStride + m_otherIndex;
960 }
961
962 EIGEN_DEVICE_FUNC
963 EIGEN_ALWAYS_INLINE Index rowStride() const {
964 return m_base_mapper.m_row_strides;
965 }
966 EIGEN_DEVICE_FUNC
967 EIGEN_ALWAYS_INLINE Index colStride() const {
968 return m_base_mapper.m_col_strides;
969 }
970
971 EIGEN_DEVICE_FUNC
972 EIGEN_ALWAYS_INLINE Index rowOffset() const {
973 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
974 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
975 return patchOffset - colOffset * m_base_mapper.m_colStride;
976 }
977
978 EIGEN_DEVICE_FUNC
979 EIGEN_ALWAYS_INLINE Index colOffset() const {
980 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
981 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
982 return colOffset;
983 }
984
985 EIGEN_DEVICE_FUNC
986 EIGEN_ALWAYS_INLINE Index depthOffset() const {
987 return m_depth_offset % patchDepth();
988 }
989
990 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
991 getLinearMapper(Index i, Index j) const {
992 return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset);
993 }
994
995 private:
996 Index m_depth_offset; // First row in the input matrix
997 Index m_col_offset; // First col in the input matrix
998
999 // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
1000 // indices for the first element in a patch specified by col_offset
1001 // (see computeBaseIndices(...) for details).
1002 Index m_rowIndex;
1003 Index m_colIndex;
1004 Index m_otherIndex;
1005
1006 const ParentMapper m_base_mapper; // Keeping a copy instead of a reference
1007 // performs better in benchmarks.
1008};
1009
1010// Arrange a block of the right input matrix (in our case it's always a "virtual
1011// matrix" constructed from extracted image patches) in contiguous memory.
1012//
1013// Given column major input (A0 beside A1 in memory):
1014// A0 B0 C0 D0 E0 F0 G0 H0 ... Z0
1015// A1 B1 C1 D1 E1 F1 G1 H1 ... Z1
1016// A2 B2 C2 D2 E2 F2 G2 H2 ... Z2
1017// A3 B3 C3 D3 E3 F3 G3 H3 ... Z3
1018// A4 B4 C4 D4 E4 F4 G4 H4 ... Z4
1019// A5 B5 C5 D5 E5 F5 G5 H5 ... Z5
1020// A6 B6 C6 D6 E6 F6 G6 H6 ... Z6
1021// A7 B7 C7 D7 E7 F7 G7 H7 ... Z7
1022// A8 ...
1023// ...
1024//
1025// *) A, B, C, ... - patches extracted from the original input.
1026// *) A0, A1, A2 ... - values from the same patch at different offsets.
1027//
1028// The traversal (packed rhs memory) order (B0 besides A0 in memory):
1029// A0 B0 C0 D0 A1 B1 C1 D1 ...
1030// E0 F0 G0 H0 E1 F1 G1 H1 ...
1031// ...
1032// Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4)
1033//
1034// This traversal order must be the same as in default gemm_pack_rhs defined in
1035// GeneralBlockPanelKernel.h.
1036//
1037// *) nr - number of registers along the 'n' dimension.
1038// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
1039// Multiplication" paper.
1040template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
1041 typename Device, typename Scalar, typename Index,
1042 typename nocontract_t, typename contract_t, int packet_size,
1043 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
1044 int nr>
1045struct gemm_pack_rhs<
1046 Scalar, Index,
1047 TensorContractionSubMapper<
1048 Scalar, Index, Rhs,
1049 TensorEvaluator<
1050 const TensorReshapingOp<
1051 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
1052 Device>,
1053 nocontract_t, contract_t, packet_size, inner_dim_contiguous,
1054 inner_dim_reordered, Alignment>,
1055 nr, ColMajor, false, false> {
1056 typedef TensorContractionSubMapper<
1057 Scalar, Index, Rhs,
1058 TensorEvaluator<
1059 const TensorReshapingOp<
1060 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
1061 Device>,
1062 nocontract_t, contract_t, packet_size, inner_dim_contiguous,
1063 inner_dim_reordered, Alignment>
1064 SubMapper;
1065 typedef SubMapper DataMapper;
1066 typedef typename packet_traits<Scalar>::type Packet;
1067
1068 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1069
1070 EIGEN_DEVICE_FUNC
1071 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
1072 Index depth, Index cols, Index stride = 0,
1073 Index offset = 0) const {
1074 eigen_assert(stride == 0);
1075 eigen_assert(offset == 0);
1076
1077 const Index packet_cols4 = (cols / 4) * 4;
1078 const Index peeled_k = (depth / packet_size) * packet_size;
1079 const bool non_standard_patches = rhs.nonStandardPatches();
1080
1081 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
1082 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1083 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1084 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1085 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1086
1087 Index k = 0;
1088 if ((packet_size % 4) == 0 && !non_standard_patches) {
1089 // FAST PATH:
1090 // Iterate over patch columns and rows, if we know that a single
1091 // packet do not span across multiple rows or columns.
1092 if ((rhs.patchDepth() % packet_size) == 0) {
1093 const Index start_col = rhs.colOffset();
1094 const Index max_col = rhs.maxCol(peeled_k);
1095
1096 for (Index c = start_col; c < max_col; ++c) {
1097 eigen_assert(k <= peeled_k);
1098
1099 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1100 const Index max_row = rhs.maxRow(peeled_k, c);
1101
1102 const bool pad_col0 = dm0.padCol(c);
1103 const bool pad_col1 = dm1.padCol(c);
1104 const bool pad_col2 = dm2.padCol(c);
1105 const bool pad_col3 = dm3.padCol(c);
1106
1107 // Check if we can squeeze reads along the `row` and `depth`
1108 // dimensions (two innermost dimensions).
1109 if (!pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 && //
1110 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && //
1111 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && //
1112 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && //
1113 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1)) {
1114 // Compute how many elements we can squeeze read.
1115 const Index start_depth =
1116 (c == start_col) ? rhs.depthOffset() : 0;
1117
1118 // Upper bound for the number of elements in the depth dimension
1119 // that we can squeeze read.
1120 const Index squeeze_length =
1121 (max_row - start_row) * rhs.patchDepth() - start_depth;
1122
1123 // Do not overshoot beyond the block size.
1124 const Index max_depth =
1125 start_depth + std::min<Index>(peeled_k - k, squeeze_length);
1126 eigen_assert((max_depth - start_depth) % packet_size == 0);
1127
1128 const Index idx0 = dm0.baseIndex(start_row, c);
1129 const Index idx1 = dm1.baseIndex(start_row, c);
1130 const Index idx2 = dm2.baseIndex(start_row, c);
1131 const Index idx3 = dm3.baseIndex(start_row, c);
1132
1133 for (Index d = start_depth; d < max_depth; d += packet_size) {
1134 eigen_assert(k < peeled_k);
1135 PacketBlock<Packet, 4> kernel;
1136 kernel.packet[0] = rhs.packetNoPadding(d, idx0);
1137 kernel.packet[1] = rhs.packetNoPadding(d, idx1);
1138 kernel.packet[2] = rhs.packetNoPadding(d, idx2);
1139 kernel.packet[3] = rhs.packetNoPadding(d, idx3);
1140 ptranspose(kernel);
1141 pstoreu(block + 0 * packet_size, kernel.packet[0]);
1142 pstoreu(block + 1 * packet_size, kernel.packet[1]);
1143 pstoreu(block + 2 * packet_size, kernel.packet[2]);
1144 pstoreu(block + 3 * packet_size, kernel.packet[3]);
1145 block += 4 * packet_size;
1146 k += packet_size;
1147 }
1148
1149 // Go to the next column.
1150 continue;
1151 }
1152
1153 // If we can't squeeze reads, process rows one by one.
1154 for (Index r = start_row; r < max_row; ++r) {
1155 eigen_assert(k <= peeled_k);
1156
1157 const bool pad0 = pad_col0 || dm0.padRow(r);
1158 const bool pad1 = pad_col1 || dm1.padRow(r);
1159 const bool pad2 = pad_col2 || dm2.padRow(r);
1160 const bool pad3 = pad_col3 || dm3.padRow(r);
1161
1162 const Index idx0 = dm0.baseIndex(r, c);
1163 const Index idx1 = dm1.baseIndex(r, c);
1164 const Index idx2 = dm2.baseIndex(r, c);
1165 const Index idx3 = dm3.baseIndex(r, c);
1166
1167 const Index start_depth = ((c == start_col) && (r == start_row))
1168 ? rhs.depthOffset()
1169 : 0;
1170 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1171 eigen_assert((max_depth - start_depth) % packet_size == 0);
1172
1173 for (Index d = start_depth; d < max_depth; d += packet_size) {
1174 eigen_assert(k < peeled_k);
1175 PacketBlock<Packet, 4> kernel;
1176 kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
1177 : rhs.packetNoPadding(d, idx0);
1178 kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
1179 : rhs.packetNoPadding(d, idx1);
1180 kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0))
1181 : rhs.packetNoPadding(d, idx2);
1182 kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0))
1183 : rhs.packetNoPadding(d, idx3);
1184 ptranspose(kernel);
1185 pstoreu(block + 0 * packet_size, kernel.packet[0]);
1186 pstoreu(block + 1 * packet_size, kernel.packet[1]);
1187 pstoreu(block + 2 * packet_size, kernel.packet[2]);
1188 pstoreu(block + 3 * packet_size, kernel.packet[3]);
1189 block += 4 * packet_size;
1190 k += packet_size;
1191 }
1192 }
1193 }
1194
1195 // The loop above should fill peeled_k elements.
1196 eigen_assert(peeled_k == k);
1197
1198 } else {
1199 for (; k < peeled_k; k += packet_size) {
1200 PacketBlock<Packet, 4> kernel;
1201 kernel.packet[0] = dm0.loadPacketStandard(k);
1202 kernel.packet[1] = dm1.loadPacketStandard(k);
1203 kernel.packet[2] = dm2.loadPacketStandard(k);
1204 kernel.packet[3] = dm3.loadPacketStandard(k);
1205 ptranspose(kernel);
1206 pstoreu(block + 0 * packet_size, kernel.packet[0]);
1207 pstoreu(block + 1 * packet_size, kernel.packet[1]);
1208 pstoreu(block + 2 * packet_size, kernel.packet[2]);
1209 pstoreu(block + 3 * packet_size, kernel.packet[3]);
1210 block += 4 * packet_size;
1211 }
1212 }
1213 }
1214
1215 // Copy the remaining coefficients of the column block after the peeled_k.
1216 if (!rhs.nonStandardPatches()) {
1217 for (; k < depth; k++) {
1218 block[0] = dm0.loadCoeffStandard(k);
1219 block[1] = dm1.loadCoeffStandard(k);
1220 block[2] = dm2.loadCoeffStandard(k);
1221 block[3] = dm3.loadCoeffStandard(k);
1222 block += 4;
1223 }
1224 } else {
1225 for (; k < depth; k++) {
1226 block[0] = dm0(k);
1227 block[1] = dm1(k);
1228 block[2] = dm2(k);
1229 block[3] = dm3(k);
1230 block += 4;
1231 }
1232 }
1233 }
1234
1235 // copy the remaining columns one at a time (nr==1)
1236 for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1237 const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1238 for (Index k = 0; k < depth; k++) {
1239 *block = dm0(k);
1240 block += 1;
1241 }
1242 }
1243 }
1244};
1245
1246// Template specialization for packet_size = 2. We must special-case packet
1247// blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>.
1248template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
1249 typename Device, typename Scalar, typename Index,
1250 typename nocontract_t, typename contract_t, bool inner_dim_contiguous,
1251 bool inner_dim_reordered, int Alignment, int nr>
1252struct gemm_pack_rhs<
1253 Scalar, Index,
1254 TensorContractionSubMapper<
1255 Scalar, Index, Rhs,
1256 TensorEvaluator<
1257 const TensorReshapingOp<
1258 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
1259 Device>,
1260 nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered,
1261 Alignment>,
1262 nr, ColMajor, false, false> {
1263 typedef TensorContractionSubMapper<
1264 Scalar, Index, Rhs,
1265 TensorEvaluator<
1266 const TensorReshapingOp<
1267 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
1268 Device>,
1269 nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered,
1270 Alignment>
1271 SubMapper;
1272 typedef SubMapper DataMapper;
1273 typedef typename packet_traits<Scalar>::type Packet;
1274
1275 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1276
1277 EIGEN_DEVICE_FUNC
1278 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
1279 Index depth, Index cols, Index stride = 0,
1280 Index offset = 0) const {
1281 eigen_assert(stride == 0);
1282 eigen_assert(offset == 0);
1283
1284 const int packet_size = 2;
1285 const Index packet_cols4 = (cols / 4) * 4;
1286 const Index peeled_k = (depth / packet_size) * packet_size;
1287 const bool non_standard_patches = rhs.nonStandardPatches();
1288
1289 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
1290 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1291 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1292 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1293 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1294
1295 Index k = 0;
1296 if (!non_standard_patches) {
1297 // FAST PATH:
1298 // Iterate over patch columns and rows if we know that a single
1299 // packet do not span across multiple rows or columns.
1300 if ((rhs.patchDepth() % packet_size) == 0) {
1301 const Index start_col = rhs.colOffset();
1302 const Index max_col = rhs.maxCol(peeled_k);
1303
1304 for (Index c = start_col; c < max_col; ++c) {
1305 eigen_assert(k <= peeled_k);
1306
1307 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1308 const Index max_row = rhs.maxRow(peeled_k, c);
1309
1310 const bool pad_col0 = dm0.padCol(c);
1311 const bool pad_col1 = dm1.padCol(c);
1312 const bool pad_col2 = dm2.padCol(c);
1313 const bool pad_col3 = dm3.padCol(c);
1314
1315 // We can squeeze reads along the `row` and `depth` dimensions if
1316 // the row stride is `1`, which means that `row` and `depth`
1317 // dimensions are contiguous (two innermost dimensions).
1318 if (rhs.rowStride() == 1 && //
1319 !pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 && //
1320 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && //
1321 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && //
1322 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && //
1323 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1)) {
1324 // Compute how many elements we can squeeze read.
1325 const Index start_depth =
1326 (c == start_col) ? rhs.depthOffset() : 0;
1327
1328 // Upper bound for the number of elements in the depth dimension
1329 // that we can squeeze read.
1330 const Index squeeze_length =
1331 (max_row - start_row) * rhs.patchDepth() - start_depth;
1332
1333 // Do not overshoot beyond the block size.
1334 const Index max_depth =
1335 start_depth + std::min<Index>(peeled_k - k, squeeze_length);
1336 eigen_assert((max_depth - start_depth) % packet_size == 0);
1337
1338 const Index idx0 = dm0.baseIndex(start_row, c);
1339 const Index idx1 = dm1.baseIndex(start_row, c);
1340 const Index idx2 = dm2.baseIndex(start_row, c);
1341 const Index idx3 = dm3.baseIndex(start_row, c);
1342
1343 for (Index d = start_depth; d < max_depth; d += packet_size) {
1344 PacketBlock<Packet, 2> kernel0;
1345 PacketBlock<Packet, 2> kernel1;
1346 kernel0.packet[0] = rhs.packetNoPadding(d, idx0);
1347 kernel0.packet[1] = rhs.packetNoPadding(d, idx1);
1348 kernel1.packet[0] = rhs.packetNoPadding(d, idx2);
1349 kernel1.packet[1] = rhs.packetNoPadding(d, idx3);
1350 ptranspose(kernel0);
1351 ptranspose(kernel1);
1352 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1353 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1354 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1355 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1356 block += 4 * packet_size;
1357 k += packet_size;
1358 }
1359
1360 // Go to the next column.
1361 continue;
1362 }
1363
1364 // If we can't squeeze reads, process rows one by one.
1365 for (Index r = start_row; r < max_row; ++r) {
1366 eigen_assert(k <= peeled_k);
1367
1368 const bool pad0 = pad_col0 || dm0.padRow(r);
1369 const bool pad1 = pad_col1 || dm1.padRow(r);
1370 const bool pad2 = pad_col2 || dm2.padRow(r);
1371 const bool pad3 = pad_col3 || dm3.padRow(r);
1372
1373 const Index idx0 = dm0.baseIndex(r, c);
1374 const Index idx1 = dm1.baseIndex(r, c);
1375 const Index idx2 = dm2.baseIndex(r, c);
1376 const Index idx3 = dm3.baseIndex(r, c);
1377
1378 const Index start_depth = ((c == start_col) && (r == start_row))
1379 ? rhs.depthOffset()
1380 : 0;
1381 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1382 eigen_assert((max_depth - start_depth) % packet_size == 0);
1383
1384 for (Index d = start_depth; d < max_depth; d += packet_size) {
1385 eigen_assert(k < peeled_k);
1386 PacketBlock<Packet, 2> kernel0;
1387 PacketBlock<Packet, 2> kernel1;
1388 kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
1389 : rhs.packetNoPadding(d, idx0);
1390 kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
1391 : rhs.packetNoPadding(d, idx1);
1392 kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0))
1393 : rhs.packetNoPadding(d, idx2);
1394 kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0))
1395 : rhs.packetNoPadding(d, idx3);
1396 ptranspose(kernel0);
1397 ptranspose(kernel1);
1398 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1399 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1400 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1401 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1402 block += 4 * packet_size;
1403 k += packet_size;
1404 }
1405 }
1406 }
1407
1408 // The loop above should fill peeled_k elements.
1409 eigen_assert(peeled_k == k);
1410
1411 } else {
1412 // Packet can span multiple rows or columns, so we have to go
1413 // though the slower "standard" path.
1414 for (; k < peeled_k; k += packet_size) {
1415 PacketBlock<Packet, 2> kernel0;
1416 PacketBlock<Packet, 2> kernel1;
1417 kernel0.packet[0] = dm0.loadPacketStandard(k);
1418 kernel0.packet[1] = dm1.loadPacketStandard(k);
1419 kernel1.packet[0] = dm2.loadPacketStandard(k);
1420 kernel1.packet[1] = dm3.loadPacketStandard(k);
1421 ptranspose(kernel0);
1422 ptranspose(kernel1);
1423 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1424 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1425 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1426 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1427 block += 4 * packet_size;
1428 }
1429 }
1430 }
1431
1432 // Copy the remaining coefficients of the column block after the peeled_k.
1433 if (!non_standard_patches) {
1434 for (; k < depth; k++) {
1435 block[0] = dm0.loadCoeffStandard(k);
1436 block[1] = dm1.loadCoeffStandard(k);
1437 block[2] = dm2.loadCoeffStandard(k);
1438 block[3] = dm3.loadCoeffStandard(k);
1439 block += 4;
1440 }
1441 } else {
1442 for (; k < depth; k++) {
1443 block[0] = dm0(k);
1444 block[1] = dm1(k);
1445 block[2] = dm2(k);
1446 block[3] = dm3(k);
1447 block += 4;
1448 }
1449 }
1450 }
1451
1452 // Copy the remaining columns one at a time (nr==1).
1453 for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1454 const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1455 for (Index k = 0; k < depth; k++) {
1456 *block = dm0(k);
1457 block += 1;
1458 }
1459 }
1460 }
1461};
1462
1463// Special case for non-vectorized types such as float16.
1464template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
1465 typename Device, typename Scalar, typename Index,
1466 typename nocontract_t, typename contract_t, bool inner_dim_contiguous,
1467 bool inner_dim_reordered, int Alignment, int nr>
1468struct gemm_pack_rhs<
1469 Scalar, Index,
1470 TensorContractionSubMapper<
1471 Scalar, Index, Rhs,
1472 TensorEvaluator<
1473 const TensorReshapingOp<
1474 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
1475 Device>,
1476 nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered,
1477 Alignment>,
1478 nr, ColMajor, false, false> {
1479 typedef TensorContractionSubMapper<
1480 Scalar, Index, Rhs,
1481 TensorEvaluator<
1482 const TensorReshapingOp<
1483 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
1484 Device>,
1485 nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered,
1486 Alignment>
1487 SubMapper;
1488 typedef SubMapper DataMapper;
1489
1490 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1491
1492 EIGEN_DEVICE_FUNC
1493 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
1494 Index depth, Index cols, Index stride = 0,
1495 Index offset = 0) const {
1496 eigen_assert(stride == 0);
1497 eigen_assert(offset == 0);
1498
1499 const Index packet_cols4 = (cols / 4) * 4;
1500
1501 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
1502 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1503 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1504 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1505 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1506
1507 if (!rhs.nonStandardPatches()) {
1508 for (Index k = 0; k < depth; k++) {
1509 block[0] = dm0.loadCoeffStandard(k);
1510 block[1] = dm1.loadCoeffStandard(k);
1511 block[2] = dm2.loadCoeffStandard(k);
1512 block[3] = dm3.loadCoeffStandard(k);
1513 block += 4;
1514 }
1515 } else {
1516 for (Index k = 0; k < depth; k++) {
1517 block[0] = dm0(k);
1518 block[1] = dm1(k);
1519 block[2] = dm2(k);
1520 block[3] = dm3(k);
1521 block += 4;
1522 }
1523 }
1524 }
1525
1526 // Copy the remaining columns one at a time (nr==1).
1527 for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1528 const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1529 for (Index k = 0; k < depth; k++) {
1530 *block = dm0(k);
1531 block += 1;
1532 }
1533 }
1534 }
1535};
1536#endif
1537} // end namespace internal
1538
1539/** SpatialConvolution
1540 * \ingroup CXX11_NeuralNetworks_Module
1541 *
1542 * \brief Applies a 2D convolution over a multichannel input image.
1543 *
1544 * The input parameter is expected to be a tensor with a rank of 3 or more
1545 * (channels, height, width, and optionally others)
1546 * The kernel parameter is expected to be a 4D tensor (filters, channels,
1547 * kernel_height, kernel_width)
1548 * The input and the kernel must both be in col-major layout. The result will
1549 * also be in col-major layout.
1550 *
1551 * If col_in_stride, row_in_stride > 1, then applies convolution with holes
1552 * (aka atrous convolution), sampling every col_in_stride, row_in_stride input
1553 * pixels.
1554 *
1555 * If padding_top, padding_bottom, padding_left, or padding_right is specified,
1556 * then those paddings will be used to pad the input, and padding_type must be
1557 * PADDING_VALID.
1558 *
1559 * The result can be assigned to a tensor of rank equal to the rank of the
1560 * input. The dimensions of the result will be filters, height, width (and
1561 * others if applicable).
1562 *
1563 * It is possible to swap the order of the width and height dimensions provided
1564 * that the same order is used in the input, the kernel, and the output.
1565 *
1566 * It is also possible to add an output kernel to the contraction, output
1567 * kernel is called by Eigen when it "finalizes" the block of an output tensor.
1568 *
1569 */
1570template <typename Input, typename Kernel,
1571 typename OutputKernel = const NoOpOutputKernel>
1572EIGEN_ALWAYS_INLINE static const std::conditional_t<
1573 internal::traits<Input>::Layout == ColMajor,
1574 TensorReshapingOp<
1575 const DSizes<typename internal::traits<Input>::Index,
1576 internal::traits<Input>::NumDimensions>,
1577 const TensorContractionOp<
1578 const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
1579 const TensorReshapingOp<
1580 const DSizes<typename internal::traits<Input>::Index, 2>,
1581 const Kernel>,
1582 const TensorReshapingOp<
1583 const DSizes<typename internal::traits<Input>::Index, 2>,
1584 const TensorImagePatchOp<Dynamic, Dynamic, const Input> >,
1585 const OutputKernel> >,
1586 TensorReshapingOp<
1587 const DSizes<typename internal::traits<Input>::Index,
1588 internal::traits<Input>::NumDimensions>,
1589 const TensorContractionOp<
1590 const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
1591 const TensorReshapingOp<
1592 const DSizes<typename internal::traits<Input>::Index, 2>,
1593 const TensorImagePatchOp<Dynamic, Dynamic, const Input> >,
1594 const TensorReshapingOp<
1595 const DSizes<typename internal::traits<Input>::Index, 2>,
1596 const Kernel>,
1597 const OutputKernel> > >
1598SpatialConvolution(const Input& input, const Kernel& kernel,
1599 const Index row_stride = 1, const Index col_stride = 1,
1600 const PaddingType padding_type = PADDING_SAME,
1601 const Index row_in_stride = 1, const Index col_in_stride = 1,
1602 const OutputKernel& output_kernel = OutputKernel(),
1603 Index padding_top = 0, Index padding_bottom = 0,
1604 Index padding_left = 0, Index padding_right = 0) {
1605 typedef typename internal::traits<Input>::Index TensorIndex;
1606 typedef typename internal::traits<Input>::Scalar InputScalar;
1607 TensorRef<Tensor<InputScalar, internal::traits<Input>::NumDimensions,
1608 internal::traits<Input>::Layout, TensorIndex> >
1609 in(input);
1610 TensorRef<Tensor<typename internal::traits<Kernel>::Scalar,
1611 internal::traits<Kernel>::NumDimensions,
1612 internal::traits<Kernel>::Layout, TensorIndex> >
1613 kern(kernel);
1614
1615 EIGEN_STATIC_ASSERT(
1616 internal::traits<Input>::Layout == internal::traits<Kernel>::Layout,
1617 YOU_MADE_A_PROGRAMMING_MISTAKE)
1618 const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
1619
1620 const int NumDims = internal::traits<Input>::NumDimensions;
1621
1622 // Number of filters to apply. This is the same as the output depth of the
1623 // result
1624 const TensorIndex kernelFilters =
1625 isColMajor ? kern.dimensions()[0] : kern.dimensions()[3];
1626 // Number of channels. This is the same as the input depth.
1627 const TensorIndex kernelChannels =
1628 isColMajor ? kern.dimensions()[1] : kern.dimensions()[2];
1629 const TensorIndex kernelRows =
1630 isColMajor ? kern.dimensions()[2] : kern.dimensions()[1];
1631 const TensorIndex kernelCols =
1632 isColMajor ? kern.dimensions()[3] : kern.dimensions()[0];
1633
1634 const Index kernelRowsEff =
1635 kernelRows + (kernelRows - 1) * (row_in_stride - 1);
1636 const Index kernelColsEff =
1637 kernelCols + (kernelCols - 1) * (col_in_stride - 1);
1638
1639 array<IndexPair<TensorIndex>, 1> contract_dims;
1640 contract_dims[0] = IndexPair<TensorIndex>(1, 0);
1641
1642 const TensorIndex InputRows =
1643 isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
1644 const TensorIndex InputCols =
1645 isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
1646 const bool padding_explicit =
1647 (padding_top || padding_bottom || padding_left || padding_right);
1648
1649 TensorIndex out_height;
1650 TensorIndex out_width;
1651 switch (padding_type) {
1652 case PADDING_VALID: {
1653 const TensorIndex InputRowsEff = InputRows + padding_top + padding_bottom;
1654 const TensorIndex InputColsEff = InputCols + padding_left + padding_right;
1655 out_height = divup(InputRowsEff - kernelRowsEff + 1, row_stride);
1656 out_width = divup(InputColsEff - kernelColsEff + 1, col_stride);
1657 break;
1658 }
1659 case PADDING_SAME: {
1660 eigen_assert(!padding_explicit);
1661 out_height = divup(InputRows, row_stride);
1662 out_width = divup(InputCols, col_stride);
1663 break;
1664 }
1665 default: {
1666 // Initialize unused variables to avoid a compiler warning
1667 out_height = 0;
1668 out_width = 0;
1669 eigen_assert(false && "unexpected padding");
1670 }
1671 }
1672
1673 // Molds the output of the patch extraction code into a 2d tensor:
1674 // - the first dimension (dims[0]): the patch values to be multiplied with the
1675 // kernels
1676 // - the second dimension (dims[1]): everything else
1677 DSizes<TensorIndex, 2> pre_contract_dims;
1678 if (isColMajor) {
1679 pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols;
1680 pre_contract_dims[1] = out_height * out_width;
1681 for (int i = 3; i < NumDims; ++i) {
1682 pre_contract_dims[1] *= in.dimension(i);
1683 }
1684 } else {
1685 pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols;
1686 pre_contract_dims[0] = out_height * out_width;
1687 for (int i = 0; i < NumDims - 3; ++i) {
1688 pre_contract_dims[0] *= in.dimension(i);
1689 }
1690 }
1691
1692 // Molds the output of the contraction into the shape expected by the used
1693 // (assuming this is ColMajor):
1694 // - 1st dim: kernel filters
1695 // - 2nd dim: output height
1696 // - 3rd dim: output width
1697 // - 4th dim and beyond: everything else including batch size
1698 DSizes<TensorIndex, NumDims> post_contract_dims;
1699 if (isColMajor) {
1700 post_contract_dims[0] = kernelFilters;
1701 post_contract_dims[1] = out_height;
1702 post_contract_dims[2] = out_width;
1703 for (int i = 3; i < NumDims; ++i) {
1704 post_contract_dims[i] = in.dimension(i);
1705 }
1706 } else {
1707 post_contract_dims[NumDims - 1] = kernelFilters;
1708 post_contract_dims[NumDims - 2] = out_height;
1709 post_contract_dims[NumDims - 3] = out_width;
1710 for (int i = 0; i < NumDims - 3; ++i) {
1711 post_contract_dims[i] = in.dimension(i);
1712 }
1713 }
1714
1715 DSizes<TensorIndex, 2> kernel_dims;
1716 if (isColMajor) {
1717 kernel_dims[0] = kernelFilters;
1718 kernel_dims[1] = kernelChannels * kernelRows * kernelCols;
1719 } else {
1720 kernel_dims[0] = kernelChannels * kernelRows * kernelCols;
1721 kernel_dims[1] = kernelFilters;
1722 }
1723 if (padding_explicit) {
1724 return choose(
1725 Cond<internal::traits<Input>::Layout == ColMajor>(),
1726 kernel.reshape(kernel_dims)
1727 .contract(input
1728 .extract_image_patches(
1729 kernelRows, kernelCols, row_stride, col_stride,
1730 row_in_stride, col_in_stride,
1731 /*row_inflate_stride=*/1,
1732 /*col_inflate_stride=*/1, padding_top,
1733 padding_bottom, padding_left, padding_right,
1734 /*padding_value=*/static_cast<InputScalar>(0))
1735 .reshape(pre_contract_dims),
1736 contract_dims, output_kernel)
1737 .reshape(post_contract_dims),
1738 input
1739 .extract_image_patches(
1740 kernelRows, kernelCols, row_stride, col_stride, row_in_stride,
1741 col_in_stride,
1742 /*row_inflate_stride=*/1,
1743 /*col_inflate_stride=*/1, padding_top, padding_bottom,
1744 padding_left, padding_right,
1745 /*padding_value=*/static_cast<InputScalar>(0))
1746 .reshape(pre_contract_dims)
1747 .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel)
1748 .reshape(post_contract_dims));
1749 } else {
1750 return choose(
1751 Cond<internal::traits<Input>::Layout == ColMajor>(),
1752 kernel.reshape(kernel_dims)
1753 .contract(input
1754 .extract_image_patches(
1755 kernelRows, kernelCols, row_stride, col_stride,
1756 row_in_stride, col_in_stride, padding_type)
1757 .reshape(pre_contract_dims),
1758 contract_dims, output_kernel)
1759 .reshape(post_contract_dims),
1760 input
1761 .extract_image_patches(kernelRows, kernelCols, row_stride,
1762 col_stride, row_in_stride, col_in_stride,
1763 padding_type)
1764 .reshape(pre_contract_dims)
1765 .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel)
1766 .reshape(post_contract_dims));
1767 }
1768}
1769
1770} // end namespace Eigen
1771
1772#endif // TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_
1773