1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_CONTRACTION_KERNEL_H_
17#define TENSORFLOW_CORE_KERNELS_EIGEN_CONTRACTION_KERNEL_H_
18
19// Depending on a build configuration this header provides custom kernel for
20// Eigen tensor contractions (small matrix multiplication kernel used to
21// multiple together blocks of the original tensors).
22//
23// 1) --define tensorflow_mkldnn_contraction_kernel=1
24// Use Mkldnn single threaded sgemm. The mkldnn kernels are generated at
25// runtime and use avx/avx2/fma/avx512 based on cpu status registers
26// (https://en.wikipedia.org/wiki/CPUID).
27//
28// If you use `tensor.contract(other_tensor)` in your code, you must include
29// this header to get the benefit of custom contraction kernel:
30//
31// #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
32// #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
33// #endif
34
35#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
36
37// FixedPoint header must be included after Tensor.
38// clang-format off
39#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
40// clang-format on
41
42#if defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL)
43#include "dnnl.h"
44#endif
45
46#include "tensorflow/core/platform/dynamic_annotations.h"
47
48namespace Eigen {
49namespace internal {
50
51#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
52// Returns `true` iff we can use custom contraction kernels. This is a runtime
53// check, that uses environment variables.
54EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE bool UseCustomContractionKernels();
55
56// Pack a 2D block of a Tensor expression into contiguous block of memory with
57// col-major storage order. We do not have access to the underlying Tensor
58// expression, we only have a DataMapper (TensorContractionInputMapper for
59// tensor contractions, or blas_data_mapper for plain tensors), that provides a
60// two-dimensional view into the Tensor expression.
61//
62// Default Eigen gemm_pack_rhs and gemm_pack_lhs pack blocks of tensor
63// expressions into the packed format described in "Anatomy of High-Performance
64// Matrix Multiplication" paper (1). Eigen::internal::gebp_kernel relies on this
65// packing format for efficient micro-panel multiplication.
66//
67// This simple packing can be used with any '?gemm' function from BLAS
68// libraries, that work with col-major matrices.
69//
70// (1) http://www.cs.utexas.edu/~flame/pubs/GotoTOMS_revision.pdf
71//
72// IMPORTANT: `gemm_pack_colmajor_block` always packs the block in column major
73// order, DataMapperStorageOrder specifies the storage order of the underlying
74// Tensor expression.
75template <typename Scalar, typename IndexType, typename DataMapper,
76 int DataMapperStorageOrder>
77struct gemm_pack_colmajor_block;
78
79// gemm_pack_colmajor_block for ColMajor storage order.
80template <typename Scalar, typename IndexType, typename DataMapper>
81struct gemm_pack_colmajor_block<Scalar, IndexType, DataMapper,
82 /*DataMapperStorageOrder*/ ColMajor> {
83 typedef typename internal::packet_traits<Scalar>::type Packet;
84 typedef typename DataMapper::LinearMapper LinearMapper;
85
86 enum { PacketSize = internal::packet_traits<Scalar>::size };
87
88 EIGEN_DONT_INLINE
89 void operator()(Scalar* block, const DataMapper& data_mapper, IndexType rows,
90 IndexType cols) {
91 const IndexType unrolled_rows = rows - 4 * PacketSize;
92 const IndexType vectorized_rows = rows - PacketSize;
93
94 for (IndexType col = 0; col < cols; ++col) {
95 LinearMapper lm = data_mapper.getLinearMapper(0, col);
96
97 IndexType row = 0;
98 // Give compiler a strong possibility to unroll the loop.
99 for (; row <= unrolled_rows; row += 4 * PacketSize) {
100 for (IndexType j = 0; j < 4; ++j) {
101 const Packet p = lm.template loadPacket<Packet>(row + j * PacketSize);
102 internal::pstoreu(block + j * PacketSize, p);
103 }
104 block += 4 * PacketSize;
105 }
106 // Process remaining rows with packets.
107 for (; row <= vectorized_rows; row += PacketSize) {
108 const Packet p = lm.template loadPacket<Packet>(row);
109 internal::pstoreu(block, p);
110 block += PacketSize;
111 }
112 // Finalize with coefficients.
113 for (; row < rows; ++row) {
114 *block = lm(row);
115 ++block;
116 }
117 }
118 }
119};
120
121#endif // TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL
122
123// Enabled by build option: "--define tensorflow_mkldnn_contraction_kernel=1"
124#if defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL)
125
126template <typename Scalar, typename IndexType, typename OutputMapper,
127 bool ConjugateLhs = false, bool ConjugateRhs = false>
128struct dnnl_gemm_kernel;
129
130// dnnl_gemm_kernel for floats defined as a thin layer on top of mkldnn_sgemm.
131template <typename IndexType, typename OutputMapper, bool ConjugateLhs,
132 bool ConjugateRhs>
133struct dnnl_gemm_kernel</*Scalar*/ float, IndexType, OutputMapper, ConjugateLhs,
134 ConjugateRhs> {
135 static_assert(!ConjugateLhs, "DNNL kernel doesn't support ConjugateLhs");
136 static_assert(!ConjugateRhs, "DNNL kernel doesn't support ConjugateRhs");
137
138 static constexpr int kComputeStrideFromBlockDimensions = -1;
139
140 using LhsScalar = float;
141 using RhsScalar = float;
142 using ResScalar = float;
143
144 EIGEN_DONT_INLINE
145 void operator()(const OutputMapper& output, const LhsScalar* blockA,
146 const RhsScalar* blockB, const IndexType rows,
147 const IndexType depth, const IndexType cols, float alpha,
148 float beta, int ldA = kComputeStrideFromBlockDimensions,
149 int ldB = kComputeStrideFromBlockDimensions,
150 char transposeA = 'N', char transposeB = 'N') {
151 static const int max_index = (std::numeric_limits<int>::max)();
152
153 eigen_assert(max_index >= rows);
154 eigen_assert(max_index >= cols);
155 eigen_assert(max_index >= depth);
156 eigen_assert(max_index >= output.stride());
157
158 const int m = static_cast<int>(rows);
159 const int n = static_cast<int>(cols);
160 const int k = static_cast<int>(depth);
161
162 ldA = ldA == kComputeStrideFromBlockDimensions ? m : ldA;
163 ldB = ldB == kComputeStrideFromBlockDimensions ? k : ldB;
164 const int ldC = static_cast<int>(output.stride());
165
166 // DNNL takes row-major matrices. Our packed column-major matrices can be
167 // viewed as a transposed row-major matrix, i.e.,
168 // C_colmajor = C_rowmajor^T = (A_rowmajor * B_rowmajor)^T
169 // = B_rowmajor^T * A_rowmajor^T
170 // = B_colmajor * A_colmajor
171 // So we can just swap the input matrices A and B for DNNL.
172 // TODO(penporn): Switch to row-major packing instead.
173 dnnl_status_t st =
174 dnnl_sgemm(transposeB, transposeA, n, m, k, alpha, blockB, ldB, blockA,
175 ldA, beta, const_cast<ResScalar*>(output.data()), ldC);
176 eigen_assert(st == 0);
177
178#if DYNAMIC_ANNOTATIONS_ENABLED == 1 || defined(MEMORY_SANITIZER)
179 for (IndexType col = 0; col < cols; ++col) {
180 ResScalar* row_base = &output(0, col);
181 EIGEN_UNUSED_VARIABLE(row_base); // Suppress unused variable error.
182 TF_ANNOTATE_MEMORY_IS_INITIALIZED(row_base, sizeof(ResScalar) * rows);
183 }
184#endif
185
186 // eigen_assert is a no-op in optimized mode so we add these to avoid
187 // compiler's unused-variable errors.
188 EIGEN_UNUSED_VARIABLE(max_index);
189 EIGEN_UNUSED_VARIABLE(st);
190 }
191};
192
193template <typename IndexType, typename OutputMapper, bool ConjugateLhs = false,
194 bool ConjugateRhs = false>
195struct mkldnn_gemm_s8u8s32_kernel {
196 static_assert(!ConjugateLhs, "DNNL kernel doesn't support ConjugateLhs");
197 static_assert(!ConjugateRhs, "DNNL kernel doesn't support ConjugateRhs");
198
199 static constexpr int kComputeStrideFromBlockDimensions = -1;
200
201 using LhsScalar = Eigen::QInt8;
202 using RhsScalar = Eigen::QUInt8;
203 using ResScalar = Eigen::QInt32;
204
205 EIGEN_DONT_INLINE
206 void operator()(const OutputMapper& output, const LhsScalar* blockA,
207 const RhsScalar* blockB, const IndexType rows,
208 const IndexType depth, const IndexType cols, float alpha,
209 float beta, int ldA = kComputeStrideFromBlockDimensions,
210 int ldB = kComputeStrideFromBlockDimensions,
211 char transposeA = 'N', char transposeB = 'N') {
212 static const int max_index = (std::numeric_limits<int>::max)();
213
214 eigen_assert(max_index >= rows);
215 eigen_assert(max_index >= cols);
216 eigen_assert(max_index >= depth);
217 eigen_assert(max_index >= output.stride());
218
219 const int m = static_cast<int>(rows);
220 const int n = static_cast<int>(cols);
221 const int k = static_cast<int>(depth);
222
223 ldA = ldA == kComputeStrideFromBlockDimensions ? m : ldA;
224 ldB = ldB == kComputeStrideFromBlockDimensions ? k : ldB;
225 const int ldC = static_cast<int>(output.stride());
226
227 // Currently we support only symmetric quantization with zero point at 0.
228 const int8_t ao = 0;
229 const int8_t bo = 0;
230
231 // Don't add any offset to the result C.
232 const char offsetc = 'F';
233 const int32_t co = 0;
234
235 const auto* A = reinterpret_cast<const int8_t*>(blockA);
236 const auto* B = reinterpret_cast<const uint8_t*>(blockB);
237 auto* C = reinterpret_cast<int32_t*>(const_cast<ResScalar*>(output.data()));
238
239 // DNNL takes row-major matrices. Our packed column-major matrices can be
240 // viewed as a transposed row-major matrix, i.e., C_colmajor = C_rowmajor^T.
241 // C_colmajor = C_rowmajor^T = (A_rowmajor * B_rowmajor)^T
242 // = B_rowmajor^T * A_rowmajor^T
243 // = B_colmajor * A_colmajor
244 // So we can just swap the input matrices A and B for DNNL.
245 // TODO(penporn): Switch to row-major packing instead.
246 dnnl_status_t st = dnnl_gemm_u8s8s32(transposeB, transposeA, offsetc, //
247 n, m, k, //
248 alpha, //
249 B, ldB, bo, //
250 A, ldA, ao, //
251 beta, //
252 C, ldC, &co);
253 eigen_assert(st == 0);
254
255#if DYNAMIC_ANNOTATIONS_ENABLED == 1 || defined(MEMORY_SANITIZER)
256 for (IndexType col = 0; col < cols; ++col) {
257 ResScalar* row_base = &output(0, col);
258 EIGEN_UNUSED_VARIABLE(row_base); // Suppress unused variable error.
259 TF_ANNOTATE_MEMORY_IS_INITIALIZED(row_base, sizeof(ResScalar) * rows);
260 }
261#endif
262
263 // eigen_assert is a no-op in optimized mode so we add these to avoid
264 // compiler's unused-variable errors.
265 EIGEN_UNUSED_VARIABLE(max_index);
266 EIGEN_UNUSED_VARIABLE(st);
267 }
268};
269
270// For mkldnn_sgemm having the right dimensions (especially for small matrices)
271// is more important than fitting all the working set in L1/L2 caches.
272// TODO(ezhulenev): Do better heuristics.
273template <typename StorageIndex, int sharding_type>
274class TensorContractionBlocking<float, float, float, StorageIndex,
275 sharding_type> {
276 // For now mkldnn has only mkldnn_sgemm (gemm for floats).
277 using Scalar = float;
278
279 // Adjust the block sizes to work well with mkldnn kernels.
280
281 // Multiply default choice of block size along M and N dimensions.
282 // TODO(ezhulenev): Explore if this can work in general (kScaleM=2.0 worked
283 // well in some of models).
284 static constexpr float kScaleM = 1.5;
285 static constexpr float kScaleN = 1.0;
286
287 // Mkldnn Avx/Avx2/Avx512 unroll factors are: 8/16/48.
288 static constexpr StorageIndex kUnrollM = 48;
289
290 // Mkldnn Avx/Avx2/Avx512 unroll factors are: 6/6/8.
291 static constexpr StorageIndex kUnrollN = 24;
292
293 public:
294 TensorContractionBlocking(StorageIndex k, StorageIndex m, StorageIndex n,
295 StorageIndex num_threads = 1)
296 : kc_(k), mc_(m), nc_(n) {
297 // 1. Compute block sizes using default Eigen heuristics.
298 if (sharding_type == ShardByCol) {
299 computeProductBlockingSizes<Scalar, Scalar, 1>(kc_, mc_, nc_,
300 num_threads);
301 } else {
302 computeProductBlockingSizes<Scalar, Scalar, 1>(kc_, nc_, mc_,
303 num_threads);
304 }
305
306 // If dimensions do not pass basic sanity checks return immediately.
307 if (kc_ <= 0 || mc_ <= 0 || nc_ <= 0) return;
308
309 // If we are using default Eigen gebp kernel there is no need to adjust the
310 // block sizes for DNNL.
311 if (!UseCustomContractionKernels()) return;
312
313 // 2. And refine them to work well with mkldnn sgemm.
314 mc_ = (std::min)(
315 m, Eigen::divup(static_cast<StorageIndex>(mc_ * kScaleM), kUnrollM) *
316 kUnrollM);
317 nc_ = (std::min)(
318 n, Eigen::divup(static_cast<StorageIndex>(nc_ * kScaleN), kUnrollN) *
319 kUnrollN);
320
321 // We split Kth dimensions in roughly equal slices.
322 StorageIndex target_k_slices =
323 (std::max)(StorageIndex(1), Eigen::divup(k, kc_));
324 StorageIndex packet_size = internal::packet_traits<Scalar>::size;
325 if (packet_size < 8) packet_size = 8;
326 StorageIndex target_bk =
327 Eigen::divup(k / target_k_slices, packet_size) * packet_size;
328 kc_ = (std::min)(k, target_bk);
329 }
330
331 EIGEN_ALWAYS_INLINE StorageIndex kc() const { return kc_; }
332 EIGEN_ALWAYS_INLINE StorageIndex mc() const { return mc_; }
333 EIGEN_ALWAYS_INLINE StorageIndex nc() const { return nc_; }
334
335 private:
336 StorageIndex kc_;
337 StorageIndex mc_;
338 StorageIndex nc_;
339};
340
341template <typename StorageIndex, int sharding_type>
342class TensorContractionBlocking<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
343 StorageIndex, sharding_type> {
344 // TODO(ezhulenev): Define proper gebp_traits in Eigen for quantized types?
345
346 // Default Eigen block heuristics for `QInt8xQUInt8 -> QInt32` are wrong.
347 // Mostly because gebp_traits are not correctly defined. But we know that we
348 // are going to use s8u8s32_gemm from DNNL, so we use float heuristics, and
349 // adjust them to work well with DNNL.
350 using LhsScalar = Eigen::QInt8;
351 using RhsScalar = Eigen::QUInt8;
352 using ResScalar = Eigen::QInt32;
353
354 // Multiply default choice of block size along M, N and K dimensions.
355 static constexpr float kScaleM = 1.5;
356 static constexpr float kScaleN = 1.5;
357 static constexpr float kScaleK = 1.5;
358
359 public:
360 TensorContractionBlocking(StorageIndex k, StorageIndex m, StorageIndex n,
361 StorageIndex num_threads = 1)
362 : kc_(k), mc_(m), nc_(n) {
363 // Each dimension is a multiple of 32 (fits into _m256i).
364 mc_ = (std::min)(m, static_cast<StorageIndex>(192));
365 nc_ = (std::min)(n, static_cast<StorageIndex>(288));
366 kc_ = (std::min)(k, static_cast<StorageIndex>(320));
367 }
368
369 EIGEN_ALWAYS_INLINE StorageIndex kc() const { return kc_; }
370 EIGEN_ALWAYS_INLINE StorageIndex mc() const { return mc_; }
371 EIGEN_ALWAYS_INLINE StorageIndex nc() const { return nc_; }
372
373 private:
374 StorageIndex kc_;
375 StorageIndex mc_;
376 StorageIndex nc_;
377};
378
379// If the Lhs or Rhs Tensor expressions are already evaluated and have access to
380// raw data, we can skip packing step and setup pointers and a stride to the
381// underlying memory buffer and pass them directly to Gemm.
382template <typename Scalar, typename StorageIndex>
383struct ColMajorBlock {
384 bool is_direct_access;
385
386 // Valid iff `is_direct_access == false`
387 Scalar* packed_data;
388
389 // Valid iff `is_direct_access == true`
390 Scalar* raw_data;
391 StorageIndex stride;
392 char transpose;
393};
394
395template <typename DataMapper>
396struct DirectColMajorAccess {
397 enum { value = false };
398
399 template <typename Scalar, typename StorageIndex>
400 static bool block(const typename DataMapper::SubMapper& data_mapper,
401 const StorageIndex rows, const StorageIndex cols,
402 const StorageIndex num_kernels,
403 ColMajorBlock<Scalar, StorageIndex>* block) {
404 eigen_assert(false && "Not implemented");
405 return false;
406 }
407};
408
409// If we have an access to raw memory of the contraction input, we can safely
410// skip packing if:
411// (1) Packing is a no-op.
412// (2) Packed block will be used just once.
413//
414// If a packed block is used many times, it's more efficient to pack it into
415// contiguous block of memory to reduce pressure on TLB.
416//
417// TODO(ezhulenev): Add support for more tensor expressions that matters.
418#define REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_EXPR) \
419 template <typename Scalar, typename StorageIndex, int Side, typename Device, \
420 typename nocontract_t, typename contract_t, int packet_size, \
421 int Alignment> \
422 struct DirectColMajorAccess<TensorContractionInputMapper< \
423 Scalar, StorageIndex, Side, TensorEvaluator<TENSOR_EXPR, Device>, \
424 nocontract_t, contract_t, packet_size, /*inner_dim_contiguous=*/true, \
425 /*inner_dim_reordered=*/false, Alignment>> { \
426 enum { value = true }; \
427 \
428 using DataMapper = TensorContractionInputMapper< \
429 Scalar, StorageIndex, Side, TensorEvaluator<TENSOR_EXPR, Device>, \
430 nocontract_t, contract_t, packet_size, /*inner_dim_contiguous=*/true, \
431 /*inner_dim_reordered=*/false, Alignment>; \
432 \
433 static bool block(const typename DataMapper::SubMapper& data_mapper, \
434 const StorageIndex rows, const StorageIndex cols, \
435 const StorageIndex num_kernels, \
436 ColMajorBlock<Scalar, StorageIndex>* block) { \
437 static_assert(DataMapper::DirectOffsets == true, \
438 "DataMapper must support direct offsets"); \
439 \
440 const StorageIndex vert_offset = data_mapper.vert_offset(); \
441 const StorageIndex horiz_offset = data_mapper.horiz_offset(); \
442 const StorageIndex stride = \
443 Side == Lhs ? data_mapper.base_mapper().stride() \
444 : data_mapper.base_mapper().nocontract_strides()[0]; \
445 const Scalar* data = data_mapper.base_mapper().tensor().data(); \
446 data = Side == Lhs ? data : data + vert_offset + horiz_offset * stride; \
447 \
448 const bool is_no_op_packing = stride == rows; \
449 const StorageIndex addressable_mem = (stride * cols * sizeof(Scalar)); \
450 const bool use_direct_access = \
451 is_no_op_packing || num_kernels == 1 /* used once */ || \
452 ((num_kernels == 2) && \
453 (addressable_mem < (256 << 10) /* 256 kb */)); \
454 \
455 if (use_direct_access) { \
456 block->is_direct_access = true; \
457 block->raw_data = const_cast<Scalar*>(data); \
458 block->stride = stride; \
459 block->transpose = 'N'; \
460 return true; \
461 } \
462 return false; \
463 } \
464 }
465
466#define SIMPLE_TENSOR const Tensor<Scalar, 2, Eigen::ColMajor, StorageIndex>
467
468#define TENSOR_MAP_ROWMAJOR \
469 const TensorMap<Tensor<const Scalar, 2, Eigen::RowMajor, StorageIndex>, \
470 Eigen::Aligned>
471
472#define TENSOR_MAP_COLMAJOR \
473 const TensorMap<Tensor<const Scalar, 2, Eigen::ColMajor, StorageIndex>, \
474 Eigen::Aligned>
475
476#define TENSOR_MAP_CONST_ROWMAJOR \
477 const TensorMap<Tensor<Scalar, 2, Eigen::RowMajor, StorageIndex>, \
478 Eigen::Aligned>
479
480#define TENSOR_MAP_CONST_COLMAJOR \
481 const TensorMap<Tensor<Scalar, 2, Eigen::ColMajor, StorageIndex>, \
482 Eigen::Aligned>
483
484// This is reshaped convolution filter from `eigen_spatial_convolutions.h`.
485#define TENSOR_RESHAPE \
486 const TensorReshapingOp< \
487 const Eigen::DSizes<StorageIndex, 2>, \
488 const TensorMap<Tensor<const Scalar, 4, Eigen::RowMajor, StorageIndex>, \
489 Eigen::Aligned>>
490
491REGISTER_DIRECT_COL_MAJOR_ACCESS(SIMPLE_TENSOR);
492REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_MAP_ROWMAJOR);
493REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_MAP_COLMAJOR);
494REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_MAP_CONST_ROWMAJOR);
495REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_MAP_CONST_COLMAJOR);
496REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_RESHAPE);
497
498#undef SIMPLE_TENSOR
499#undef TENSOR_MAP_ROWMAJOR
500#undef TENSOR_MAP_COLMAJOR
501#undef TENSOR_MAP_CONST_ROWMAJOR
502#undef TENSOR_MAP_CONST_COLMAJOR
503#undef TENSOR_RESHAPE
504#undef REGISTER_DIRECT_COL_MAJOR_ACCESS
505
506template <typename ResScalar, typename LhsScalar, typename RhsScalar,
507 typename StorageIndex, typename OutputMapper>
508struct GemmKernelProvider {
509 enum { Defined = 0 };
510 using GemmKernel = void;
511};
512
513template <typename StorageIndex, typename OutputMapper>
514struct GemmKernelProvider<float, float, float, StorageIndex, OutputMapper> {
515 enum { Defined = 1 };
516 using GemmKernel = dnnl_gemm_kernel<float, StorageIndex, OutputMapper>;
517};
518
519template <typename StorageIndex, typename OutputMapper>
520struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
521 StorageIndex, OutputMapper> {
522 enum { Defined = 1 };
523 using GemmKernel = mkldnn_gemm_s8u8s32_kernel<StorageIndex, OutputMapper>;
524};
525
526// NOTE: 'std::enable_if' doesn't work for template specializations. See
527// "default template argument in a class template partial specialization".
528
529// Tensor contraction kernel that can fallback on Eigen gebp_kernel at runtime.
530#define REGISTER_TENSOR_CONTRACTION_KERNEL_WITH_FALLBACK( \
531 RES_SCALAR, LHS_SCALAR, RHS_SCALAR) \
532 \
533 template <typename StorageIndex, typename OutputMapper, typename LhsMapper, \
534 typename RhsMapper> \
535 struct TensorContractionKernel<RES_SCALAR, LHS_SCALAR, RHS_SCALAR, \
536 StorageIndex, OutputMapper, LhsMapper, \
537 RhsMapper> { \
538 TensorContractionKernel(StorageIndex m, StorageIndex k, StorageIndex n, \
539 StorageIndex bm, StorageIndex bk, StorageIndex bn) \
540 : m(m), k(k), n(n), bm(bm), bk(bk), bn(bn) {} \
541 \
542 enum { HasBeta = true }; \
543 \
544 using ResScalar = RES_SCALAR; \
545 using LhsScalar = LHS_SCALAR; \
546 using RhsScalar = RHS_SCALAR; \
547 \
548 using Traits = typename internal::gebp_traits<LhsScalar, RhsScalar>; \
549 \
550 using LhsBlock = ColMajorBlock<LhsScalar, StorageIndex>; \
551 using RhsBlock = ColMajorBlock<RhsScalar, StorageIndex>; \
552 \
553 using DirectLhsAccess = DirectColMajorAccess<LhsMapper>; \
554 using DirectRhsAccess = DirectColMajorAccess<RhsMapper>; \
555 \
556 /* Packed Lhs/Rhs block memory allocator.*/ \
557 typedef TensorContractionBlockMemAllocator<LhsScalar, RhsScalar> \
558 BlockMemAllocator; \
559 typedef typename BlockMemAllocator::BlockMemHandle BlockMemHandle; \
560 \
561 using LhsPacker = \
562 gemm_pack_colmajor_block<LhsScalar, StorageIndex, \
563 typename LhsMapper::SubMapper, ColMajor>; \
564 using RhsPacker = \
565 gemm_pack_colmajor_block<RhsScalar, StorageIndex, \
566 typename RhsMapper::SubMapper, ColMajor>; \
567 \
568 using GemmKernelProviderType = \
569 GemmKernelProvider<ResScalar, LhsScalar, RhsScalar, StorageIndex, \
570 OutputMapper>; \
571 static_assert( \
572 GemmKernelProviderType::Defined, \
573 "Custom GEMM kernel is not registered for given scalar types"); \
574 using GemmKernel = typename GemmKernelProviderType::GemmKernel; \
575 \
576 /* Fallback on default Eigen pack and GEBP kernel if custom contraction */ \
577 /* kernels disabled at runtime. */ \
578 using EigenLhsPacker = \
579 gemm_pack_lhs<LhsScalar, StorageIndex, typename LhsMapper::SubMapper, \
580 Traits::mr, Traits::LhsProgress, \
581 typename Traits::LhsPacket4Packing, ColMajor>; \
582 using EigenRhsPacker = \
583 gemm_pack_rhs<RhsScalar, StorageIndex, typename RhsMapper::SubMapper, \
584 Traits::nr, ColMajor>; \
585 using GebpKernel = \
586 gebp_kernel<LhsScalar, RhsScalar, StorageIndex, OutputMapper, \
587 Traits::mr, Traits::nr, /*ConjugateLhs*/ false, \
588 /*ConjugateRhs*/ false>; \
589 \
590 template <typename Device> \
591 EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device& d, LhsBlock* lhs_block, \
592 RhsBlock* rhs_block) { \
593 return BlockMemAllocator::allocate( \
594 d, bm, bk, bn, &lhs_block->packed_data, &rhs_block->packed_data); \
595 } \
596 \
597 template <typename Device> \
598 EIGEN_DEVICE_FUNC BlockMemHandle \
599 allocateSlices(Device& d, const int num_lhs, const int num_rhs, \
600 const int num_slices, std::vector<LhsBlock>* lhs_blocks, \
601 std::vector<RhsBlock>* rhs_blocks) { \
602 eigen_assert(num_slices > 0); \
603 std::vector<std::vector<LhsScalar*>> lhs_mem(num_slices); \
604 std::vector<std::vector<RhsScalar*>> rhs_mem(num_slices); \
605 \
606 BlockMemHandle block_mem = BlockMemAllocator::allocateSlices( \
607 d, bm, bk, bn, num_lhs, num_rhs, num_slices, lhs_mem.data(), \
608 rhs_mem.data()); \
609 \
610 for (Index x = 0; x < num_slices; x++) { \
611 if (num_lhs > 0) lhs_blocks[x].resize(num_lhs); \
612 for (Index m = 0; m < num_lhs; m++) { \
613 lhs_blocks[x][m].packed_data = lhs_mem[x][m]; \
614 } \
615 if (num_rhs > 0) rhs_blocks[x].resize(num_rhs); \
616 for (Index n = 0; n < num_rhs; n++) { \
617 rhs_blocks[x][n].packed_data = rhs_mem[x][n]; \
618 } \
619 } \
620 \
621 return block_mem; \
622 } \
623 \
624 template <typename Device> \
625 EIGEN_DEVICE_FUNC static void deallocate(Device& d, \
626 BlockMemHandle handle) { \
627 BlockMemAllocator::deallocate(d, handle); \
628 } \
629 \
630 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packLhs( \
631 LhsBlock* lhsBlock, const typename LhsMapper::SubMapper& data_mapper, \
632 const StorageIndex depth, const StorageIndex rows) { \
633 if (UseCustomContractionKernels()) { \
634 const bool is_direct_access = \
635 DirectLhsAccess::value && \
636 DirectLhsAccess::block(data_mapper, rows, depth, \
637 bn > 0 ? divup(n, bn) : 0, lhsBlock); \
638 \
639 if (!is_direct_access) { \
640 lhsBlock->is_direct_access = false; \
641 LhsPacker()(lhsBlock->packed_data, data_mapper, rows, depth); \
642 } \
643 } else { \
644 lhsBlock->is_direct_access = false; \
645 EigenLhsPacker()(lhsBlock->packed_data, data_mapper, depth, rows, \
646 /*stride*/ 0, /*offset*/ 0); \
647 } \
648 } \
649 \
650 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packRhs( \
651 RhsBlock* rhsBlock, const typename RhsMapper::SubMapper& data_mapper, \
652 const StorageIndex depth, const StorageIndex cols) { \
653 if (UseCustomContractionKernels()) { \
654 const bool is_direct_access = \
655 DirectRhsAccess::value && \
656 DirectRhsAccess::block(data_mapper, depth, cols, \
657 bm > 0 ? divup(m, bm) : 0, rhsBlock); \
658 \
659 if (!is_direct_access) { \
660 rhsBlock->is_direct_access = false; \
661 RhsPacker()(rhsBlock->packed_data, data_mapper, depth, cols); \
662 } \
663 } else { \
664 rhsBlock->is_direct_access = false; \
665 EigenRhsPacker()(rhsBlock->packed_data, data_mapper, depth, cols); \
666 } \
667 } \
668 \
669 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void invoke( \
670 const OutputMapper& output_mapper, const LhsBlock& lhsBlock, \
671 const RhsBlock& rhsBlock, const StorageIndex rows, \
672 const StorageIndex depth, const StorageIndex cols, const float alpha, \
673 const float beta) { \
674 if (UseCustomContractionKernels()) { \
675 if ((DirectLhsAccess::value && lhsBlock.is_direct_access) && \
676 (DirectRhsAccess::value && rhsBlock.is_direct_access)) { \
677 GemmKernel()(output_mapper, lhsBlock.raw_data, rhsBlock.raw_data, \
678 rows, depth, cols, alpha, beta, \
679 /*ldA=*/lhsBlock.stride, /*ldB=*/rhsBlock.stride, \
680 /*transposeA=*/lhsBlock.transpose, \
681 /*transposeB=*/rhsBlock.transpose); \
682 \
683 } else if (DirectLhsAccess::value && lhsBlock.is_direct_access) { \
684 GemmKernel()(output_mapper, lhsBlock.raw_data, rhsBlock.packed_data, \
685 rows, depth, cols, alpha, beta, \
686 /*ldA=*/lhsBlock.stride, \
687 /*ldB=*/GemmKernel::kComputeStrideFromBlockDimensions, \
688 /*transposeA=*/lhsBlock.transpose, /*transposeB=*/'N'); \
689 \
690 } else if (DirectRhsAccess::value && rhsBlock.is_direct_access) { \
691 GemmKernel()(output_mapper, lhsBlock.packed_data, rhsBlock.raw_data, \
692 rows, depth, cols, alpha, beta, \
693 /*ldA=*/GemmKernel::kComputeStrideFromBlockDimensions, \
694 /*ldB=*/rhsBlock.stride, /*transposeA=*/'N', \
695 /*transposeB=*/rhsBlock.transpose); \
696 \
697 } else { \
698 GemmKernel()(output_mapper, lhsBlock.packed_data, \
699 rhsBlock.packed_data, rows, depth, cols, alpha, beta); \
700 } \
701 } else { \
702 /* Gebp kernel does not support beta, so we have to clear memory in */ \
703 /* the output mapper manually. */ \
704 /* WARNING(ezhulenev): This is optimized into a memset in a loop, */ \
705 /* could be much slower for small matrices. Currently this code */ \
706 /* path used only for testing, and performance does not matter. */ \
707 if (beta == 0.0) { \
708 for (StorageIndex col = 0; col < cols; ++col) { \
709 ResScalar* output_base = &output_mapper(0, col); \
710 typedef Array<ResScalar, Dynamic, 1> OutputRow; \
711 typedef Map<OutputRow, 0, InnerStride<1>> OutputRowMap; \
712 OutputRowMap(output_base, rows).setZero(); \
713 } \
714 } \
715 \
716 GebpKernel()( \
717 output_mapper, lhsBlock.packed_data, rhsBlock.packed_data, rows, \
718 depth, cols, alpha, \
719 /*strideA*/ GemmKernel::kComputeStrideFromBlockDimensions, \
720 /*strideB*/ GemmKernel::kComputeStrideFromBlockDimensions, \
721 /*offsetA*/ 0, /*offsetB*/ 0); \
722 } \
723 } \
724 \
725 private: \
726 /* These are dimensions of the original Tensors, and selected block */ \
727 /* sizes. The actual block sizes passed to all function above might be */ \
728 /* smaller because of the partial blocks at the end. */ \
729 const StorageIndex m; \
730 const StorageIndex k; \
731 const StorageIndex n; \
732 const StorageIndex bm; \
733 const StorageIndex bk; \
734 const StorageIndex bn; \
735 }
736
737// Tensor contraction kernel that do not fallback on Eigen. Currently not all
738// data types are supported by Eigen data packing and default gebp_kernel.
739#define REGISTER_TENSOR_CONTRACTION_KERNEL_NO_FALLBACK(RES_SCALAR, LHS_SCALAR, \
740 RHS_SCALAR) \
741 \
742 template <typename StorageIndex, typename OutputMapper, typename LhsMapper, \
743 typename RhsMapper> \
744 struct TensorContractionKernel<RES_SCALAR, LHS_SCALAR, RHS_SCALAR, \
745 StorageIndex, OutputMapper, LhsMapper, \
746 RhsMapper> { \
747 TensorContractionKernel(StorageIndex m, StorageIndex k, StorageIndex n, \
748 StorageIndex bm, StorageIndex bk, StorageIndex bn) \
749 : m(m), k(k), n(n), bm(bm), bk(bk), bn(bn) {} \
750 \
751 enum { HasBeta = true }; \
752 \
753 using ResScalar = RES_SCALAR; \
754 using LhsScalar = LHS_SCALAR; \
755 using RhsScalar = RHS_SCALAR; \
756 \
757 using Traits = typename internal::gebp_traits<LhsScalar, RhsScalar>; \
758 \
759 using LhsBlock = ColMajorBlock<LhsScalar, StorageIndex>; \
760 using RhsBlock = ColMajorBlock<RhsScalar, StorageIndex>; \
761 \
762 using DirectLhsAccess = DirectColMajorAccess<LhsMapper>; \
763 using DirectRhsAccess = DirectColMajorAccess<RhsMapper>; \
764 \
765 /* Packed Lhs/Rhs block memory allocator.*/ \
766 typedef TensorContractionBlockMemAllocator<LhsScalar, RhsScalar> \
767 BlockMemAllocator; \
768 typedef typename BlockMemAllocator::BlockMemHandle BlockMemHandle; \
769 \
770 using LhsPacker = \
771 gemm_pack_colmajor_block<LhsScalar, StorageIndex, \
772 typename LhsMapper::SubMapper, ColMajor>; \
773 using RhsPacker = \
774 gemm_pack_colmajor_block<RhsScalar, StorageIndex, \
775 typename RhsMapper::SubMapper, ColMajor>; \
776 \
777 using GemmKernelProviderType = \
778 GemmKernelProvider<ResScalar, LhsScalar, RhsScalar, StorageIndex, \
779 OutputMapper>; \
780 static_assert( \
781 GemmKernelProviderType::Defined, \
782 "Custom GEMM kernel is not registered for given scalar types"); \
783 using GemmKernel = typename GemmKernelProviderType::GemmKernel; \
784 \
785 template <typename Device> \
786 EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device& d, LhsBlock* lhs_block, \
787 RhsBlock* rhs_block) { \
788 return BlockMemAllocator::allocate( \
789 d, bm, bk, bn, &lhs_block->packed_data, &rhs_block->packed_data); \
790 } \
791 \
792 template <typename Device> \
793 EIGEN_DEVICE_FUNC BlockMemHandle \
794 allocateSlices(Device& d, const int num_lhs, const int num_rhs, \
795 const int num_slices, std::vector<LhsBlock>* lhs_blocks, \
796 std::vector<RhsBlock>* rhs_blocks) { \
797 eigen_assert(num_slices > 0); \
798 std::vector<std::vector<LhsScalar*>> lhs_mem(num_slices); \
799 std::vector<std::vector<RhsScalar*>> rhs_mem(num_slices); \
800 \
801 BlockMemHandle block_mem = BlockMemAllocator::allocateSlices( \
802 d, bm, bk, bn, num_lhs, num_rhs, num_slices, lhs_mem.data(), \
803 rhs_mem.data()); \
804 \
805 for (Index x = 0; x < num_slices; x++) { \
806 if (num_lhs > 0) lhs_blocks[x].resize(num_lhs); \
807 for (Index m = 0; m < num_lhs; m++) { \
808 lhs_blocks[x][m].packed_data = lhs_mem[x][m]; \
809 } \
810 if (num_rhs > 0) rhs_blocks[x].resize(num_rhs); \
811 for (Index n = 0; n < num_rhs; n++) { \
812 rhs_blocks[x][n].packed_data = rhs_mem[x][n]; \
813 } \
814 } \
815 \
816 return block_mem; \
817 } \
818 \
819 template <typename Device> \
820 EIGEN_DEVICE_FUNC static void deallocate(Device& d, \
821 BlockMemHandle handle) { \
822 BlockMemAllocator::deallocate(d, handle); \
823 } \
824 \
825 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packLhs( \
826 LhsBlock* lhsBlock, const typename LhsMapper::SubMapper& data_mapper, \
827 const StorageIndex depth, const StorageIndex rows) { \
828 const bool is_direct_access = \
829 DirectLhsAccess::value && \
830 DirectLhsAccess::block(data_mapper, rows, depth, \
831 bn > 0 ? divup(n, bn) : 0, lhsBlock); \
832 \
833 if (!is_direct_access) { \
834 lhsBlock->is_direct_access = false; \
835 LhsPacker()(lhsBlock->packed_data, data_mapper, rows, depth); \
836 } \
837 } \
838 \
839 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packRhs( \
840 RhsBlock* rhsBlock, const typename RhsMapper::SubMapper& data_mapper, \
841 const StorageIndex depth, const StorageIndex cols) { \
842 const bool is_direct_access = \
843 DirectRhsAccess::value && \
844 DirectRhsAccess::block(data_mapper, depth, cols, \
845 bm > 0 ? divup(m, bm) : 0, rhsBlock); \
846 \
847 if (!is_direct_access) { \
848 rhsBlock->is_direct_access = false; \
849 RhsPacker()(rhsBlock->packed_data, data_mapper, depth, cols); \
850 } \
851 } \
852 \
853 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void invoke( \
854 const OutputMapper& output_mapper, const LhsBlock& lhsBlock, \
855 const RhsBlock& rhsBlock, const StorageIndex rows, \
856 const StorageIndex depth, const StorageIndex cols, const float alpha, \
857 const float beta) { \
858 if ((DirectLhsAccess::value && lhsBlock.is_direct_access) && \
859 (DirectRhsAccess::value && rhsBlock.is_direct_access)) { \
860 GemmKernel()(output_mapper, lhsBlock.raw_data, rhsBlock.raw_data, \
861 rows, depth, cols, alpha, beta, /*ldA=*/lhsBlock.stride, \
862 /*ldB=*/rhsBlock.stride, \
863 /*transposeA=*/lhsBlock.transpose, \
864 /*transposeB=*/rhsBlock.transpose); \
865 \
866 } else if (DirectLhsAccess::value && lhsBlock.is_direct_access) { \
867 GemmKernel()(output_mapper, lhsBlock.raw_data, rhsBlock.packed_data, \
868 rows, depth, cols, alpha, beta, /*ldA=*/lhsBlock.stride, \
869 /*ldB=*/GemmKernel::kComputeStrideFromBlockDimensions, \
870 /*transposeA=*/lhsBlock.transpose, /*transposeB=*/'N'); \
871 \
872 } else if (DirectRhsAccess::value && rhsBlock.is_direct_access) { \
873 GemmKernel()(output_mapper, lhsBlock.packed_data, rhsBlock.raw_data, \
874 rows, depth, cols, alpha, beta, \
875 /*ldA=*/GemmKernel::kComputeStrideFromBlockDimensions, \
876 /*ldB=*/rhsBlock.stride, /*transposeA=*/'N', \
877 /*transposeB=*/rhsBlock.transpose); \
878 \
879 } else { \
880 GemmKernel()(output_mapper, lhsBlock.packed_data, \
881 rhsBlock.packed_data, rows, depth, cols, alpha, beta); \
882 } \
883 } \
884 \
885 private: \
886 /* These are dimensions of the original Tensors, and selected block */ \
887 /* sizes. The actual block sizes passed to all function above might be */ \
888 /* smaller because of the partial blocks at the end. */ \
889 const StorageIndex m; \
890 const StorageIndex k; \
891 const StorageIndex n; \
892 const StorageIndex bm; \
893 const StorageIndex bk; \
894 const StorageIndex bn; \
895 }
896
897REGISTER_TENSOR_CONTRACTION_KERNEL_WITH_FALLBACK(float, float, float);
898REGISTER_TENSOR_CONTRACTION_KERNEL_NO_FALLBACK(Eigen::QInt32, Eigen::QInt8,
899 Eigen::QUInt8);
900
901#undef REGISTER_TENSOR_CONTRACTION_KERNEL
902
903#endif // defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL)
904
905} // namespace internal
906} // namespace Eigen
907
908#endif // TENSORFLOW_CORE_KERNELS_EIGEN_CONTRACTION_KERNEL_H_
909