1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/math_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/kernels/sparse_matmul_op.h"
21
22#include <map>
23#include <memory>
24#include <vector>
25
26#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27#include "tensorflow/core/common_runtime/device.h"
28#include "tensorflow/core/framework/bfloat16.h"
29#include "tensorflow/core/framework/op.h"
30#include "tensorflow/core/framework/op_kernel.h"
31#include "tensorflow/core/framework/types.h"
32#include "tensorflow/core/kernels/fill_functor.h"
33#include "tensorflow/core/lib/core/threadpool.h"
34#include "tensorflow/core/platform/blocking_counter.h"
35#include "tensorflow/core/platform/errors.h"
36#include "tensorflow/core/platform/logging.h"
37#include "tensorflow/core/platform/macros.h"
38#include "tensorflow/core/platform/mutex.h"
39#include "tensorflow/core/platform/thread_annotations.h"
40#include "tensorflow/core/platform/types.h"
41#ifdef TENSORFLOW_USE_LIBXSMM
42#include "include/libxsmm_intrinsics_x86.h"
43#include "include/libxsmm_malloc.h"
44#include "include/libxsmm_spmdm.h"
45#endif
46
47#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
48#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
49#endif
50
51#define ALWAYS_INLINE EIGEN_ALWAYS_INLINE
52
53namespace tensorflow {
54namespace {
55
56template <typename T>
57using BasicMatrix = Eigen::Tensor<T, 2, Eigen::RowMajor>;
58
59template <typename T>
60using BasicMatrixMap =
61 Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Aligned>;
62
63using Matrix = BasicMatrix<float>;
64using MatrixMap = BasicMatrixMap<float>;
65using CPUDevice = Eigen::ThreadPoolDevice;
66using DSizes = Eigen::DSizes<Eigen::DenseIndex, 2>;
67
68// Two commonly used static dsizes. We use Eigen::type2index to allow as much
69// compile time optimization as possible.
70inline Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>>
71dsizes_00() {
72 return Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>>();
73}
74inline Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>>
75dsizes_10() {
76 return Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>>();
77}
78
79// Blocksizes
80// TODO(agarwal): compute these sizes based on cache sizes.
81const int K = 64;
82const int M = 64;
83const int N = 128;
84
85// This stores a sparse representation of a slice of a matrix with size
86// (num_rows, num_cols). The slice is represented as a series of blocks of size
87// (num_rows, b), where b = block_size for all but the last block, which may
88// have fewer columns.
89//
90// num_rows and block_size are assumed to be <= 256. This allows storing
91// different indices as uint8.
92//
93// For each block, we store all the non zero entries in data/data3 vector and
94// the corresponding coordinates of the element in index/index3 vectors. index3
95// vector stores index of 3 elements in the same row so that these elements can
96// share the same row coordinate. Each entry in Index3 corresponds to 3 entries
97// in data3.
98//
99// Note that all the data/indices of all the blocks are stored in the same
100// vectors respectively. To identify block boundaries, we store the block
101// offsets using index3_offset/index_offset. If there are n blocks in the slice,
102// index3_offset and index_offset have n entries. The indices for the ith block
103// are the values in the following range:
104// [index3[index3_offset[i-1]], index3[index3_offset[i]]). Similarly for
105// index_offset.
106template <typename T>
107struct SparseSlice {
108 using ConstMatrixMap = BasicMatrixMap<const T>;
109
110 public:
111 // Indices of three elements on the same row.
112 struct Index3 {
113 uint8 m; // row
114 // columns
115 uint8 k1;
116 uint8 k2;
117 uint8 k3;
118 };
119
120 // Index of one element.
121 struct Index {
122 uint8 m;
123 uint8 k;
124 };
125
126 SparseSlice(int nrows, int ncols, int bsize)
127 : num_rows(nrows), num_cols(ncols), block_size(bsize) {
128 DCHECK_LE(nrows, 256);
129 DCHECK_LE(block_size, 256);
130 }
131
132 // Initializes the slice with data starting at mat(0, col_offset) and with
133 // size (num_rows, num_cols).
134 // If Transpose is true, implicitly transposes mat.
135 template <bool Transpose = false>
136 void Initialize(const ConstMatrixMap& mat, int col_offset);
137
138 void Clear();
139
140 // See comments above.
141 std::vector<int> index3_offset;
142 std::vector<Index3> index3;
143 std::vector<T> data3;
144
145 // See comments above. Similar to "index3" except that each element in "index"
146 // corresponds to one element in data.
147 std::vector<int> index_offset;
148 std::vector<Index> index;
149 std::vector<T> data;
150
151 // Number of rows and columns for the slice.
152 const int num_rows;
153 const int num_cols;
154
155 // Block size used to initialize from a matrix.
156 const int block_size;
157};
158
159template <typename T>
160bool IsZero(T v);
161
162template <>
163ALWAYS_INLINE bool IsZero(bfloat16 v) {
164 return !static_cast<bool>(v);
165}
166
167template <>
168ALWAYS_INLINE bool IsZero(float v) {
169 return v == 0.0f;
170}
171
172template <typename T>
173template <bool Transpose>
174void SparseSlice<T>::Initialize(
175 const typename SparseSlice<T>::ConstMatrixMap& mat, int col_offset) {
176 const int mat_rows = Transpose ? mat.dimension(1) : mat.dimension(0);
177 const int mat_cols = Transpose ? mat.dimension(0) : mat.dimension(1);
178 DCHECK_LE(num_rows, mat_rows);
179 DCHECK_LE(num_cols + col_offset, mat_cols);
180
181 int num_blocks = (num_cols + block_size - 1) / block_size;
182 int mat_size = num_rows * num_cols;
183
184 index3_offset.reserve(num_blocks);
185 data3.reserve(mat_size);
186 index3.reserve(mat_size / 3);
187
188 index_offset.reserve(num_blocks);
189 data.reserve(num_blocks * num_rows * 2);
190 index.reserve(num_blocks * num_rows * 2);
191
192 Index3 idx3;
193 const int stride = Transpose ? mat.dimension(1) : 1;
194
195 for (int i = 0; i < num_blocks; ++i) {
196 int num_block_cols = std::min(block_size, num_cols - block_size * i);
197 for (int row = 0; row < num_rows; ++row) {
198 idx3.m = static_cast<uint8>(row);
199 // Safety note: The following code has a race, since it checks whether
200 // *curr is nonzero and then reads it again on use. However, the result
201 // of the race is only that some of the "nonzeros" in the resulting sparse
202 // representation may actually be zero, which is harmless.
203 const auto* start =
204 Transpose ? &mat(col_offset, row) : &mat(row, col_offset);
205 const auto* curr = start;
206 const auto* end = start + stride * num_block_cols;
207 uint8 k = 0;
208#define NEXT_ELEM \
209 curr += stride; \
210 ++k;
211#define EAT_ZEROS \
212 while (curr < end && IsZero<T>(*curr)) { \
213 NEXT_ELEM; \
214 }
215 while (true) {
216 EAT_ZEROS
217 if (curr >= end) break;
218 idx3.k1 = k;
219 const T value1 = *curr;
220 NEXT_ELEM;
221
222 EAT_ZEROS
223 if (curr >= end) {
224 data.push_back(value1);
225 index.push_back({idx3.m, idx3.k1});
226 break;
227 }
228 idx3.k2 = k;
229 const T value2 = *curr;
230 NEXT_ELEM;
231
232 EAT_ZEROS
233 if (curr >= end) {
234 data.push_back(value2);
235 index.push_back({idx3.m, idx3.k2});
236 data.push_back(value1);
237 index.push_back({idx3.m, idx3.k1});
238 break;
239 }
240 idx3.k3 = k;
241 data3.push_back(value1);
242 data3.push_back(value2);
243 data3.push_back(*curr);
244 NEXT_ELEM;
245 index3.push_back(idx3);
246#undef NEXT_ELEM
247#undef EAT_ZEROS
248 }
249 }
250 col_offset += block_size;
251 index3_offset.push_back(index3.size());
252 index_offset.push_back(index.size());
253 }
254 DCHECK_EQ(index3_offset.size(), num_blocks);
255 DCHECK_EQ(index_offset.size(), num_blocks);
256 DCHECK_EQ(3 * index3.size(), data3.size());
257 DCHECK_EQ(index.size(), data.size());
258}
259
260template <typename T>
261void SparseSlice<T>::Clear() {
262 index3_offset.clear();
263 index3.clear();
264 data3.clear();
265 index_offset.clear();
266 index.clear();
267 data.clear();
268}
269
270using Packet = Eigen::internal::packet_traits<float>::type;
271const int kNumOperands = (sizeof(Packet) / sizeof(float));
272#define LOAD(x) Eigen::internal::pload<Packet>(x);
273#define EXPAND_BFLOAT_L(x, y) \
274 const auto y = Eigen::internal::pexpand_bf16_l<Packet>(x);
275#define EXPAND_BFLOAT_U(x, y) \
276 const auto y = Eigen::internal::pexpand_bf16_u<Packet>(x);
277#define STORE(x, y) Eigen::internal::pstore<float>(x, y);
278#define FMA(a, b, c, d) d = Eigen::internal::pmadd<Packet>(a, b, c);
279
280ALWAYS_INLINE float ConvertBfloat16ToFloat(const bfloat16* src) {
281 float out = 0;
282 auto tmp = reinterpret_cast<bfloat16*>(&out);
283#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
284 tmp[0] = *src;
285#else
286 tmp[1] = *src;
287#endif
288 return out;
289}
290
291ALWAYS_INLINE Packet ConvertFourBfloat16ToFloat(const bfloat16* src) {
292 return Eigen::internal::pload4bf16<Packet>(
293 reinterpret_cast<const float*>(src));
294}
295
296ALWAYS_INLINE Packet ConvertTwoBfloat16ToFloat(const bfloat16* src) {
297 return Eigen::internal::pload2bf16<Packet>(
298 reinterpret_cast<const float*>(src));
299}
300
301ALWAYS_INLINE void ScalarMulAdd(const float a, const float** inp, float** out) {
302 **out += a * **inp;
303 ++*inp;
304 ++*out;
305}
306
307ALWAYS_INLINE void ScalarMulAdd(const float a, const bfloat16** inp,
308 float** out) {
309 float inp_f = ConvertBfloat16ToFloat(*inp);
310 **out += a * inp_f;
311 ++*inp;
312 ++*out;
313}
314ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2,
315 const float a3, const bfloat16** inp1,
316 const bfloat16** inp2,
317 const bfloat16** inp3, float** out) {
318 float inp1_f = ConvertBfloat16ToFloat(*inp1);
319 float inp2_f = ConvertBfloat16ToFloat(*inp2);
320 float inp3_f = ConvertBfloat16ToFloat(*inp3);
321 **out += a1 * inp1_f + a2 * inp2_f + a3 * inp3_f;
322 ++*out;
323 ++*inp1;
324 ++*inp2;
325 ++*inp3;
326}
327
328ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2,
329 const float a3, const float** inp1,
330 const float** inp2, const float** inp3,
331 float** out) {
332 **out += a1 * **inp1 + a2 * **inp2 + a3 * **inp3;
333 ++*out;
334 ++*inp1;
335 ++*inp2;
336 ++*inp3;
337}
338
339ALWAYS_INLINE void LoadSingleScalar(const bfloat16** data, Packet* l) {
340 auto tmp = ConvertBfloat16ToFloat(*data);
341 *l = Eigen::internal::pset1<Packet>(tmp);
342 ++*data;
343}
344
345ALWAYS_INLINE void LoadTwoScalars(const bfloat16** data, Packet* l1,
346 Packet* l2) {
347 if (kNumOperands >= 2) {
348 auto tmp = ConvertTwoBfloat16ToFloat(*data);
349 *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp);
350 *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp);
351 *data += 2;
352 } else {
353 LoadSingleScalar(data, l1);
354 LoadSingleScalar(data, l2);
355 }
356}
357
358ALWAYS_INLINE void LoadFourScalars(const bfloat16** data, Packet* l1,
359 Packet* l2, Packet* l3, Packet* l4) {
360 if (kNumOperands >= 4) {
361 auto tmp = ConvertFourBfloat16ToFloat(*data);
362 *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp);
363 *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp);
364 *l3 = Eigen::internal::pbroadcast_third<Packet>(tmp);
365 *l4 = Eigen::internal::pbroadcast_fourth<Packet>(tmp);
366 *data += 4;
367 } else {
368 LoadTwoScalars(data, l1, l2);
369 LoadTwoScalars(data, l3, l4);
370 }
371}
372
373ALWAYS_INLINE void LoadSingleScalar(const float** data, Packet* l) {
374 *l = Eigen::internal::pload1<Packet>(*data);
375 ++(*data);
376}
377
378ALWAYS_INLINE void LoadTwoScalars(const float** data, Packet* l1, Packet* l2) {
379 LoadSingleScalar(data, l1);
380 LoadSingleScalar(data, l2);
381}
382
383ALWAYS_INLINE void LoadFourScalars(const float** data, Packet* l1, Packet* l2,
384 Packet* l3, Packet* l4) {
385 LoadTwoScalars(data, l1, l2);
386 LoadTwoScalars(data, l3, l4);
387}
388
389template <typename T>
390ALWAYS_INLINE void LoadThreeScalars(const T** data, Packet* l1, Packet* l2,
391 Packet* l3) {
392 LoadTwoScalars(data, l1, l2);
393 LoadSingleScalar(data, l3);
394}
395
396template <typename T>
397ALWAYS_INLINE void LoadSixScalars(const T** data, Packet* l1, Packet* l2,
398 Packet* l3, Packet* l4, Packet* l5,
399 Packet* l6) {
400 LoadFourScalars(data, l1, l2, l3, l4);
401 LoadTwoScalars(data, l5, l6);
402}
403
404// Vectorized version of ScalarMulAdd.
405ALWAYS_INLINE void MulAdd(const Packet a, const bfloat16** binp, float** out) {
406 auto inp = reinterpret_cast<const float*>(*binp);
407 const auto b = LOAD(inp);
408 EXPAND_BFLOAT_L(b, b_0);
409 EXPAND_BFLOAT_U(b, b_1);
410 *binp += 2 * kNumOperands;
411 auto c1 = LOAD(*out);
412 auto c2 = LOAD(*out + kNumOperands);
413 FMA(a, b_0, c1, c1);
414 FMA(a, b_1, c2, c2);
415 STORE(*out, c1);
416 STORE(*out + kNumOperands, c2);
417 *out += 2 * kNumOperands;
418}
419
420// Vectorized version of ScalarMulAdd3Way.
421ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3,
422 const bfloat16** binp1, const bfloat16** binp2,
423 const bfloat16** binp3, float** out) {
424 auto inp1 = reinterpret_cast<const float*>(*binp1);
425 auto inp2 = reinterpret_cast<const float*>(*binp2);
426 auto inp3 = reinterpret_cast<const float*>(*binp3);
427 auto c1 = LOAD(*out);
428 auto c2 = LOAD(*out + kNumOperands);
429 const auto b1 = LOAD(inp1);
430 EXPAND_BFLOAT_L(b1, b1_0);
431 EXPAND_BFLOAT_U(b1, b1_1);
432 *binp1 += 2 * kNumOperands;
433 const auto b2 = LOAD(inp2);
434 EXPAND_BFLOAT_L(b2, b2_0);
435 EXPAND_BFLOAT_U(b2, b2_1);
436 *binp2 += 2 * kNumOperands;
437 const auto b3 = LOAD(inp3);
438 EXPAND_BFLOAT_L(b3, b3_0);
439 EXPAND_BFLOAT_U(b3, b3_1);
440 *binp3 += 2 * kNumOperands;
441 FMA(a1, b1_0, c1, c1);
442 FMA(a1, b1_1, c2, c2);
443 FMA(a2, b2_0, c1, c1);
444 FMA(a2, b2_1, c2, c2);
445 FMA(a3, b3_0, c1, c1);
446 FMA(a3, b3_1, c2, c2);
447 STORE(*out, c1);
448 STORE(*out + kNumOperands, c2);
449 *out += 2 * kNumOperands;
450}
451
452// Unroll MulAdd3Way for two iterations
453ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2,
454 const Packet a3, const bfloat16** binp1,
455 const bfloat16** binp2, const bfloat16** binp3,
456 float** out) {
457 auto inp1 = reinterpret_cast<const float*>(*binp1);
458 auto inp2 = reinterpret_cast<const float*>(*binp2);
459 auto inp3 = reinterpret_cast<const float*>(*binp3);
460 auto c1 = LOAD(*out);
461 auto c2 = LOAD(*out + kNumOperands);
462 const auto b1 = LOAD(inp1);
463 const auto b2 = LOAD(inp2);
464 const auto b3 = LOAD(inp3);
465
466 EXPAND_BFLOAT_L(b1, b1_0);
467 EXPAND_BFLOAT_U(b1, b1_1);
468 EXPAND_BFLOAT_L(b2, b2_0);
469 EXPAND_BFLOAT_U(b2, b2_1);
470 EXPAND_BFLOAT_L(b3, b3_0);
471 EXPAND_BFLOAT_U(b3, b3_1);
472 auto c3 = LOAD(*out + 2 * kNumOperands);
473 auto c4 = LOAD(*out + 3 * kNumOperands);
474 const auto b4 = LOAD(inp1 + kNumOperands);
475 const auto b5 = LOAD(inp2 + kNumOperands);
476 const auto b6 = LOAD(inp3 + kNumOperands);
477
478 EXPAND_BFLOAT_L(b4, b4_0);
479 EXPAND_BFLOAT_U(b4, b4_1);
480 EXPAND_BFLOAT_L(b5, b5_0);
481 EXPAND_BFLOAT_U(b5, b5_1);
482 EXPAND_BFLOAT_L(b6, b6_0);
483 EXPAND_BFLOAT_U(b6, b6_1);
484
485 FMA(a1, b1_0, c1, c1);
486 FMA(a1, b1_1, c2, c2);
487 FMA(a1, b4_0, c3, c3);
488 FMA(a1, b4_1, c4, c4);
489 FMA(a2, b2_0, c1, c1);
490 FMA(a2, b2_1, c2, c2);
491 FMA(a2, b5_0, c3, c3);
492 FMA(a2, b5_1, c4, c4);
493 FMA(a3, b3_0, c1, c1);
494 FMA(a3, b3_1, c2, c2);
495 FMA(a3, b6_0, c3, c3);
496 FMA(a3, b6_1, c4, c4);
497 STORE(*out, c1);
498 STORE(*out + kNumOperands, c2);
499 STORE(*out + 2 * kNumOperands, c3);
500 STORE(*out + 3 * kNumOperands, c4);
501 *out += 4 * kNumOperands;
502 *binp1 += 4 * kNumOperands;
503 *binp2 += 4 * kNumOperands;
504 *binp3 += 4 * kNumOperands;
505}
506
507// Apply MulAdd3Way on 128 operands.
508ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2,
509 const Packet a3, const bfloat16** inp1,
510 const bfloat16** inp2, const bfloat16** inp3,
511 float** out) {
512 for (int k = 0; k < 128 / (8 * kNumOperands); ++k) {
513 TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
514 TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
515 }
516}
517
518// Vectorized version of ScalarMulAdd
519ALWAYS_INLINE void MulAdd(const Packet a, const float** inp, float** out) {
520 const auto b = LOAD(*inp);
521 *inp += kNumOperands;
522 auto c = LOAD(*out);
523 FMA(a, b, c, c);
524 STORE(*out, c);
525 *out += kNumOperands;
526}
527
528// Vectorized version of ScalarMulAdd3Way
529ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3,
530 const float** inp1, const float** inp2,
531 const float** inp3, float** out) {
532 auto c = LOAD(*out);
533 const auto b1 = LOAD(*inp1);
534 *inp1 += kNumOperands;
535 const auto b2 = LOAD(*inp2);
536 *inp2 += kNumOperands;
537 const auto b3 = LOAD(*inp3);
538 *inp3 += kNumOperands;
539 FMA(a1, b1, c, c);
540 FMA(a2, b2, c, c);
541 FMA(a3, b3, c, c);
542 STORE(*out, c);
543 *out += kNumOperands;
544}
545
546// Unroll MulAdd3Way for two iterations
547ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2,
548 const Packet a3, const float** inp1,
549 const float** inp2, const float** inp3,
550 float** out) {
551 auto c1 = LOAD(*out);
552 const auto b1 = LOAD(*inp1);
553 const auto b2 = LOAD(*inp2);
554 const auto b3 = LOAD(*inp3);
555
556 auto c2 = LOAD(*out + kNumOperands);
557 const auto b4 = LOAD(*inp1 + kNumOperands);
558 const auto b5 = LOAD(*inp2 + kNumOperands);
559 const auto b6 = LOAD(*inp3 + kNumOperands);
560
561 FMA(a1, b1, c1, c1);
562 FMA(a1, b4, c2, c2);
563 FMA(a2, b2, c1, c1);
564 FMA(a2, b5, c2, c2);
565 FMA(a3, b3, c1, c1);
566 FMA(a3, b6, c2, c2);
567 STORE(*out, c1);
568 STORE(*out + kNumOperands, c2);
569 *out += 2 * kNumOperands;
570 *inp1 += 2 * kNumOperands;
571 *inp2 += 2 * kNumOperands;
572 *inp3 += 2 * kNumOperands;
573}
574
575// Unroll MulAdd3Way for four iterations
576ALWAYS_INLINE void FourMulAdd3Way(const Packet a1, const Packet a2,
577 const Packet a3, const float** inp1,
578 const float** inp2, const float** inp3,
579 float** out) {
580 TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
581 TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
582}
583
584// Apply MulAdd3Way on 128 operands.
585ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2,
586 const Packet a3, const float** inp1,
587 const float** inp2, const float** inp3,
588 float** out) {
589 if (kNumOperands == 8) {
590 FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
591 FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
592 FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
593 FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
594 } else {
595 DCHECK_LE(4 * kNumOperands, 128);
596 for (int i = 0; i < 128 / (4 * kNumOperands); ++i) {
597 MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
598 MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
599 MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
600 MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
601 }
602 }
603}
604// Computes product of "left_slices" with "num_cols" columns of "right", and
605// stores the output in *"output".
606// Note that left_slices is a list of SparseSlices, which are conceptually
607// assumed to be concatenated along the column dimension. Also each SparseSlice
608// is encoded as a list of blocks with upto N columns. See SparseSlice for more
609// details.
610template <typename TL, typename TR, int Cols>
611inline void GEPP(
612 const std::vector<SparseSlice<TL>*>& left_slices,
613 const Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>,
614 Eigen::Aligned>& right,
615 const int num_cols, Matrix* output) {
616 const int cols = (Cols == -1) ? num_cols : Cols;
617 DCHECK_EQ(num_cols, cols);
618 const int right_num_cols = right.dimension(1);
619 const int output_num_cols = output->dimension(1);
620 static const int kNumOperandsR = kNumOperands * sizeof(float) / sizeof(TR);
621 const int cols_mod = cols % kNumOperandsR;
622 int k_offset = 0;
623 // Pre-compute pointers for output matrix.
624 float* out_ptrs[M];
625 float* const out_start = &(*output)(0, 0);
626 for (int j = 0; j < M; ++j) {
627 out_ptrs[j] = out_start + output_num_cols * j;
628 }
629 for (const auto* left_slice : left_slices) {
630 const auto& left = *left_slice;
631 const auto* data3 = (!left.data3.empty()) ? &left.data3[0] : nullptr;
632 const auto* data = (!left.data.empty()) ? &left.data[0] : nullptr;
633 const int num_blocks = left.index3_offset.size();
634 int begin3 = 0;
635 int begin = 0;
636 for (int i = 0; i < num_blocks; ++i) {
637 // Pre-compute pointers for right matrix
638 const TR* right_ptrs[K];
639 const auto* const right_start = &right(k_offset, 0);
640 DCHECK_LT(k_offset, right.dimension(0));
641 for (int j = 0; j < K; ++j) {
642 right_ptrs[j] = right_start + right_num_cols * j;
643 }
644
645 const int end3 = left.index3_offset[i];
646 int j = begin3;
647 // Loop unrolled for 2 iterations.
648 for (; j + 1 < end3; j += 2) {
649 Packet l1, l2, l3, nl1, nl2, nl3;
650 LoadSixScalars(&data3, &l1, &l2, &l3, &nl1, &nl2, &nl3);
651 const auto& index = left.index3[j];
652 const auto& nindex = left.index3[j + 1];
653 float* out = out_ptrs[index.m];
654 float* nout = out_ptrs[nindex.m];
655 const auto* r1 = right_ptrs[index.k1];
656 const auto* r2 = right_ptrs[index.k2];
657 const auto* r3 = right_ptrs[index.k3];
658
659 const auto* nr1 = right_ptrs[nindex.k1];
660 const auto* nr2 = right_ptrs[nindex.k2];
661 const auto* nr3 = right_ptrs[nindex.k3];
662 if (cols == 128) {
663 MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out);
664 MulAdd3Way128(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout);
665 } else {
666 for (int n = 0; n < cols / kNumOperandsR; ++n) {
667 MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out);
668 MulAdd3Way(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout);
669 }
670
671 const float sl1 = Eigen::internal::pfirst<Packet>(l1);
672 const float sl2 = Eigen::internal::pfirst<Packet>(l2);
673 const float sl3 = Eigen::internal::pfirst<Packet>(l3);
674 const float nsl1 = Eigen::internal::pfirst<Packet>(nl1);
675 const float nsl2 = Eigen::internal::pfirst<Packet>(nl2);
676 const float nsl3 = Eigen::internal::pfirst<Packet>(nl3);
677 for (int k = 0; k < cols_mod; ++k) {
678 ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out);
679 ScalarMulAdd3Way(nsl1, nsl2, nsl3, &nr1, &nr2, &nr3, &nout);
680 }
681 }
682 }
683 if (j < end3) {
684 Packet l1, l2, l3;
685 LoadThreeScalars(&data3, &l1, &l2, &l3);
686
687 const auto& index = left.index3[j];
688 float* out = out_ptrs[index.m];
689 const auto* r1 = right_ptrs[index.k1];
690 const auto* r2 = right_ptrs[index.k2];
691 const auto* r3 = right_ptrs[index.k3];
692 if (cols == 128) {
693 MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out);
694 } else {
695 for (int n = 0; n < cols / kNumOperandsR; ++n) {
696 MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out);
697 }
698 const float sl1 = Eigen::internal::pfirst<Packet>(l1);
699 const float sl2 = Eigen::internal::pfirst<Packet>(l2);
700 const float sl3 = Eigen::internal::pfirst<Packet>(l3);
701 for (int k = 0; k < cols_mod; ++k) {
702 ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out);
703 }
704 }
705 }
706 begin3 = end3;
707 int end = left.index_offset[i];
708 // Loop unrolled for 4 iterations.
709 j = begin;
710 for (; j + 3 < end; j += 4) {
711 Packet l, nl, n2l, n3l;
712 LoadFourScalars(&data, &l, &nl, &n2l, &n3l);
713
714 const auto& index = left.index[j];
715 const auto& nindex = left.index[j + 1];
716 const auto& n2index = left.index[j + 2];
717 const auto& n3index = left.index[j + 3];
718 const auto* r = right_ptrs[index.k];
719 const auto* nr = right_ptrs[nindex.k];
720 const auto* n2r = right_ptrs[n2index.k];
721 const auto* n3r = right_ptrs[n3index.k];
722 float* out = out_ptrs[index.m];
723 float* nout = out_ptrs[nindex.m];
724 float* n2out = out_ptrs[n2index.m];
725 float* n3out = out_ptrs[n3index.m];
726
727 for (int n = 0; n < cols / kNumOperandsR; ++n) {
728 MulAdd(l, &r, &out);
729 MulAdd(nl, &nr, &nout);
730 MulAdd(n2l, &n2r, &n2out);
731 MulAdd(n3l, &n3r, &n3out);
732 }
733
734 const float sl1 = Eigen::internal::pfirst<Packet>(l);
735 const float sl2 = Eigen::internal::pfirst<Packet>(nl);
736 const float sl3 = Eigen::internal::pfirst<Packet>(n2l);
737 const float sl4 = Eigen::internal::pfirst<Packet>(n3l);
738 for (int k = 0; k < cols_mod; ++k) {
739 ScalarMulAdd(sl1, &r, &out);
740 ScalarMulAdd(sl2, &nr, &nout);
741 ScalarMulAdd(sl3, &n2r, &n2out);
742 ScalarMulAdd(sl4, &n3r, &n3out);
743 }
744 }
745 while (j < end) {
746 Packet l;
747 LoadSingleScalar(&data, &l);
748 const auto& index = left.index[j];
749 const auto* r = right_ptrs[index.k];
750 float* out = out_ptrs[index.m];
751 for (int n = 0; n < cols / kNumOperandsR; ++n) {
752 MulAdd(l, &r, &out);
753 }
754 const float sl = Eigen::internal::pfirst<Packet>(l);
755 for (int k = 0; k < cols_mod; ++k) {
756 ScalarMulAdd(sl, &r, &out);
757 }
758 j++;
759 }
760 k_offset += left.block_size;
761 begin = end;
762 }
763 }
764}
765
766#undef LOAD
767#undef EXPAND_BFLOAT_L
768#undef EXPAND_BFLOAT_U
769#undef STORE
770#undef FMA
771
772} // namespace
773
774template <typename TL, typename TR>
775class SparseMatMul {
776 using MatrixL = BasicMatrix<TL>;
777 using MatrixR = BasicMatrix<TR>;
778 using ConstMatrixMapL = BasicMatrixMap<const TL>;
779 using ConstMatrixMapR = BasicMatrixMap<const TR>;
780 using MatrixMapR = BasicMatrixMap<TR>;
781
782 public:
783 // Not used; added to match interface of LibxsmmSparseMatMul
784 struct TensorInfoCache {};
785
786 // Perform matrix multiplication of "left" and "right", and store the result
787 // in *"output".
788 public:
789 static inline void Compute(TensorInfoCache* cache,
790 const ConstMatrixMapL& left,
791 const ConstMatrixMapR& right, bool transpose_left,
792 const DeviceBase::CpuWorkerThreads* thread_pool,
793 bool transpose_output, MatrixMap* output);
794
795 private:
796 // Computes multiplication of left and num_cols columns of right, and stores
797 // the output block in *"output" at offsets "output_row_offset" and
798 // "output_col_offset". If assign is true, assigns the value to that block,
799 // else adds the values to the existing values.
800 static inline void ComputeOutputBlock(
801 const std::vector<SparseSlice<TL>*>& left, const ConstMatrixMapR& right,
802 int num_cols, int output_row_offset, int output_col_offset, bool assign,
803 bool transpose_output, MatrixMap* output);
804
805 // Encodes "mat" using a sparse representation and stores that in
806 // "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and
807 // "slice_num_cols", each grid element is converted into a SparseSlice and
808 // stored in mat_slices. "slice_block_size" is used to perform further column
809 // blocking of each slice.
810 static inline std::unique_ptr<BlockingCounter> CreateSparseSlices(
811 const ConstMatrixMapL& mat, bool transpose, int slice_num_rows,
812 int slice_block_size, int slice_num_cols,
813 std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
814 const DeviceBase::CpuWorkerThreads* thread_pool);
815
816 // This function chops "mat" along column dimension into pieces with at most N
817 // columns, and concatenates the pieces one after the other in "buffer". It
818 // returns the list of the pieces in "slices". It returns a BlockingCounter
819 // which should be used to wait for the shuffle operations to complete.
820 static inline std::unique_ptr<BlockingCounter> CreateDenseSlices(
821 const ConstMatrixMapR& mat, int row_start, int num_rows, int col_start,
822 int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool,
823 MatrixR* buffer, std::vector<ConstMatrixMapR*>* slices);
824
825 // Helper function for CreateDenseSlices to move the data around. It returns a
826 // BlockingCounter which should be used to wait for the shuffle operations to
827 // complete.
828 static inline BlockingCounter* ShuffleMatrix(
829 const ConstMatrixMapR& mat, int slice_row_start, int slice_num_rows,
830 int slice_col_start, int slice_num_cols, const int N,
831 const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer);
832
833 // Helper function for CreateDenseSlices to create slices.
834 static inline void SliceMatrix(const MatrixR& mat, const int num_rows,
835 const int num_slices,
836 std::vector<ConstMatrixMapR*>* slices);
837
838 // Heuristics to compute various block sizes.
839 // KR, NR: block sizes for "right". We run blocking iterations that operate on
840 // matrices with at most this size.
841 // KL: grid size along the column dimension used while encoding left.
842 // IB, JB: number of left and right slices to multiply together. This is used
843 // for ordering different ComputeBlockOutput operations inside each blocking
844 // iteration so as to potentially reduce the working set size.
845 static inline void ComputeBlockSizes(const ConstMatrixMapL& left,
846 const ConstMatrixMapR& right,
847 bool transpose_left, int num_threads,
848 int* KR, int* NR, int* KL, int* JB,
849 int* IB);
850
851 TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMul);
852};
853
854#ifdef TENSORFLOW_USE_LIBXSMM
855template <typename TL, typename TR>
856class LibxsmmSparseMatMul {
857 using MatrixL = BasicMatrix<TL>;
858 using MatrixR = BasicMatrix<TR>;
859 using ConstMatrixMapL = BasicMatrixMap<const TL>;
860 using ConstMatrixMapR = BasicMatrixMap<const TR>;
861 using MatrixMapR = BasicMatrixMap<TR>;
862
863 public:
864 // This structure contains a set of libxsmm kernels for sizes that have been
865 // encountered previously by this operator so that libxsmm does not need to
866 // reallocate its scratchpad memory each time (which hurts performance
867 // substantially).
868 struct TensorInfoCache {
869 struct TensorInfoCacheEntry {
870 // Parameters for kernel
871 int M;
872 int K;
873 int N;
874 int max_threads;
875 // libxsmm handle and matrix data
876 libxsmm_spmdm_handle handle;
877 libxsmm_CSR_sparseslice* output_csr;
878 // Chain to non-libxsmm implementation's cache in case that ever becomes
879 // useful (it is an empty struct right now)
880 typename SparseMatMul<TL, TR>::TensorInfoCache
881 non_libxsmm_cache; // Currently not used
882 };
883 // protects entries; invariant: entries is a valid std::multimap
884 tensorflow::mutex lock;
885 // Because there could be multiple matrix multiplies with the same sizes
886 // going on at the same time, we need to allow multiple cache entries for a
887 // given set of parameters. Taking and returning entries is used to make
888 // sure the same cache entry is not used from two threads at a time.
889 std::multimap<std::tuple<int, int, int, int>,
890 std::unique_ptr<TensorInfoCacheEntry>>
891 entries TF_GUARDED_BY(lock);
892
893 TensorInfoCache() : lock(), entries() {}
894 // Look up and remove first entry with these parameters, creating one if
895 // there isn't one
896 std::unique_ptr<TensorInfoCacheEntry> take_cache_entry(int M, int K, int N,
897 int max_threads)
898 TF_LOCKS_EXCLUDED(lock) {
899 tensorflow::mutex_lock ml(lock);
900 auto key = std::make_tuple(M, K, N, max_threads);
901 auto it = entries.find(key);
902 if (it != entries.end()) {
903 auto val = std::move(it->second);
904 entries.erase(it);
905 return val;
906 } else {
907 std::unique_ptr<TensorInfoCacheEntry> e{
908 new TensorInfoCacheEntry{M, K, N, max_threads, {}, nullptr}};
909 // setup scoped allocator, which uses cpu_allocator() for this scope
910 const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator;
911 libxsmm_spmdm_init(M, N, K, max_threads, &e->handle, &e->output_csr);
912 return e;
913 }
914 }
915 // Add a cache entry with certain parameters
916 void return_cache_entry(std::unique_ptr<TensorInfoCacheEntry> e)
917 TF_LOCKS_EXCLUDED(lock) {
918 tensorflow::mutex_lock ml(lock);
919 auto key = std::make_tuple(e->M, e->K, e->N, e->max_threads);
920 entries.insert(std::make_pair(key, std::move(e)));
921 }
922 ~TensorInfoCache() {
923 tensorflow::mutex_lock ml(lock);
924 for (auto& p : entries) {
925 libxsmm_spmdm_destroy(&p.second->handle);
926 }
927 entries.clear();
928 }
929
930 private:
931 TF_DISALLOW_COPY_AND_ASSIGN(TensorInfoCache);
932 };
933
934 // Perform matrix multiplication of "left" and "right", and store the result
935 // in *"output".
936 public:
937 static inline void Compute(TensorInfoCache* cache,
938 const ConstMatrixMapL& left,
939 const ConstMatrixMapR& right, bool transpose_left,
940 const DeviceBase::CpuWorkerThreads* thread_pool,
941 bool transpose_output, MatrixMap* output);
942
943 private:
944 TF_DISALLOW_COPY_AND_ASSIGN(LibxsmmSparseMatMul);
945};
946#endif
947
948template <typename TL, typename TR,
949 template <typename TL2, typename TR2> class DoMatMul>
950class SparseMatMulOp : public OpKernel {
951 using MatrixR = BasicMatrix<TR>;
952 using ConstMatrixMapR = BasicMatrixMap<const TR>;
953
954 public:
955 explicit SparseMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
956 OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
957 OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
958 OP_REQUIRES_OK(ctx, ctx->GetAttr("a_is_sparse", &a_is_sparse_));
959 OP_REQUIRES_OK(ctx, ctx->GetAttr("b_is_sparse", &b_is_sparse_));
960 }
961
962 void Compute(OpKernelContext* ctx) override {
963 const Tensor& a = ctx->input(0);
964 const Tensor& b = ctx->input(1);
965 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()),
966 errors::InvalidArgument("a is not a matrix"));
967 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()),
968 errors::InvalidArgument("b is not a matrix"));
969
970 const int m = transpose_a_ ? a.dim_size(1) : a.dim_size(0);
971 const int k = transpose_a_ ? a.dim_size(0) : a.dim_size(1);
972 const int n = transpose_b_ ? b.dim_size(0) : b.dim_size(1);
973 const int k2 = transpose_b_ ? b.dim_size(1) : b.dim_size(0);
974
975 OP_REQUIRES(ctx, k == k2,
976 errors::InvalidArgument(
977 "Matrix size incompatible: a: ", a.shape().DebugString(),
978 ", b: ", b.shape().DebugString()));
979 OP_REQUIRES(ctx, m >= 0 && n >= 0 && k >= 0,
980 errors::InvalidArgument(
981 "Matrix dimensions cannot be negative: a: ",
982 a.shape().DebugString(), ", b: ", b.shape().DebugString()));
983 Tensor* output = nullptr;
984 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({m, n}), &output));
985
986 // Return early if at least one of the output dimension size is 0.
987 if (m == 0 || n == 0) {
988 return;
989 }
990
991 if (k == 0) {
992 // If the inner dimension k in the matrix multiplication is zero, we fill
993 // the output with zeros.
994 functor::SetZeroFunctor<CPUDevice, float> f;
995 f(ctx->eigen_device<CPUDevice>(), output->flat<float>());
996 return;
997 }
998
999 auto out = output->matrix<float>();
1000
1001 std::unique_ptr<Tensor> a_float;
1002 std::unique_ptr<Tensor> b_float;
1003 if (!a_is_sparse_ && !b_is_sparse_) {
1004 auto left = &a;
1005 auto right = &b;
1006 // TODO(agarwal): multi-thread the conversions from bfloat16 to float.
1007 if (std::is_same<TL, bfloat16>::value) {
1008 a_float.reset(new Tensor(DT_FLOAT, a.shape()));
1009 BFloat16ToFloat(a.flat<bfloat16>().data(),
1010 a_float->flat<float>().data(), a.NumElements());
1011 left = a_float.get();
1012 }
1013 if (std::is_same<TR, bfloat16>::value) {
1014 b_float.reset(new Tensor(DT_FLOAT, b.shape()));
1015 BFloat16ToFloat(b.flat<bfloat16>().data(),
1016 b_float->flat<float>().data(), b.NumElements());
1017 right = b_float.get();
1018 }
1019 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
1020 dim_pair[0].first = transpose_a_ ? 0 : 1;
1021 dim_pair[0].second = transpose_b_ ? 1 : 0;
1022
1023 out.device(ctx->template eigen_device<CPUDevice>()) =
1024 left->matrix<float>().contract(right->matrix<float>(), dim_pair);
1025 return;
1026 }
1027
1028 auto left = &a;
1029 auto right = &b;
1030 bool transpose_output = false;
1031 bool transpose_a = transpose_a_;
1032 bool transpose_b = transpose_b_;
1033 if (!a_is_sparse_) {
1034 // Swap the order of multiplications using the identity:
1035 // A * B = (B' * A')'.
1036 std::swap(left, right);
1037 std::swap(transpose_a, transpose_b);
1038 transpose_a = !transpose_a;
1039 transpose_b = !transpose_b;
1040 transpose_output = !transpose_output;
1041 }
1042
1043 std::unique_ptr<Tensor> right_tr;
1044 if (transpose_b) {
1045 // TODO(agarwal): avoid transposing the matrix here and directly handle
1046 // transpose in CreateDenseSlices.
1047 OP_REQUIRES(ctx, right->dim_size(0) != 0,
1048 errors::InvalidArgument("b has an entry 0 in it's shape."));
1049 OP_REQUIRES(ctx, right->dim_size(1) != 0,
1050 errors::InvalidArgument("b has an entry 0 in it's shape."));
1051 right_tr.reset(
1052 new Tensor(right->dtype(),
1053 TensorShape({right->dim_size(1), right->dim_size(0)})));
1054
1055 const auto perm = dsizes_10();
1056 if (transpose_output) {
1057 right_tr->matrix<TL>().device(ctx->template eigen_device<CPUDevice>()) =
1058 right->matrix<TL>().shuffle(perm);
1059 } else {
1060 right_tr->matrix<TR>().device(ctx->template eigen_device<CPUDevice>()) =
1061 right->matrix<TR>().shuffle(perm);
1062 }
1063 right = right_tr.get();
1064 }
1065
1066 if (transpose_output) {
1067 DoMatMul<TR, TL>::Compute(&this->cache_tr_, left->matrix<TR>(),
1068 right->matrix<TL>(), transpose_a,
1069 ctx->device()->tensorflow_cpu_worker_threads(),
1070 transpose_output, &out);
1071 } else {
1072 DoMatMul<TL, TR>::Compute(&this->cache_nt_, left->matrix<TL>(),
1073 right->matrix<TR>(), transpose_a,
1074 ctx->device()->tensorflow_cpu_worker_threads(),
1075 transpose_output, &out);
1076 }
1077 }
1078
1079 private:
1080 bool transpose_a_;
1081 bool transpose_b_;
1082 bool a_is_sparse_;
1083 bool b_is_sparse_;
1084
1085 // Cache for non-transposed-output multiply
1086 typename DoMatMul<TL, TR>::TensorInfoCache cache_nt_;
1087 // Cache for transposed-output multiply
1088 typename DoMatMul<TR, TL>::TensorInfoCache cache_tr_;
1089
1090 TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMulOp);
1091};
1092
1093template <typename TL, typename TR>
1094inline void SparseMatMul<TL, TR>::ComputeOutputBlock(
1095 const std::vector<SparseSlice<TL>*>& left,
1096 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right, int num_cols,
1097 int output_row_offset, int output_col_offset, bool assign,
1098 bool transpose_output, MatrixMap* output) {
1099 const auto perm = dsizes_10();
1100 int num_rows = left[0]->num_rows;
1101 const int rhs_num_cols = right.dimension(1);
1102 DCHECK_LE(num_cols, rhs_num_cols);
1103 Matrix out(num_rows, rhs_num_cols);
1104 out.setZero();
1105 if (num_cols == N) {
1106 GEPP<TL, TR, N>(left, right, num_cols, &out);
1107 } else {
1108 GEPP<TL, TR, -1>(left, right, num_cols, &out);
1109 }
1110 if (!assign) {
1111 const DSizes begin(output_row_offset, output_col_offset);
1112 const DSizes sizes(num_rows, num_cols);
1113 if (transpose_output) {
1114 if (num_cols == rhs_num_cols) {
1115 output->shuffle(perm).slice(begin, sizes) += out;
1116 } else {
1117 const auto zero = dsizes_00();
1118 output->shuffle(perm).slice(begin, sizes) += out.slice(zero, sizes);
1119 }
1120 } else {
1121 if (num_cols == rhs_num_cols) {
1122 output->slice(begin, sizes) += out;
1123 } else {
1124 const auto zero = dsizes_00();
1125 output->slice(begin, sizes) += out.slice(zero, sizes);
1126 }
1127 }
1128 } else {
1129 std::unique_ptr<Matrix> out_tr;
1130 if (transpose_output) {
1131 out_tr.reset(new Matrix(rhs_num_cols, num_rows));
1132 *out_tr = out.shuffle(perm);
1133 std::swap(output_row_offset, output_col_offset);
1134 std::swap(num_rows, num_cols);
1135 }
1136 const Matrix& final_out = transpose_output ? *out_tr : out;
1137 for (int i = 0; i < num_rows; ++i) {
1138 memcpy(&(*output)(output_row_offset + i, output_col_offset),
1139 &final_out(i, 0), num_cols * sizeof(float));
1140 }
1141 }
1142}
1143
1144template <typename TL, typename TR>
1145inline std::unique_ptr<BlockingCounter>
1146SparseMatMul<TL, TR>::CreateSparseSlices(
1147 const typename SparseMatMul<TL, TR>::ConstMatrixMapL& mat, bool transpose,
1148 int slice_num_rows, int slice_block_size, int slice_num_cols,
1149 std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
1150 const DeviceBase::CpuWorkerThreads* thread_pool) {
1151 const int mat_num_rows = transpose ? mat.dimension(1) : mat.dimension(0);
1152 const int mat_num_cols = transpose ? mat.dimension(0) : mat.dimension(1);
1153 const int num_slices_dim0 =
1154 std::max(1, (mat_num_rows + slice_num_rows - 1) / slice_num_rows);
1155 const int num_slices_dim1 =
1156 std::max(1, (mat_num_cols + slice_num_cols - 1) / slice_num_cols);
1157 mat_slices->resize(num_slices_dim0);
1158 BlockingCounter* counter =
1159 new BlockingCounter(num_slices_dim0 * num_slices_dim1);
1160 auto work = [counter, transpose](SparseSlice<TL>* sparse_slice,
1161 SparseMatMul<TL, TR>::ConstMatrixMapL* slice,
1162 int col_offset) {
1163 if (transpose) {
1164 sparse_slice->template Initialize<true>(*slice, col_offset);
1165 } else {
1166 sparse_slice->template Initialize<false>(*slice, col_offset);
1167 }
1168 delete slice;
1169 counter->DecrementCount();
1170 };
1171 for (int i = 0; i < num_slices_dim0; ++i) {
1172 (*mat_slices)[i].resize(num_slices_dim1);
1173 int num_rows =
1174 std::min<int>(slice_num_rows, mat_num_rows - i * slice_num_rows);
1175 for (int j = 0; j < num_slices_dim1; ++j) {
1176 int num_cols =
1177 std::min<int>(slice_num_cols, mat_num_cols - j * slice_num_cols);
1178 SparseMatMul<TL, TR>::ConstMatrixMapL* slice = nullptr;
1179 if (transpose) {
1180 slice = new SparseMatMul<TL, TR>::ConstMatrixMapL(
1181 &mat(0, i * slice_num_rows), mat.dimensions());
1182 } else {
1183 DSizes d(num_rows, mat_num_cols);
1184 slice = new SparseMatMul<TL, TR>::ConstMatrixMapL(
1185 &mat(i * slice_num_rows, 0), d);
1186 }
1187 auto* sparse_slice =
1188 new SparseSlice<TL>(num_rows, num_cols, slice_block_size);
1189 (*mat_slices)[i][j] = sparse_slice;
1190 thread_pool->workers->Schedule(
1191 [=]() { work(sparse_slice, slice, slice_num_cols * j); });
1192 }
1193 }
1194 return std::unique_ptr<BlockingCounter>(counter);
1195}
1196#define LOAD(x) Eigen::internal::ploadu<Packet>((x));
1197#define INTERLEAVE(x) Eigen::internal::pinterleave4x64<Packet>(x);
1198#define STORE(x, y) Eigen::internal::pstoreu<float>(x, y);
1199
1200template <int NUM_ELEM = -1>
1201ALWAYS_INLINE void CopyAndMayBeInterleaveBfloat16(void* bdst, const void* bsrc,
1202 int num_elements) {
1203 DCHECK_GE(kNumOperands, 8);
1204 static const int kStep = kNumOperands * sizeof(float) / sizeof(bfloat16);
1205 const int num = (NUM_ELEM == -1) ? num_elements : NUM_ELEM;
1206 DCHECK_EQ(num, num_elements);
1207 const float* src = reinterpret_cast<const float*>(bsrc);
1208 float* dst = reinterpret_cast<float*>(bdst);
1209 for (int index = 0; index + kStep <= num; index += kStep) {
1210 auto in = LOAD(src);
1211 auto tmp = INTERLEAVE(in);
1212 STORE(dst, tmp);
1213 src += kNumOperands;
1214 dst += kNumOperands;
1215 }
1216 if (num % kStep != 0) {
1217 memcpy(dst, src, (num % kStep) * sizeof(bfloat16));
1218 }
1219}
1220
1221template <typename T>
1222ALWAYS_INLINE void CopyAndMayBeInterleave(void* dst, const void* src,
1223 int num_elements) {
1224 if (std::is_same<T, float>::value || kNumOperands < 8) {
1225 memcpy(dst, src, num_elements * sizeof(T));
1226 } else if (std::is_same<T, bfloat16>::value) {
1227 if (num_elements == N) {
1228 CopyAndMayBeInterleaveBfloat16<N>(dst, src, num_elements);
1229 } else {
1230 CopyAndMayBeInterleaveBfloat16<-1>(dst, src, num_elements);
1231 }
1232 } else {
1233 LOG(FATAL) << "Unsupported type";
1234 }
1235}
1236
1237#undef LOAD
1238#undef Interleave
1239#undef Store
1240
1241template <typename TL, typename TR>
1242inline BlockingCounter* SparseMatMul<TL, TR>::ShuffleMatrix(
1243 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat,
1244 int slice_row_start, int slice_num_rows, int slice_col_start,
1245 int slice_num_cols, const int N,
1246 const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer) {
1247 DCHECK_EQ(N % 2, 0);
1248 DCHECK_LE(kNumOperands * sizeof(float) / sizeof(TR), N);
1249 // Note(nikhilsarda): This heuristic is optimal in benchmarks as of
1250 // Jan 21, 2020.
1251 int num_threads = std::min(thread_pool->num_threads, 8);
1252 BlockingCounter* counter = new BlockingCounter(num_threads);
1253 DCHECK_EQ(N, buffer->dimension(1));
1254 auto shuffle_work = [&mat, slice_row_start, slice_num_rows, slice_col_start,
1255 slice_num_cols, N, buffer, counter](int s, int e) {
1256 const int row_start = s % slice_num_rows + slice_row_start;
1257 const int col_start = s / slice_num_rows * N + slice_col_start;
1258 auto* out_start = &(*buffer)(s, 0);
1259 const auto* input_start = &mat(row_start, col_start);
1260 const auto* input_end = &mat(slice_row_start + slice_num_rows - 1,
1261 slice_col_start + slice_num_cols - 1);
1262 const int mat_num_cols = mat.dimension(1);
1263 const int row_slice_size = slice_num_rows * mat_num_cols;
1264
1265 const int aligned_end = slice_num_cols / N * slice_num_rows;
1266 const int e1 = std::min(e, aligned_end);
1267 while (s < e1) {
1268 CopyAndMayBeInterleave<TR>(out_start, input_start, N);
1269 out_start += N;
1270 input_start += mat_num_cols;
1271 if (input_start > input_end) {
1272 input_start = input_start - row_slice_size + N;
1273 }
1274 ++s;
1275 }
1276 int s1 = std::max(s, aligned_end);
1277 const int copy_num_cols = slice_num_cols % N;
1278 while (s1 < e) {
1279 CopyAndMayBeInterleave<TR>(out_start, input_start, copy_num_cols);
1280 out_start += N;
1281 input_start += mat_num_cols;
1282 ++s1;
1283 }
1284 if (counter) counter->DecrementCount();
1285 };
1286
1287 int start = 0;
1288 int end = 0;
1289 int num_out_rows = (slice_num_cols + N - 1) / N * slice_num_rows;
1290 DCHECK_LE(num_out_rows, buffer->dimension(0));
1291 for (int i = std::max(1, num_threads); i > 0; --i) {
1292 end = start + num_out_rows / i;
1293 thread_pool->workers->Schedule([=]() { shuffle_work(start, end); });
1294 num_out_rows -= (end - start);
1295 start = end;
1296 }
1297 return counter;
1298}
1299
1300template <typename TL, typename TR>
1301inline void SparseMatMul<TL, TR>::SliceMatrix(
1302 const MatrixR& mat, const int num_rows, const int num_slices,
1303 std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
1304 slices->resize(num_slices);
1305 DSizes d(num_rows, mat.dimension(1));
1306 DCHECK_LE(num_rows * num_slices, mat.dimension(0));
1307 for (int i = 0; i < num_slices; ++i) {
1308 (*slices)[i] = new ConstMatrixMapR(&mat(i * num_rows, 0), d);
1309 }
1310}
1311
1312template <typename TL, typename TR>
1313inline std::unique_ptr<BlockingCounter> SparseMatMul<TL, TR>::CreateDenseSlices(
1314 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat, int row_start,
1315 int num_rows, int col_start, int num_cols,
1316 const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer,
1317 std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
1318 std::unique_ptr<BlockingCounter> shuffle_counter(ShuffleMatrix(
1319 mat, row_start, num_rows, col_start, num_cols, N, thread_pool, buffer));
1320 const int num_slices = (num_cols + N - 1) / N;
1321 SliceMatrix(*buffer, num_rows, num_slices, slices);
1322 return shuffle_counter;
1323}
1324
1325template <typename TL, typename TR>
1326inline void SparseMatMul<TL, TR>::ComputeBlockSizes(
1327 const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left,
1328 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right,
1329 bool transpose_left, int num_threads, int* KR, int* NR, int* KL, int* JB,
1330 int* IB) {
1331 // Heuristics for calculating block sizes
1332 // Assume two hyperthreads per core.
1333 const int est_num_cores = std::max(1, (num_threads + 1) / 2);
1334 // Use block of rhs with at most 128K floats per core.
1335 const int mem = est_num_cores * 128 * 1024;
1336 *KR = std::min(static_cast<int>(right.dimension(0)), mem / 256);
1337 *NR = right.dimension(1);
1338 if (*KR * *NR > mem) {
1339 // 4096 may be enough to amortize the cost of writes.
1340 *KR = std::min<int>(*KR, 4096);
1341 }
1342 // Use sizes that are multiples of K and 256.
1343 *KR = std::max(1, *KR / K) * K;
1344 *NR = std::max(1, *NR / 256) * 256;
1345 if (*KR * *NR > mem) {
1346 *NR = mem / *KR;
1347 }
1348 *NR = std::max(1, *NR / 256) * 256;
1349
1350 const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
1351 const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
1352 for (*KL = 1024; *KL > K; *KL /= 2) {
1353 if (*KR % *KL == 0 &&
1354 std::max<int>(1, left_dim0 / 64) * (left_dim1 / *KL) > est_num_cores) {
1355 break;
1356 }
1357 }
1358 DCHECK_EQ(*KL % K, 0);
1359 DCHECK_GE(*KR, *KL);
1360 if (*KR < right.dimension(0)) {
1361 CHECK_EQ(*KR % *KL, 0);
1362 }
1363
1364 *JB = std::max(1, static_cast<int>(sqrt(num_threads) / 2.0));
1365 *IB = 8 * *JB;
1366 DCHECK_EQ(N * sizeof(float) % 64, size_t{0});
1367}
1368
1369#ifdef TENSORFLOW_USE_LIBXSMM
1370
1371template <typename F>
1372void do_on_all_threads(const DeviceBase::CpuWorkerThreads* thread_pool,
1373 const F& f) {
1374 int num_threads = thread_pool->num_threads;
1375 if (num_threads == 0) {
1376 LOG(FATAL) << "Have 0 threads in thread pool";
1377 } else if (num_threads == 1) {
1378 f(0);
1379 } else {
1380 BlockingCounter counter(num_threads - 1);
1381 for (int i = 1; i < num_threads; ++i) {
1382 thread_pool->workers->Schedule([&, i]() {
1383 f(i);
1384 counter.DecrementCount();
1385 });
1386 }
1387 f(0);
1388 counter.Wait();
1389 }
1390}
1391
1392template <typename T>
1393struct empty_type_wrapper {};
1394
1395// Copies of interface to libxsmm_spmdm_createSparseSlice_*_notrans_thread to
1396// allow overloading
1397void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1398 empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
1399 const float* A, libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id,
1400 int tid, int nthreads) {
1401 return libxsmm_spmdm_createSparseSlice_fp32_thread(
1402 handle, transA, A, libxsmm_output_csr_a, block_id, tid, nthreads);
1403}
1404void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1405 empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle,
1406 char transA, const bfloat16* A,
1407 libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, int tid,
1408 int nthreads) {
1409 return libxsmm_spmdm_createSparseSlice_bfloat16_thread(
1410 handle, transA, reinterpret_cast<const libxsmm_bfloat16*>(A),
1411 libxsmm_output_csr_a, block_id, tid, nthreads);
1412}
1413
1414void wrapper_libxsmm_spmdm_compute_generic_thread(
1415 empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle,
1416 char transA, char transB, const bfloat16* alpha,
1417 libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC,
1418 const bfloat16* beta, float* C, int block_id, int tid, int nthreads) {
1419 return libxsmm_spmdm_compute_bfloat16_thread(
1420 handle, transA, transB, reinterpret_cast<const libxsmm_bfloat16*>(alpha),
1421 A_sparse, reinterpret_cast<const libxsmm_bfloat16*>(B), transC,
1422 reinterpret_cast<const libxsmm_bfloat16*>(beta), C, block_id, tid,
1423 nthreads);
1424}
1425void wrapper_libxsmm_spmdm_compute_generic_thread(
1426 empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
1427 char transB, const float* alpha, libxsmm_CSR_sparseslice* A_sparse,
1428 const float* B, char transC, const float* beta, float* C, int block_id,
1429 int tid, int nthreads) {
1430 return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, alpha,
1431 A_sparse, B, transC, beta, C,
1432 block_id, tid, nthreads);
1433}
1434
1435template <typename TL, typename TR>
1436inline void LibxsmmSparseMatMul<TL, TR>::Compute(
1437 typename LibxsmmSparseMatMul<TL, TR>::TensorInfoCache* cache,
1438 const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapL& left,
1439 const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapR& right,
1440 bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
1441 bool transpose_output, MatrixMap* output) {
1442 const int num_threads = thread_pool->num_threads;
1443 const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
1444 const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
1445 const int right_dim0 = right.dimension(0);
1446 const int right_dim1 = right.dimension(1);
1447 CHECK_EQ(left_dim1, right_dim0);
1448 CHECK_EQ(left_dim0,
1449 (transpose_output ? output->dimension(1) : output->dimension(0)));
1450 CHECK_EQ(right_dim1,
1451 (transpose_output ? output->dimension(0) : output->dimension(1)));
1452 auto left_data = left.data();
1453 auto right_data = right.data();
1454 auto output_data = output->data();
1455 // Initialize libxsmm for this matrix; make sure another thread doesn't use
1456 // this handle
1457 auto entry =
1458 cache->take_cache_entry(left_dim0, right_dim0, right_dim1, num_threads);
1459 // Convert the left matrix to compressed sparse row (CSR) format
1460 ptrdiff_t total_num_creation_blocks =
1461 libxsmm_spmdm_get_num_createSparseSlice_blocks(&entry->handle);
1462 std::atomic<int> cur_create_block_number;
1463 cur_create_block_number.store(0);
1464 do_on_all_threads(thread_pool, [&](int i) {
1465 while (true) {
1466 int work_item = cur_create_block_number.fetch_add(1);
1467 if (work_item >= total_num_creation_blocks) break;
1468 wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1469 empty_type_wrapper<TL>{}, &entry->handle,
1470 (transpose_left ? 'T' : 'N'), left_data, entry->output_csr, work_item,
1471 i, num_threads);
1472 }
1473 });
1474 // Do matrix-matrix multiplication
1475 ptrdiff_t total_num_mult_blocks =
1476 libxsmm_spmdm_get_num_compute_blocks(&entry->handle);
1477 std::atomic<int> cur_mult_block_number;
1478 cur_mult_block_number.store(0);
1479 do_on_all_threads(thread_pool, [&](int i) {
1480 while (true) {
1481 int work_item = cur_mult_block_number.fetch_add(1);
1482 if (work_item >= total_num_mult_blocks) break;
1483 const TL alpha(1.0); // Stored in a variable so we can get a pointer
1484 const TL beta(0.0); // Stored in a variable so we can get a pointer
1485 wrapper_libxsmm_spmdm_compute_generic_thread(
1486 empty_type_wrapper<TL>{}, &entry->handle,
1487 (transpose_left ? 'T' : 'N'), 'N', &alpha, entry->output_csr,
1488 right_data, (transpose_output ? 'T' : 'N'), &beta, output_data,
1489 work_item, i, num_threads);
1490 }
1491 });
1492 // Put handle + CSR storage back into cache
1493 cache->return_cache_entry(std::move(entry));
1494}
1495
1496#endif // TENSORFLOW_USE_LIBXSMM
1497
1498// Here is an overview of the SparseMatMul code. Note that we assume that the
1499// left matrix is sparse.
1500//
1501// The matrix "left" is divided into a grid with blocksize of (M, KL). Each
1502// block is encoded as a SparseSlice. These grid elements are stored as
1503// std::vector<std::vector<SparseSlice>>. Each element of the outer vector
1504// represents M rows of the left matrix. Lets call these elements l_i and lets
1505// call each element of the inner vector L_mk.
1506//
1507// The matrix "right" is divided into a grid with block size KR * NR. Lets
1508// denote the blocks on the right as R_kn. Note that we ensure that KL divides
1509// KR so that for each element R_kn, we don't need to multiply it with any
1510// partial L_mk blocks.
1511//
1512// We then multiply each right side block R_kn with the full "left" matrix and
1513// update the output. These iterations are run sequentially since R_kn are
1514// packed into the same underlying temporary buffer.
1515//
1516// In each iteration we do the following:
1517// 1. Create slices r_j of R_kn: We split R_kn into vertical blocks with N
1518// (=128) columns and then concatenating these slices into a buffer. This is
1519// done so that each slice r_j of R_kn is stored contiguously in memory. Note
1520// that if R_kj has dimensions (KR, NR), we create NR / N slices, and the
1521// buffer has dimensions (KR * NR / N, N) (assuming N divides NR).
1522// 2. For each (l_i, r_j), we compute the inner product using the GEPP function
1523// and update the output block o_ij. These calls are further blocked to
1524// reduce the working set size. In each iteration we take IB elements from
1525// {l_i} and JB elements from {r_j} and compute the IB * JB inner products.
1526template <typename TL, typename TR>
1527inline void SparseMatMul<TL, TR>::Compute(
1528 typename SparseMatMul<TL, TR>::TensorInfoCache* /*cache*/,
1529 const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left,
1530 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right,
1531 bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
1532 bool transpose_output, MatrixMap* output) {
1533 const int num_threads = thread_pool->num_threads;
1534 int KR, NR, KL, JB, IB;
1535 ComputeBlockSizes(left, right, transpose_left, num_threads, &KR, &NR, &KL,
1536 &JB, &IB);
1537 // Slice the left matrix
1538 std::vector<std::vector<SparseSlice<TL>*>> left_slices;
1539 std::unique_ptr<BlockingCounter> sparse_slice_counter =
1540 CreateSparseSlices(ConstMatrixMapL(left.data(), left.dimensions()),
1541 transpose_left, M, K, KL, &left_slices, thread_pool);
1542 const int num_left_slices = left_slices.size();
1543
1544 const int right_dim0 = right.dimension(0);
1545 const int right_dim1 = right.dimension(1);
1546 // Allocate buffer for storing slices of right matrix.
1547 // Note buffer needs enough space to hold at most a KR * NR matrix since that
1548 // is the block size per iteration.
1549 const int buffer_num_rows =
1550 std::min(KR, right_dim0) * ((std::min(NR, right_dim1) + N - 1) / N);
1551 MatrixR buffer(buffer_num_rows, N);
1552 std::vector<ConstMatrixMapR*> right_slices;
1553
1554 std::vector<SparseSlice<TL>*> block_left_slices;
1555 std::vector<std::function<void(void)>> tasks;
1556 // Number of blocks based on block sizes of KR * NR.
1557 const int num_k_blocks = (right_dim0 + KR - 1) / KR;
1558 const int num_n_blocks = (right_dim1 + NR - 1) / NR;
1559 std::unique_ptr<BlockingCounter> dense_slice_counter;
1560
1561 for (int nb = 0; nb < num_n_blocks; ++nb) {
1562 const int right_num_cols =
1563 std::min(NR, static_cast<int>(right_dim1 - NR * nb));
1564 for (int kb = 0; kb < num_k_blocks; ++kb) {
1565 const int right_num_rows =
1566 std::min(KR, static_cast<int>(right_dim0 - KR * kb));
1567 dense_slice_counter = CreateDenseSlices(
1568 right, kb * KR, right_num_rows, nb * NR, right_num_cols, thread_pool,
1569 &buffer, &right_slices);
1570 const int num_right_slices = right_slices.size();
1571 tasks.reserve(num_left_slices * num_right_slices);
1572 for (int j_outer = 0; j_outer < num_right_slices; j_outer += JB) {
1573 for (int i_outer = 0; i_outer < num_left_slices; i_outer += IB) {
1574 for (int j_inner = j_outer;
1575 j_inner < std::min(num_right_slices, j_outer + JB); ++j_inner) {
1576 const int num_cols = std::min(N, right_num_cols - N * j_inner);
1577 for (int i_inner = i_outer;
1578 i_inner < std::min(num_left_slices, i_outer + IB); ++i_inner) {
1579 block_left_slices.clear();
1580 int begin = kb * KR / KL;
1581 int end = std::min<int>((kb + 1) * KR / KL,
1582 (right.dimension(0) + KL - 1) / KL);
1583 DCHECK_LT(begin, end);
1584 block_left_slices.insert(block_left_slices.begin(),
1585 left_slices[i_inner].begin() + begin,
1586 left_slices[i_inner].begin() + end);
1587 tasks.push_back(std::bind(
1588 &ComputeOutputBlock, block_left_slices,
1589 std::ref(*right_slices[j_inner]), num_cols, M * i_inner,
1590 N * j_inner + nb * NR, kb == 0, transpose_output, output));
1591 }
1592 }
1593 }
1594 }
1595 if (sparse_slice_counter) {
1596 sparse_slice_counter->Wait();
1597 sparse_slice_counter.reset(nullptr);
1598 }
1599 if (dense_slice_counter) {
1600 dense_slice_counter->Wait();
1601 dense_slice_counter.reset(nullptr);
1602 }
1603 BlockingCounter bc(tasks.size());
1604 for (const auto& t : tasks) {
1605 thread_pool->workers->Schedule([&bc, &t]() {
1606 t();
1607 bc.DecrementCount();
1608 });
1609 }
1610 bc.Wait();
1611 tasks.clear();
1612 for (auto& temp : right_slices) {
1613 delete temp;
1614 }
1615 right_slices.clear();
1616 }
1617 }
1618 for (auto& left_slice : left_slices) {
1619 for (auto& temp : left_slice) {
1620 delete temp;
1621 }
1622 left_slice.clear();
1623 }
1624}
1625
1626#define REGISTER_SPARSE_MATMUL(TA, TB) \
1627 REGISTER_KERNEL_BUILDER(Name("SparseMatMul") \
1628 .Device(DEVICE_CPU) \
1629 .TypeConstraint<TA>("Ta") \
1630 .TypeConstraint<TB>("Tb"), \
1631 SparseMatMulOp<TA, TB, SparseMatMul>);
1632#ifdef TENSORFLOW_USE_LIBXSMM
1633#define REGISTER_SPARSE_MATMUL_LIBXSMM(TA, TB) \
1634 REGISTER_KERNEL_BUILDER(Name("SparseMatMul") \
1635 .Device(DEVICE_CPU) \
1636 .TypeConstraint<TA>("Ta") \
1637 .TypeConstraint<TB>("Tb"), \
1638 SparseMatMulOp<TA, TB, LibxsmmSparseMatMul>);
1639#endif
1640
1641REGISTER_SPARSE_MATMUL(float, bfloat16);
1642REGISTER_SPARSE_MATMUL(bfloat16, float);
1643
1644#ifdef TENSORFLOW_USE_LIBXSMM
1645REGISTER_SPARSE_MATMUL_LIBXSMM(bfloat16, bfloat16);
1646REGISTER_SPARSE_MATMUL_LIBXSMM(float, float);
1647#else
1648REGISTER_SPARSE_MATMUL(bfloat16, bfloat16);
1649REGISTER_SPARSE_MATMUL(float, float);
1650#endif
1651
1652#undef REGISTER_SPARSE_MATMUL
1653
1654} // end namespace tensorflow
1655