1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/math_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include "tensorflow/core/kernels/sparse_matmul_op.h" |
21 | |
22 | #include <map> |
23 | #include <memory> |
24 | #include <vector> |
25 | |
26 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
27 | #include "tensorflow/core/common_runtime/device.h" |
28 | #include "tensorflow/core/framework/bfloat16.h" |
29 | #include "tensorflow/core/framework/op.h" |
30 | #include "tensorflow/core/framework/op_kernel.h" |
31 | #include "tensorflow/core/framework/types.h" |
32 | #include "tensorflow/core/kernels/fill_functor.h" |
33 | #include "tensorflow/core/lib/core/threadpool.h" |
34 | #include "tensorflow/core/platform/blocking_counter.h" |
35 | #include "tensorflow/core/platform/errors.h" |
36 | #include "tensorflow/core/platform/logging.h" |
37 | #include "tensorflow/core/platform/macros.h" |
38 | #include "tensorflow/core/platform/mutex.h" |
39 | #include "tensorflow/core/platform/thread_annotations.h" |
40 | #include "tensorflow/core/platform/types.h" |
41 | #ifdef TENSORFLOW_USE_LIBXSMM |
42 | #include "include/libxsmm_intrinsics_x86.h" |
43 | #include "include/libxsmm_malloc.h" |
44 | #include "include/libxsmm_spmdm.h" |
45 | #endif |
46 | |
47 | #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) |
48 | #include "tensorflow/core/kernels/eigen_contraction_kernel.h" |
49 | #endif |
50 | |
51 | #define ALWAYS_INLINE EIGEN_ALWAYS_INLINE |
52 | |
53 | namespace tensorflow { |
54 | namespace { |
55 | |
56 | template <typename T> |
57 | using BasicMatrix = Eigen::Tensor<T, 2, Eigen::RowMajor>; |
58 | |
59 | template <typename T> |
60 | using BasicMatrixMap = |
61 | Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Aligned>; |
62 | |
63 | using Matrix = BasicMatrix<float>; |
64 | using MatrixMap = BasicMatrixMap<float>; |
65 | using CPUDevice = Eigen::ThreadPoolDevice; |
66 | using DSizes = Eigen::DSizes<Eigen::DenseIndex, 2>; |
67 | |
68 | // Two commonly used static dsizes. We use Eigen::type2index to allow as much |
69 | // compile time optimization as possible. |
70 | inline Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>> |
71 | dsizes_00() { |
72 | return Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>>(); |
73 | } |
74 | inline Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>> |
75 | dsizes_10() { |
76 | return Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>>(); |
77 | } |
78 | |
79 | // Blocksizes |
80 | // TODO(agarwal): compute these sizes based on cache sizes. |
81 | const int K = 64; |
82 | const int M = 64; |
83 | const int N = 128; |
84 | |
85 | // This stores a sparse representation of a slice of a matrix with size |
86 | // (num_rows, num_cols). The slice is represented as a series of blocks of size |
87 | // (num_rows, b), where b = block_size for all but the last block, which may |
88 | // have fewer columns. |
89 | // |
90 | // num_rows and block_size are assumed to be <= 256. This allows storing |
91 | // different indices as uint8. |
92 | // |
93 | // For each block, we store all the non zero entries in data/data3 vector and |
94 | // the corresponding coordinates of the element in index/index3 vectors. index3 |
95 | // vector stores index of 3 elements in the same row so that these elements can |
96 | // share the same row coordinate. Each entry in Index3 corresponds to 3 entries |
97 | // in data3. |
98 | // |
99 | // Note that all the data/indices of all the blocks are stored in the same |
100 | // vectors respectively. To identify block boundaries, we store the block |
101 | // offsets using index3_offset/index_offset. If there are n blocks in the slice, |
102 | // index3_offset and index_offset have n entries. The indices for the ith block |
103 | // are the values in the following range: |
104 | // [index3[index3_offset[i-1]], index3[index3_offset[i]]). Similarly for |
105 | // index_offset. |
106 | template <typename T> |
107 | struct SparseSlice { |
108 | using ConstMatrixMap = BasicMatrixMap<const T>; |
109 | |
110 | public: |
111 | // Indices of three elements on the same row. |
112 | struct Index3 { |
113 | uint8 m; // row |
114 | // columns |
115 | uint8 k1; |
116 | uint8 k2; |
117 | uint8 k3; |
118 | }; |
119 | |
120 | // Index of one element. |
121 | struct Index { |
122 | uint8 m; |
123 | uint8 k; |
124 | }; |
125 | |
126 | SparseSlice(int nrows, int ncols, int bsize) |
127 | : num_rows(nrows), num_cols(ncols), block_size(bsize) { |
128 | DCHECK_LE(nrows, 256); |
129 | DCHECK_LE(block_size, 256); |
130 | } |
131 | |
132 | // Initializes the slice with data starting at mat(0, col_offset) and with |
133 | // size (num_rows, num_cols). |
134 | // If Transpose is true, implicitly transposes mat. |
135 | template <bool Transpose = false> |
136 | void Initialize(const ConstMatrixMap& mat, int col_offset); |
137 | |
138 | void Clear(); |
139 | |
140 | // See comments above. |
141 | std::vector<int> index3_offset; |
142 | std::vector<Index3> index3; |
143 | std::vector<T> data3; |
144 | |
145 | // See comments above. Similar to "index3" except that each element in "index" |
146 | // corresponds to one element in data. |
147 | std::vector<int> index_offset; |
148 | std::vector<Index> index; |
149 | std::vector<T> data; |
150 | |
151 | // Number of rows and columns for the slice. |
152 | const int num_rows; |
153 | const int num_cols; |
154 | |
155 | // Block size used to initialize from a matrix. |
156 | const int block_size; |
157 | }; |
158 | |
159 | template <typename T> |
160 | bool IsZero(T v); |
161 | |
162 | template <> |
163 | ALWAYS_INLINE bool IsZero(bfloat16 v) { |
164 | return !static_cast<bool>(v); |
165 | } |
166 | |
167 | template <> |
168 | ALWAYS_INLINE bool IsZero(float v) { |
169 | return v == 0.0f; |
170 | } |
171 | |
172 | template <typename T> |
173 | template <bool Transpose> |
174 | void SparseSlice<T>::Initialize( |
175 | const typename SparseSlice<T>::ConstMatrixMap& mat, int col_offset) { |
176 | const int mat_rows = Transpose ? mat.dimension(1) : mat.dimension(0); |
177 | const int mat_cols = Transpose ? mat.dimension(0) : mat.dimension(1); |
178 | DCHECK_LE(num_rows, mat_rows); |
179 | DCHECK_LE(num_cols + col_offset, mat_cols); |
180 | |
181 | int num_blocks = (num_cols + block_size - 1) / block_size; |
182 | int mat_size = num_rows * num_cols; |
183 | |
184 | index3_offset.reserve(num_blocks); |
185 | data3.reserve(mat_size); |
186 | index3.reserve(mat_size / 3); |
187 | |
188 | index_offset.reserve(num_blocks); |
189 | data.reserve(num_blocks * num_rows * 2); |
190 | index.reserve(num_blocks * num_rows * 2); |
191 | |
192 | Index3 idx3; |
193 | const int stride = Transpose ? mat.dimension(1) : 1; |
194 | |
195 | for (int i = 0; i < num_blocks; ++i) { |
196 | int num_block_cols = std::min(block_size, num_cols - block_size * i); |
197 | for (int row = 0; row < num_rows; ++row) { |
198 | idx3.m = static_cast<uint8>(row); |
199 | // Safety note: The following code has a race, since it checks whether |
200 | // *curr is nonzero and then reads it again on use. However, the result |
201 | // of the race is only that some of the "nonzeros" in the resulting sparse |
202 | // representation may actually be zero, which is harmless. |
203 | const auto* start = |
204 | Transpose ? &mat(col_offset, row) : &mat(row, col_offset); |
205 | const auto* curr = start; |
206 | const auto* end = start + stride * num_block_cols; |
207 | uint8 k = 0; |
208 | #define NEXT_ELEM \ |
209 | curr += stride; \ |
210 | ++k; |
211 | #define EAT_ZEROS \ |
212 | while (curr < end && IsZero<T>(*curr)) { \ |
213 | NEXT_ELEM; \ |
214 | } |
215 | while (true) { |
216 | EAT_ZEROS |
217 | if (curr >= end) break; |
218 | idx3.k1 = k; |
219 | const T value1 = *curr; |
220 | NEXT_ELEM; |
221 | |
222 | EAT_ZEROS |
223 | if (curr >= end) { |
224 | data.push_back(value1); |
225 | index.push_back({idx3.m, idx3.k1}); |
226 | break; |
227 | } |
228 | idx3.k2 = k; |
229 | const T value2 = *curr; |
230 | NEXT_ELEM; |
231 | |
232 | EAT_ZEROS |
233 | if (curr >= end) { |
234 | data.push_back(value2); |
235 | index.push_back({idx3.m, idx3.k2}); |
236 | data.push_back(value1); |
237 | index.push_back({idx3.m, idx3.k1}); |
238 | break; |
239 | } |
240 | idx3.k3 = k; |
241 | data3.push_back(value1); |
242 | data3.push_back(value2); |
243 | data3.push_back(*curr); |
244 | NEXT_ELEM; |
245 | index3.push_back(idx3); |
246 | #undef NEXT_ELEM |
247 | #undef EAT_ZEROS |
248 | } |
249 | } |
250 | col_offset += block_size; |
251 | index3_offset.push_back(index3.size()); |
252 | index_offset.push_back(index.size()); |
253 | } |
254 | DCHECK_EQ(index3_offset.size(), num_blocks); |
255 | DCHECK_EQ(index_offset.size(), num_blocks); |
256 | DCHECK_EQ(3 * index3.size(), data3.size()); |
257 | DCHECK_EQ(index.size(), data.size()); |
258 | } |
259 | |
260 | template <typename T> |
261 | void SparseSlice<T>::Clear() { |
262 | index3_offset.clear(); |
263 | index3.clear(); |
264 | data3.clear(); |
265 | index_offset.clear(); |
266 | index.clear(); |
267 | data.clear(); |
268 | } |
269 | |
270 | using Packet = Eigen::internal::packet_traits<float>::type; |
271 | const int kNumOperands = (sizeof(Packet) / sizeof(float)); |
272 | #define LOAD(x) Eigen::internal::pload<Packet>(x); |
273 | #define EXPAND_BFLOAT_L(x, y) \ |
274 | const auto y = Eigen::internal::pexpand_bf16_l<Packet>(x); |
275 | #define EXPAND_BFLOAT_U(x, y) \ |
276 | const auto y = Eigen::internal::pexpand_bf16_u<Packet>(x); |
277 | #define STORE(x, y) Eigen::internal::pstore<float>(x, y); |
278 | #define FMA(a, b, c, d) d = Eigen::internal::pmadd<Packet>(a, b, c); |
279 | |
280 | ALWAYS_INLINE float ConvertBfloat16ToFloat(const bfloat16* src) { |
281 | float out = 0; |
282 | auto tmp = reinterpret_cast<bfloat16*>(&out); |
283 | #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ |
284 | tmp[0] = *src; |
285 | #else |
286 | tmp[1] = *src; |
287 | #endif |
288 | return out; |
289 | } |
290 | |
291 | ALWAYS_INLINE Packet ConvertFourBfloat16ToFloat(const bfloat16* src) { |
292 | return Eigen::internal::pload4bf16<Packet>( |
293 | reinterpret_cast<const float*>(src)); |
294 | } |
295 | |
296 | ALWAYS_INLINE Packet ConvertTwoBfloat16ToFloat(const bfloat16* src) { |
297 | return Eigen::internal::pload2bf16<Packet>( |
298 | reinterpret_cast<const float*>(src)); |
299 | } |
300 | |
301 | ALWAYS_INLINE void ScalarMulAdd(const float a, const float** inp, float** out) { |
302 | **out += a * **inp; |
303 | ++*inp; |
304 | ++*out; |
305 | } |
306 | |
307 | ALWAYS_INLINE void ScalarMulAdd(const float a, const bfloat16** inp, |
308 | float** out) { |
309 | float inp_f = ConvertBfloat16ToFloat(*inp); |
310 | **out += a * inp_f; |
311 | ++*inp; |
312 | ++*out; |
313 | } |
314 | ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2, |
315 | const float a3, const bfloat16** inp1, |
316 | const bfloat16** inp2, |
317 | const bfloat16** inp3, float** out) { |
318 | float inp1_f = ConvertBfloat16ToFloat(*inp1); |
319 | float inp2_f = ConvertBfloat16ToFloat(*inp2); |
320 | float inp3_f = ConvertBfloat16ToFloat(*inp3); |
321 | **out += a1 * inp1_f + a2 * inp2_f + a3 * inp3_f; |
322 | ++*out; |
323 | ++*inp1; |
324 | ++*inp2; |
325 | ++*inp3; |
326 | } |
327 | |
328 | ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2, |
329 | const float a3, const float** inp1, |
330 | const float** inp2, const float** inp3, |
331 | float** out) { |
332 | **out += a1 * **inp1 + a2 * **inp2 + a3 * **inp3; |
333 | ++*out; |
334 | ++*inp1; |
335 | ++*inp2; |
336 | ++*inp3; |
337 | } |
338 | |
339 | ALWAYS_INLINE void LoadSingleScalar(const bfloat16** data, Packet* l) { |
340 | auto tmp = ConvertBfloat16ToFloat(*data); |
341 | *l = Eigen::internal::pset1<Packet>(tmp); |
342 | ++*data; |
343 | } |
344 | |
345 | ALWAYS_INLINE void LoadTwoScalars(const bfloat16** data, Packet* l1, |
346 | Packet* l2) { |
347 | if (kNumOperands >= 2) { |
348 | auto tmp = ConvertTwoBfloat16ToFloat(*data); |
349 | *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp); |
350 | *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp); |
351 | *data += 2; |
352 | } else { |
353 | LoadSingleScalar(data, l1); |
354 | LoadSingleScalar(data, l2); |
355 | } |
356 | } |
357 | |
358 | ALWAYS_INLINE void LoadFourScalars(const bfloat16** data, Packet* l1, |
359 | Packet* l2, Packet* l3, Packet* l4) { |
360 | if (kNumOperands >= 4) { |
361 | auto tmp = ConvertFourBfloat16ToFloat(*data); |
362 | *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp); |
363 | *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp); |
364 | *l3 = Eigen::internal::pbroadcast_third<Packet>(tmp); |
365 | *l4 = Eigen::internal::pbroadcast_fourth<Packet>(tmp); |
366 | *data += 4; |
367 | } else { |
368 | LoadTwoScalars(data, l1, l2); |
369 | LoadTwoScalars(data, l3, l4); |
370 | } |
371 | } |
372 | |
373 | ALWAYS_INLINE void LoadSingleScalar(const float** data, Packet* l) { |
374 | *l = Eigen::internal::pload1<Packet>(*data); |
375 | ++(*data); |
376 | } |
377 | |
378 | ALWAYS_INLINE void LoadTwoScalars(const float** data, Packet* l1, Packet* l2) { |
379 | LoadSingleScalar(data, l1); |
380 | LoadSingleScalar(data, l2); |
381 | } |
382 | |
383 | ALWAYS_INLINE void LoadFourScalars(const float** data, Packet* l1, Packet* l2, |
384 | Packet* l3, Packet* l4) { |
385 | LoadTwoScalars(data, l1, l2); |
386 | LoadTwoScalars(data, l3, l4); |
387 | } |
388 | |
389 | template <typename T> |
390 | ALWAYS_INLINE void LoadThreeScalars(const T** data, Packet* l1, Packet* l2, |
391 | Packet* l3) { |
392 | LoadTwoScalars(data, l1, l2); |
393 | LoadSingleScalar(data, l3); |
394 | } |
395 | |
396 | template <typename T> |
397 | ALWAYS_INLINE void LoadSixScalars(const T** data, Packet* l1, Packet* l2, |
398 | Packet* l3, Packet* l4, Packet* l5, |
399 | Packet* l6) { |
400 | LoadFourScalars(data, l1, l2, l3, l4); |
401 | LoadTwoScalars(data, l5, l6); |
402 | } |
403 | |
404 | // Vectorized version of ScalarMulAdd. |
405 | ALWAYS_INLINE void MulAdd(const Packet a, const bfloat16** binp, float** out) { |
406 | auto inp = reinterpret_cast<const float*>(*binp); |
407 | const auto b = LOAD(inp); |
408 | EXPAND_BFLOAT_L(b, b_0); |
409 | EXPAND_BFLOAT_U(b, b_1); |
410 | *binp += 2 * kNumOperands; |
411 | auto c1 = LOAD(*out); |
412 | auto c2 = LOAD(*out + kNumOperands); |
413 | FMA(a, b_0, c1, c1); |
414 | FMA(a, b_1, c2, c2); |
415 | STORE(*out, c1); |
416 | STORE(*out + kNumOperands, c2); |
417 | *out += 2 * kNumOperands; |
418 | } |
419 | |
420 | // Vectorized version of ScalarMulAdd3Way. |
421 | ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3, |
422 | const bfloat16** binp1, const bfloat16** binp2, |
423 | const bfloat16** binp3, float** out) { |
424 | auto inp1 = reinterpret_cast<const float*>(*binp1); |
425 | auto inp2 = reinterpret_cast<const float*>(*binp2); |
426 | auto inp3 = reinterpret_cast<const float*>(*binp3); |
427 | auto c1 = LOAD(*out); |
428 | auto c2 = LOAD(*out + kNumOperands); |
429 | const auto b1 = LOAD(inp1); |
430 | EXPAND_BFLOAT_L(b1, b1_0); |
431 | EXPAND_BFLOAT_U(b1, b1_1); |
432 | *binp1 += 2 * kNumOperands; |
433 | const auto b2 = LOAD(inp2); |
434 | EXPAND_BFLOAT_L(b2, b2_0); |
435 | EXPAND_BFLOAT_U(b2, b2_1); |
436 | *binp2 += 2 * kNumOperands; |
437 | const auto b3 = LOAD(inp3); |
438 | EXPAND_BFLOAT_L(b3, b3_0); |
439 | EXPAND_BFLOAT_U(b3, b3_1); |
440 | *binp3 += 2 * kNumOperands; |
441 | FMA(a1, b1_0, c1, c1); |
442 | FMA(a1, b1_1, c2, c2); |
443 | FMA(a2, b2_0, c1, c1); |
444 | FMA(a2, b2_1, c2, c2); |
445 | FMA(a3, b3_0, c1, c1); |
446 | FMA(a3, b3_1, c2, c2); |
447 | STORE(*out, c1); |
448 | STORE(*out + kNumOperands, c2); |
449 | *out += 2 * kNumOperands; |
450 | } |
451 | |
452 | // Unroll MulAdd3Way for two iterations |
453 | ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2, |
454 | const Packet a3, const bfloat16** binp1, |
455 | const bfloat16** binp2, const bfloat16** binp3, |
456 | float** out) { |
457 | auto inp1 = reinterpret_cast<const float*>(*binp1); |
458 | auto inp2 = reinterpret_cast<const float*>(*binp2); |
459 | auto inp3 = reinterpret_cast<const float*>(*binp3); |
460 | auto c1 = LOAD(*out); |
461 | auto c2 = LOAD(*out + kNumOperands); |
462 | const auto b1 = LOAD(inp1); |
463 | const auto b2 = LOAD(inp2); |
464 | const auto b3 = LOAD(inp3); |
465 | |
466 | EXPAND_BFLOAT_L(b1, b1_0); |
467 | EXPAND_BFLOAT_U(b1, b1_1); |
468 | EXPAND_BFLOAT_L(b2, b2_0); |
469 | EXPAND_BFLOAT_U(b2, b2_1); |
470 | EXPAND_BFLOAT_L(b3, b3_0); |
471 | EXPAND_BFLOAT_U(b3, b3_1); |
472 | auto c3 = LOAD(*out + 2 * kNumOperands); |
473 | auto c4 = LOAD(*out + 3 * kNumOperands); |
474 | const auto b4 = LOAD(inp1 + kNumOperands); |
475 | const auto b5 = LOAD(inp2 + kNumOperands); |
476 | const auto b6 = LOAD(inp3 + kNumOperands); |
477 | |
478 | EXPAND_BFLOAT_L(b4, b4_0); |
479 | EXPAND_BFLOAT_U(b4, b4_1); |
480 | EXPAND_BFLOAT_L(b5, b5_0); |
481 | EXPAND_BFLOAT_U(b5, b5_1); |
482 | EXPAND_BFLOAT_L(b6, b6_0); |
483 | EXPAND_BFLOAT_U(b6, b6_1); |
484 | |
485 | FMA(a1, b1_0, c1, c1); |
486 | FMA(a1, b1_1, c2, c2); |
487 | FMA(a1, b4_0, c3, c3); |
488 | FMA(a1, b4_1, c4, c4); |
489 | FMA(a2, b2_0, c1, c1); |
490 | FMA(a2, b2_1, c2, c2); |
491 | FMA(a2, b5_0, c3, c3); |
492 | FMA(a2, b5_1, c4, c4); |
493 | FMA(a3, b3_0, c1, c1); |
494 | FMA(a3, b3_1, c2, c2); |
495 | FMA(a3, b6_0, c3, c3); |
496 | FMA(a3, b6_1, c4, c4); |
497 | STORE(*out, c1); |
498 | STORE(*out + kNumOperands, c2); |
499 | STORE(*out + 2 * kNumOperands, c3); |
500 | STORE(*out + 3 * kNumOperands, c4); |
501 | *out += 4 * kNumOperands; |
502 | *binp1 += 4 * kNumOperands; |
503 | *binp2 += 4 * kNumOperands; |
504 | *binp3 += 4 * kNumOperands; |
505 | } |
506 | |
507 | // Apply MulAdd3Way on 128 operands. |
508 | ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2, |
509 | const Packet a3, const bfloat16** inp1, |
510 | const bfloat16** inp2, const bfloat16** inp3, |
511 | float** out) { |
512 | for (int k = 0; k < 128 / (8 * kNumOperands); ++k) { |
513 | TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
514 | TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
515 | } |
516 | } |
517 | |
518 | // Vectorized version of ScalarMulAdd |
519 | ALWAYS_INLINE void MulAdd(const Packet a, const float** inp, float** out) { |
520 | const auto b = LOAD(*inp); |
521 | *inp += kNumOperands; |
522 | auto c = LOAD(*out); |
523 | FMA(a, b, c, c); |
524 | STORE(*out, c); |
525 | *out += kNumOperands; |
526 | } |
527 | |
528 | // Vectorized version of ScalarMulAdd3Way |
529 | ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3, |
530 | const float** inp1, const float** inp2, |
531 | const float** inp3, float** out) { |
532 | auto c = LOAD(*out); |
533 | const auto b1 = LOAD(*inp1); |
534 | *inp1 += kNumOperands; |
535 | const auto b2 = LOAD(*inp2); |
536 | *inp2 += kNumOperands; |
537 | const auto b3 = LOAD(*inp3); |
538 | *inp3 += kNumOperands; |
539 | FMA(a1, b1, c, c); |
540 | FMA(a2, b2, c, c); |
541 | FMA(a3, b3, c, c); |
542 | STORE(*out, c); |
543 | *out += kNumOperands; |
544 | } |
545 | |
546 | // Unroll MulAdd3Way for two iterations |
547 | ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2, |
548 | const Packet a3, const float** inp1, |
549 | const float** inp2, const float** inp3, |
550 | float** out) { |
551 | auto c1 = LOAD(*out); |
552 | const auto b1 = LOAD(*inp1); |
553 | const auto b2 = LOAD(*inp2); |
554 | const auto b3 = LOAD(*inp3); |
555 | |
556 | auto c2 = LOAD(*out + kNumOperands); |
557 | const auto b4 = LOAD(*inp1 + kNumOperands); |
558 | const auto b5 = LOAD(*inp2 + kNumOperands); |
559 | const auto b6 = LOAD(*inp3 + kNumOperands); |
560 | |
561 | FMA(a1, b1, c1, c1); |
562 | FMA(a1, b4, c2, c2); |
563 | FMA(a2, b2, c1, c1); |
564 | FMA(a2, b5, c2, c2); |
565 | FMA(a3, b3, c1, c1); |
566 | FMA(a3, b6, c2, c2); |
567 | STORE(*out, c1); |
568 | STORE(*out + kNumOperands, c2); |
569 | *out += 2 * kNumOperands; |
570 | *inp1 += 2 * kNumOperands; |
571 | *inp2 += 2 * kNumOperands; |
572 | *inp3 += 2 * kNumOperands; |
573 | } |
574 | |
575 | // Unroll MulAdd3Way for four iterations |
576 | ALWAYS_INLINE void FourMulAdd3Way(const Packet a1, const Packet a2, |
577 | const Packet a3, const float** inp1, |
578 | const float** inp2, const float** inp3, |
579 | float** out) { |
580 | TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
581 | TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
582 | } |
583 | |
584 | // Apply MulAdd3Way on 128 operands. |
585 | ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2, |
586 | const Packet a3, const float** inp1, |
587 | const float** inp2, const float** inp3, |
588 | float** out) { |
589 | if (kNumOperands == 8) { |
590 | FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
591 | FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
592 | FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
593 | FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
594 | } else { |
595 | DCHECK_LE(4 * kNumOperands, 128); |
596 | for (int i = 0; i < 128 / (4 * kNumOperands); ++i) { |
597 | MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
598 | MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
599 | MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
600 | MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
601 | } |
602 | } |
603 | } |
604 | // Computes product of "left_slices" with "num_cols" columns of "right", and |
605 | // stores the output in *"output". |
606 | // Note that left_slices is a list of SparseSlices, which are conceptually |
607 | // assumed to be concatenated along the column dimension. Also each SparseSlice |
608 | // is encoded as a list of blocks with upto N columns. See SparseSlice for more |
609 | // details. |
610 | template <typename TL, typename TR, int Cols> |
611 | inline void GEPP( |
612 | const std::vector<SparseSlice<TL>*>& left_slices, |
613 | const Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>, |
614 | Eigen::Aligned>& right, |
615 | const int num_cols, Matrix* output) { |
616 | const int cols = (Cols == -1) ? num_cols : Cols; |
617 | DCHECK_EQ(num_cols, cols); |
618 | const int right_num_cols = right.dimension(1); |
619 | const int output_num_cols = output->dimension(1); |
620 | static const int kNumOperandsR = kNumOperands * sizeof(float) / sizeof(TR); |
621 | const int cols_mod = cols % kNumOperandsR; |
622 | int k_offset = 0; |
623 | // Pre-compute pointers for output matrix. |
624 | float* out_ptrs[M]; |
625 | float* const out_start = &(*output)(0, 0); |
626 | for (int j = 0; j < M; ++j) { |
627 | out_ptrs[j] = out_start + output_num_cols * j; |
628 | } |
629 | for (const auto* left_slice : left_slices) { |
630 | const auto& left = *left_slice; |
631 | const auto* data3 = (!left.data3.empty()) ? &left.data3[0] : nullptr; |
632 | const auto* data = (!left.data.empty()) ? &left.data[0] : nullptr; |
633 | const int num_blocks = left.index3_offset.size(); |
634 | int begin3 = 0; |
635 | int begin = 0; |
636 | for (int i = 0; i < num_blocks; ++i) { |
637 | // Pre-compute pointers for right matrix |
638 | const TR* right_ptrs[K]; |
639 | const auto* const right_start = &right(k_offset, 0); |
640 | DCHECK_LT(k_offset, right.dimension(0)); |
641 | for (int j = 0; j < K; ++j) { |
642 | right_ptrs[j] = right_start + right_num_cols * j; |
643 | } |
644 | |
645 | const int end3 = left.index3_offset[i]; |
646 | int j = begin3; |
647 | // Loop unrolled for 2 iterations. |
648 | for (; j + 1 < end3; j += 2) { |
649 | Packet l1, l2, l3, nl1, nl2, nl3; |
650 | LoadSixScalars(&data3, &l1, &l2, &l3, &nl1, &nl2, &nl3); |
651 | const auto& index = left.index3[j]; |
652 | const auto& nindex = left.index3[j + 1]; |
653 | float* out = out_ptrs[index.m]; |
654 | float* nout = out_ptrs[nindex.m]; |
655 | const auto* r1 = right_ptrs[index.k1]; |
656 | const auto* r2 = right_ptrs[index.k2]; |
657 | const auto* r3 = right_ptrs[index.k3]; |
658 | |
659 | const auto* nr1 = right_ptrs[nindex.k1]; |
660 | const auto* nr2 = right_ptrs[nindex.k2]; |
661 | const auto* nr3 = right_ptrs[nindex.k3]; |
662 | if (cols == 128) { |
663 | MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out); |
664 | MulAdd3Way128(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout); |
665 | } else { |
666 | for (int n = 0; n < cols / kNumOperandsR; ++n) { |
667 | MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out); |
668 | MulAdd3Way(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout); |
669 | } |
670 | |
671 | const float sl1 = Eigen::internal::pfirst<Packet>(l1); |
672 | const float sl2 = Eigen::internal::pfirst<Packet>(l2); |
673 | const float sl3 = Eigen::internal::pfirst<Packet>(l3); |
674 | const float nsl1 = Eigen::internal::pfirst<Packet>(nl1); |
675 | const float nsl2 = Eigen::internal::pfirst<Packet>(nl2); |
676 | const float nsl3 = Eigen::internal::pfirst<Packet>(nl3); |
677 | for (int k = 0; k < cols_mod; ++k) { |
678 | ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out); |
679 | ScalarMulAdd3Way(nsl1, nsl2, nsl3, &nr1, &nr2, &nr3, &nout); |
680 | } |
681 | } |
682 | } |
683 | if (j < end3) { |
684 | Packet l1, l2, l3; |
685 | LoadThreeScalars(&data3, &l1, &l2, &l3); |
686 | |
687 | const auto& index = left.index3[j]; |
688 | float* out = out_ptrs[index.m]; |
689 | const auto* r1 = right_ptrs[index.k1]; |
690 | const auto* r2 = right_ptrs[index.k2]; |
691 | const auto* r3 = right_ptrs[index.k3]; |
692 | if (cols == 128) { |
693 | MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out); |
694 | } else { |
695 | for (int n = 0; n < cols / kNumOperandsR; ++n) { |
696 | MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out); |
697 | } |
698 | const float sl1 = Eigen::internal::pfirst<Packet>(l1); |
699 | const float sl2 = Eigen::internal::pfirst<Packet>(l2); |
700 | const float sl3 = Eigen::internal::pfirst<Packet>(l3); |
701 | for (int k = 0; k < cols_mod; ++k) { |
702 | ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out); |
703 | } |
704 | } |
705 | } |
706 | begin3 = end3; |
707 | int end = left.index_offset[i]; |
708 | // Loop unrolled for 4 iterations. |
709 | j = begin; |
710 | for (; j + 3 < end; j += 4) { |
711 | Packet l, nl, n2l, n3l; |
712 | LoadFourScalars(&data, &l, &nl, &n2l, &n3l); |
713 | |
714 | const auto& index = left.index[j]; |
715 | const auto& nindex = left.index[j + 1]; |
716 | const auto& n2index = left.index[j + 2]; |
717 | const auto& n3index = left.index[j + 3]; |
718 | const auto* r = right_ptrs[index.k]; |
719 | const auto* nr = right_ptrs[nindex.k]; |
720 | const auto* n2r = right_ptrs[n2index.k]; |
721 | const auto* n3r = right_ptrs[n3index.k]; |
722 | float* out = out_ptrs[index.m]; |
723 | float* nout = out_ptrs[nindex.m]; |
724 | float* n2out = out_ptrs[n2index.m]; |
725 | float* n3out = out_ptrs[n3index.m]; |
726 | |
727 | for (int n = 0; n < cols / kNumOperandsR; ++n) { |
728 | MulAdd(l, &r, &out); |
729 | MulAdd(nl, &nr, &nout); |
730 | MulAdd(n2l, &n2r, &n2out); |
731 | MulAdd(n3l, &n3r, &n3out); |
732 | } |
733 | |
734 | const float sl1 = Eigen::internal::pfirst<Packet>(l); |
735 | const float sl2 = Eigen::internal::pfirst<Packet>(nl); |
736 | const float sl3 = Eigen::internal::pfirst<Packet>(n2l); |
737 | const float sl4 = Eigen::internal::pfirst<Packet>(n3l); |
738 | for (int k = 0; k < cols_mod; ++k) { |
739 | ScalarMulAdd(sl1, &r, &out); |
740 | ScalarMulAdd(sl2, &nr, &nout); |
741 | ScalarMulAdd(sl3, &n2r, &n2out); |
742 | ScalarMulAdd(sl4, &n3r, &n3out); |
743 | } |
744 | } |
745 | while (j < end) { |
746 | Packet l; |
747 | LoadSingleScalar(&data, &l); |
748 | const auto& index = left.index[j]; |
749 | const auto* r = right_ptrs[index.k]; |
750 | float* out = out_ptrs[index.m]; |
751 | for (int n = 0; n < cols / kNumOperandsR; ++n) { |
752 | MulAdd(l, &r, &out); |
753 | } |
754 | const float sl = Eigen::internal::pfirst<Packet>(l); |
755 | for (int k = 0; k < cols_mod; ++k) { |
756 | ScalarMulAdd(sl, &r, &out); |
757 | } |
758 | j++; |
759 | } |
760 | k_offset += left.block_size; |
761 | begin = end; |
762 | } |
763 | } |
764 | } |
765 | |
766 | #undef LOAD |
767 | #undef EXPAND_BFLOAT_L |
768 | #undef EXPAND_BFLOAT_U |
769 | #undef STORE |
770 | #undef FMA |
771 | |
772 | } // namespace |
773 | |
774 | template <typename TL, typename TR> |
775 | class SparseMatMul { |
776 | using MatrixL = BasicMatrix<TL>; |
777 | using MatrixR = BasicMatrix<TR>; |
778 | using ConstMatrixMapL = BasicMatrixMap<const TL>; |
779 | using ConstMatrixMapR = BasicMatrixMap<const TR>; |
780 | using MatrixMapR = BasicMatrixMap<TR>; |
781 | |
782 | public: |
783 | // Not used; added to match interface of LibxsmmSparseMatMul |
784 | struct TensorInfoCache {}; |
785 | |
786 | // Perform matrix multiplication of "left" and "right", and store the result |
787 | // in *"output". |
788 | public: |
789 | static inline void Compute(TensorInfoCache* cache, |
790 | const ConstMatrixMapL& left, |
791 | const ConstMatrixMapR& right, bool transpose_left, |
792 | const DeviceBase::CpuWorkerThreads* thread_pool, |
793 | bool transpose_output, MatrixMap* output); |
794 | |
795 | private: |
796 | // Computes multiplication of left and num_cols columns of right, and stores |
797 | // the output block in *"output" at offsets "output_row_offset" and |
798 | // "output_col_offset". If assign is true, assigns the value to that block, |
799 | // else adds the values to the existing values. |
800 | static inline void ComputeOutputBlock( |
801 | const std::vector<SparseSlice<TL>*>& left, const ConstMatrixMapR& right, |
802 | int num_cols, int output_row_offset, int output_col_offset, bool assign, |
803 | bool transpose_output, MatrixMap* output); |
804 | |
805 | // Encodes "mat" using a sparse representation and stores that in |
806 | // "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and |
807 | // "slice_num_cols", each grid element is converted into a SparseSlice and |
808 | // stored in mat_slices. "slice_block_size" is used to perform further column |
809 | // blocking of each slice. |
810 | static inline std::unique_ptr<BlockingCounter> CreateSparseSlices( |
811 | const ConstMatrixMapL& mat, bool transpose, int slice_num_rows, |
812 | int slice_block_size, int slice_num_cols, |
813 | std::vector<std::vector<SparseSlice<TL>*>>* mat_slices, |
814 | const DeviceBase::CpuWorkerThreads* thread_pool); |
815 | |
816 | // This function chops "mat" along column dimension into pieces with at most N |
817 | // columns, and concatenates the pieces one after the other in "buffer". It |
818 | // returns the list of the pieces in "slices". It returns a BlockingCounter |
819 | // which should be used to wait for the shuffle operations to complete. |
820 | static inline std::unique_ptr<BlockingCounter> CreateDenseSlices( |
821 | const ConstMatrixMapR& mat, int row_start, int num_rows, int col_start, |
822 | int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool, |
823 | MatrixR* buffer, std::vector<ConstMatrixMapR*>* slices); |
824 | |
825 | // Helper function for CreateDenseSlices to move the data around. It returns a |
826 | // BlockingCounter which should be used to wait for the shuffle operations to |
827 | // complete. |
828 | static inline BlockingCounter* ShuffleMatrix( |
829 | const ConstMatrixMapR& mat, int slice_row_start, int slice_num_rows, |
830 | int slice_col_start, int slice_num_cols, const int N, |
831 | const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer); |
832 | |
833 | // Helper function for CreateDenseSlices to create slices. |
834 | static inline void SliceMatrix(const MatrixR& mat, const int num_rows, |
835 | const int num_slices, |
836 | std::vector<ConstMatrixMapR*>* slices); |
837 | |
838 | // Heuristics to compute various block sizes. |
839 | // KR, NR: block sizes for "right". We run blocking iterations that operate on |
840 | // matrices with at most this size. |
841 | // KL: grid size along the column dimension used while encoding left. |
842 | // IB, JB: number of left and right slices to multiply together. This is used |
843 | // for ordering different ComputeBlockOutput operations inside each blocking |
844 | // iteration so as to potentially reduce the working set size. |
845 | static inline void ComputeBlockSizes(const ConstMatrixMapL& left, |
846 | const ConstMatrixMapR& right, |
847 | bool transpose_left, int num_threads, |
848 | int* KR, int* NR, int* KL, int* JB, |
849 | int* IB); |
850 | |
851 | TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMul); |
852 | }; |
853 | |
854 | #ifdef TENSORFLOW_USE_LIBXSMM |
855 | template <typename TL, typename TR> |
856 | class LibxsmmSparseMatMul { |
857 | using MatrixL = BasicMatrix<TL>; |
858 | using MatrixR = BasicMatrix<TR>; |
859 | using ConstMatrixMapL = BasicMatrixMap<const TL>; |
860 | using ConstMatrixMapR = BasicMatrixMap<const TR>; |
861 | using MatrixMapR = BasicMatrixMap<TR>; |
862 | |
863 | public: |
864 | // This structure contains a set of libxsmm kernels for sizes that have been |
865 | // encountered previously by this operator so that libxsmm does not need to |
866 | // reallocate its scratchpad memory each time (which hurts performance |
867 | // substantially). |
868 | struct TensorInfoCache { |
869 | struct TensorInfoCacheEntry { |
870 | // Parameters for kernel |
871 | int M; |
872 | int K; |
873 | int N; |
874 | int max_threads; |
875 | // libxsmm handle and matrix data |
876 | libxsmm_spmdm_handle handle; |
877 | libxsmm_CSR_sparseslice* output_csr; |
878 | // Chain to non-libxsmm implementation's cache in case that ever becomes |
879 | // useful (it is an empty struct right now) |
880 | typename SparseMatMul<TL, TR>::TensorInfoCache |
881 | non_libxsmm_cache; // Currently not used |
882 | }; |
883 | // protects entries; invariant: entries is a valid std::multimap |
884 | tensorflow::mutex lock; |
885 | // Because there could be multiple matrix multiplies with the same sizes |
886 | // going on at the same time, we need to allow multiple cache entries for a |
887 | // given set of parameters. Taking and returning entries is used to make |
888 | // sure the same cache entry is not used from two threads at a time. |
889 | std::multimap<std::tuple<int, int, int, int>, |
890 | std::unique_ptr<TensorInfoCacheEntry>> |
891 | entries TF_GUARDED_BY(lock); |
892 | |
893 | TensorInfoCache() : lock(), entries() {} |
894 | // Look up and remove first entry with these parameters, creating one if |
895 | // there isn't one |
896 | std::unique_ptr<TensorInfoCacheEntry> take_cache_entry(int M, int K, int N, |
897 | int max_threads) |
898 | TF_LOCKS_EXCLUDED(lock) { |
899 | tensorflow::mutex_lock ml(lock); |
900 | auto key = std::make_tuple(M, K, N, max_threads); |
901 | auto it = entries.find(key); |
902 | if (it != entries.end()) { |
903 | auto val = std::move(it->second); |
904 | entries.erase(it); |
905 | return val; |
906 | } else { |
907 | std::unique_ptr<TensorInfoCacheEntry> e{ |
908 | new TensorInfoCacheEntry{M, K, N, max_threads, {}, nullptr}}; |
909 | // setup scoped allocator, which uses cpu_allocator() for this scope |
910 | const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator; |
911 | libxsmm_spmdm_init(M, N, K, max_threads, &e->handle, &e->output_csr); |
912 | return e; |
913 | } |
914 | } |
915 | // Add a cache entry with certain parameters |
916 | void return_cache_entry(std::unique_ptr<TensorInfoCacheEntry> e) |
917 | TF_LOCKS_EXCLUDED(lock) { |
918 | tensorflow::mutex_lock ml(lock); |
919 | auto key = std::make_tuple(e->M, e->K, e->N, e->max_threads); |
920 | entries.insert(std::make_pair(key, std::move(e))); |
921 | } |
922 | ~TensorInfoCache() { |
923 | tensorflow::mutex_lock ml(lock); |
924 | for (auto& p : entries) { |
925 | libxsmm_spmdm_destroy(&p.second->handle); |
926 | } |
927 | entries.clear(); |
928 | } |
929 | |
930 | private: |
931 | TF_DISALLOW_COPY_AND_ASSIGN(TensorInfoCache); |
932 | }; |
933 | |
934 | // Perform matrix multiplication of "left" and "right", and store the result |
935 | // in *"output". |
936 | public: |
937 | static inline void Compute(TensorInfoCache* cache, |
938 | const ConstMatrixMapL& left, |
939 | const ConstMatrixMapR& right, bool transpose_left, |
940 | const DeviceBase::CpuWorkerThreads* thread_pool, |
941 | bool transpose_output, MatrixMap* output); |
942 | |
943 | private: |
944 | TF_DISALLOW_COPY_AND_ASSIGN(LibxsmmSparseMatMul); |
945 | }; |
946 | #endif |
947 | |
948 | template <typename TL, typename TR, |
949 | template <typename TL2, typename TR2> class DoMatMul> |
950 | class SparseMatMulOp : public OpKernel { |
951 | using MatrixR = BasicMatrix<TR>; |
952 | using ConstMatrixMapR = BasicMatrixMap<const TR>; |
953 | |
954 | public: |
955 | explicit SparseMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
956 | OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a" , &transpose_a_)); |
957 | OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b" , &transpose_b_)); |
958 | OP_REQUIRES_OK(ctx, ctx->GetAttr("a_is_sparse" , &a_is_sparse_)); |
959 | OP_REQUIRES_OK(ctx, ctx->GetAttr("b_is_sparse" , &b_is_sparse_)); |
960 | } |
961 | |
962 | void Compute(OpKernelContext* ctx) override { |
963 | const Tensor& a = ctx->input(0); |
964 | const Tensor& b = ctx->input(1); |
965 | OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()), |
966 | errors::InvalidArgument("a is not a matrix" )); |
967 | OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()), |
968 | errors::InvalidArgument("b is not a matrix" )); |
969 | |
970 | const int m = transpose_a_ ? a.dim_size(1) : a.dim_size(0); |
971 | const int k = transpose_a_ ? a.dim_size(0) : a.dim_size(1); |
972 | const int n = transpose_b_ ? b.dim_size(0) : b.dim_size(1); |
973 | const int k2 = transpose_b_ ? b.dim_size(1) : b.dim_size(0); |
974 | |
975 | OP_REQUIRES(ctx, k == k2, |
976 | errors::InvalidArgument( |
977 | "Matrix size incompatible: a: " , a.shape().DebugString(), |
978 | ", b: " , b.shape().DebugString())); |
979 | OP_REQUIRES(ctx, m >= 0 && n >= 0 && k >= 0, |
980 | errors::InvalidArgument( |
981 | "Matrix dimensions cannot be negative: a: " , |
982 | a.shape().DebugString(), ", b: " , b.shape().DebugString())); |
983 | Tensor* output = nullptr; |
984 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({m, n}), &output)); |
985 | |
986 | // Return early if at least one of the output dimension size is 0. |
987 | if (m == 0 || n == 0) { |
988 | return; |
989 | } |
990 | |
991 | if (k == 0) { |
992 | // If the inner dimension k in the matrix multiplication is zero, we fill |
993 | // the output with zeros. |
994 | functor::SetZeroFunctor<CPUDevice, float> f; |
995 | f(ctx->eigen_device<CPUDevice>(), output->flat<float>()); |
996 | return; |
997 | } |
998 | |
999 | auto out = output->matrix<float>(); |
1000 | |
1001 | std::unique_ptr<Tensor> a_float; |
1002 | std::unique_ptr<Tensor> b_float; |
1003 | if (!a_is_sparse_ && !b_is_sparse_) { |
1004 | auto left = &a; |
1005 | auto right = &b; |
1006 | // TODO(agarwal): multi-thread the conversions from bfloat16 to float. |
1007 | if (std::is_same<TL, bfloat16>::value) { |
1008 | a_float.reset(new Tensor(DT_FLOAT, a.shape())); |
1009 | BFloat16ToFloat(a.flat<bfloat16>().data(), |
1010 | a_float->flat<float>().data(), a.NumElements()); |
1011 | left = a_float.get(); |
1012 | } |
1013 | if (std::is_same<TR, bfloat16>::value) { |
1014 | b_float.reset(new Tensor(DT_FLOAT, b.shape())); |
1015 | BFloat16ToFloat(b.flat<bfloat16>().data(), |
1016 | b_float->flat<float>().data(), b.NumElements()); |
1017 | right = b_float.get(); |
1018 | } |
1019 | Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair; |
1020 | dim_pair[0].first = transpose_a_ ? 0 : 1; |
1021 | dim_pair[0].second = transpose_b_ ? 1 : 0; |
1022 | |
1023 | out.device(ctx->template eigen_device<CPUDevice>()) = |
1024 | left->matrix<float>().contract(right->matrix<float>(), dim_pair); |
1025 | return; |
1026 | } |
1027 | |
1028 | auto left = &a; |
1029 | auto right = &b; |
1030 | bool transpose_output = false; |
1031 | bool transpose_a = transpose_a_; |
1032 | bool transpose_b = transpose_b_; |
1033 | if (!a_is_sparse_) { |
1034 | // Swap the order of multiplications using the identity: |
1035 | // A * B = (B' * A')'. |
1036 | std::swap(left, right); |
1037 | std::swap(transpose_a, transpose_b); |
1038 | transpose_a = !transpose_a; |
1039 | transpose_b = !transpose_b; |
1040 | transpose_output = !transpose_output; |
1041 | } |
1042 | |
1043 | std::unique_ptr<Tensor> right_tr; |
1044 | if (transpose_b) { |
1045 | // TODO(agarwal): avoid transposing the matrix here and directly handle |
1046 | // transpose in CreateDenseSlices. |
1047 | OP_REQUIRES(ctx, right->dim_size(0) != 0, |
1048 | errors::InvalidArgument("b has an entry 0 in it's shape." )); |
1049 | OP_REQUIRES(ctx, right->dim_size(1) != 0, |
1050 | errors::InvalidArgument("b has an entry 0 in it's shape." )); |
1051 | right_tr.reset( |
1052 | new Tensor(right->dtype(), |
1053 | TensorShape({right->dim_size(1), right->dim_size(0)}))); |
1054 | |
1055 | const auto perm = dsizes_10(); |
1056 | if (transpose_output) { |
1057 | right_tr->matrix<TL>().device(ctx->template eigen_device<CPUDevice>()) = |
1058 | right->matrix<TL>().shuffle(perm); |
1059 | } else { |
1060 | right_tr->matrix<TR>().device(ctx->template eigen_device<CPUDevice>()) = |
1061 | right->matrix<TR>().shuffle(perm); |
1062 | } |
1063 | right = right_tr.get(); |
1064 | } |
1065 | |
1066 | if (transpose_output) { |
1067 | DoMatMul<TR, TL>::Compute(&this->cache_tr_, left->matrix<TR>(), |
1068 | right->matrix<TL>(), transpose_a, |
1069 | ctx->device()->tensorflow_cpu_worker_threads(), |
1070 | transpose_output, &out); |
1071 | } else { |
1072 | DoMatMul<TL, TR>::Compute(&this->cache_nt_, left->matrix<TL>(), |
1073 | right->matrix<TR>(), transpose_a, |
1074 | ctx->device()->tensorflow_cpu_worker_threads(), |
1075 | transpose_output, &out); |
1076 | } |
1077 | } |
1078 | |
1079 | private: |
1080 | bool transpose_a_; |
1081 | bool transpose_b_; |
1082 | bool a_is_sparse_; |
1083 | bool b_is_sparse_; |
1084 | |
1085 | // Cache for non-transposed-output multiply |
1086 | typename DoMatMul<TL, TR>::TensorInfoCache cache_nt_; |
1087 | // Cache for transposed-output multiply |
1088 | typename DoMatMul<TR, TL>::TensorInfoCache cache_tr_; |
1089 | |
1090 | TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMulOp); |
1091 | }; |
1092 | |
1093 | template <typename TL, typename TR> |
1094 | inline void SparseMatMul<TL, TR>::ComputeOutputBlock( |
1095 | const std::vector<SparseSlice<TL>*>& left, |
1096 | const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right, int num_cols, |
1097 | int output_row_offset, int output_col_offset, bool assign, |
1098 | bool transpose_output, MatrixMap* output) { |
1099 | const auto perm = dsizes_10(); |
1100 | int num_rows = left[0]->num_rows; |
1101 | const int rhs_num_cols = right.dimension(1); |
1102 | DCHECK_LE(num_cols, rhs_num_cols); |
1103 | Matrix out(num_rows, rhs_num_cols); |
1104 | out.setZero(); |
1105 | if (num_cols == N) { |
1106 | GEPP<TL, TR, N>(left, right, num_cols, &out); |
1107 | } else { |
1108 | GEPP<TL, TR, -1>(left, right, num_cols, &out); |
1109 | } |
1110 | if (!assign) { |
1111 | const DSizes begin(output_row_offset, output_col_offset); |
1112 | const DSizes sizes(num_rows, num_cols); |
1113 | if (transpose_output) { |
1114 | if (num_cols == rhs_num_cols) { |
1115 | output->shuffle(perm).slice(begin, sizes) += out; |
1116 | } else { |
1117 | const auto zero = dsizes_00(); |
1118 | output->shuffle(perm).slice(begin, sizes) += out.slice(zero, sizes); |
1119 | } |
1120 | } else { |
1121 | if (num_cols == rhs_num_cols) { |
1122 | output->slice(begin, sizes) += out; |
1123 | } else { |
1124 | const auto zero = dsizes_00(); |
1125 | output->slice(begin, sizes) += out.slice(zero, sizes); |
1126 | } |
1127 | } |
1128 | } else { |
1129 | std::unique_ptr<Matrix> out_tr; |
1130 | if (transpose_output) { |
1131 | out_tr.reset(new Matrix(rhs_num_cols, num_rows)); |
1132 | *out_tr = out.shuffle(perm); |
1133 | std::swap(output_row_offset, output_col_offset); |
1134 | std::swap(num_rows, num_cols); |
1135 | } |
1136 | const Matrix& final_out = transpose_output ? *out_tr : out; |
1137 | for (int i = 0; i < num_rows; ++i) { |
1138 | memcpy(&(*output)(output_row_offset + i, output_col_offset), |
1139 | &final_out(i, 0), num_cols * sizeof(float)); |
1140 | } |
1141 | } |
1142 | } |
1143 | |
1144 | template <typename TL, typename TR> |
1145 | inline std::unique_ptr<BlockingCounter> |
1146 | SparseMatMul<TL, TR>::CreateSparseSlices( |
1147 | const typename SparseMatMul<TL, TR>::ConstMatrixMapL& mat, bool transpose, |
1148 | int slice_num_rows, int slice_block_size, int slice_num_cols, |
1149 | std::vector<std::vector<SparseSlice<TL>*>>* mat_slices, |
1150 | const DeviceBase::CpuWorkerThreads* thread_pool) { |
1151 | const int mat_num_rows = transpose ? mat.dimension(1) : mat.dimension(0); |
1152 | const int mat_num_cols = transpose ? mat.dimension(0) : mat.dimension(1); |
1153 | const int num_slices_dim0 = |
1154 | std::max(1, (mat_num_rows + slice_num_rows - 1) / slice_num_rows); |
1155 | const int num_slices_dim1 = |
1156 | std::max(1, (mat_num_cols + slice_num_cols - 1) / slice_num_cols); |
1157 | mat_slices->resize(num_slices_dim0); |
1158 | BlockingCounter* counter = |
1159 | new BlockingCounter(num_slices_dim0 * num_slices_dim1); |
1160 | auto work = [counter, transpose](SparseSlice<TL>* sparse_slice, |
1161 | SparseMatMul<TL, TR>::ConstMatrixMapL* slice, |
1162 | int col_offset) { |
1163 | if (transpose) { |
1164 | sparse_slice->template Initialize<true>(*slice, col_offset); |
1165 | } else { |
1166 | sparse_slice->template Initialize<false>(*slice, col_offset); |
1167 | } |
1168 | delete slice; |
1169 | counter->DecrementCount(); |
1170 | }; |
1171 | for (int i = 0; i < num_slices_dim0; ++i) { |
1172 | (*mat_slices)[i].resize(num_slices_dim1); |
1173 | int num_rows = |
1174 | std::min<int>(slice_num_rows, mat_num_rows - i * slice_num_rows); |
1175 | for (int j = 0; j < num_slices_dim1; ++j) { |
1176 | int num_cols = |
1177 | std::min<int>(slice_num_cols, mat_num_cols - j * slice_num_cols); |
1178 | SparseMatMul<TL, TR>::ConstMatrixMapL* slice = nullptr; |
1179 | if (transpose) { |
1180 | slice = new SparseMatMul<TL, TR>::ConstMatrixMapL( |
1181 | &mat(0, i * slice_num_rows), mat.dimensions()); |
1182 | } else { |
1183 | DSizes d(num_rows, mat_num_cols); |
1184 | slice = new SparseMatMul<TL, TR>::ConstMatrixMapL( |
1185 | &mat(i * slice_num_rows, 0), d); |
1186 | } |
1187 | auto* sparse_slice = |
1188 | new SparseSlice<TL>(num_rows, num_cols, slice_block_size); |
1189 | (*mat_slices)[i][j] = sparse_slice; |
1190 | thread_pool->workers->Schedule( |
1191 | [=]() { work(sparse_slice, slice, slice_num_cols * j); }); |
1192 | } |
1193 | } |
1194 | return std::unique_ptr<BlockingCounter>(counter); |
1195 | } |
1196 | #define LOAD(x) Eigen::internal::ploadu<Packet>((x)); |
1197 | #define INTERLEAVE(x) Eigen::internal::pinterleave4x64<Packet>(x); |
1198 | #define STORE(x, y) Eigen::internal::pstoreu<float>(x, y); |
1199 | |
1200 | template <int NUM_ELEM = -1> |
1201 | ALWAYS_INLINE void CopyAndMayBeInterleaveBfloat16(void* bdst, const void* bsrc, |
1202 | int num_elements) { |
1203 | DCHECK_GE(kNumOperands, 8); |
1204 | static const int kStep = kNumOperands * sizeof(float) / sizeof(bfloat16); |
1205 | const int num = (NUM_ELEM == -1) ? num_elements : NUM_ELEM; |
1206 | DCHECK_EQ(num, num_elements); |
1207 | const float* src = reinterpret_cast<const float*>(bsrc); |
1208 | float* dst = reinterpret_cast<float*>(bdst); |
1209 | for (int index = 0; index + kStep <= num; index += kStep) { |
1210 | auto in = LOAD(src); |
1211 | auto tmp = INTERLEAVE(in); |
1212 | STORE(dst, tmp); |
1213 | src += kNumOperands; |
1214 | dst += kNumOperands; |
1215 | } |
1216 | if (num % kStep != 0) { |
1217 | memcpy(dst, src, (num % kStep) * sizeof(bfloat16)); |
1218 | } |
1219 | } |
1220 | |
1221 | template <typename T> |
1222 | ALWAYS_INLINE void CopyAndMayBeInterleave(void* dst, const void* src, |
1223 | int num_elements) { |
1224 | if (std::is_same<T, float>::value || kNumOperands < 8) { |
1225 | memcpy(dst, src, num_elements * sizeof(T)); |
1226 | } else if (std::is_same<T, bfloat16>::value) { |
1227 | if (num_elements == N) { |
1228 | CopyAndMayBeInterleaveBfloat16<N>(dst, src, num_elements); |
1229 | } else { |
1230 | CopyAndMayBeInterleaveBfloat16<-1>(dst, src, num_elements); |
1231 | } |
1232 | } else { |
1233 | LOG(FATAL) << "Unsupported type" ; |
1234 | } |
1235 | } |
1236 | |
1237 | #undef LOAD |
1238 | #undef Interleave |
1239 | #undef Store |
1240 | |
1241 | template <typename TL, typename TR> |
1242 | inline BlockingCounter* SparseMatMul<TL, TR>::ShuffleMatrix( |
1243 | const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat, |
1244 | int slice_row_start, int slice_num_rows, int slice_col_start, |
1245 | int slice_num_cols, const int N, |
1246 | const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer) { |
1247 | DCHECK_EQ(N % 2, 0); |
1248 | DCHECK_LE(kNumOperands * sizeof(float) / sizeof(TR), N); |
1249 | // Note(nikhilsarda): This heuristic is optimal in benchmarks as of |
1250 | // Jan 21, 2020. |
1251 | int num_threads = std::min(thread_pool->num_threads, 8); |
1252 | BlockingCounter* counter = new BlockingCounter(num_threads); |
1253 | DCHECK_EQ(N, buffer->dimension(1)); |
1254 | auto shuffle_work = [&mat, slice_row_start, slice_num_rows, slice_col_start, |
1255 | slice_num_cols, N, buffer, counter](int s, int e) { |
1256 | const int row_start = s % slice_num_rows + slice_row_start; |
1257 | const int col_start = s / slice_num_rows * N + slice_col_start; |
1258 | auto* out_start = &(*buffer)(s, 0); |
1259 | const auto* input_start = &mat(row_start, col_start); |
1260 | const auto* input_end = &mat(slice_row_start + slice_num_rows - 1, |
1261 | slice_col_start + slice_num_cols - 1); |
1262 | const int mat_num_cols = mat.dimension(1); |
1263 | const int row_slice_size = slice_num_rows * mat_num_cols; |
1264 | |
1265 | const int aligned_end = slice_num_cols / N * slice_num_rows; |
1266 | const int e1 = std::min(e, aligned_end); |
1267 | while (s < e1) { |
1268 | CopyAndMayBeInterleave<TR>(out_start, input_start, N); |
1269 | out_start += N; |
1270 | input_start += mat_num_cols; |
1271 | if (input_start > input_end) { |
1272 | input_start = input_start - row_slice_size + N; |
1273 | } |
1274 | ++s; |
1275 | } |
1276 | int s1 = std::max(s, aligned_end); |
1277 | const int copy_num_cols = slice_num_cols % N; |
1278 | while (s1 < e) { |
1279 | CopyAndMayBeInterleave<TR>(out_start, input_start, copy_num_cols); |
1280 | out_start += N; |
1281 | input_start += mat_num_cols; |
1282 | ++s1; |
1283 | } |
1284 | if (counter) counter->DecrementCount(); |
1285 | }; |
1286 | |
1287 | int start = 0; |
1288 | int end = 0; |
1289 | int num_out_rows = (slice_num_cols + N - 1) / N * slice_num_rows; |
1290 | DCHECK_LE(num_out_rows, buffer->dimension(0)); |
1291 | for (int i = std::max(1, num_threads); i > 0; --i) { |
1292 | end = start + num_out_rows / i; |
1293 | thread_pool->workers->Schedule([=]() { shuffle_work(start, end); }); |
1294 | num_out_rows -= (end - start); |
1295 | start = end; |
1296 | } |
1297 | return counter; |
1298 | } |
1299 | |
1300 | template <typename TL, typename TR> |
1301 | inline void SparseMatMul<TL, TR>::SliceMatrix( |
1302 | const MatrixR& mat, const int num_rows, const int num_slices, |
1303 | std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) { |
1304 | slices->resize(num_slices); |
1305 | DSizes d(num_rows, mat.dimension(1)); |
1306 | DCHECK_LE(num_rows * num_slices, mat.dimension(0)); |
1307 | for (int i = 0; i < num_slices; ++i) { |
1308 | (*slices)[i] = new ConstMatrixMapR(&mat(i * num_rows, 0), d); |
1309 | } |
1310 | } |
1311 | |
1312 | template <typename TL, typename TR> |
1313 | inline std::unique_ptr<BlockingCounter> SparseMatMul<TL, TR>::CreateDenseSlices( |
1314 | const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat, int row_start, |
1315 | int num_rows, int col_start, int num_cols, |
1316 | const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer, |
1317 | std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) { |
1318 | std::unique_ptr<BlockingCounter> shuffle_counter(ShuffleMatrix( |
1319 | mat, row_start, num_rows, col_start, num_cols, N, thread_pool, buffer)); |
1320 | const int num_slices = (num_cols + N - 1) / N; |
1321 | SliceMatrix(*buffer, num_rows, num_slices, slices); |
1322 | return shuffle_counter; |
1323 | } |
1324 | |
1325 | template <typename TL, typename TR> |
1326 | inline void SparseMatMul<TL, TR>::ComputeBlockSizes( |
1327 | const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left, |
1328 | const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right, |
1329 | bool transpose_left, int num_threads, int* KR, int* NR, int* KL, int* JB, |
1330 | int* IB) { |
1331 | // Heuristics for calculating block sizes |
1332 | // Assume two hyperthreads per core. |
1333 | const int est_num_cores = std::max(1, (num_threads + 1) / 2); |
1334 | // Use block of rhs with at most 128K floats per core. |
1335 | const int mem = est_num_cores * 128 * 1024; |
1336 | *KR = std::min(static_cast<int>(right.dimension(0)), mem / 256); |
1337 | *NR = right.dimension(1); |
1338 | if (*KR * *NR > mem) { |
1339 | // 4096 may be enough to amortize the cost of writes. |
1340 | *KR = std::min<int>(*KR, 4096); |
1341 | } |
1342 | // Use sizes that are multiples of K and 256. |
1343 | *KR = std::max(1, *KR / K) * K; |
1344 | *NR = std::max(1, *NR / 256) * 256; |
1345 | if (*KR * *NR > mem) { |
1346 | *NR = mem / *KR; |
1347 | } |
1348 | *NR = std::max(1, *NR / 256) * 256; |
1349 | |
1350 | const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0); |
1351 | const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1); |
1352 | for (*KL = 1024; *KL > K; *KL /= 2) { |
1353 | if (*KR % *KL == 0 && |
1354 | std::max<int>(1, left_dim0 / 64) * (left_dim1 / *KL) > est_num_cores) { |
1355 | break; |
1356 | } |
1357 | } |
1358 | DCHECK_EQ(*KL % K, 0); |
1359 | DCHECK_GE(*KR, *KL); |
1360 | if (*KR < right.dimension(0)) { |
1361 | CHECK_EQ(*KR % *KL, 0); |
1362 | } |
1363 | |
1364 | *JB = std::max(1, static_cast<int>(sqrt(num_threads) / 2.0)); |
1365 | *IB = 8 * *JB; |
1366 | DCHECK_EQ(N * sizeof(float) % 64, size_t{0}); |
1367 | } |
1368 | |
1369 | #ifdef TENSORFLOW_USE_LIBXSMM |
1370 | |
1371 | template <typename F> |
1372 | void do_on_all_threads(const DeviceBase::CpuWorkerThreads* thread_pool, |
1373 | const F& f) { |
1374 | int num_threads = thread_pool->num_threads; |
1375 | if (num_threads == 0) { |
1376 | LOG(FATAL) << "Have 0 threads in thread pool" ; |
1377 | } else if (num_threads == 1) { |
1378 | f(0); |
1379 | } else { |
1380 | BlockingCounter counter(num_threads - 1); |
1381 | for (int i = 1; i < num_threads; ++i) { |
1382 | thread_pool->workers->Schedule([&, i]() { |
1383 | f(i); |
1384 | counter.DecrementCount(); |
1385 | }); |
1386 | } |
1387 | f(0); |
1388 | counter.Wait(); |
1389 | } |
1390 | } |
1391 | |
1392 | template <typename T> |
1393 | struct empty_type_wrapper {}; |
1394 | |
1395 | // Copies of interface to libxsmm_spmdm_createSparseSlice_*_notrans_thread to |
1396 | // allow overloading |
1397 | void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread( |
1398 | empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA, |
1399 | const float* A, libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, |
1400 | int tid, int nthreads) { |
1401 | return libxsmm_spmdm_createSparseSlice_fp32_thread( |
1402 | handle, transA, A, libxsmm_output_csr_a, block_id, tid, nthreads); |
1403 | } |
1404 | void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread( |
1405 | empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle, |
1406 | char transA, const bfloat16* A, |
1407 | libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, int tid, |
1408 | int nthreads) { |
1409 | return libxsmm_spmdm_createSparseSlice_bfloat16_thread( |
1410 | handle, transA, reinterpret_cast<const libxsmm_bfloat16*>(A), |
1411 | libxsmm_output_csr_a, block_id, tid, nthreads); |
1412 | } |
1413 | |
1414 | void wrapper_libxsmm_spmdm_compute_generic_thread( |
1415 | empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle, |
1416 | char transA, char transB, const bfloat16* alpha, |
1417 | libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC, |
1418 | const bfloat16* beta, float* C, int block_id, int tid, int nthreads) { |
1419 | return libxsmm_spmdm_compute_bfloat16_thread( |
1420 | handle, transA, transB, reinterpret_cast<const libxsmm_bfloat16*>(alpha), |
1421 | A_sparse, reinterpret_cast<const libxsmm_bfloat16*>(B), transC, |
1422 | reinterpret_cast<const libxsmm_bfloat16*>(beta), C, block_id, tid, |
1423 | nthreads); |
1424 | } |
1425 | void wrapper_libxsmm_spmdm_compute_generic_thread( |
1426 | empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA, |
1427 | char transB, const float* alpha, libxsmm_CSR_sparseslice* A_sparse, |
1428 | const float* B, char transC, const float* beta, float* C, int block_id, |
1429 | int tid, int nthreads) { |
1430 | return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, alpha, |
1431 | A_sparse, B, transC, beta, C, |
1432 | block_id, tid, nthreads); |
1433 | } |
1434 | |
1435 | template <typename TL, typename TR> |
1436 | inline void LibxsmmSparseMatMul<TL, TR>::Compute( |
1437 | typename LibxsmmSparseMatMul<TL, TR>::TensorInfoCache* cache, |
1438 | const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapL& left, |
1439 | const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapR& right, |
1440 | bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool, |
1441 | bool transpose_output, MatrixMap* output) { |
1442 | const int num_threads = thread_pool->num_threads; |
1443 | const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0); |
1444 | const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1); |
1445 | const int right_dim0 = right.dimension(0); |
1446 | const int right_dim1 = right.dimension(1); |
1447 | CHECK_EQ(left_dim1, right_dim0); |
1448 | CHECK_EQ(left_dim0, |
1449 | (transpose_output ? output->dimension(1) : output->dimension(0))); |
1450 | CHECK_EQ(right_dim1, |
1451 | (transpose_output ? output->dimension(0) : output->dimension(1))); |
1452 | auto left_data = left.data(); |
1453 | auto right_data = right.data(); |
1454 | auto output_data = output->data(); |
1455 | // Initialize libxsmm for this matrix; make sure another thread doesn't use |
1456 | // this handle |
1457 | auto entry = |
1458 | cache->take_cache_entry(left_dim0, right_dim0, right_dim1, num_threads); |
1459 | // Convert the left matrix to compressed sparse row (CSR) format |
1460 | ptrdiff_t total_num_creation_blocks = |
1461 | libxsmm_spmdm_get_num_createSparseSlice_blocks(&entry->handle); |
1462 | std::atomic<int> cur_create_block_number; |
1463 | cur_create_block_number.store(0); |
1464 | do_on_all_threads(thread_pool, [&](int i) { |
1465 | while (true) { |
1466 | int work_item = cur_create_block_number.fetch_add(1); |
1467 | if (work_item >= total_num_creation_blocks) break; |
1468 | wrapper_libxsmm_spmdm_createSparseSlice_generic_thread( |
1469 | empty_type_wrapper<TL>{}, &entry->handle, |
1470 | (transpose_left ? 'T' : 'N'), left_data, entry->output_csr, work_item, |
1471 | i, num_threads); |
1472 | } |
1473 | }); |
1474 | // Do matrix-matrix multiplication |
1475 | ptrdiff_t total_num_mult_blocks = |
1476 | libxsmm_spmdm_get_num_compute_blocks(&entry->handle); |
1477 | std::atomic<int> cur_mult_block_number; |
1478 | cur_mult_block_number.store(0); |
1479 | do_on_all_threads(thread_pool, [&](int i) { |
1480 | while (true) { |
1481 | int work_item = cur_mult_block_number.fetch_add(1); |
1482 | if (work_item >= total_num_mult_blocks) break; |
1483 | const TL alpha(1.0); // Stored in a variable so we can get a pointer |
1484 | const TL beta(0.0); // Stored in a variable so we can get a pointer |
1485 | wrapper_libxsmm_spmdm_compute_generic_thread( |
1486 | empty_type_wrapper<TL>{}, &entry->handle, |
1487 | (transpose_left ? 'T' : 'N'), 'N', &alpha, entry->output_csr, |
1488 | right_data, (transpose_output ? 'T' : 'N'), &beta, output_data, |
1489 | work_item, i, num_threads); |
1490 | } |
1491 | }); |
1492 | // Put handle + CSR storage back into cache |
1493 | cache->return_cache_entry(std::move(entry)); |
1494 | } |
1495 | |
1496 | #endif // TENSORFLOW_USE_LIBXSMM |
1497 | |
1498 | // Here is an overview of the SparseMatMul code. Note that we assume that the |
1499 | // left matrix is sparse. |
1500 | // |
1501 | // The matrix "left" is divided into a grid with blocksize of (M, KL). Each |
1502 | // block is encoded as a SparseSlice. These grid elements are stored as |
1503 | // std::vector<std::vector<SparseSlice>>. Each element of the outer vector |
1504 | // represents M rows of the left matrix. Lets call these elements l_i and lets |
1505 | // call each element of the inner vector L_mk. |
1506 | // |
1507 | // The matrix "right" is divided into a grid with block size KR * NR. Lets |
1508 | // denote the blocks on the right as R_kn. Note that we ensure that KL divides |
1509 | // KR so that for each element R_kn, we don't need to multiply it with any |
1510 | // partial L_mk blocks. |
1511 | // |
1512 | // We then multiply each right side block R_kn with the full "left" matrix and |
1513 | // update the output. These iterations are run sequentially since R_kn are |
1514 | // packed into the same underlying temporary buffer. |
1515 | // |
1516 | // In each iteration we do the following: |
1517 | // 1. Create slices r_j of R_kn: We split R_kn into vertical blocks with N |
1518 | // (=128) columns and then concatenating these slices into a buffer. This is |
1519 | // done so that each slice r_j of R_kn is stored contiguously in memory. Note |
1520 | // that if R_kj has dimensions (KR, NR), we create NR / N slices, and the |
1521 | // buffer has dimensions (KR * NR / N, N) (assuming N divides NR). |
1522 | // 2. For each (l_i, r_j), we compute the inner product using the GEPP function |
1523 | // and update the output block o_ij. These calls are further blocked to |
1524 | // reduce the working set size. In each iteration we take IB elements from |
1525 | // {l_i} and JB elements from {r_j} and compute the IB * JB inner products. |
1526 | template <typename TL, typename TR> |
1527 | inline void SparseMatMul<TL, TR>::Compute( |
1528 | typename SparseMatMul<TL, TR>::TensorInfoCache* /*cache*/, |
1529 | const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left, |
1530 | const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right, |
1531 | bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool, |
1532 | bool transpose_output, MatrixMap* output) { |
1533 | const int num_threads = thread_pool->num_threads; |
1534 | int KR, NR, KL, JB, IB; |
1535 | ComputeBlockSizes(left, right, transpose_left, num_threads, &KR, &NR, &KL, |
1536 | &JB, &IB); |
1537 | // Slice the left matrix |
1538 | std::vector<std::vector<SparseSlice<TL>*>> left_slices; |
1539 | std::unique_ptr<BlockingCounter> sparse_slice_counter = |
1540 | CreateSparseSlices(ConstMatrixMapL(left.data(), left.dimensions()), |
1541 | transpose_left, M, K, KL, &left_slices, thread_pool); |
1542 | const int num_left_slices = left_slices.size(); |
1543 | |
1544 | const int right_dim0 = right.dimension(0); |
1545 | const int right_dim1 = right.dimension(1); |
1546 | // Allocate buffer for storing slices of right matrix. |
1547 | // Note buffer needs enough space to hold at most a KR * NR matrix since that |
1548 | // is the block size per iteration. |
1549 | const int buffer_num_rows = |
1550 | std::min(KR, right_dim0) * ((std::min(NR, right_dim1) + N - 1) / N); |
1551 | MatrixR buffer(buffer_num_rows, N); |
1552 | std::vector<ConstMatrixMapR*> right_slices; |
1553 | |
1554 | std::vector<SparseSlice<TL>*> block_left_slices; |
1555 | std::vector<std::function<void(void)>> tasks; |
1556 | // Number of blocks based on block sizes of KR * NR. |
1557 | const int num_k_blocks = (right_dim0 + KR - 1) / KR; |
1558 | const int num_n_blocks = (right_dim1 + NR - 1) / NR; |
1559 | std::unique_ptr<BlockingCounter> dense_slice_counter; |
1560 | |
1561 | for (int nb = 0; nb < num_n_blocks; ++nb) { |
1562 | const int right_num_cols = |
1563 | std::min(NR, static_cast<int>(right_dim1 - NR * nb)); |
1564 | for (int kb = 0; kb < num_k_blocks; ++kb) { |
1565 | const int right_num_rows = |
1566 | std::min(KR, static_cast<int>(right_dim0 - KR * kb)); |
1567 | dense_slice_counter = CreateDenseSlices( |
1568 | right, kb * KR, right_num_rows, nb * NR, right_num_cols, thread_pool, |
1569 | &buffer, &right_slices); |
1570 | const int num_right_slices = right_slices.size(); |
1571 | tasks.reserve(num_left_slices * num_right_slices); |
1572 | for (int j_outer = 0; j_outer < num_right_slices; j_outer += JB) { |
1573 | for (int i_outer = 0; i_outer < num_left_slices; i_outer += IB) { |
1574 | for (int j_inner = j_outer; |
1575 | j_inner < std::min(num_right_slices, j_outer + JB); ++j_inner) { |
1576 | const int num_cols = std::min(N, right_num_cols - N * j_inner); |
1577 | for (int i_inner = i_outer; |
1578 | i_inner < std::min(num_left_slices, i_outer + IB); ++i_inner) { |
1579 | block_left_slices.clear(); |
1580 | int begin = kb * KR / KL; |
1581 | int end = std::min<int>((kb + 1) * KR / KL, |
1582 | (right.dimension(0) + KL - 1) / KL); |
1583 | DCHECK_LT(begin, end); |
1584 | block_left_slices.insert(block_left_slices.begin(), |
1585 | left_slices[i_inner].begin() + begin, |
1586 | left_slices[i_inner].begin() + end); |
1587 | tasks.push_back(std::bind( |
1588 | &ComputeOutputBlock, block_left_slices, |
1589 | std::ref(*right_slices[j_inner]), num_cols, M * i_inner, |
1590 | N * j_inner + nb * NR, kb == 0, transpose_output, output)); |
1591 | } |
1592 | } |
1593 | } |
1594 | } |
1595 | if (sparse_slice_counter) { |
1596 | sparse_slice_counter->Wait(); |
1597 | sparse_slice_counter.reset(nullptr); |
1598 | } |
1599 | if (dense_slice_counter) { |
1600 | dense_slice_counter->Wait(); |
1601 | dense_slice_counter.reset(nullptr); |
1602 | } |
1603 | BlockingCounter bc(tasks.size()); |
1604 | for (const auto& t : tasks) { |
1605 | thread_pool->workers->Schedule([&bc, &t]() { |
1606 | t(); |
1607 | bc.DecrementCount(); |
1608 | }); |
1609 | } |
1610 | bc.Wait(); |
1611 | tasks.clear(); |
1612 | for (auto& temp : right_slices) { |
1613 | delete temp; |
1614 | } |
1615 | right_slices.clear(); |
1616 | } |
1617 | } |
1618 | for (auto& left_slice : left_slices) { |
1619 | for (auto& temp : left_slice) { |
1620 | delete temp; |
1621 | } |
1622 | left_slice.clear(); |
1623 | } |
1624 | } |
1625 | |
1626 | #define REGISTER_SPARSE_MATMUL(TA, TB) \ |
1627 | REGISTER_KERNEL_BUILDER(Name("SparseMatMul") \ |
1628 | .Device(DEVICE_CPU) \ |
1629 | .TypeConstraint<TA>("Ta") \ |
1630 | .TypeConstraint<TB>("Tb"), \ |
1631 | SparseMatMulOp<TA, TB, SparseMatMul>); |
1632 | #ifdef TENSORFLOW_USE_LIBXSMM |
1633 | #define REGISTER_SPARSE_MATMUL_LIBXSMM(TA, TB) \ |
1634 | REGISTER_KERNEL_BUILDER(Name("SparseMatMul") \ |
1635 | .Device(DEVICE_CPU) \ |
1636 | .TypeConstraint<TA>("Ta") \ |
1637 | .TypeConstraint<TB>("Tb"), \ |
1638 | SparseMatMulOp<TA, TB, LibxsmmSparseMatMul>); |
1639 | #endif |
1640 | |
1641 | REGISTER_SPARSE_MATMUL(float, bfloat16); |
1642 | REGISTER_SPARSE_MATMUL(bfloat16, float); |
1643 | |
1644 | #ifdef TENSORFLOW_USE_LIBXSMM |
1645 | REGISTER_SPARSE_MATMUL_LIBXSMM(bfloat16, bfloat16); |
1646 | REGISTER_SPARSE_MATMUL_LIBXSMM(float, float); |
1647 | #else |
1648 | REGISTER_SPARSE_MATMUL(bfloat16, bfloat16); |
1649 | REGISTER_SPARSE_MATMUL(float, float); |
1650 | #endif |
1651 | |
1652 | #undef REGISTER_SPARSE_MATMUL |
1653 | |
1654 | } // end namespace tensorflow |
1655 | |