1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_
17#define TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_
18
19#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20
21#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
22#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
23#endif
24
25#include "tensorflow/core/kernels/eigen_convolution_helpers.h"
26
27namespace Eigen {
28
29namespace internal {
30
31#if !EIGEN_ALTIVEC_USE_CUSTOM_PACK
32// WARNING: Most of the code here implicitly assumes that the matrix is in
33// ColMajor layout. This is guaranteed by the tensor contraction (see
34// TensorContraction.h).
35//
36// Inside Eigen a tensor contraction is represented by a matrix multiplication.
37// We don't want to actually extract volume patches and reshape the result into
38// a matrix (this involves allocating huge extra memory), so the patch
39// extraction and reshape operations are implicit.
40//
41// TensorContractionInputMapper takes a matrix index and returns the coefficient
42// (or the packet) of the "virtual tensor", that would be at that index if we
43// were to actually reshape the result of patch extraction.
44//
45// TensorContractionSubMapper provides a similar view into the "virtual matrix"
46// at the given vertical and horizontal offsets.
47//
48// "Virtual matrix" dimensions:
49// *0: kernelChannels * kernelPlanes * kernelRows * kernelCols
50// 1: out_planes * out_height * out_width * OTHERS (e.g batches, etc...)
51//
52// *) extracted patches are continuous in memory (innermost dimension assuming
53// col major layout)
54//
55// With this dimensions:
56// row - offset within a single patch (in code: patchId)
57// col - index of the extracted patch (in code: patchIndex)
58// patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
59//
60template <typename NewDimension, Index Planes, Index Rows, Index Cols,
61 typename ArgType, typename Device, typename Scalar_, typename Index,
62 typename nocontract_t, typename contract_t, int Side, int packet_size,
63 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
64class TensorContractionInputMapper<
65 Scalar_, Index, Side,
66 TensorEvaluator<const TensorReshapingOp<NewDimension,
67 const TensorVolumePatchOp<
68 Planes, Rows, Cols, ArgType> >,
69 Device>,
70 nocontract_t, contract_t, packet_size, inner_dim_contiguous,
71 inner_dim_reordered, Alignment> {
72 public:
73 typedef Scalar_ Scalar;
74 typedef TensorContractionInputMapper<
75 Scalar, Index, Side,
76 TensorEvaluator<const TensorReshapingOp<
77 NewDimension, const TensorVolumePatchOp<
78 Planes, Rows, Cols, ArgType> >,
79 Device>,
80 nocontract_t, contract_t, packet_size, inner_dim_contiguous,
81 inner_dim_reordered, Alignment>
82 Self;
83 typedef TensorContractionSubMapper<
84 Scalar, Index, Side,
85 TensorEvaluator<const TensorReshapingOp<
86 NewDimension, const TensorVolumePatchOp<
87 Planes, Rows, Cols, ArgType> >,
88 Device>,
89 nocontract_t, contract_t, packet_size, inner_dim_contiguous,
90 inner_dim_reordered, Alignment>
91 SubMapper;
92 typedef SubMapper VectorMapper;
93 typedef SubMapper LinearMapper;
94 typedef typename packet_traits<Scalar>::type Packet;
95
96 EIGEN_DEVICE_FUNC
97 TensorContractionInputMapper(
98 const TensorEvaluator<
99 const TensorReshapingOp<
100 NewDimension,
101 const TensorVolumePatchOp<Planes, Rows, Cols, ArgType> >,
102 Device>& tensor,
103 const nocontract_t&, const nocontract_t&, const contract_t&,
104 const contract_t&)
105 : m_impl(tensor.impl().impl()) {
106 if (internal::traits<ArgType>::Layout == ColMajor) {
107 m_patch_depth = tensor.impl().dimensions()[0];
108 m_patch_planes = tensor.impl().dimensions()[1];
109 m_patch_rows = tensor.impl().dimensions()[2];
110 m_patch_cols = tensor.impl().dimensions()[3];
111 m_num_patches = tensor.impl().dimensions()[4];
112 } else {
113 const int NumDims = tensor.impl().dimensions().size();
114 m_patch_depth = tensor.impl().dimensions()[NumDims - 1];
115 m_patch_planes = tensor.impl().dimensions()[NumDims - 2];
116 m_patch_rows = tensor.impl().dimensions()[NumDims - 3];
117 m_patch_cols = tensor.impl().dimensions()[NumDims - 4];
118 m_num_patches = tensor.impl().dimensions()[NumDims - 5];
119 }
120
121 // Strides for navigating through the single patch.
122 m_patch_plane_stride = m_patch_depth;
123 m_patch_row_stride = m_patch_planes * m_patch_plane_stride;
124 m_patch_col_stride = m_patch_rows * m_patch_row_stride;
125
126 // Strides for the output tensor.
127 // IMPORTANT: These strides are used to locate an element in a patch at a
128 // depth zero (channel), which is not quite the same as "traditional"
129 // stride.
130 m_rowStride = m_patch_planes;
131 m_colStride = m_patch_rows * m_rowStride;
132 m_patchStride = m_colStride * m_patch_cols * m_patch_depth;
133 m_otherStride = m_patchStride * m_num_patches;
134
135 m_outputPlanes = tensor.impl().outputPlanes();
136 m_outputRows = tensor.impl().outputRows();
137 m_outputCols = tensor.impl().outputCols();
138
139 m_outputPlanesRows = m_outputPlanes * m_outputRows;
140
141 m_plane_strides = tensor.impl().userPlaneStride();
142 m_row_strides = tensor.impl().userRowStride();
143 m_col_strides = tensor.impl().userColStride();
144
145 m_in_plane_strides = tensor.impl().userInPlaneStride();
146 m_in_row_strides = tensor.impl().userInRowStride();
147 m_in_col_strides = tensor.impl().userInColStride();
148
149 m_patch_plane_inflate_strides = tensor.impl().planeInflateStride();
150 m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
151 m_patch_col_inflate_strides = tensor.impl().colInflateStride();
152
153 if (internal::traits<ArgType>::Layout == ColMajor) {
154 m_inputDepth = tensor.impl().impl().dimensions()[0];
155 m_inputPlanes = tensor.impl().impl().dimensions()[1];
156 m_inputRows = tensor.impl().impl().dimensions()[2];
157 m_inputCols = tensor.impl().impl().dimensions()[3];
158 } else {
159 const int NumDims = tensor.impl().impl().dimensions().size();
160 m_inputDepth = tensor.impl().impl().dimensions()[NumDims - 1];
161 m_inputPlanes = tensor.impl().impl().dimensions()[NumDims - 2];
162 m_inputRows = tensor.impl().impl().dimensions()[NumDims - 3];
163 m_inputCols = tensor.impl().impl().dimensions()[NumDims - 4];
164 }
165
166 // Strides for navigating through the input tensor.
167 m_planeInputStride = m_inputDepth;
168 m_rowInputStride = m_inputDepth * m_inputPlanes;
169 m_colInputStride = m_inputDepth * m_inputRows * m_inputPlanes;
170 m_patchInputStride =
171 m_inputDepth * m_inputRows * m_inputCols * m_inputPlanes;
172
173 m_planePaddingTop = tensor.impl().planePaddingTop();
174 m_rowPaddingTop = tensor.impl().rowPaddingTop();
175 m_colPaddingLeft = tensor.impl().colPaddingLeft();
176
177 m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
178
179 m_fastPatchPlaneStride =
180 internal::TensorIntDivisor<Index>(m_patch_plane_stride);
181 m_fastPatchRowStride =
182 internal::TensorIntDivisor<Index>(m_patch_row_stride);
183 m_fastPatchColStride =
184 internal::TensorIntDivisor<Index>(m_patch_col_stride);
185
186 m_fastInputPlaneStride =
187 internal::TensorIntDivisor<Index>(m_patch_plane_inflate_strides);
188 m_fastInputRowStride =
189 internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
190 m_fastInputColStride =
191 internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides);
192
193 m_fastRowStride = internal::TensorIntDivisor<Index>(m_rowStride);
194 m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
195
196 m_fastDimZero = internal::TensorIntDivisor<Index>(m_patch_depth);
197 m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
198 m_fastOutputPlanes = internal::TensorIntDivisor<Index>(m_outputPlanes);
199 m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
200 m_fastOutputCols = internal::TensorIntDivisor<Index>(m_outputCols);
201
202 m_fastOutputPlanesRows =
203 internal::TensorIntDivisor<Index>(m_outputPlanesRows);
204 }
205
206 EIGEN_DEVICE_FUNC
207 TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper)
208 : m_impl(base_mapper.m_impl) {
209 m_patch_depth = base_mapper.m_patch_depth;
210 m_patch_planes = base_mapper.m_patch_planes;
211 m_patch_rows = base_mapper.m_patch_rows;
212 m_patch_cols = base_mapper.m_patch_cols;
213 m_num_patches = base_mapper.m_num_patches;
214
215 m_patch_plane_stride = base_mapper.m_patch_plane_stride;
216 m_patch_row_stride = base_mapper.m_patch_row_stride;
217 m_patch_col_stride = base_mapper.m_patch_col_stride;
218
219 m_rowStride = base_mapper.m_rowStride;
220 m_colStride = base_mapper.m_colStride;
221 m_patchStride = base_mapper.m_patchStride;
222 m_otherStride = base_mapper.m_otherStride;
223
224 m_planeInputStride = base_mapper.m_planeInputStride;
225 m_rowInputStride = base_mapper.m_rowInputStride;
226 m_colInputStride = base_mapper.m_colInputStride;
227 m_patchInputStride = base_mapper.m_patchInputStride;
228 m_otherInputStride = base_mapper.m_otherInputStride;
229
230 m_inputDepth = base_mapper.m_inputDepth;
231 m_inputPlanes = base_mapper.m_inputPlanes;
232 m_inputRows = base_mapper.m_inputRows;
233 m_inputCols = base_mapper.m_inputCols;
234
235 m_outputPlanes = base_mapper.m_outputPlanes;
236 m_outputRows = base_mapper.m_outputRows;
237 m_outputCols = base_mapper.m_outputCols;
238
239 m_plane_strides = base_mapper.m_plane_strides;
240 m_row_strides = base_mapper.m_row_strides;
241 m_col_strides = base_mapper.m_col_strides;
242
243 m_in_plane_strides = base_mapper.m_in_plane_strides;
244 m_in_row_strides = base_mapper.m_in_row_strides;
245 m_in_col_strides = base_mapper.m_in_col_strides;
246
247 m_patch_plane_inflate_strides = base_mapper.m_patch_plane_inflate_strides;
248 m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
249 m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
250
251 m_planePaddingTop = base_mapper.m_planePaddingTop;
252 m_rowPaddingTop = base_mapper.m_rowPaddingTop;
253 m_colPaddingLeft = base_mapper.m_colPaddingLeft;
254
255 m_outputPlanesRows = base_mapper.m_outputPlanesRows;
256
257 m_fastNumPatches = base_mapper.m_fastNumPatches;
258 m_fastPatchPlaneStride = base_mapper.m_fastPatchPlaneStride;
259 m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
260 m_fastPatchColStride = base_mapper.m_fastPatchColStride;
261 m_fastInputPlaneStride = base_mapper.m_fastInputPlaneStride;
262 m_fastInputRowStride = base_mapper.m_fastInputRowStride;
263 m_fastInputColStride = base_mapper.m_fastInputColStride;
264 m_fastRowStride = base_mapper.m_fastRowStride;
265 m_fastColStride = base_mapper.m_fastColStride;
266 m_fastOutputPlanes = base_mapper.m_fastOutputPlanes;
267 m_fastOutputRows = base_mapper.m_fastOutputRows;
268 m_fastOutputCols = base_mapper.m_fastOutputCols;
269 m_fastDimZero = base_mapper.m_fastDimZero;
270 m_fastOutputPlanesRows = base_mapper.m_fastOutputPlanesRows;
271 }
272
273 // If true, turns off some optimizations for loading packets since the image
274 // patches are "non-standard" such as there are non-trivial strides or
275 // inflations in the input.
276 EIGEN_DEVICE_FUNC
277 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
278 return m_in_plane_strides != 1 || m_in_row_strides != 1 ||
279 m_in_col_strides != 1 || m_patch_plane_inflate_strides != 1 ||
280 m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1;
281 }
282
283 EIGEN_DEVICE_FUNC
284 EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
285 return SubMapper(*this, i, j);
286 }
287
288 EIGEN_DEVICE_FUNC
289 EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
290 return LinearMapper(*this, i, j);
291 }
292
293 EIGEN_DEVICE_FUNC
294 EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const {
295 Index planeIndex, rowIndex, colIndex, otherIndex;
296 computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex);
297 return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex);
298 }
299
300 // Load the coefficient at the patchIndex location instead of the usual
301 // m_rowIndex, m_colIndex, m_otherIndex. This is currently only used by the
302 // gpu code.
303 EIGEN_DEVICE_FUNC
304 EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const {
305 Index planeIndex, rowIndex, colIndex, otherIndex;
306 computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex);
307 return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex);
308 }
309
310 EIGEN_DEVICE_FUNC
311 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const {
312 Index planeIndex, rowIndex, colIndex, otherIndex;
313 computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex);
314 return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex);
315 }
316
317 // Load the packet at the patchIndex location instead of the usual m_rowIndex,
318 // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
319 EIGEN_DEVICE_FUNC
320 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const {
321 Index planeIndex, rowIndex, colIndex, otherIndex;
322 computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex);
323 return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex);
324 }
325
326 EIGEN_DEVICE_FUNC
327 EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const {
328 return m_impl;
329 }
330
331 EIGEN_DEVICE_FUNC
332 EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_planeInputStride; }
333 EIGEN_DEVICE_FUNC
334 EIGEN_ALWAYS_INLINE Index patchPlanes() const { return m_rowStride; }
335 EIGEN_DEVICE_FUNC
336 EIGEN_ALWAYS_INLINE Index patchRows() const { return m_patch_rows; }
337 EIGEN_DEVICE_FUNC
338 EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; }
339
340 private:
341 friend class TensorContractionSubMapper<
342 Scalar, Index, Side,
343 TensorEvaluator<const TensorReshapingOp<
344 NewDimension, const TensorVolumePatchOp<
345 Planes, Rows, Cols, ArgType> >,
346 Device>,
347 nocontract_t, contract_t, packet_size, inner_dim_contiguous,
348 inner_dim_reordered, Alignment>;
349
350 // Load coefficient from a patch specified by the "within patch offset"
351 // (patchId) and the precomputed indices of the first element of the patch.
352 EIGEN_DEVICE_FUNC
353 EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index planeIndex,
354 Index rowIndex, Index colIndex,
355 Index otherIndex) const {
356 // Find the offset of the element wrt the location of the first element.
357 const Index patchOffset = patchId / m_fastDimZero;
358
359 const Index colOffset = patchOffset / m_fastColStride;
360 const Index inputCol = colIndex + colOffset * m_in_col_strides;
361 const Index origInputCol =
362 (m_patch_col_inflate_strides == 1)
363 ? inputCol
364 : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
365
366 const Index rowOffset =
367 (patchOffset - colOffset * m_colStride) / m_fastRowStride;
368 const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
369 const Index origInputRow =
370 (m_patch_row_inflate_strides == 1)
371 ? inputRow
372 : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
373
374 const Index planeOffset =
375 patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
376 const Index inputPlane = planeIndex + planeOffset * m_in_plane_strides;
377 const Index origInputPlane =
378 (m_patch_plane_inflate_strides == 1)
379 ? inputPlane
380 : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0);
381
382 if (origInputCol < 0 || origInputRow < 0 || origInputPlane < 0 ||
383 origInputCol >= m_inputCols || origInputRow >= m_inputRows ||
384 origInputPlane >= m_inputPlanes ||
385 (inputCol != origInputCol * m_patch_col_inflate_strides) ||
386 (inputRow != origInputRow * m_patch_row_inflate_strides) ||
387 (inputPlane != origInputPlane * m_patch_plane_inflate_strides)) {
388 return Scalar(0);
389 }
390
391 const Index depth = patchId - patchOffset * patchDepth();
392 const Index inputIndex = depth + origInputPlane * m_planeInputStride +
393 origInputRow * m_rowInputStride +
394 origInputCol * m_colInputStride + otherIndex;
395
396 return m_impl.coeff(inputIndex);
397 }
398
399 // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
400 // and `in_strides` equal to 1 (template specialization without templates).
401 EIGEN_DEVICE_FUNC
402 EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index planeIndex,
403 Index rowIndex, Index colIndex,
404 Index otherIndex) const {
405 eigen_assert(!nonStandardPatches());
406
407 // Find the offset of the element wrt the location of the first element.
408 const Index patchOffset = patchId / m_fastDimZero;
409
410 const Index colOffset = patchOffset / m_fastColStride;
411 const Index rowOffset =
412 (patchOffset - colOffset * m_colStride) / m_fastRowStride;
413 const Index planeOffset =
414 patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
415
416 const Index inputCol = colIndex + colOffset;
417 const Index inputRow = rowIndex + rowOffset;
418 const Index inputPlane = planeIndex + planeOffset;
419
420 if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 ||
421 inputRow >= m_inputRows || inputPlane < 0 ||
422 inputPlane >= m_inputPlanes) {
423 return Scalar(0);
424 }
425
426 const Index depth = patchId - patchOffset * patchDepth();
427 const Index inputIndex = depth + inputPlane * m_planeInputStride +
428 inputRow * m_rowInputStride +
429 inputCol * m_colInputStride + otherIndex;
430
431 return m_impl.coeff(inputIndex);
432 }
433
434 // Load packet from a patch specified by the "within patch offset"
435 // (patchId) and the precomputed indices of the first element of the patch.
436 EIGEN_DEVICE_FUNC
437 EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index planeIndex,
438 Index rowIndex, Index colIndex,
439 Index otherIndex) const {
440 const Index packetSize = internal::unpacket_traits<Packet>::size;
441
442 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
443 eigen_assert(patchId <
444 patchDepth() * patchPlanes() * patchRows() * patchCols());
445
446 if (nonStandardPatches()) {
447 return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
448 otherIndex);
449 }
450 typedef decltype(m_impl) TensorEvaluatorT;
451 return loadPacketStandard<Packet, TensorEvaluatorT>(
452 patchId, planeIndex, rowIndex, colIndex, otherIndex);
453 }
454
455 // Helper function to load a 'partial' packet - this is the single row part of
456 // a packet that is split across two rows (but single column). In the
457 // 'partial' packet, the elements corresponding to the row (specified through
458 // rowOffset) are loaded and the rest of the elements are zero-filled into the
459 // 'partial' packet. This function is called from
460 // loadPacketStandardFromSingleColumnTwoRows(). This code path is exercised
461 // only when the packet type supports masked load and when the partial packet
462 // load is available in the TensorEvaluator.
463 EIGEN_DEVICE_FUNC
464 EIGEN_ALWAYS_INLINE Packet loadPartialPacketStandard(
465 Index planeIndex, Index rowIndex, Index colIndex, Index otherIndex,
466 Index patchId, const Index span[], const Index patchOffsets[],
467 Index colOffset, Index rowOffset) const {
468 const Index inputCol = colIndex + colOffset;
469 const Index inputRow = rowIndex + rowOffset;
470 const Index planeOffsets[2] = {
471 patchOffsets[0] - colOffset * m_colStride - rowOffset * m_rowStride,
472 patchOffsets[1] - colOffset * m_colStride - rowOffset * m_rowStride};
473 const Index inputPlanes[2] = {planeIndex + planeOffsets[0],
474 planeIndex + planeOffsets[1]};
475
476 if (inputRow >= m_inputRows || inputRow < 0 || inputCol >= m_inputCols ||
477 inputCol < 0 || inputPlanes[0] >= m_inputPlanes || inputPlanes[1] < 0) {
478 // Partial packet is all zeros
479 return internal::pset1<Packet>(Scalar(0));
480 } else if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) {
481 // From inputIndex-span[0], we need to load elements starting from index
482 // span[0] all the way upto (and including) span[1].
483 const Index depth = patchId - patchOffsets[0] * patchDepth();
484 const Index inputIndex = depth + inputPlanes[0] * m_planeInputStride +
485 inputRow * m_rowInputStride +
486 inputCol * m_colInputStride + otherIndex;
487 return m_impl.template partialPacket<Packet>(
488 inputIndex - span[0], mask<Packet>(span[0], span[1] + 1));
489 } else {
490 // Using slow path for this partial packet.
491 // We need to load elements starting from index span[0] all the way upto
492 // (and including) span[1]. We split this load into 3 parts:
493 // 0 : span[0]-1 - Zeros will be loaded for these indices
494 // span[0] : span[1] - Elements will be loaded here for these indices
495 // span[1]+1 : packetSize-1 - Zeross will be loaded for these indices
496 const Index packetSize = internal::unpacket_traits<Packet>::size;
497 EIGEN_ALIGN_MAX
498 std::remove_const_t<Scalar> values[packetSize];
499 for (int i = 0; i < span[0]; ++i) values[i] = Scalar(0);
500 for (int i = span[0]; i < span[1] + 1; ++i)
501 values[i] = loadCoeff(patchId - span[0] + i, planeIndex, rowIndex,
502 colIndex, otherIndex);
503 for (int i = span[1] + 1; i < packetSize; ++i) values[i] = Scalar(0);
504 return internal::pload<Packet>(values);
505 }
506 }
507
508 // Helper function to load a packet that is split across two rows (but single
509 // column). If required, this function is called from loadPacketStandard()
510 // when the packet type supports masked load and when the partial packet load
511 // is available in the TensorEvaluator.
512 EIGEN_DEVICE_FUNC
513 EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumnTwoRows(
514 Index patchId, Index planeIndex, Index rowIndex, Index colIndex,
515 Index otherIndex, const Index patchOffsets[], const Index colOffsets[],
516 const Index rowOffsets[]) const {
517 eigen_assert(colOffsets[1] == colOffsets[0] &&
518 rowOffsets[1] == rowOffsets[0] + 1);
519 const Index packetSize = internal::unpacket_traits<Packet>::size;
520
521 // Packet to load will be split into 2 parts where each part spans a single
522 // row and both the parts span the same column.
523 // First determine where to split.
524 const Index patchIdSplit =
525 (((rowOffsets[1] * m_rowStride) + (colOffsets[0] * m_colStride)) *
526 m_patch_depth) -
527 1;
528 const Index patchOffsetSplit = patchIdSplit / m_fastDimZero;
529
530 // patchIds[i]: patchId corresponding to partial packet i
531 // spans[i]: Start and end indices corresponding to the elements
532 // to be loaded for partial packet i
533 // patchOffsets2Cols[i]: patchOffsets corresponding to partial packet i
534 const Index patchIds[2] = {patchId, patchIdSplit + 1};
535 const Index spans[2][2] = {{0, patchIdSplit - patchId},
536 {patchIdSplit - patchId + 1, packetSize - 1}};
537 const Index patchOffsets2Cols[2][2] = {
538 {patchOffsets[0], patchOffsetSplit},
539 {patchOffsetSplit + 1, patchOffsets[1]}};
540
541 // Load partial packets and do bit-wise OR to generate required packet
542 return internal::por<Packet>(
543 loadPartialPacketStandard(planeIndex, rowIndex, colIndex, otherIndex,
544 patchIds[0], spans[0], patchOffsets2Cols[0],
545 colOffsets[0], rowOffsets[0]),
546 loadPartialPacketStandard(planeIndex, rowIndex, colIndex, otherIndex,
547 patchIds[1], spans[1], patchOffsets2Cols[1],
548 colOffsets[1], rowOffsets[1]));
549 }
550
551 // Helper function to load a packet that is present in a single column and
552 // row. If required, this function is called from loadPacketStandard().
553 EIGEN_DEVICE_FUNC
554 EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumnSingleRow(
555 Index patchId, Index planeIndex, Index rowIndex, Index colIndex,
556 Index otherIndex, const Index patchOffsets[], const Index colOffsets[],
557 const Index rowOffsets[], const Index inputCols[],
558 const Index inputRows[]) const {
559 eigen_assert(colOffsets[1] == colOffsets[0] &&
560 rowOffsets[1] == rowOffsets[0]);
561 const Index planeOffsets[2] = {
562 patchOffsets[0] - colOffsets[0] * m_colStride -
563 rowOffsets[0] * m_rowStride,
564 patchOffsets[1] - colOffsets[1] * m_colStride -
565 rowOffsets[1] * m_rowStride};
566 eigen_assert(planeOffsets[0] <= planeOffsets[1]);
567 const Index inputPlanes[2] = {planeIndex + planeOffsets[0],
568 planeIndex + planeOffsets[1]};
569
570 if (inputPlanes[0] >= m_inputPlanes || inputPlanes[1] < 0) {
571 return internal::pset1<Packet>(Scalar(0));
572 }
573 if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) {
574 const Index depth = patchId - patchOffsets[0] * patchDepth();
575 const Index inputIndex = depth + inputPlanes[0] * m_planeInputStride +
576 inputRows[0] * m_rowInputStride +
577 inputCols[0] * m_colInputStride + otherIndex;
578 return m_impl.template packet<Unaligned>(inputIndex);
579 }
580 return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
581 otherIndex);
582 }
583
584 // Load standard packet from a patch specified by the "within patch offset"
585 // (patchId) and the precomputed indices of the first element of the patch.
586 // This function will be called if partial packet loading is not available
587 // for the TensorEvaluator or if the packet type does not support masked
588 // load.
589 template <typename PacketT, typename TensorEvaluatorT>
590 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
591 !TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value,
592 PacketT>::type
593 loadPacketStandard(Index patchId, Index planeIndex, Index rowIndex,
594 Index colIndex, Index otherIndex) const {
595 const Index packetSize = internal::unpacket_traits<Packet>::size;
596 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
597 eigen_assert(patchId <
598 patchDepth() * patchPlanes() * patchRows() * patchCols());
599 eigen_assert(!nonStandardPatches());
600
601 if ((patchDepth() % packetSize) == 0) {
602 return loadPacketFast(patchId, planeIndex, rowIndex, colIndex,
603 otherIndex);
604 } else {
605 // Offsets and input calculation here are identical to
606 // loadCoeffStandard(...), but repeated twice.
607
608 const Index patchOffsets[2] = {
609 patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero};
610
611 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
612 patchOffsets[1] / m_fastColStride};
613 eigen_assert(colOffsets[0] <= colOffsets[1]);
614
615 const Index inputCols[2] = {colIndex + colOffsets[0],
616 colIndex + colOffsets[1]};
617 if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
618 return internal::pset1<Packet>(Scalar(0));
619 }
620
621 if (inputCols[0] == inputCols[1]) {
622 const Index rowOffsets[2] = {
623 (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride,
624 (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride};
625 eigen_assert(rowOffsets[0] <= rowOffsets[1]);
626 const Index inputRows[2] = {rowIndex + rowOffsets[0],
627 rowIndex + rowOffsets[1]};
628
629 if (inputRows[0] >= m_inputRows || inputRows[1] < 0) {
630 return internal::pset1<Packet>(Scalar(0));
631 }
632
633 if (inputRows[0] == inputRows[1]) {
634 return loadPacketStandardFromSingleColumnSingleRow(
635 patchId, planeIndex, rowIndex, colIndex, otherIndex, patchOffsets,
636 colOffsets, rowOffsets, inputCols, inputRows);
637 }
638 }
639 }
640
641 return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
642 otherIndex);
643 }
644
645 // Load standard packet from a patch specified by the "within patch offset"
646 // (patchId) and the precomputed indices of the first element of the patch.
647 // This function will be called if partial packet loading is available for
648 // the TensorEvaluator and if the packet type supports masked load.
649 // The only difference between this and the other case is that if the packet
650 // to load is split across two rows (but in same column), then in this case
651 // instead of going to the slow (element-by-element) load, we load two packets
652 // - each containing elements from one of the rows (rest of the elements of
653 // the packets are zeroes), and then combine these two packets to generate the
654 // required packet. The idea is to enable fast load (if possible) of these
655 // 'partial' packets.
656 template <typename PacketT, typename TensorEvaluatorT>
657 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
658 TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value,
659 PacketT>::type
660 loadPacketStandard(Index patchId, Index planeIndex, Index rowIndex,
661 Index colIndex, Index otherIndex) const {
662 const Index packetSize = internal::unpacket_traits<Packet>::size;
663 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
664 eigen_assert(patchId <
665 patchDepth() * patchPlanes() * patchRows() * patchCols());
666 eigen_assert(!nonStandardPatches());
667
668 if ((patchDepth() % packetSize) == 0) {
669 return loadPacketFast(patchId, planeIndex, rowIndex, colIndex,
670 otherIndex);
671 } else {
672 // Offsets and input calculation here are identical to
673 // loadCoeffStandard(...), but repeated twice.
674
675 const Index patchOffsets[2] = {
676 patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero};
677
678 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
679 patchOffsets[1] / m_fastColStride};
680 eigen_assert(colOffsets[0] <= colOffsets[1]);
681
682 const Index inputCols[2] = {colIndex + colOffsets[0],
683 colIndex + colOffsets[1]};
684 if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
685 return internal::pset1<Packet>(Scalar(0));
686 }
687
688 if (inputCols[0] == inputCols[1]) {
689 const Index rowOffsets[2] = {
690 (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride,
691 (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride};
692 eigen_assert(rowOffsets[0] <= rowOffsets[1]);
693 const Index inputRows[2] = {rowIndex + rowOffsets[0],
694 rowIndex + rowOffsets[1]};
695
696 if (inputRows[0] >= m_inputRows || inputRows[1] < 0) {
697 return internal::pset1<Packet>(Scalar(0));
698 }
699
700 if (inputRows[0] == inputRows[1]) {
701 return loadPacketStandardFromSingleColumnSingleRow(
702 patchId, planeIndex, rowIndex, colIndex, otherIndex, patchOffsets,
703 colOffsets, rowOffsets, inputCols, inputRows);
704 }
705 if (inputRows[0] + 1 == inputRows[1]) {
706 return loadPacketStandardFromSingleColumnTwoRows(
707 patchId, planeIndex, rowIndex, colIndex, otherIndex, patchOffsets,
708 colOffsets, rowOffsets);
709 }
710 }
711 }
712
713 return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
714 otherIndex);
715 }
716
717 EIGEN_DEVICE_FUNC
718 EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index planeIndex,
719 Index rowIndex, Index colIndex,
720 Index otherIndex) const {
721 const Index packetSize = internal::unpacket_traits<Packet>::size;
722 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
723 eigen_assert(patchId <
724 patchDepth() * patchPlanes() * patchRows() * patchCols());
725
726 eigen_assert(!nonStandardPatches());
727 eigen_assert((patchDepth() % packetSize) == 0);
728
729 // Find the offset of the element wrt the location of the first element.
730 const Index patchOffset = patchId / m_fastDimZero;
731 eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
732
733 const Index colOffset = patchOffset / m_fastColStride;
734 const Index rowOffset =
735 (patchOffset - colOffset * m_colStride) / m_fastRowStride;
736 const Index planeOffset =
737 patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
738
739 const Index inputCol = colIndex + colOffset;
740 const Index inputRow = rowIndex + rowOffset;
741 const Index inputPlane = planeIndex + planeOffset;
742
743 if (inputCol < 0 || inputRow < 0 || inputPlane < 0 ||
744 inputCol >= m_inputCols || inputRow >= m_inputRows ||
745 inputPlane >= m_inputPlanes) {
746 return internal::pset1<Packet>(Scalar(0));
747 }
748
749 const Index depth = patchId - patchOffset * patchDepth();
750 const Index inputIndex = depth + inputPlane * m_planeInputStride +
751 inputRow * m_rowInputStride +
752 inputCol * m_colInputStride + otherIndex;
753 return m_impl.template packet<Unaligned>(inputIndex);
754 }
755
756 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
757 packetWithPossibleZero(Index patchId, Index planeIndex, Index rowIndex,
758 Index colIndex, Index otherIndex) const {
759 const int packetSize = internal::unpacket_traits<Packet>::size;
760 EIGEN_ALIGN_MAX
761 std::remove_const_t<Scalar> values[packetSize];
762 for (int i = 0; i < packetSize; ++i) {
763 values[i] =
764 loadCoeff(patchId + i, planeIndex, rowIndex, colIndex, otherIndex);
765 }
766 Packet rslt = internal::pload<Packet>(values);
767 return rslt;
768 }
769
770 // Precompute the indices (plane, row, col, other) of the first element of
771 // the given patch index, within the output tensor of the TensorVolumePatchOp.
772 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(
773 Index patchIndex, Index& planeIndex, Index& rowIndex, Index& colIndex,
774 Index& otherIndex) const {
775 const size_t NumInputDims = array_size<
776 typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
777
778 // Check if patchIndex might contain batch and other dimensions.
779 otherIndex = (NumInputDims == 4) ? 0 : patchIndex / m_fastNumPatches;
780
781 // Compute index of the patch within the batch (and other dimensions).
782 const Index patch3DIndex = (NumInputDims == 4)
783 ? patchIndex
784 : (patchIndex - otherIndex * m_num_patches);
785
786 otherIndex *= m_patchInputStride;
787
788 colIndex = patch3DIndex / m_fastOutputPlanesRows;
789 rowIndex =
790 (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes;
791 planeIndex =
792 patch3DIndex - (colIndex * m_outputRows + rowIndex) * m_outputPlanes;
793
794 colIndex = colIndex * m_col_strides - m_colPaddingLeft;
795 rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
796 planeIndex = planeIndex * m_plane_strides - m_planePaddingTop;
797 }
798
799 Index m_patch_depth; // number of channels in the patch
800 Index m_patch_planes; // number of planes in the patch
801 Index m_patch_rows; // number of rows in the patch
802 Index m_patch_cols; // number of columns in the patch
803 Index m_num_patches; // number of patches to extract
804
805 // Strides for navigating through the single patch.
806 Index m_patch_plane_stride;
807 Index m_patch_row_stride;
808 Index m_patch_col_stride;
809
810 // Strides for the output tensor (depth is not the part of the stride).
811 Index m_rowStride;
812 Index m_colStride;
813 Index m_patchStride;
814 Index m_otherStride;
815
816 Index m_planeInputStride; // Plane stride in the input tensor
817 Index m_rowInputStride; // Row stride in the input tensor
818 Index m_colInputStride; // Col stride in the input tensor
819 Index m_patchInputStride; // Patch stride in the input tensor
820 Index m_otherInputStride;
821
822 Index m_inputDepth; // Depth of the input tensor
823 Index m_inputPlanes; // Number of planes in the input tensor
824 Index m_inputRows; // Number of rows in the input tensor
825 Index m_inputCols; // Number of cols in the input tensor
826
827 Index m_outputPlanes; // Number of output planes
828 Index m_outputRows; // Number of output rows
829 Index m_outputCols; // Number of output cols
830 Index m_outputPlanesRows; // Cached outputPlanes * outputRows.
831
832 Index m_plane_strides; // User specified plane stride
833 Index m_row_strides; // User specified row stride
834 Index m_col_strides; // User specified col stride
835
836 // User specified plane/row/col atrous convolution strides.
837 Index m_in_plane_strides;
838 Index m_in_row_strides;
839 Index m_in_col_strides;
840
841 // User specified plane/row/col inflation strides in the image patch.
842 Index m_patch_plane_inflate_strides;
843 Index m_patch_row_inflate_strides;
844 Index m_patch_col_inflate_strides;
845
846 Index m_planePaddingTop; // Plane padding
847 Index m_rowPaddingTop; // Row padding
848 Index m_colPaddingLeft; // Column padding
849
850 // Fast representation of various divisors.
851 internal::TensorIntDivisor<Index> m_fastNumPatches;
852
853 internal::TensorIntDivisor<Index> m_fastPatchPlaneStride;
854 internal::TensorIntDivisor<Index> m_fastPatchRowStride;
855 internal::TensorIntDivisor<Index> m_fastPatchColStride;
856
857 internal::TensorIntDivisor<Index> m_fastInputPlaneStride;
858 internal::TensorIntDivisor<Index> m_fastInputRowStride;
859 internal::TensorIntDivisor<Index> m_fastInputColStride;
860
861 internal::TensorIntDivisor<Index> m_fastRowStride;
862 internal::TensorIntDivisor<Index> m_fastColStride;
863
864 internal::TensorIntDivisor<Index> m_fastDimZero; // aka output depth
865 internal::TensorIntDivisor<Index> m_fastOutputPlanes;
866 internal::TensorIntDivisor<Index> m_fastOutputRows;
867 internal::TensorIntDivisor<Index> m_fastOutputCols;
868 internal::TensorIntDivisor<Index> m_fastOutputPlanesRows;
869
870 const TensorEvaluator<ArgType, Device> m_impl;
871};
872
873template <typename NewDimension, Index Planes, Index Rows, Index Cols,
874 typename ArgType, typename Device, typename Scalar, typename Index,
875 typename nocontract_t, typename contract_t, int Side, int packet_size,
876 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
877class TensorContractionSubMapper<
878 Scalar, Index, Side,
879 TensorEvaluator<const TensorReshapingOp<NewDimension,
880 const TensorVolumePatchOp<
881 Planes, Rows, Cols, ArgType> >,
882 Device>,
883 nocontract_t, contract_t, packet_size, inner_dim_contiguous,
884 inner_dim_reordered, Alignment> {
885 public:
886 typedef typename packet_traits<Scalar>::type Packet;
887 typedef typename packet_traits<Scalar>::half HalfPacket;
888
889 typedef TensorContractionInputMapper<
890 Scalar, Index, Side,
891 TensorEvaluator<const TensorReshapingOp<
892 NewDimension, const TensorVolumePatchOp<
893 Planes, Rows, Cols, ArgType> >,
894 Device>,
895 nocontract_t, contract_t, packet_size, inner_dim_contiguous,
896 inner_dim_reordered, Alignment>
897 ParentMapper;
898 typedef TensorContractionSubMapper<
899 Scalar, Index, Side,
900 TensorEvaluator<const TensorReshapingOp<
901 NewDimension, const TensorVolumePatchOp<
902 Planes, Rows, Cols, ArgType> >,
903 Device>,
904 nocontract_t, contract_t, packet_size, inner_dim_contiguous,
905 inner_dim_reordered, Alignment>
906 Self;
907 typedef Self LinearMapper;
908
909 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
910 const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
911 : m_base_mapper(base_mapper),
912 m_depth_offset(vert_offset),
913 m_col_offset(horiz_offset) {
914 m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex,
915 m_colIndex, m_otherIndex);
916 }
917 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
918 const Self& base_mapper, Index vert_offset, Index horiz_offset)
919 : m_base_mapper(base_mapper.m_base_mapper),
920 m_depth_offset(vert_offset + base_mapper.m_depth_offset),
921 m_col_offset(horiz_offset + base_mapper.m_col_offset) {
922 m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex,
923 m_colIndex, m_otherIndex);
924 }
925 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
926 return m_base_mapper.loadCoeff(i + m_depth_offset, m_planeIndex, m_rowIndex,
927 m_colIndex, m_otherIndex);
928 }
929 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i,
930 Index j) const {
931 return m_base_mapper(i + m_depth_offset, j + m_col_offset);
932 }
933
934 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
935 return m_base_mapper.loadPacket(i + m_depth_offset, m_planeIndex,
936 m_rowIndex, m_colIndex, m_otherIndex);
937 }
938 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i,
939 Index j) const {
940 return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset,
941 j + m_col_offset);
942 }
943
944 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar
945 loadCoeffStandard(Index i) const {
946 return m_base_mapper.loadCoeffStandard(
947 i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex);
948 }
949
950 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const {
951 return m_base_mapper.loadPacketFast(i + m_depth_offset, m_planeIndex,
952 m_rowIndex, m_colIndex, m_otherIndex);
953 }
954 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
955 loadPacketStandard(Index i) const {
956 typedef decltype(m_base_mapper.m_impl) TensorEvaluatorT;
957 return m_base_mapper.template loadPacketStandard<Packet, TensorEvaluatorT>(
958 i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex);
959 }
960 template <typename Packet>
961 EIGEN_DEVICE_FUNC bool aligned(Index) const {
962 return false;
963 }
964
965 EIGEN_DEVICE_FUNC
966 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
967 return m_base_mapper.nonStandardPatches();
968 }
969
970 // Max(Col|Row|Plane|Depth): compute the upper limit for the column, row,
971 // plane and depth index respectively that fits into the peeled_k elements
972 // starting at m_depth_offset.
973
974 EIGEN_DEVICE_FUNC
975 EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const {
976 const Index max_col =
977 fastPatchColStride().divide(m_depth_offset + peeled_k);
978 return std::min<Index>(1 + max_col, patchCols());
979 }
980
981 EIGEN_DEVICE_FUNC
982 EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k,
983 const Index col) const {
984 const Index max_row = fastPatchRowStride().divide(
985 m_depth_offset + peeled_k - col * patchColStride());
986 return std::min<Index>(1 + max_row, patchRows());
987 }
988
989 EIGEN_DEVICE_FUNC
990 EIGEN_ALWAYS_INLINE Index maxPlane(const Index peeled_k, const Index col,
991 const Index row) const {
992 const Index max_plane = fastPatchPlaneStride().divide(
993 m_depth_offset + peeled_k - col * patchColStride() -
994 row * patchRowStride());
995 return std::min<Index>(1 + max_plane, patchPlanes());
996 }
997
998 // MaxDepth uses only the remaining number of elements in the peeled_k.
999 EIGEN_DEVICE_FUNC
1000 EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements,
1001 const Index start_depth) const {
1002 return std::min<Index>(start_depth + num_elements, patchDepth());
1003 }
1004
1005 // Every register matters in this code, so sometimes to prevent register
1006 // spilling, instead of the variable that you would expect to see, we use
1007 // another one, that is guaranteed to have the same value. E.g. patch depth is
1008 // always the same as input depth, and it's also the same as input plane
1009 // stride. Bunch of other parameters have similar relations.
1010
1011 typedef internal::TensorIntDivisor<Index> IndexDivisor;
1012
1013 EIGEN_DEVICE_FUNC
1014 EIGEN_ALWAYS_INLINE Index patchDepth() const {
1015 eigen_assert(m_base_mapper.m_patch_depth ==
1016 m_base_mapper.m_planeInputStride &&
1017 "Patch depth must be equal to plane input stride.");
1018 return m_base_mapper.m_planeInputStride;
1019 }
1020
1021 EIGEN_DEVICE_FUNC
1022 EIGEN_ALWAYS_INLINE Index patchPlanes() const {
1023 eigen_assert(m_base_mapper.m_patch_planes == m_base_mapper.m_rowStride &&
1024 "Patch planes must be equal to row stride.");
1025 return m_base_mapper.m_rowStride;
1026 }
1027 EIGEN_DEVICE_FUNC
1028 EIGEN_ALWAYS_INLINE Index patchRows() const {
1029 return m_base_mapper.m_patch_rows;
1030 }
1031 EIGEN_DEVICE_FUNC
1032 EIGEN_ALWAYS_INLINE Index patchCols() const {
1033 return m_base_mapper.m_patch_cols;
1034 }
1035
1036 EIGEN_DEVICE_FUNC
1037 EIGEN_ALWAYS_INLINE Index patchPlaneStride() const {
1038 eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride &&
1039 "Patch depth must be equal to patch plane stride.");
1040 return patchDepth();
1041 }
1042 EIGEN_DEVICE_FUNC
1043 EIGEN_ALWAYS_INLINE Index patchRowStride() const {
1044 return m_base_mapper.m_patch_row_stride;
1045 }
1046 EIGEN_DEVICE_FUNC
1047 EIGEN_ALWAYS_INLINE Index patchColStride() const {
1048 return m_base_mapper.m_patch_col_stride;
1049 }
1050
1051 EIGEN_DEVICE_FUNC
1052 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchPlaneStride() const {
1053 eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride &&
1054 "Patch depth must be equal to patch plane stride.");
1055 return m_base_mapper.m_fastDimZero; // patch_depth
1056 }
1057 EIGEN_DEVICE_FUNC
1058 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const {
1059 return m_base_mapper.m_fastPatchRowStride;
1060 }
1061 EIGEN_DEVICE_FUNC
1062 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const {
1063 return m_base_mapper.m_fastPatchColStride;
1064 }
1065
1066 EIGEN_DEVICE_FUNC
1067 EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
1068 const Index baseIndex) const {
1069 const Index inputIndex = depth + baseIndex;
1070 return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex);
1071 }
1072 EIGEN_DEVICE_FUNC
1073 EIGEN_ALWAYS_INLINE Scalar coeffNoPadding(const Index depth,
1074 const Index baseIndex) const {
1075 const Index inputIndex = depth + baseIndex;
1076 return m_base_mapper.m_impl.coeff(inputIndex);
1077 }
1078
1079 EIGEN_DEVICE_FUNC
1080 EIGEN_ALWAYS_INLINE bool padPlane(const Index plane) const {
1081 const Index p = m_planeIndex + plane;
1082 return p < 0 || p >= m_base_mapper.m_inputPlanes;
1083 }
1084 EIGEN_DEVICE_FUNC
1085 EIGEN_ALWAYS_INLINE bool padRow(const Index row) const {
1086 const Index r = m_rowIndex + row;
1087 return r < 0 || r >= m_base_mapper.m_inputRows;
1088 }
1089 EIGEN_DEVICE_FUNC
1090 EIGEN_ALWAYS_INLINE bool padCol(const Index col) const {
1091 const Index c = m_colIndex + col;
1092 return c < 0 || c >= m_base_mapper.m_inputCols;
1093 }
1094 EIGEN_DEVICE_FUNC
1095 EIGEN_ALWAYS_INLINE Index baseIndex(const Index plane, const Index row,
1096 const Index col) const {
1097 const Index p = m_planeIndex + plane;
1098 const Index r = m_rowIndex + row;
1099 const Index c = m_colIndex + col;
1100 return p * m_base_mapper.m_planeInputStride +
1101 r * m_base_mapper.m_rowInputStride +
1102 c * m_base_mapper.m_colInputStride + m_otherIndex;
1103 }
1104
1105 EIGEN_DEVICE_FUNC
1106 EIGEN_ALWAYS_INLINE Index planeOffset() const {
1107 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
1108 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
1109 const Index rowOffset =
1110 (patchOffset - colOffset * m_base_mapper.m_colStride) /
1111 m_base_mapper.m_fastRowStride;
1112 const Index planeOffset = patchOffset -
1113 colOffset * m_base_mapper.m_colStride -
1114 rowOffset * m_base_mapper.m_rowStride;
1115 return planeOffset;
1116 }
1117
1118 EIGEN_DEVICE_FUNC
1119 EIGEN_ALWAYS_INLINE Index rowOffset() const {
1120 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
1121 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
1122 const Index rowOffset =
1123 (patchOffset - colOffset * m_base_mapper.m_colStride) /
1124 m_base_mapper.m_fastRowStride;
1125 return rowOffset;
1126 }
1127
1128 EIGEN_DEVICE_FUNC
1129 EIGEN_ALWAYS_INLINE Index colOffset() const {
1130 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
1131 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
1132 return colOffset;
1133 }
1134
1135 EIGEN_DEVICE_FUNC
1136 EIGEN_ALWAYS_INLINE Index depthOffset() const {
1137 return m_depth_offset % patchDepth();
1138 }
1139
1140 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
1141 getLinearMapper(Index i, Index j) const {
1142 return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset);
1143 }
1144
1145 private:
1146 const ParentMapper m_base_mapper; // Keeping a copy instead of a reference
1147 // performs better in benchmarks.
1148
1149 Index m_depth_offset; // First row in the input matrix
1150 Index m_col_offset; // First col in the input matrix
1151
1152 // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
1153 // indices for the first element in a patch specified by col_offset
1154 // (see computeBaseIndices(...) for details).
1155 Index m_planeIndex;
1156 Index m_rowIndex;
1157 Index m_colIndex;
1158 Index m_otherIndex;
1159};
1160
1161// Arrange a block of the right input matrix (in our case it's always a "virtual
1162// matrix" constructed from extracted volume patches) in contiguous memory.
1163//
1164// Given column major input (A0 beside A1 in memory):
1165// A0 B0 C0 D0 E0 F0 G0 H0 ... Z0
1166// A1 B1 C1 D1 E1 F1 G1 H1 ... Z1
1167// A2 B2 C2 D2 E2 F2 G2 H2 ... Z2
1168// A3 B3 C3 D3 E3 F3 G3 H3 ... Z3
1169// A4 B4 C4 D4 E4 F4 G4 H4 ... Z4
1170// A5 B5 C5 D5 E5 F5 G5 H5 ... Z5
1171// A6 B6 C6 D6 E6 F6 G6 H6 ... Z6
1172// A7 B7 C7 D7 E7 F7 G7 H7 ... Z7
1173// A8 ...
1174// ...
1175//
1176// *) A, B, C, ... - patches extracted from the original input.
1177// *) A0, A1, A2 ... - values from the same patch at different offsets.
1178//
1179// The traversal (packed rhs memory) order (B0 besides A0 in memory):
1180// A0 B0 C0 D0 A1 B1 C1 D1 ...
1181// E0 F0 G0 H0 E1 F1 G1 H1 ...
1182// ...
1183// Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4)
1184//
1185// This traversal order must be the same as in default gemm_pack_rhs defined in
1186// GeneralBlockPanelKernel.h.
1187//
1188// *) nr - number of registers along the 'n' dimension.
1189// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
1190// Multiplication" paper.
1191//
1192// TODO(ezhulenev): Add support for squeezing reads along two innermost
1193// dimensions (see eigen_spatial_convolutions).
1194template <typename NewDimension, Index Planes, Index Rows, Index Cols,
1195 typename ArgType, typename Device, typename Scalar, typename Index,
1196 typename nocontract_t, typename contract_t, int packet_size,
1197 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
1198 int nr>
1199struct gemm_pack_rhs<
1200 Scalar, Index,
1201 TensorContractionSubMapper<
1202 Scalar, Index, Rhs,
1203 TensorEvaluator<const TensorReshapingOp<
1204 NewDimension, const TensorVolumePatchOp<
1205 Planes, Rows, Cols, ArgType> >,
1206 Device>,
1207 nocontract_t, contract_t, packet_size, inner_dim_contiguous,
1208 inner_dim_reordered, Alignment>,
1209 nr, ColMajor, false, false> {
1210 typedef TensorContractionSubMapper<
1211 Scalar, Index, Rhs,
1212 TensorEvaluator<const TensorReshapingOp<
1213 NewDimension, const TensorVolumePatchOp<
1214 Planes, Rows, Cols, ArgType> >,
1215 Device>,
1216 nocontract_t, contract_t, packet_size, inner_dim_contiguous,
1217 inner_dim_reordered, Alignment>
1218 SubMapper;
1219
1220 typedef SubMapper DataMapper;
1221 typedef typename packet_traits<Scalar>::type Packet;
1222
1223 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
1224
1225 EIGEN_DEVICE_FUNC
1226 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
1227 Index depth, Index cols, Index stride = 0,
1228 Index offset = 0) const {
1229 eigen_assert(stride == 0);
1230 eigen_assert(offset == 0);
1231
1232 const Index packet_cols4 = (cols / 4) * 4;
1233 const Index peeled_k = (depth / packet_size) * packet_size;
1234 const bool non_standard_patches = rhs.nonStandardPatches();
1235
1236 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
1237 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1238 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1239 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1240 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1241
1242 Index k = 0;
1243 if ((packet_size % 4) == 0 && !non_standard_patches) {
1244 // FAST PATH:
1245 // Iterate over patch columns, rows and planes if we know that a single
1246 // packet do not span across multiple planes, rows or columns.
1247 if ((rhs.patchDepth() % packet_size) == 0) {
1248 const Index start_col = rhs.colOffset();
1249 const Index max_col = rhs.maxCol(peeled_k);
1250
1251 for (Index c = start_col; c < max_col; ++c) {
1252 eigen_assert(k <= peeled_k);
1253
1254 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1255 const Index max_row = rhs.maxRow(peeled_k, c);
1256
1257 const bool pad_col0 = dm0.padCol(c);
1258 const bool pad_col1 = dm1.padCol(c);
1259 const bool pad_col2 = dm2.padCol(c);
1260 const bool pad_col3 = dm3.padCol(c);
1261
1262 for (Index r = start_row; r < max_row; ++r) {
1263 eigen_assert(k <= peeled_k);
1264
1265 const Index start_plane = ((c == start_col) && (r == start_row))
1266 ? rhs.planeOffset()
1267 : 0;
1268 const Index max_plane = rhs.maxPlane(peeled_k, c, r);
1269
1270 const bool pad_row0 = pad_col0 || dm0.padRow(r);
1271 const bool pad_row1 = pad_col1 || dm1.padRow(r);
1272 const bool pad_row2 = pad_col2 || dm2.padRow(r);
1273 const bool pad_row3 = pad_col3 || dm3.padRow(r);
1274
1275 for (Index p = start_plane; p < max_plane; ++p) {
1276 eigen_assert(k <= peeled_k);
1277
1278 const bool pad0 = pad_row0 || dm0.padPlane(p);
1279 const bool pad1 = pad_row1 || dm1.padPlane(p);
1280 const bool pad2 = pad_row2 || dm2.padPlane(p);
1281 const bool pad3 = pad_row3 || dm3.padPlane(p);
1282
1283 const Index idx0 = dm0.baseIndex(p, r, c);
1284 const Index idx1 = dm1.baseIndex(p, r, c);
1285 const Index idx2 = dm2.baseIndex(p, r, c);
1286 const Index idx3 = dm3.baseIndex(p, r, c);
1287
1288 const Index start_depth =
1289 ((c == start_col) && (r == start_row) && (p == start_plane))
1290 ? rhs.depthOffset()
1291 : 0;
1292 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1293 eigen_assert((max_depth - start_depth) % packet_size == 0);
1294
1295 for (Index d = start_depth; d < max_depth; d += packet_size) {
1296 eigen_assert(k < peeled_k);
1297 PacketBlock<Packet, 4> kernel;
1298 kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
1299 : rhs.packetNoPadding(d, idx0);
1300 kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
1301 : rhs.packetNoPadding(d, idx1);
1302 kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0))
1303 : rhs.packetNoPadding(d, idx2);
1304 kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0))
1305 : rhs.packetNoPadding(d, idx3);
1306 ptranspose(kernel);
1307 pstoreu(block + 0 * packet_size, kernel.packet[0]);
1308 pstoreu(block + 1 * packet_size, kernel.packet[1]);
1309 pstoreu(block + 2 * packet_size, kernel.packet[2]);
1310 pstoreu(block + 3 * packet_size, kernel.packet[3]);
1311 block += 4 * packet_size;
1312 k += packet_size;
1313 }
1314 }
1315 }
1316 }
1317
1318 // The loop above should fill peeled_k elements.
1319 eigen_assert(peeled_k == k);
1320
1321 } else {
1322 // Packet can span multiple planes, rows or columns, so we have to go
1323 // though the slower "standard" path.
1324 for (; k < peeled_k; k += packet_size) {
1325 PacketBlock<Packet, 4> kernel;
1326 kernel.packet[0] = dm0.loadPacketStandard(k);
1327 kernel.packet[1] = dm1.loadPacketStandard(k);
1328 kernel.packet[2] = dm2.loadPacketStandard(k);
1329 kernel.packet[3] = dm3.loadPacketStandard(k);
1330 ptranspose(kernel);
1331 pstoreu(block + 0 * packet_size, kernel.packet[0]);
1332 pstoreu(block + 1 * packet_size, kernel.packet[1]);
1333 pstoreu(block + 2 * packet_size, kernel.packet[2]);
1334 pstoreu(block + 3 * packet_size, kernel.packet[3]);
1335 block += 4 * packet_size;
1336 }
1337 }
1338 }
1339
1340 // Copy the remaining coefficients of the column block after the peeled_k.
1341 if (!non_standard_patches) {
1342 for (; k < depth; k++) {
1343 block[0] = dm0.loadCoeffStandard(k);
1344 block[1] = dm1.loadCoeffStandard(k);
1345 block[2] = dm2.loadCoeffStandard(k);
1346 block[3] = dm3.loadCoeffStandard(k);
1347 block += 4;
1348 }
1349 } else {
1350 for (; k < depth; k++) {
1351 block[0] = dm0(k);
1352 block[1] = dm1(k);
1353 block[2] = dm2(k);
1354 block[3] = dm3(k);
1355 block += 4;
1356 }
1357 }
1358 }
1359
1360 // Copy the remaining columns one at a time (nr==1).
1361 for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1362 const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1363 for (Index k = 0; k < depth; k++) {
1364 *block = dm0(k);
1365 block += 1;
1366 }
1367 }
1368 }
1369};
1370
1371// Template specialization for packet_size = 2. We must special-case packet
1372// blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>.
1373//
1374// TODO(ezhulenev): Add support for squeezing reads along two innermost
1375// dimensions (see eigen_spatial_convolutions).
1376template <typename NewDimension, Index Planes, Index Rows, Index Cols,
1377 typename ArgType, typename Device, typename Scalar, typename Index,
1378 typename nocontract_t, typename contract_t, bool inner_dim_contiguous,
1379 bool inner_dim_reordered, int Alignment, int nr>
1380struct gemm_pack_rhs<
1381 Scalar, Index,
1382 TensorContractionSubMapper<
1383 Scalar, Index, Rhs,
1384 TensorEvaluator<const TensorReshapingOp<
1385 NewDimension, const TensorVolumePatchOp<
1386 Planes, Rows, Cols, ArgType> >,
1387 Device>,
1388 nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous,
1389 inner_dim_reordered, Alignment>,
1390 nr, ColMajor, false, false> {
1391 typedef TensorContractionSubMapper<
1392 Scalar, Index, Rhs,
1393 TensorEvaluator<const TensorReshapingOp<
1394 NewDimension, const TensorVolumePatchOp<
1395 Planes, Rows, Cols, ArgType> >,
1396 Device>,
1397 nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous,
1398 inner_dim_reordered, Alignment>
1399 SubMapper;
1400 typedef SubMapper DataMapper;
1401 typedef typename packet_traits<Scalar>::type Packet;
1402
1403 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
1404
1405 EIGEN_DEVICE_FUNC
1406 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
1407 Index depth, Index cols, Index stride = 0,
1408 Index offset = 0) const {
1409 eigen_assert(stride == 0);
1410 eigen_assert(offset == 0);
1411
1412 const int packet_size = 2;
1413
1414 const Index packet_cols4 = (cols / 4) * 4;
1415 const Index peeled_k = (depth / packet_size) * packet_size;
1416 const bool non_standard_patches = rhs.nonStandardPatches();
1417
1418 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
1419 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1420 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1421 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1422 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1423
1424 Index k = 0;
1425 if (!non_standard_patches) {
1426 // FAST PATH:
1427 // Iterate over patch columns, rows and planes if we know that a single
1428 // packet do not span across multiple planes, rows or columns.
1429 if ((rhs.patchDepth() % packet_size) == 0) {
1430 const Index start_col = rhs.colOffset();
1431 const Index max_col = rhs.maxCol(peeled_k);
1432
1433 for (Index c = start_col; c < max_col; ++c) {
1434 eigen_assert(k <= peeled_k);
1435
1436 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1437 const Index max_row = rhs.maxRow(peeled_k, c);
1438
1439 const bool pad_col0 = dm0.padCol(c);
1440 const bool pad_col1 = dm1.padCol(c);
1441 const bool pad_col2 = dm2.padCol(c);
1442 const bool pad_col3 = dm3.padCol(c);
1443
1444 for (Index r = start_row; r < max_row; ++r) {
1445 eigen_assert(k <= peeled_k);
1446
1447 const Index start_plane = ((c == start_col) && (r == start_row))
1448 ? rhs.planeOffset()
1449 : 0;
1450 const Index max_plane = rhs.maxPlane(peeled_k, c, r);
1451
1452 const bool pad_row0 = dm0.padRow(r);
1453 const bool pad_row1 = dm1.padRow(r);
1454 const bool pad_row2 = dm2.padRow(r);
1455 const bool pad_row3 = dm3.padRow(r);
1456
1457 for (Index p = start_plane; p < max_plane; ++p) {
1458 eigen_assert(k <= peeled_k);
1459
1460 const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p);
1461 const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p);
1462 const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p);
1463 const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p);
1464
1465 const Index idx0 = dm0.baseIndex(p, r, c);
1466 const Index idx1 = dm1.baseIndex(p, r, c);
1467 const Index idx2 = dm2.baseIndex(p, r, c);
1468 const Index idx3 = dm3.baseIndex(p, r, c);
1469
1470 const Index start_depth =
1471 ((c == start_col) && (r == start_row) && (p == start_plane))
1472 ? rhs.depthOffset()
1473 : 0;
1474 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1475 eigen_assert((max_depth - start_depth) % packet_size == 0);
1476
1477 for (Index d = start_depth; d < max_depth; d += packet_size) {
1478 eigen_assert(k < peeled_k);
1479 PacketBlock<Packet, 2> kernel0;
1480 PacketBlock<Packet, 2> kernel1;
1481 kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
1482 : rhs.packetNoPadding(d, idx0);
1483 kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
1484 : rhs.packetNoPadding(d, idx1);
1485 kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0))
1486 : rhs.packetNoPadding(d, idx2);
1487 kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0))
1488 : rhs.packetNoPadding(d, idx3);
1489 ptranspose(kernel0);
1490 ptranspose(kernel1);
1491 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1492 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1493 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1494 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1495 block += 4 * packet_size;
1496 k += packet_size;
1497 }
1498 }
1499 }
1500 }
1501
1502 // The loop above should fill peeled_k elements.
1503 eigen_assert(peeled_k == k);
1504
1505 } else {
1506 for (; k < peeled_k; k += packet_size) {
1507 PacketBlock<Packet, 2> kernel0;
1508 PacketBlock<Packet, 2> kernel1;
1509 kernel0.packet[0] = dm0.loadPacketStandard(k);
1510 kernel0.packet[1] = dm1.loadPacketStandard(k);
1511 kernel1.packet[0] = dm2.loadPacketStandard(k);
1512 kernel1.packet[1] = dm3.loadPacketStandard(k);
1513 ptranspose(kernel0);
1514 ptranspose(kernel1);
1515 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1516 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1517 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1518 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1519 block += 4 * packet_size;
1520 }
1521 }
1522 }
1523
1524 // Copy the remaining coefficients of the column block after the peeled_k.
1525 if (!rhs.nonStandardPatches()) {
1526 for (; k < depth; k++) {
1527 block[0] = dm0.loadCoeffStandard(k);
1528 block[1] = dm1.loadCoeffStandard(k);
1529 block[2] = dm2.loadCoeffStandard(k);
1530 block[3] = dm3.loadCoeffStandard(k);
1531 block += 4;
1532 }
1533 } else {
1534 for (; k < depth; k++) {
1535 block[0] = dm0(k);
1536 block[1] = dm1(k);
1537 block[2] = dm2(k);
1538 block[3] = dm3(k);
1539 block += 4;
1540 }
1541 }
1542 }
1543
1544 // Copy the remaining columns one at a time (nr==1).
1545 for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1546 const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1547 for (Index k = 0; k < depth; k++) {
1548 *block = dm0(k);
1549 block += 1;
1550 }
1551 }
1552 }
1553};
1554
1555// Special case for non-vectorized types such as float16 (packet_size = 1).
1556template <typename NewDimension, Index Planes, Index Rows, Index Cols,
1557 typename ArgType, typename Device, typename Scalar, typename Index,
1558 typename nocontract_t, typename contract_t, bool inner_dim_contiguous,
1559 bool inner_dim_reordered, int Alignment, int nr>
1560struct gemm_pack_rhs<
1561 Scalar, Index,
1562 TensorContractionSubMapper<
1563 Scalar, Index, Rhs,
1564 TensorEvaluator<const TensorReshapingOp<
1565 NewDimension, const TensorVolumePatchOp<
1566 Planes, Rows, Cols, ArgType> >,
1567 Device>,
1568 nocontract_t, contract_t, /*packet_size*/ 1, inner_dim_contiguous,
1569 inner_dim_reordered, Alignment>,
1570 nr, ColMajor, false, false> {
1571 typedef TensorContractionSubMapper<
1572 Scalar, Index, Rhs,
1573 TensorEvaluator<const TensorReshapingOp<
1574 NewDimension, const TensorVolumePatchOp<
1575 Planes, Rows, Cols, ArgType> >,
1576 Device>,
1577 nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered,
1578 Alignment>
1579 SubMapper;
1580 typedef SubMapper DataMapper;
1581
1582 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
1583
1584 EIGEN_DEVICE_FUNC
1585 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
1586 Index depth, Index cols, Index stride = 0,
1587 Index offset = 0) const {
1588 eigen_assert(stride == 0);
1589 eigen_assert(offset == 0);
1590
1591 const Index packet_cols4 = (cols / 4) * 4;
1592
1593 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
1594 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1595 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1596 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1597 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1598
1599 if (!rhs.nonStandardPatches()) {
1600 for (Index k = 0; k < depth; k++) {
1601 block[0] = dm0.loadCoeffStandard(k);
1602 block[1] = dm1.loadCoeffStandard(k);
1603 block[2] = dm2.loadCoeffStandard(k);
1604 block[3] = dm3.loadCoeffStandard(k);
1605 block += 4;
1606 }
1607 } else {
1608 for (Index k = 0; k < depth; k++) {
1609 block[0] = dm0(k);
1610 block[1] = dm1(k);
1611 block[2] = dm2(k);
1612 block[3] = dm3(k);
1613 block += 4;
1614 }
1615 }
1616 }
1617
1618 // Copy the remaining columns one at a time (nr==1).
1619 for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1620 const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1621 for (Index k = 0; k < depth; k++) {
1622 *block = dm0(k);
1623 block += 1;
1624 }
1625 }
1626 }
1627};
1628#endif
1629
1630#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
1631// Pack a block of the right input matrix (in our case it's always a "virtual
1632// matrix" constructed from extracted image patches) in contiguous block in
1633// column-major storage order. Knowing the properties of the original patch op
1634// we can do it more efficient than the default gemm_pack_colmajor_block.
1635//
1636// TODO(ezhulenev): gemm_pack_colmajor_block for spatial convolutions supports
1637// squeezing reads along the 2 innermost dimensions, add it here if needed.
1638template <typename NewDimension, Index Planes, Index Rows, Index Cols,
1639 typename ArgType, typename Device, typename Scalar,
1640 typename StorageIndex, typename nocontract_t, typename contract_t,
1641 int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
1642 int Alignment>
1643struct gemm_pack_colmajor_block<
1644 Scalar, StorageIndex,
1645 TensorContractionSubMapper<
1646 Scalar, StorageIndex, Rhs,
1647 TensorEvaluator<const TensorReshapingOp<
1648 NewDimension, const TensorVolumePatchOp<
1649 Planes, Rows, Cols, ArgType> >,
1650 Device>,
1651 nocontract_t, contract_t, packet_size, inner_dim_contiguous,
1652 inner_dim_reordered, Alignment>,
1653 ColMajor> {
1654 typedef TensorContractionSubMapper<
1655 Scalar, StorageIndex, Rhs,
1656 TensorEvaluator<const TensorReshapingOp<
1657 NewDimension, const TensorVolumePatchOp<
1658 Planes, Rows, Cols, ArgType> >,
1659 Device>,
1660 nocontract_t, contract_t, packet_size, inner_dim_contiguous,
1661 inner_dim_reordered, Alignment>
1662 SubMapper;
1663
1664 typedef SubMapper DataMapper;
1665 typedef typename packet_traits<Scalar>::type Packet;
1666
1667 EIGEN_DONT_INLINE
1668 void operator()(Scalar* block, const DataMapper& rhs, StorageIndex rows,
1669 StorageIndex cols) {
1670 const bool standard_patches = !rhs.nonStandardPatches();
1671
1672 if (standard_patches && rhs.patchDepth() % packet_size == 0) {
1673 packStandardPatches<true>(block, rhs, rows, cols);
1674
1675 } else if (standard_patches) {
1676 packStandardPatches<false>(block, rhs, rows, cols);
1677
1678 } else {
1679 // With non-standard patches we don't do any vectorized loads.
1680 // TODO(ezhulenev): It doesn't look like that we should completely give up
1681 // on packets. Make this code path faster!
1682 for (StorageIndex col = 0; col < cols; ++col) {
1683 SubMapper lm = rhs.getLinearMapper(0, col);
1684 for (StorageIndex i = 0; i < rows; ++i) {
1685 *block = lm(i);
1686 ++block;
1687 }
1688 }
1689 }
1690 }
1691
1692 private:
1693 // Pack standard volume patches:
1694 //
1695 // - patch_depth_is_multiple_of_packet_size=true: We are guaranteed to have
1696 // depth dimension size to be a multiple of packet size, so we can skip all
1697 // non vectorized loads and checks.
1698 //
1699 template <bool patch_depth_is_multiple_of_packet_size>
1700 EIGEN_ALWAYS_INLINE void packStandardPatches(Scalar* block,
1701 const DataMapper& rhs,
1702 StorageIndex rows,
1703 StorageIndex cols) {
1704 eigen_assert(!rhs.nonStandardPatches());
1705
1706 // Give vectorized_rows the name used in all other gemm_pack_rhs above.
1707 const Index peeled_k = (rows / packet_size) * packet_size;
1708
1709 const Index start_col = rhs.colOffset();
1710 const Index max_col = rhs.maxCol(peeled_k);
1711
1712 for (StorageIndex col = 0; col < cols; ++col) {
1713 SubMapper lm = rhs.getLinearMapper(0, col);
1714
1715 Index k = 0;
1716 for (Index c = start_col; c < max_col; ++c) {
1717 eigen_assert(k <= peeled_k);
1718
1719 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1720 const Index max_row = rhs.maxRow(peeled_k, c);
1721 const bool pad_col = lm.padCol(c);
1722
1723 for (Index r = start_row; r < max_row; ++r) {
1724 eigen_assert(k <= peeled_k);
1725
1726 const Index start_plane =
1727 ((c == start_col) && (r == start_row)) ? rhs.planeOffset() : 0;
1728 const Index max_plane = rhs.maxPlane(peeled_k, c, r);
1729 const bool pad_row = pad_col || lm.padRow(r);
1730
1731 for (Index p = start_plane; p < max_plane; ++p) {
1732 eigen_assert(k <= peeled_k);
1733
1734 const Index start_depth =
1735 ((c == start_col) && (r == start_row) && (p == start_plane))
1736 ? rhs.depthOffset()
1737 : 0;
1738 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1739
1740 const bool pad = pad_col || pad_row || lm.padPlane(p);
1741 const Index base_idx = lm.baseIndex(p, r, c);
1742
1743 if (patch_depth_is_multiple_of_packet_size)
1744 eigen_assert((max_depth - start_depth) % packet_size == 0);
1745
1746 // If patch depth is a multiple of packet size, it's guaranteed that
1747 // we can process all values in depth dimension with packets.
1748 const Index max_vectorized_depth =
1749 patch_depth_is_multiple_of_packet_size
1750 ? max_depth
1751 : max_depth - packet_size;
1752
1753 Index d = start_depth;
1754
1755 // 1. Process depth dimension with vectorized instructions.
1756 for (; d < max_vectorized_depth; d += packet_size) {
1757 eigen_assert(k < peeled_k);
1758 const Packet packet = pad ? pset1<Packet>(Scalar(0))
1759 : rhs.packetNoPadding(d, base_idx);
1760 internal::pstoreu(block, packet);
1761 block += packet_size;
1762 k += packet_size;
1763 }
1764
1765 // 2. Finish with coefficients.
1766 if (!patch_depth_is_multiple_of_packet_size) {
1767 for (; d < max_depth; d++) {
1768 eigen_assert(k < peeled_k);
1769 *block = pad ? Scalar(0) : rhs.coeffNoPadding(d, base_idx);
1770 ++block;
1771 ++k;
1772 }
1773 }
1774 }
1775 }
1776 }
1777
1778 // The loop above should fill peeled_k elements.
1779 eigen_assert(peeled_k == k);
1780
1781 // Fill remaining elements using loadCoeffStandard.
1782 for (; k < rows; ++k) {
1783 *block = lm.loadCoeffStandard(k);
1784 ++block;
1785 }
1786 }
1787 }
1788};
1789#endif // defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
1790
1791} // namespace internal
1792
1793/** CuboidConvolution
1794 * \ingroup CXX11_NeuralNetworks_Module
1795 *
1796 * \brief Applies a 3D convolution over a multichannel input voxel block.
1797 *
1798 * The input parameter is expected to be a tensor with a rank of 4 or more
1799 * (channels, depth, height, width, and optionally others).
1800 * The kernel parameter is expected to be a 5D tensor (filters, channels,
1801 * kernel_depth, kernel_height, kernel_width).
1802 * The result can be assigned to a tensor of rank equal to the rank of the
1803 * input. The dimensions of the result will be filters, depth, height, width
1804 * (and others if applicable).
1805 *
1806 * The input and kernel have to be in the same layout, and both row-major and
1807 * col-major are supported. The shapes given above are for col-major layout.
1808 * For row-major, all dimensions should be reversed.
1809 *
1810 * It is possible to swap the order of the depth, width, and height dimensions
1811 * provided that the same order is used in the input, the kernel, and the
1812 * output.
1813 */
1814template <typename Input, typename Kernel>
1815EIGEN_ALWAYS_INLINE static const std::conditional_t<
1816 internal::traits<Input>::Layout == ColMajor,
1817 TensorReshapingOp<
1818 const DSizes<typename internal::traits<Input>::Index,
1819 internal::traits<Input>::NumDimensions>,
1820 const TensorContractionOp<
1821 const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
1822 const TensorReshapingOp<
1823 const DSizes<typename internal::traits<Input>::Index, 2>,
1824 const Kernel>,
1825 const TensorReshapingOp<
1826 const DSizes<typename internal::traits<Input>::Index, 2>,
1827 const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
1828 const Input> > > >,
1829 TensorReshapingOp<
1830 const DSizes<typename internal::traits<Input>::Index,
1831 internal::traits<Input>::NumDimensions>,
1832 const TensorContractionOp<
1833 const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
1834 const TensorReshapingOp<
1835 const DSizes<typename internal::traits<Input>::Index, 2>,
1836 const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
1837 const Input> >,
1838 const TensorReshapingOp<
1839 const DSizes<typename internal::traits<Input>::Index, 2>,
1840 const Kernel> > > >
1841CuboidConvolution(const Input& input, const Kernel& kernel,
1842 const Index stridePlanes = 1, const Index strideRows = 1,
1843 const Index strideCols = 1,
1844 const PaddingType padding_type = PADDING_SAME) {
1845 typedef typename internal::traits<Input>::Index TensorIndex;
1846 TensorRef<Tensor<typename internal::traits<Input>::Scalar,
1847 internal::traits<Input>::NumDimensions,
1848 internal::traits<Input>::Layout, TensorIndex> >
1849 in(input);
1850 TensorRef<Tensor<typename internal::traits<Kernel>::Scalar,
1851 internal::traits<Kernel>::NumDimensions,
1852 internal::traits<Kernel>::Layout, TensorIndex> >
1853 kern(kernel);
1854
1855 EIGEN_STATIC_ASSERT(
1856 internal::traits<Input>::Layout == internal::traits<Kernel>::Layout,
1857 YOU_MADE_A_PROGRAMMING_MISTAKE);
1858 static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
1859 static const int NumDims = internal::traits<Input>::NumDimensions;
1860
1861 // Number of filters to apply. This is the same as the output depth of the
1862 // result.
1863 const TensorIndex kernelFilters =
1864 isColMajor ? kern.dimensions()[0] : kern.dimensions()[4];
1865 const TensorIndex kernelChannels =
1866 isColMajor ? kern.dimensions()[1] : kern.dimensions()[3];
1867
1868 // Spatial size of the kernel.
1869 const TensorIndex kernelPlanes =
1870 isColMajor ? kern.dimensions()[2] : kern.dimensions()[2];
1871 const TensorIndex kernelRows =
1872 isColMajor ? kern.dimensions()[3] : kern.dimensions()[1];
1873 const TensorIndex kernelCols =
1874 isColMajor ? kern.dimensions()[4] : kern.dimensions()[0];
1875
1876 if (isColMajor) {
1877 eigen_assert(kernelChannels == in.dimension(0));
1878 } else {
1879 eigen_assert(kernelChannels == in.dimension(NumDims - 1));
1880 }
1881
1882 const TensorIndex inputPlanes =
1883 isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
1884 const TensorIndex inputRows =
1885 isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
1886 const TensorIndex inputCols =
1887 isColMajor ? in.dimension(3) : in.dimension(NumDims - 4);
1888
1889 TensorIndex out_planes;
1890 TensorIndex out_height;
1891 TensorIndex out_width;
1892 switch (padding_type) {
1893 case PADDING_VALID:
1894 out_planes = Eigen::divup(inputPlanes - kernelPlanes + 1,
1895 static_cast<TensorIndex>(stridePlanes));
1896 out_height = Eigen::divup(inputRows - kernelRows + 1,
1897 static_cast<TensorIndex>(strideRows));
1898 out_width = Eigen::divup(inputCols - kernelCols + 1,
1899 static_cast<TensorIndex>(strideCols));
1900 break;
1901 case PADDING_SAME:
1902 out_planes =
1903 Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes));
1904 out_height =
1905 Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows));
1906 out_width = Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols));
1907 break;
1908 default:
1909 out_planes = 0;
1910 out_height = 0;
1911 out_width = 0;
1912 eigen_assert(false && "unexpected padding");
1913 }
1914
1915 DSizes<TensorIndex, 2> kernel_dims;
1916 if (isColMajor) {
1917 kernel_dims[0] = kernelFilters;
1918 kernel_dims[1] = kernelChannels * kernelPlanes * kernelRows * kernelCols;
1919 } else {
1920 kernel_dims[0] = kernelChannels * kernelPlanes * kernelRows * kernelCols;
1921 kernel_dims[1] = kernelFilters;
1922 }
1923
1924 // Molds the output of the patch extraction result into a 2D tensor:
1925 // - the first dimension (dims[0]): the patch values to be multiplied with the
1926 // kernels
1927 // - the second dimension (dims[1]): everything else
1928 DSizes<TensorIndex, 2> pre_contract_dims;
1929 if (isColMajor) {
1930 pre_contract_dims[0] =
1931 kernelChannels * kernelPlanes * kernelRows * kernelCols;
1932 pre_contract_dims[1] = out_planes * out_height * out_width;
1933 for (int i = 4; i < NumDims; ++i) {
1934 pre_contract_dims[1] *= in.dimension(i);
1935 }
1936 } else {
1937 pre_contract_dims[1] =
1938 kernelChannels * kernelPlanes * kernelRows * kernelCols;
1939 pre_contract_dims[0] = out_planes * out_height * out_width;
1940 for (int i = 0; i < NumDims - 4; ++i) {
1941 pre_contract_dims[0] *= in.dimension(i);
1942 }
1943 }
1944
1945 array<IndexPair<TensorIndex>, 1> contract_dims;
1946 contract_dims[0] = IndexPair<TensorIndex>(1, 0);
1947
1948 // Molds the output of the contraction into the shape expected by the user
1949 // (assuming ColMajor):
1950 // - 1st dim: kernel filters
1951 // - 2nd dim: output depth
1952 // - 3nd dim: output height
1953 // - 4rd dim: output width
1954 // - 5th dim and beyond: everything else including batch size
1955 DSizes<TensorIndex, NumDims> post_contract_dims;
1956 if (isColMajor) {
1957 post_contract_dims[0] = kernelFilters;
1958 post_contract_dims[1] = out_planes;
1959 post_contract_dims[2] = out_height;
1960 post_contract_dims[3] = out_width;
1961 for (int i = 4; i < NumDims; ++i) {
1962 post_contract_dims[i] = in.dimension(i);
1963 }
1964 } else {
1965 post_contract_dims[NumDims - 1] = kernelFilters;
1966 post_contract_dims[NumDims - 2] = out_planes;
1967 post_contract_dims[NumDims - 3] = out_height;
1968 post_contract_dims[NumDims - 4] = out_width;
1969 for (int i = 0; i < NumDims - 4; ++i) {
1970 post_contract_dims[i] = in.dimension(i);
1971 }
1972 }
1973
1974 return choose(
1975 Cond<internal::traits<Input>::Layout == ColMajor>(),
1976 kernel.reshape(kernel_dims)
1977 .contract(input
1978 .extract_volume_patches(
1979 kernelPlanes, kernelRows, kernelCols, stridePlanes,
1980 strideRows, strideCols, padding_type)
1981 .reshape(pre_contract_dims),
1982 contract_dims)
1983 .reshape(post_contract_dims),
1984 input
1985 .extract_volume_patches(kernelPlanes, kernelRows, kernelCols,
1986 stridePlanes, strideRows, strideCols,
1987 padding_type)
1988 .reshape(pre_contract_dims)
1989 .contract(kernel.reshape(kernel_dims), contract_dims)
1990 .reshape(post_contract_dims));
1991}
1992
1993} // end namespace Eigen
1994
1995#endif // TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_
1996