1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_CONTRACTION_KERNEL_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_EIGEN_CONTRACTION_KERNEL_H_ |
18 | |
19 | // Depending on a build configuration this header provides custom kernel for |
20 | // Eigen tensor contractions (small matrix multiplication kernel used to |
21 | // multiple together blocks of the original tensors). |
22 | // |
23 | // 1) --define tensorflow_mkldnn_contraction_kernel=1 |
24 | // Use Mkldnn single threaded sgemm. The mkldnn kernels are generated at |
25 | // runtime and use avx/avx2/fma/avx512 based on cpu status registers |
26 | // (https://en.wikipedia.org/wiki/CPUID). |
27 | // |
28 | // If you use `tensor.contract(other_tensor)` in your code, you must include |
29 | // this header to get the benefit of custom contraction kernel: |
30 | // |
31 | // #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) |
32 | // #include "tensorflow/core/kernels/eigen_contraction_kernel.h" |
33 | // #endif |
34 | |
35 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
36 | |
37 | // FixedPoint header must be included after Tensor. |
38 | // clang-format off |
39 | #include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint" |
40 | // clang-format on |
41 | |
42 | #if defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL) |
43 | #include "dnnl.h" |
44 | #endif |
45 | |
46 | #include "tensorflow/core/platform/dynamic_annotations.h" |
47 | |
48 | namespace Eigen { |
49 | namespace internal { |
50 | |
51 | #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) |
52 | // Returns `true` iff we can use custom contraction kernels. This is a runtime |
53 | // check, that uses environment variables. |
54 | EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE bool UseCustomContractionKernels(); |
55 | |
56 | // Pack a 2D block of a Tensor expression into contiguous block of memory with |
57 | // col-major storage order. We do not have access to the underlying Tensor |
58 | // expression, we only have a DataMapper (TensorContractionInputMapper for |
59 | // tensor contractions, or blas_data_mapper for plain tensors), that provides a |
60 | // two-dimensional view into the Tensor expression. |
61 | // |
62 | // Default Eigen gemm_pack_rhs and gemm_pack_lhs pack blocks of tensor |
63 | // expressions into the packed format described in "Anatomy of High-Performance |
64 | // Matrix Multiplication" paper (1). Eigen::internal::gebp_kernel relies on this |
65 | // packing format for efficient micro-panel multiplication. |
66 | // |
67 | // This simple packing can be used with any '?gemm' function from BLAS |
68 | // libraries, that work with col-major matrices. |
69 | // |
70 | // (1) http://www.cs.utexas.edu/~flame/pubs/GotoTOMS_revision.pdf |
71 | // |
72 | // IMPORTANT: `gemm_pack_colmajor_block` always packs the block in column major |
73 | // order, DataMapperStorageOrder specifies the storage order of the underlying |
74 | // Tensor expression. |
75 | template <typename Scalar, typename IndexType, typename DataMapper, |
76 | int DataMapperStorageOrder> |
77 | struct gemm_pack_colmajor_block; |
78 | |
79 | // gemm_pack_colmajor_block for ColMajor storage order. |
80 | template <typename Scalar, typename IndexType, typename DataMapper> |
81 | struct gemm_pack_colmajor_block<Scalar, IndexType, DataMapper, |
82 | /*DataMapperStorageOrder*/ ColMajor> { |
83 | typedef typename internal::packet_traits<Scalar>::type Packet; |
84 | typedef typename DataMapper::LinearMapper LinearMapper; |
85 | |
86 | enum { PacketSize = internal::packet_traits<Scalar>::size }; |
87 | |
88 | EIGEN_DONT_INLINE |
89 | void operator()(Scalar* block, const DataMapper& data_mapper, IndexType rows, |
90 | IndexType cols) { |
91 | const IndexType unrolled_rows = rows - 4 * PacketSize; |
92 | const IndexType vectorized_rows = rows - PacketSize; |
93 | |
94 | for (IndexType col = 0; col < cols; ++col) { |
95 | LinearMapper lm = data_mapper.getLinearMapper(0, col); |
96 | |
97 | IndexType row = 0; |
98 | // Give compiler a strong possibility to unroll the loop. |
99 | for (; row <= unrolled_rows; row += 4 * PacketSize) { |
100 | for (IndexType j = 0; j < 4; ++j) { |
101 | const Packet p = lm.template loadPacket<Packet>(row + j * PacketSize); |
102 | internal::pstoreu(block + j * PacketSize, p); |
103 | } |
104 | block += 4 * PacketSize; |
105 | } |
106 | // Process remaining rows with packets. |
107 | for (; row <= vectorized_rows; row += PacketSize) { |
108 | const Packet p = lm.template loadPacket<Packet>(row); |
109 | internal::pstoreu(block, p); |
110 | block += PacketSize; |
111 | } |
112 | // Finalize with coefficients. |
113 | for (; row < rows; ++row) { |
114 | *block = lm(row); |
115 | ++block; |
116 | } |
117 | } |
118 | } |
119 | }; |
120 | |
121 | #endif // TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL |
122 | |
123 | // Enabled by build option: "--define tensorflow_mkldnn_contraction_kernel=1" |
124 | #if defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL) |
125 | |
126 | template <typename Scalar, typename IndexType, typename OutputMapper, |
127 | bool ConjugateLhs = false, bool ConjugateRhs = false> |
128 | struct dnnl_gemm_kernel; |
129 | |
130 | // dnnl_gemm_kernel for floats defined as a thin layer on top of mkldnn_sgemm. |
131 | template <typename IndexType, typename OutputMapper, bool ConjugateLhs, |
132 | bool ConjugateRhs> |
133 | struct dnnl_gemm_kernel</*Scalar*/ float, IndexType, OutputMapper, ConjugateLhs, |
134 | ConjugateRhs> { |
135 | static_assert(!ConjugateLhs, "DNNL kernel doesn't support ConjugateLhs" ); |
136 | static_assert(!ConjugateRhs, "DNNL kernel doesn't support ConjugateRhs" ); |
137 | |
138 | static constexpr int kComputeStrideFromBlockDimensions = -1; |
139 | |
140 | using LhsScalar = float; |
141 | using RhsScalar = float; |
142 | using ResScalar = float; |
143 | |
144 | EIGEN_DONT_INLINE |
145 | void operator()(const OutputMapper& output, const LhsScalar* blockA, |
146 | const RhsScalar* blockB, const IndexType rows, |
147 | const IndexType depth, const IndexType cols, float alpha, |
148 | float beta, int ldA = kComputeStrideFromBlockDimensions, |
149 | int ldB = kComputeStrideFromBlockDimensions, |
150 | char transposeA = 'N', char transposeB = 'N') { |
151 | static const int max_index = (std::numeric_limits<int>::max)(); |
152 | |
153 | eigen_assert(max_index >= rows); |
154 | eigen_assert(max_index >= cols); |
155 | eigen_assert(max_index >= depth); |
156 | eigen_assert(max_index >= output.stride()); |
157 | |
158 | const int m = static_cast<int>(rows); |
159 | const int n = static_cast<int>(cols); |
160 | const int k = static_cast<int>(depth); |
161 | |
162 | ldA = ldA == kComputeStrideFromBlockDimensions ? m : ldA; |
163 | ldB = ldB == kComputeStrideFromBlockDimensions ? k : ldB; |
164 | const int ldC = static_cast<int>(output.stride()); |
165 | |
166 | // DNNL takes row-major matrices. Our packed column-major matrices can be |
167 | // viewed as a transposed row-major matrix, i.e., |
168 | // C_colmajor = C_rowmajor^T = (A_rowmajor * B_rowmajor)^T |
169 | // = B_rowmajor^T * A_rowmajor^T |
170 | // = B_colmajor * A_colmajor |
171 | // So we can just swap the input matrices A and B for DNNL. |
172 | // TODO(penporn): Switch to row-major packing instead. |
173 | dnnl_status_t st = |
174 | dnnl_sgemm(transposeB, transposeA, n, m, k, alpha, blockB, ldB, blockA, |
175 | ldA, beta, const_cast<ResScalar*>(output.data()), ldC); |
176 | eigen_assert(st == 0); |
177 | |
178 | #if DYNAMIC_ANNOTATIONS_ENABLED == 1 || defined(MEMORY_SANITIZER) |
179 | for (IndexType col = 0; col < cols; ++col) { |
180 | ResScalar* row_base = &output(0, col); |
181 | EIGEN_UNUSED_VARIABLE(row_base); // Suppress unused variable error. |
182 | TF_ANNOTATE_MEMORY_IS_INITIALIZED(row_base, sizeof(ResScalar) * rows); |
183 | } |
184 | #endif |
185 | |
186 | // eigen_assert is a no-op in optimized mode so we add these to avoid |
187 | // compiler's unused-variable errors. |
188 | EIGEN_UNUSED_VARIABLE(max_index); |
189 | EIGEN_UNUSED_VARIABLE(st); |
190 | } |
191 | }; |
192 | |
193 | template <typename IndexType, typename OutputMapper, bool ConjugateLhs = false, |
194 | bool ConjugateRhs = false> |
195 | struct mkldnn_gemm_s8u8s32_kernel { |
196 | static_assert(!ConjugateLhs, "DNNL kernel doesn't support ConjugateLhs" ); |
197 | static_assert(!ConjugateRhs, "DNNL kernel doesn't support ConjugateRhs" ); |
198 | |
199 | static constexpr int kComputeStrideFromBlockDimensions = -1; |
200 | |
201 | using LhsScalar = Eigen::QInt8; |
202 | using RhsScalar = Eigen::QUInt8; |
203 | using ResScalar = Eigen::QInt32; |
204 | |
205 | EIGEN_DONT_INLINE |
206 | void operator()(const OutputMapper& output, const LhsScalar* blockA, |
207 | const RhsScalar* blockB, const IndexType rows, |
208 | const IndexType depth, const IndexType cols, float alpha, |
209 | float beta, int ldA = kComputeStrideFromBlockDimensions, |
210 | int ldB = kComputeStrideFromBlockDimensions, |
211 | char transposeA = 'N', char transposeB = 'N') { |
212 | static const int max_index = (std::numeric_limits<int>::max)(); |
213 | |
214 | eigen_assert(max_index >= rows); |
215 | eigen_assert(max_index >= cols); |
216 | eigen_assert(max_index >= depth); |
217 | eigen_assert(max_index >= output.stride()); |
218 | |
219 | const int m = static_cast<int>(rows); |
220 | const int n = static_cast<int>(cols); |
221 | const int k = static_cast<int>(depth); |
222 | |
223 | ldA = ldA == kComputeStrideFromBlockDimensions ? m : ldA; |
224 | ldB = ldB == kComputeStrideFromBlockDimensions ? k : ldB; |
225 | const int ldC = static_cast<int>(output.stride()); |
226 | |
227 | // Currently we support only symmetric quantization with zero point at 0. |
228 | const int8_t ao = 0; |
229 | const int8_t bo = 0; |
230 | |
231 | // Don't add any offset to the result C. |
232 | const char offsetc = 'F'; |
233 | const int32_t co = 0; |
234 | |
235 | const auto* A = reinterpret_cast<const int8_t*>(blockA); |
236 | const auto* B = reinterpret_cast<const uint8_t*>(blockB); |
237 | auto* C = reinterpret_cast<int32_t*>(const_cast<ResScalar*>(output.data())); |
238 | |
239 | // DNNL takes row-major matrices. Our packed column-major matrices can be |
240 | // viewed as a transposed row-major matrix, i.e., C_colmajor = C_rowmajor^T. |
241 | // C_colmajor = C_rowmajor^T = (A_rowmajor * B_rowmajor)^T |
242 | // = B_rowmajor^T * A_rowmajor^T |
243 | // = B_colmajor * A_colmajor |
244 | // So we can just swap the input matrices A and B for DNNL. |
245 | // TODO(penporn): Switch to row-major packing instead. |
246 | dnnl_status_t st = dnnl_gemm_u8s8s32(transposeB, transposeA, offsetc, // |
247 | n, m, k, // |
248 | alpha, // |
249 | B, ldB, bo, // |
250 | A, ldA, ao, // |
251 | beta, // |
252 | C, ldC, &co); |
253 | eigen_assert(st == 0); |
254 | |
255 | #if DYNAMIC_ANNOTATIONS_ENABLED == 1 || defined(MEMORY_SANITIZER) |
256 | for (IndexType col = 0; col < cols; ++col) { |
257 | ResScalar* row_base = &output(0, col); |
258 | EIGEN_UNUSED_VARIABLE(row_base); // Suppress unused variable error. |
259 | TF_ANNOTATE_MEMORY_IS_INITIALIZED(row_base, sizeof(ResScalar) * rows); |
260 | } |
261 | #endif |
262 | |
263 | // eigen_assert is a no-op in optimized mode so we add these to avoid |
264 | // compiler's unused-variable errors. |
265 | EIGEN_UNUSED_VARIABLE(max_index); |
266 | EIGEN_UNUSED_VARIABLE(st); |
267 | } |
268 | }; |
269 | |
270 | // For mkldnn_sgemm having the right dimensions (especially for small matrices) |
271 | // is more important than fitting all the working set in L1/L2 caches. |
272 | // TODO(ezhulenev): Do better heuristics. |
273 | template <typename StorageIndex, int sharding_type> |
274 | class TensorContractionBlocking<float, float, float, StorageIndex, |
275 | sharding_type> { |
276 | // For now mkldnn has only mkldnn_sgemm (gemm for floats). |
277 | using Scalar = float; |
278 | |
279 | // Adjust the block sizes to work well with mkldnn kernels. |
280 | |
281 | // Multiply default choice of block size along M and N dimensions. |
282 | // TODO(ezhulenev): Explore if this can work in general (kScaleM=2.0 worked |
283 | // well in some of models). |
284 | static constexpr float kScaleM = 1.5; |
285 | static constexpr float kScaleN = 1.0; |
286 | |
287 | // Mkldnn Avx/Avx2/Avx512 unroll factors are: 8/16/48. |
288 | static constexpr StorageIndex kUnrollM = 48; |
289 | |
290 | // Mkldnn Avx/Avx2/Avx512 unroll factors are: 6/6/8. |
291 | static constexpr StorageIndex kUnrollN = 24; |
292 | |
293 | public: |
294 | TensorContractionBlocking(StorageIndex k, StorageIndex m, StorageIndex n, |
295 | StorageIndex num_threads = 1) |
296 | : kc_(k), mc_(m), nc_(n) { |
297 | // 1. Compute block sizes using default Eigen heuristics. |
298 | if (sharding_type == ShardByCol) { |
299 | computeProductBlockingSizes<Scalar, Scalar, 1>(kc_, mc_, nc_, |
300 | num_threads); |
301 | } else { |
302 | computeProductBlockingSizes<Scalar, Scalar, 1>(kc_, nc_, mc_, |
303 | num_threads); |
304 | } |
305 | |
306 | // If dimensions do not pass basic sanity checks return immediately. |
307 | if (kc_ <= 0 || mc_ <= 0 || nc_ <= 0) return; |
308 | |
309 | // If we are using default Eigen gebp kernel there is no need to adjust the |
310 | // block sizes for DNNL. |
311 | if (!UseCustomContractionKernels()) return; |
312 | |
313 | // 2. And refine them to work well with mkldnn sgemm. |
314 | mc_ = (std::min)( |
315 | m, Eigen::divup(static_cast<StorageIndex>(mc_ * kScaleM), kUnrollM) * |
316 | kUnrollM); |
317 | nc_ = (std::min)( |
318 | n, Eigen::divup(static_cast<StorageIndex>(nc_ * kScaleN), kUnrollN) * |
319 | kUnrollN); |
320 | |
321 | // We split Kth dimensions in roughly equal slices. |
322 | StorageIndex target_k_slices = |
323 | (std::max)(StorageIndex(1), Eigen::divup(k, kc_)); |
324 | StorageIndex packet_size = internal::packet_traits<Scalar>::size; |
325 | if (packet_size < 8) packet_size = 8; |
326 | StorageIndex target_bk = |
327 | Eigen::divup(k / target_k_slices, packet_size) * packet_size; |
328 | kc_ = (std::min)(k, target_bk); |
329 | } |
330 | |
331 | EIGEN_ALWAYS_INLINE StorageIndex kc() const { return kc_; } |
332 | EIGEN_ALWAYS_INLINE StorageIndex mc() const { return mc_; } |
333 | EIGEN_ALWAYS_INLINE StorageIndex nc() const { return nc_; } |
334 | |
335 | private: |
336 | StorageIndex kc_; |
337 | StorageIndex mc_; |
338 | StorageIndex nc_; |
339 | }; |
340 | |
341 | template <typename StorageIndex, int sharding_type> |
342 | class TensorContractionBlocking<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8, |
343 | StorageIndex, sharding_type> { |
344 | // TODO(ezhulenev): Define proper gebp_traits in Eigen for quantized types? |
345 | |
346 | // Default Eigen block heuristics for `QInt8xQUInt8 -> QInt32` are wrong. |
347 | // Mostly because gebp_traits are not correctly defined. But we know that we |
348 | // are going to use s8u8s32_gemm from DNNL, so we use float heuristics, and |
349 | // adjust them to work well with DNNL. |
350 | using LhsScalar = Eigen::QInt8; |
351 | using RhsScalar = Eigen::QUInt8; |
352 | using ResScalar = Eigen::QInt32; |
353 | |
354 | // Multiply default choice of block size along M, N and K dimensions. |
355 | static constexpr float kScaleM = 1.5; |
356 | static constexpr float kScaleN = 1.5; |
357 | static constexpr float kScaleK = 1.5; |
358 | |
359 | public: |
360 | TensorContractionBlocking(StorageIndex k, StorageIndex m, StorageIndex n, |
361 | StorageIndex num_threads = 1) |
362 | : kc_(k), mc_(m), nc_(n) { |
363 | // Each dimension is a multiple of 32 (fits into _m256i). |
364 | mc_ = (std::min)(m, static_cast<StorageIndex>(192)); |
365 | nc_ = (std::min)(n, static_cast<StorageIndex>(288)); |
366 | kc_ = (std::min)(k, static_cast<StorageIndex>(320)); |
367 | } |
368 | |
369 | EIGEN_ALWAYS_INLINE StorageIndex kc() const { return kc_; } |
370 | EIGEN_ALWAYS_INLINE StorageIndex mc() const { return mc_; } |
371 | EIGEN_ALWAYS_INLINE StorageIndex nc() const { return nc_; } |
372 | |
373 | private: |
374 | StorageIndex kc_; |
375 | StorageIndex mc_; |
376 | StorageIndex nc_; |
377 | }; |
378 | |
379 | // If the Lhs or Rhs Tensor expressions are already evaluated and have access to |
380 | // raw data, we can skip packing step and setup pointers and a stride to the |
381 | // underlying memory buffer and pass them directly to Gemm. |
382 | template <typename Scalar, typename StorageIndex> |
383 | struct ColMajorBlock { |
384 | bool is_direct_access; |
385 | |
386 | // Valid iff `is_direct_access == false` |
387 | Scalar* packed_data; |
388 | |
389 | // Valid iff `is_direct_access == true` |
390 | Scalar* raw_data; |
391 | StorageIndex stride; |
392 | char transpose; |
393 | }; |
394 | |
395 | template <typename DataMapper> |
396 | struct DirectColMajorAccess { |
397 | enum { value = false }; |
398 | |
399 | template <typename Scalar, typename StorageIndex> |
400 | static bool block(const typename DataMapper::SubMapper& data_mapper, |
401 | const StorageIndex rows, const StorageIndex cols, |
402 | const StorageIndex num_kernels, |
403 | ColMajorBlock<Scalar, StorageIndex>* block) { |
404 | eigen_assert(false && "Not implemented" ); |
405 | return false; |
406 | } |
407 | }; |
408 | |
409 | // If we have an access to raw memory of the contraction input, we can safely |
410 | // skip packing if: |
411 | // (1) Packing is a no-op. |
412 | // (2) Packed block will be used just once. |
413 | // |
414 | // If a packed block is used many times, it's more efficient to pack it into |
415 | // contiguous block of memory to reduce pressure on TLB. |
416 | // |
417 | // TODO(ezhulenev): Add support for more tensor expressions that matters. |
418 | #define REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_EXPR) \ |
419 | template <typename Scalar, typename StorageIndex, int Side, typename Device, \ |
420 | typename nocontract_t, typename contract_t, int packet_size, \ |
421 | int Alignment> \ |
422 | struct DirectColMajorAccess<TensorContractionInputMapper< \ |
423 | Scalar, StorageIndex, Side, TensorEvaluator<TENSOR_EXPR, Device>, \ |
424 | nocontract_t, contract_t, packet_size, /*inner_dim_contiguous=*/true, \ |
425 | /*inner_dim_reordered=*/false, Alignment>> { \ |
426 | enum { value = true }; \ |
427 | \ |
428 | using DataMapper = TensorContractionInputMapper< \ |
429 | Scalar, StorageIndex, Side, TensorEvaluator<TENSOR_EXPR, Device>, \ |
430 | nocontract_t, contract_t, packet_size, /*inner_dim_contiguous=*/true, \ |
431 | /*inner_dim_reordered=*/false, Alignment>; \ |
432 | \ |
433 | static bool block(const typename DataMapper::SubMapper& data_mapper, \ |
434 | const StorageIndex rows, const StorageIndex cols, \ |
435 | const StorageIndex num_kernels, \ |
436 | ColMajorBlock<Scalar, StorageIndex>* block) { \ |
437 | static_assert(DataMapper::DirectOffsets == true, \ |
438 | "DataMapper must support direct offsets"); \ |
439 | \ |
440 | const StorageIndex vert_offset = data_mapper.vert_offset(); \ |
441 | const StorageIndex horiz_offset = data_mapper.horiz_offset(); \ |
442 | const StorageIndex stride = \ |
443 | Side == Lhs ? data_mapper.base_mapper().stride() \ |
444 | : data_mapper.base_mapper().nocontract_strides()[0]; \ |
445 | const Scalar* data = data_mapper.base_mapper().tensor().data(); \ |
446 | data = Side == Lhs ? data : data + vert_offset + horiz_offset * stride; \ |
447 | \ |
448 | const bool is_no_op_packing = stride == rows; \ |
449 | const StorageIndex addressable_mem = (stride * cols * sizeof(Scalar)); \ |
450 | const bool use_direct_access = \ |
451 | is_no_op_packing || num_kernels == 1 /* used once */ || \ |
452 | ((num_kernels == 2) && \ |
453 | (addressable_mem < (256 << 10) /* 256 kb */)); \ |
454 | \ |
455 | if (use_direct_access) { \ |
456 | block->is_direct_access = true; \ |
457 | block->raw_data = const_cast<Scalar*>(data); \ |
458 | block->stride = stride; \ |
459 | block->transpose = 'N'; \ |
460 | return true; \ |
461 | } \ |
462 | return false; \ |
463 | } \ |
464 | } |
465 | |
466 | #define SIMPLE_TENSOR const Tensor<Scalar, 2, Eigen::ColMajor, StorageIndex> |
467 | |
468 | #define TENSOR_MAP_ROWMAJOR \ |
469 | const TensorMap<Tensor<const Scalar, 2, Eigen::RowMajor, StorageIndex>, \ |
470 | Eigen::Aligned> |
471 | |
472 | #define TENSOR_MAP_COLMAJOR \ |
473 | const TensorMap<Tensor<const Scalar, 2, Eigen::ColMajor, StorageIndex>, \ |
474 | Eigen::Aligned> |
475 | |
476 | #define TENSOR_MAP_CONST_ROWMAJOR \ |
477 | const TensorMap<Tensor<Scalar, 2, Eigen::RowMajor, StorageIndex>, \ |
478 | Eigen::Aligned> |
479 | |
480 | #define TENSOR_MAP_CONST_COLMAJOR \ |
481 | const TensorMap<Tensor<Scalar, 2, Eigen::ColMajor, StorageIndex>, \ |
482 | Eigen::Aligned> |
483 | |
484 | // This is reshaped convolution filter from `eigen_spatial_convolutions.h`. |
485 | #define TENSOR_RESHAPE \ |
486 | const TensorReshapingOp< \ |
487 | const Eigen::DSizes<StorageIndex, 2>, \ |
488 | const TensorMap<Tensor<const Scalar, 4, Eigen::RowMajor, StorageIndex>, \ |
489 | Eigen::Aligned>> |
490 | |
491 | REGISTER_DIRECT_COL_MAJOR_ACCESS(SIMPLE_TENSOR); |
492 | REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_MAP_ROWMAJOR); |
493 | REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_MAP_COLMAJOR); |
494 | REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_MAP_CONST_ROWMAJOR); |
495 | REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_MAP_CONST_COLMAJOR); |
496 | REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_RESHAPE); |
497 | |
498 | #undef SIMPLE_TENSOR |
499 | #undef TENSOR_MAP_ROWMAJOR |
500 | #undef TENSOR_MAP_COLMAJOR |
501 | #undef TENSOR_MAP_CONST_ROWMAJOR |
502 | #undef TENSOR_MAP_CONST_COLMAJOR |
503 | #undef TENSOR_RESHAPE |
504 | #undef REGISTER_DIRECT_COL_MAJOR_ACCESS |
505 | |
506 | template <typename ResScalar, typename LhsScalar, typename RhsScalar, |
507 | typename StorageIndex, typename OutputMapper> |
508 | struct GemmKernelProvider { |
509 | enum { Defined = 0 }; |
510 | using GemmKernel = void; |
511 | }; |
512 | |
513 | template <typename StorageIndex, typename OutputMapper> |
514 | struct GemmKernelProvider<float, float, float, StorageIndex, OutputMapper> { |
515 | enum { Defined = 1 }; |
516 | using GemmKernel = dnnl_gemm_kernel<float, StorageIndex, OutputMapper>; |
517 | }; |
518 | |
519 | template <typename StorageIndex, typename OutputMapper> |
520 | struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8, |
521 | StorageIndex, OutputMapper> { |
522 | enum { Defined = 1 }; |
523 | using GemmKernel = mkldnn_gemm_s8u8s32_kernel<StorageIndex, OutputMapper>; |
524 | }; |
525 | |
526 | // NOTE: 'std::enable_if' doesn't work for template specializations. See |
527 | // "default template argument in a class template partial specialization". |
528 | |
529 | // Tensor contraction kernel that can fallback on Eigen gebp_kernel at runtime. |
530 | #define REGISTER_TENSOR_CONTRACTION_KERNEL_WITH_FALLBACK( \ |
531 | RES_SCALAR, LHS_SCALAR, RHS_SCALAR) \ |
532 | \ |
533 | template <typename StorageIndex, typename OutputMapper, typename LhsMapper, \ |
534 | typename RhsMapper> \ |
535 | struct TensorContractionKernel<RES_SCALAR, LHS_SCALAR, RHS_SCALAR, \ |
536 | StorageIndex, OutputMapper, LhsMapper, \ |
537 | RhsMapper> { \ |
538 | TensorContractionKernel(StorageIndex m, StorageIndex k, StorageIndex n, \ |
539 | StorageIndex bm, StorageIndex bk, StorageIndex bn) \ |
540 | : m(m), k(k), n(n), bm(bm), bk(bk), bn(bn) {} \ |
541 | \ |
542 | enum { HasBeta = true }; \ |
543 | \ |
544 | using ResScalar = RES_SCALAR; \ |
545 | using LhsScalar = LHS_SCALAR; \ |
546 | using RhsScalar = RHS_SCALAR; \ |
547 | \ |
548 | using Traits = typename internal::gebp_traits<LhsScalar, RhsScalar>; \ |
549 | \ |
550 | using LhsBlock = ColMajorBlock<LhsScalar, StorageIndex>; \ |
551 | using RhsBlock = ColMajorBlock<RhsScalar, StorageIndex>; \ |
552 | \ |
553 | using DirectLhsAccess = DirectColMajorAccess<LhsMapper>; \ |
554 | using DirectRhsAccess = DirectColMajorAccess<RhsMapper>; \ |
555 | \ |
556 | /* Packed Lhs/Rhs block memory allocator.*/ \ |
557 | typedef TensorContractionBlockMemAllocator<LhsScalar, RhsScalar> \ |
558 | BlockMemAllocator; \ |
559 | typedef typename BlockMemAllocator::BlockMemHandle BlockMemHandle; \ |
560 | \ |
561 | using LhsPacker = \ |
562 | gemm_pack_colmajor_block<LhsScalar, StorageIndex, \ |
563 | typename LhsMapper::SubMapper, ColMajor>; \ |
564 | using RhsPacker = \ |
565 | gemm_pack_colmajor_block<RhsScalar, StorageIndex, \ |
566 | typename RhsMapper::SubMapper, ColMajor>; \ |
567 | \ |
568 | using GemmKernelProviderType = \ |
569 | GemmKernelProvider<ResScalar, LhsScalar, RhsScalar, StorageIndex, \ |
570 | OutputMapper>; \ |
571 | static_assert( \ |
572 | GemmKernelProviderType::Defined, \ |
573 | "Custom GEMM kernel is not registered for given scalar types"); \ |
574 | using GemmKernel = typename GemmKernelProviderType::GemmKernel; \ |
575 | \ |
576 | /* Fallback on default Eigen pack and GEBP kernel if custom contraction */ \ |
577 | /* kernels disabled at runtime. */ \ |
578 | using EigenLhsPacker = \ |
579 | gemm_pack_lhs<LhsScalar, StorageIndex, typename LhsMapper::SubMapper, \ |
580 | Traits::mr, Traits::LhsProgress, \ |
581 | typename Traits::LhsPacket4Packing, ColMajor>; \ |
582 | using EigenRhsPacker = \ |
583 | gemm_pack_rhs<RhsScalar, StorageIndex, typename RhsMapper::SubMapper, \ |
584 | Traits::nr, ColMajor>; \ |
585 | using GebpKernel = \ |
586 | gebp_kernel<LhsScalar, RhsScalar, StorageIndex, OutputMapper, \ |
587 | Traits::mr, Traits::nr, /*ConjugateLhs*/ false, \ |
588 | /*ConjugateRhs*/ false>; \ |
589 | \ |
590 | template <typename Device> \ |
591 | EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device& d, LhsBlock* lhs_block, \ |
592 | RhsBlock* rhs_block) { \ |
593 | return BlockMemAllocator::allocate( \ |
594 | d, bm, bk, bn, &lhs_block->packed_data, &rhs_block->packed_data); \ |
595 | } \ |
596 | \ |
597 | template <typename Device> \ |
598 | EIGEN_DEVICE_FUNC BlockMemHandle \ |
599 | allocateSlices(Device& d, const int num_lhs, const int num_rhs, \ |
600 | const int num_slices, std::vector<LhsBlock>* lhs_blocks, \ |
601 | std::vector<RhsBlock>* rhs_blocks) { \ |
602 | eigen_assert(num_slices > 0); \ |
603 | std::vector<std::vector<LhsScalar*>> lhs_mem(num_slices); \ |
604 | std::vector<std::vector<RhsScalar*>> rhs_mem(num_slices); \ |
605 | \ |
606 | BlockMemHandle block_mem = BlockMemAllocator::allocateSlices( \ |
607 | d, bm, bk, bn, num_lhs, num_rhs, num_slices, lhs_mem.data(), \ |
608 | rhs_mem.data()); \ |
609 | \ |
610 | for (Index x = 0; x < num_slices; x++) { \ |
611 | if (num_lhs > 0) lhs_blocks[x].resize(num_lhs); \ |
612 | for (Index m = 0; m < num_lhs; m++) { \ |
613 | lhs_blocks[x][m].packed_data = lhs_mem[x][m]; \ |
614 | } \ |
615 | if (num_rhs > 0) rhs_blocks[x].resize(num_rhs); \ |
616 | for (Index n = 0; n < num_rhs; n++) { \ |
617 | rhs_blocks[x][n].packed_data = rhs_mem[x][n]; \ |
618 | } \ |
619 | } \ |
620 | \ |
621 | return block_mem; \ |
622 | } \ |
623 | \ |
624 | template <typename Device> \ |
625 | EIGEN_DEVICE_FUNC static void deallocate(Device& d, \ |
626 | BlockMemHandle handle) { \ |
627 | BlockMemAllocator::deallocate(d, handle); \ |
628 | } \ |
629 | \ |
630 | EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packLhs( \ |
631 | LhsBlock* lhsBlock, const typename LhsMapper::SubMapper& data_mapper, \ |
632 | const StorageIndex depth, const StorageIndex rows) { \ |
633 | if (UseCustomContractionKernels()) { \ |
634 | const bool is_direct_access = \ |
635 | DirectLhsAccess::value && \ |
636 | DirectLhsAccess::block(data_mapper, rows, depth, \ |
637 | bn > 0 ? divup(n, bn) : 0, lhsBlock); \ |
638 | \ |
639 | if (!is_direct_access) { \ |
640 | lhsBlock->is_direct_access = false; \ |
641 | LhsPacker()(lhsBlock->packed_data, data_mapper, rows, depth); \ |
642 | } \ |
643 | } else { \ |
644 | lhsBlock->is_direct_access = false; \ |
645 | EigenLhsPacker()(lhsBlock->packed_data, data_mapper, depth, rows, \ |
646 | /*stride*/ 0, /*offset*/ 0); \ |
647 | } \ |
648 | } \ |
649 | \ |
650 | EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packRhs( \ |
651 | RhsBlock* rhsBlock, const typename RhsMapper::SubMapper& data_mapper, \ |
652 | const StorageIndex depth, const StorageIndex cols) { \ |
653 | if (UseCustomContractionKernels()) { \ |
654 | const bool is_direct_access = \ |
655 | DirectRhsAccess::value && \ |
656 | DirectRhsAccess::block(data_mapper, depth, cols, \ |
657 | bm > 0 ? divup(m, bm) : 0, rhsBlock); \ |
658 | \ |
659 | if (!is_direct_access) { \ |
660 | rhsBlock->is_direct_access = false; \ |
661 | RhsPacker()(rhsBlock->packed_data, data_mapper, depth, cols); \ |
662 | } \ |
663 | } else { \ |
664 | rhsBlock->is_direct_access = false; \ |
665 | EigenRhsPacker()(rhsBlock->packed_data, data_mapper, depth, cols); \ |
666 | } \ |
667 | } \ |
668 | \ |
669 | EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void invoke( \ |
670 | const OutputMapper& output_mapper, const LhsBlock& lhsBlock, \ |
671 | const RhsBlock& rhsBlock, const StorageIndex rows, \ |
672 | const StorageIndex depth, const StorageIndex cols, const float alpha, \ |
673 | const float beta) { \ |
674 | if (UseCustomContractionKernels()) { \ |
675 | if ((DirectLhsAccess::value && lhsBlock.is_direct_access) && \ |
676 | (DirectRhsAccess::value && rhsBlock.is_direct_access)) { \ |
677 | GemmKernel()(output_mapper, lhsBlock.raw_data, rhsBlock.raw_data, \ |
678 | rows, depth, cols, alpha, beta, \ |
679 | /*ldA=*/lhsBlock.stride, /*ldB=*/rhsBlock.stride, \ |
680 | /*transposeA=*/lhsBlock.transpose, \ |
681 | /*transposeB=*/rhsBlock.transpose); \ |
682 | \ |
683 | } else if (DirectLhsAccess::value && lhsBlock.is_direct_access) { \ |
684 | GemmKernel()(output_mapper, lhsBlock.raw_data, rhsBlock.packed_data, \ |
685 | rows, depth, cols, alpha, beta, \ |
686 | /*ldA=*/lhsBlock.stride, \ |
687 | /*ldB=*/GemmKernel::kComputeStrideFromBlockDimensions, \ |
688 | /*transposeA=*/lhsBlock.transpose, /*transposeB=*/'N'); \ |
689 | \ |
690 | } else if (DirectRhsAccess::value && rhsBlock.is_direct_access) { \ |
691 | GemmKernel()(output_mapper, lhsBlock.packed_data, rhsBlock.raw_data, \ |
692 | rows, depth, cols, alpha, beta, \ |
693 | /*ldA=*/GemmKernel::kComputeStrideFromBlockDimensions, \ |
694 | /*ldB=*/rhsBlock.stride, /*transposeA=*/'N', \ |
695 | /*transposeB=*/rhsBlock.transpose); \ |
696 | \ |
697 | } else { \ |
698 | GemmKernel()(output_mapper, lhsBlock.packed_data, \ |
699 | rhsBlock.packed_data, rows, depth, cols, alpha, beta); \ |
700 | } \ |
701 | } else { \ |
702 | /* Gebp kernel does not support beta, so we have to clear memory in */ \ |
703 | /* the output mapper manually. */ \ |
704 | /* WARNING(ezhulenev): This is optimized into a memset in a loop, */ \ |
705 | /* could be much slower for small matrices. Currently this code */ \ |
706 | /* path used only for testing, and performance does not matter. */ \ |
707 | if (beta == 0.0) { \ |
708 | for (StorageIndex col = 0; col < cols; ++col) { \ |
709 | ResScalar* output_base = &output_mapper(0, col); \ |
710 | typedef Array<ResScalar, Dynamic, 1> OutputRow; \ |
711 | typedef Map<OutputRow, 0, InnerStride<1>> OutputRowMap; \ |
712 | OutputRowMap(output_base, rows).setZero(); \ |
713 | } \ |
714 | } \ |
715 | \ |
716 | GebpKernel()( \ |
717 | output_mapper, lhsBlock.packed_data, rhsBlock.packed_data, rows, \ |
718 | depth, cols, alpha, \ |
719 | /*strideA*/ GemmKernel::kComputeStrideFromBlockDimensions, \ |
720 | /*strideB*/ GemmKernel::kComputeStrideFromBlockDimensions, \ |
721 | /*offsetA*/ 0, /*offsetB*/ 0); \ |
722 | } \ |
723 | } \ |
724 | \ |
725 | private: \ |
726 | /* These are dimensions of the original Tensors, and selected block */ \ |
727 | /* sizes. The actual block sizes passed to all function above might be */ \ |
728 | /* smaller because of the partial blocks at the end. */ \ |
729 | const StorageIndex m; \ |
730 | const StorageIndex k; \ |
731 | const StorageIndex n; \ |
732 | const StorageIndex bm; \ |
733 | const StorageIndex bk; \ |
734 | const StorageIndex bn; \ |
735 | } |
736 | |
737 | // Tensor contraction kernel that do not fallback on Eigen. Currently not all |
738 | // data types are supported by Eigen data packing and default gebp_kernel. |
739 | #define REGISTER_TENSOR_CONTRACTION_KERNEL_NO_FALLBACK(RES_SCALAR, LHS_SCALAR, \ |
740 | RHS_SCALAR) \ |
741 | \ |
742 | template <typename StorageIndex, typename OutputMapper, typename LhsMapper, \ |
743 | typename RhsMapper> \ |
744 | struct TensorContractionKernel<RES_SCALAR, LHS_SCALAR, RHS_SCALAR, \ |
745 | StorageIndex, OutputMapper, LhsMapper, \ |
746 | RhsMapper> { \ |
747 | TensorContractionKernel(StorageIndex m, StorageIndex k, StorageIndex n, \ |
748 | StorageIndex bm, StorageIndex bk, StorageIndex bn) \ |
749 | : m(m), k(k), n(n), bm(bm), bk(bk), bn(bn) {} \ |
750 | \ |
751 | enum { HasBeta = true }; \ |
752 | \ |
753 | using ResScalar = RES_SCALAR; \ |
754 | using LhsScalar = LHS_SCALAR; \ |
755 | using RhsScalar = RHS_SCALAR; \ |
756 | \ |
757 | using Traits = typename internal::gebp_traits<LhsScalar, RhsScalar>; \ |
758 | \ |
759 | using LhsBlock = ColMajorBlock<LhsScalar, StorageIndex>; \ |
760 | using RhsBlock = ColMajorBlock<RhsScalar, StorageIndex>; \ |
761 | \ |
762 | using DirectLhsAccess = DirectColMajorAccess<LhsMapper>; \ |
763 | using DirectRhsAccess = DirectColMajorAccess<RhsMapper>; \ |
764 | \ |
765 | /* Packed Lhs/Rhs block memory allocator.*/ \ |
766 | typedef TensorContractionBlockMemAllocator<LhsScalar, RhsScalar> \ |
767 | BlockMemAllocator; \ |
768 | typedef typename BlockMemAllocator::BlockMemHandle BlockMemHandle; \ |
769 | \ |
770 | using LhsPacker = \ |
771 | gemm_pack_colmajor_block<LhsScalar, StorageIndex, \ |
772 | typename LhsMapper::SubMapper, ColMajor>; \ |
773 | using RhsPacker = \ |
774 | gemm_pack_colmajor_block<RhsScalar, StorageIndex, \ |
775 | typename RhsMapper::SubMapper, ColMajor>; \ |
776 | \ |
777 | using GemmKernelProviderType = \ |
778 | GemmKernelProvider<ResScalar, LhsScalar, RhsScalar, StorageIndex, \ |
779 | OutputMapper>; \ |
780 | static_assert( \ |
781 | GemmKernelProviderType::Defined, \ |
782 | "Custom GEMM kernel is not registered for given scalar types"); \ |
783 | using GemmKernel = typename GemmKernelProviderType::GemmKernel; \ |
784 | \ |
785 | template <typename Device> \ |
786 | EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device& d, LhsBlock* lhs_block, \ |
787 | RhsBlock* rhs_block) { \ |
788 | return BlockMemAllocator::allocate( \ |
789 | d, bm, bk, bn, &lhs_block->packed_data, &rhs_block->packed_data); \ |
790 | } \ |
791 | \ |
792 | template <typename Device> \ |
793 | EIGEN_DEVICE_FUNC BlockMemHandle \ |
794 | allocateSlices(Device& d, const int num_lhs, const int num_rhs, \ |
795 | const int num_slices, std::vector<LhsBlock>* lhs_blocks, \ |
796 | std::vector<RhsBlock>* rhs_blocks) { \ |
797 | eigen_assert(num_slices > 0); \ |
798 | std::vector<std::vector<LhsScalar*>> lhs_mem(num_slices); \ |
799 | std::vector<std::vector<RhsScalar*>> rhs_mem(num_slices); \ |
800 | \ |
801 | BlockMemHandle block_mem = BlockMemAllocator::allocateSlices( \ |
802 | d, bm, bk, bn, num_lhs, num_rhs, num_slices, lhs_mem.data(), \ |
803 | rhs_mem.data()); \ |
804 | \ |
805 | for (Index x = 0; x < num_slices; x++) { \ |
806 | if (num_lhs > 0) lhs_blocks[x].resize(num_lhs); \ |
807 | for (Index m = 0; m < num_lhs; m++) { \ |
808 | lhs_blocks[x][m].packed_data = lhs_mem[x][m]; \ |
809 | } \ |
810 | if (num_rhs > 0) rhs_blocks[x].resize(num_rhs); \ |
811 | for (Index n = 0; n < num_rhs; n++) { \ |
812 | rhs_blocks[x][n].packed_data = rhs_mem[x][n]; \ |
813 | } \ |
814 | } \ |
815 | \ |
816 | return block_mem; \ |
817 | } \ |
818 | \ |
819 | template <typename Device> \ |
820 | EIGEN_DEVICE_FUNC static void deallocate(Device& d, \ |
821 | BlockMemHandle handle) { \ |
822 | BlockMemAllocator::deallocate(d, handle); \ |
823 | } \ |
824 | \ |
825 | EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packLhs( \ |
826 | LhsBlock* lhsBlock, const typename LhsMapper::SubMapper& data_mapper, \ |
827 | const StorageIndex depth, const StorageIndex rows) { \ |
828 | const bool is_direct_access = \ |
829 | DirectLhsAccess::value && \ |
830 | DirectLhsAccess::block(data_mapper, rows, depth, \ |
831 | bn > 0 ? divup(n, bn) : 0, lhsBlock); \ |
832 | \ |
833 | if (!is_direct_access) { \ |
834 | lhsBlock->is_direct_access = false; \ |
835 | LhsPacker()(lhsBlock->packed_data, data_mapper, rows, depth); \ |
836 | } \ |
837 | } \ |
838 | \ |
839 | EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packRhs( \ |
840 | RhsBlock* rhsBlock, const typename RhsMapper::SubMapper& data_mapper, \ |
841 | const StorageIndex depth, const StorageIndex cols) { \ |
842 | const bool is_direct_access = \ |
843 | DirectRhsAccess::value && \ |
844 | DirectRhsAccess::block(data_mapper, depth, cols, \ |
845 | bm > 0 ? divup(m, bm) : 0, rhsBlock); \ |
846 | \ |
847 | if (!is_direct_access) { \ |
848 | rhsBlock->is_direct_access = false; \ |
849 | RhsPacker()(rhsBlock->packed_data, data_mapper, depth, cols); \ |
850 | } \ |
851 | } \ |
852 | \ |
853 | EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void invoke( \ |
854 | const OutputMapper& output_mapper, const LhsBlock& lhsBlock, \ |
855 | const RhsBlock& rhsBlock, const StorageIndex rows, \ |
856 | const StorageIndex depth, const StorageIndex cols, const float alpha, \ |
857 | const float beta) { \ |
858 | if ((DirectLhsAccess::value && lhsBlock.is_direct_access) && \ |
859 | (DirectRhsAccess::value && rhsBlock.is_direct_access)) { \ |
860 | GemmKernel()(output_mapper, lhsBlock.raw_data, rhsBlock.raw_data, \ |
861 | rows, depth, cols, alpha, beta, /*ldA=*/lhsBlock.stride, \ |
862 | /*ldB=*/rhsBlock.stride, \ |
863 | /*transposeA=*/lhsBlock.transpose, \ |
864 | /*transposeB=*/rhsBlock.transpose); \ |
865 | \ |
866 | } else if (DirectLhsAccess::value && lhsBlock.is_direct_access) { \ |
867 | GemmKernel()(output_mapper, lhsBlock.raw_data, rhsBlock.packed_data, \ |
868 | rows, depth, cols, alpha, beta, /*ldA=*/lhsBlock.stride, \ |
869 | /*ldB=*/GemmKernel::kComputeStrideFromBlockDimensions, \ |
870 | /*transposeA=*/lhsBlock.transpose, /*transposeB=*/'N'); \ |
871 | \ |
872 | } else if (DirectRhsAccess::value && rhsBlock.is_direct_access) { \ |
873 | GemmKernel()(output_mapper, lhsBlock.packed_data, rhsBlock.raw_data, \ |
874 | rows, depth, cols, alpha, beta, \ |
875 | /*ldA=*/GemmKernel::kComputeStrideFromBlockDimensions, \ |
876 | /*ldB=*/rhsBlock.stride, /*transposeA=*/'N', \ |
877 | /*transposeB=*/rhsBlock.transpose); \ |
878 | \ |
879 | } else { \ |
880 | GemmKernel()(output_mapper, lhsBlock.packed_data, \ |
881 | rhsBlock.packed_data, rows, depth, cols, alpha, beta); \ |
882 | } \ |
883 | } \ |
884 | \ |
885 | private: \ |
886 | /* These are dimensions of the original Tensors, and selected block */ \ |
887 | /* sizes. The actual block sizes passed to all function above might be */ \ |
888 | /* smaller because of the partial blocks at the end. */ \ |
889 | const StorageIndex m; \ |
890 | const StorageIndex k; \ |
891 | const StorageIndex n; \ |
892 | const StorageIndex bm; \ |
893 | const StorageIndex bk; \ |
894 | const StorageIndex bn; \ |
895 | } |
896 | |
897 | REGISTER_TENSOR_CONTRACTION_KERNEL_WITH_FALLBACK(float, float, float); |
898 | REGISTER_TENSOR_CONTRACTION_KERNEL_NO_FALLBACK(Eigen::QInt32, Eigen::QInt8, |
899 | Eigen::QUInt8); |
900 | |
901 | #undef REGISTER_TENSOR_CONTRACTION_KERNEL |
902 | |
903 | #endif // defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL) |
904 | |
905 | } // namespace internal |
906 | } // namespace Eigen |
907 | |
908 | #endif // TENSORFLOW_CORE_KERNELS_EIGEN_CONTRACTION_KERNEL_H_ |
909 | |