1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
17#define TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
18
19#include "third_party/eigen3/Eigen/Core"
20#include "tensorflow/core/platform/byte_order.h"
21#include "tensorflow/core/platform/types.h"
22
23#if defined(PLATFORM_WINDOWS)
24#include "tensorflow/tsl/platform/windows/intrinsics_port.h"
25#endif
26
27namespace Eigen {
28namespace internal {
29
30// Return the float representation of the bfloat16 value
31// in the lower 16-bits of input
32template <typename Packet>
33EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_l(const Packet& from) {
34 tensorflow::uint32 tmp;
35#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
36 tmp = (reinterpret_cast<const tensorflow::uint32&>(from)) & 0xffff0000;
37#else
38 tmp = (reinterpret_cast<const tensorflow::uint32&>(from) << 16) & 0xffff0000;
39#endif
40 return reinterpret_cast<const float&>(tmp);
41}
42
43// Return the float representation of the bfloat16 value
44// in the upper 16-bits of input
45template <typename Packet>
46EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_u(const Packet& from) {
47 tensorflow::uint32 tmp;
48#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
49 tmp = (reinterpret_cast<const tensorflow::uint32&>(from) << 16) & 0xffff0000;
50#else
51 tmp = (reinterpret_cast<const tensorflow::uint32&>(from)) & 0xffff0000;
52#endif
53 return reinterpret_cast<const float&>(tmp);
54}
55
56// Specialization non-scalar version on non-sse.
57// Enable vectorization on z13 and higher
58#if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) || \
59 defined(EIGEN_VECTORIZE_NEON) || defined(EIGEN_VECTORIZE_ZVECTOR)
60template <typename Packet>
61EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_l(const Packet4f& from) {
62 float r[4];
63 tensorflow::uint32 p[4];
64 pstoreu(r, from);
65 tensorflow::uint32* ir = reinterpret_cast<tensorflow::uint32*>(r);
66 p[0] = (ir[0] << 16) & 0xffff0000;
67 p[1] = ir[0] & 0xffff0000;
68 p[2] = (ir[1] << 16) & 0xffff0000;
69 p[3] = ir[1] & 0xffff0000;
70 return ploadu<Packet4f>(reinterpret_cast<float*>(p));
71}
72
73template <typename Packet>
74EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_u(const Packet4f& from) {
75 float r[4];
76 tensorflow::uint32 p[4];
77 pstoreu(r, from);
78 tensorflow::uint32* ir = reinterpret_cast<tensorflow::uint32*>(r);
79 p[0] = (ir[2] << 16) & 0xffff0000;
80 p[1] = ir[2] & 0xffff0000;
81 p[2] = (ir[3] << 16) & 0xffff0000;
82 p[3] = ir[3] & 0xffff0000;
83 return ploadu<Packet4f>(reinterpret_cast<float*>(p));
84}
85#endif
86
87template <typename Packet>
88EIGEN_DEVICE_FUNC inline Packet pinterleave4x64(const Packet& from) {
89 return from;
90}
91
92template <typename Packet>
93EIGEN_DEVICE_FUNC inline Packet pbroadcast_first(const Packet& a) {
94 return a;
95}
96
97template <typename Packet>
98EIGEN_DEVICE_FUNC inline Packet pbroadcast_second(const Packet& a) {
99 assert(false && "Not applicable to Scalar Values");
100 return a;
101}
102
103template <typename Packet>
104EIGEN_DEVICE_FUNC inline Packet pbroadcast_third(const Packet& a) {
105 assert(false && "Not applicable to Scalar Values");
106 return a;
107}
108
109template <typename Packet>
110EIGEN_DEVICE_FUNC inline Packet pbroadcast_fourth(const Packet& a) {
111 assert(false && "Not applicable to Scalar Values");
112 return a;
113}
114
115template <typename Packet>
116EIGEN_DEVICE_FUNC inline Packet pload4bf16(
117 const typename unpacket_traits<Packet>::type* from) {
118 assert(false && "Not applicable to Scalar Values");
119 return Packet();
120}
121
122template <typename Packet>
123EIGEN_DEVICE_FUNC inline Packet pload2bf16(
124 const typename unpacket_traits<Packet>::type* from) {
125 assert(false && "Not applicable to Scalar Values");
126 return Packet();
127}
128
129// Specialization for pload4bf16 and pload2bf16 for non-sse.
130// Enable vectorization on z13 and higher.
131#if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) || \
132 defined(EIGEN_VECTORIZE_NEON) || defined(EIGEN_VECTORIZE_ZVECTOR)
133template <>
134EIGEN_STRONG_INLINE Packet4f pload4bf16<Packet4f>(const float* from) {
135 tensorflow::uint32 p[4];
136 const tensorflow::uint32* ir =
137 reinterpret_cast<const tensorflow::uint32*>(from);
138 p[0] = (ir[0] << 16) & 0xffff0000;
139 p[1] = ir[0] & 0xffff0000;
140 p[2] = (ir[1] << 16) & 0xffff0000;
141 p[3] = ir[1] & 0xffff0000;
142 return ploadu<Packet4f>(reinterpret_cast<float*>(p));
143}
144
145template <>
146EIGEN_STRONG_INLINE Packet4f pload2bf16<Packet4f>(const float* from) {
147 tensorflow::uint32 p[4];
148 const tensorflow::uint32* ir =
149 reinterpret_cast<const tensorflow::uint32*>(from);
150 p[0] = (ir[0] << 16) & 0xffff0000;
151 p[1] = ir[0] & 0xffff0000;
152 p[2] = (ir[0] << 16) & 0xffff0000;
153 p[3] = ir[0] & 0xffff0000;
154 return ploadu<Packet4f>(reinterpret_cast<float*>(p));
155}
156#endif
157
158#if defined(EIGEN_VECTORIZE_NEON)
159// Return a packet with the first value of the input Packet replicated
160template <>
161EIGEN_STRONG_INLINE Packet4f pbroadcast_first<Packet4f>(const Packet4f& a) {
162 return pset1<Packet4f>(pfirst(a));
163}
164template <>
165EIGEN_STRONG_INLINE Packet2f pbroadcast_first<Packet2f>(const Packet2f& a) {
166 return pset1<Packet2f>(pfirst(a));
167}
168
169// Return a packet with the second value of the input Packet replicated
170template <>
171EIGEN_STRONG_INLINE Packet4f pbroadcast_second<Packet4f>(const Packet4f& a) {
172 return pset1<Packet4f>(vgetq_lane_f32(a, 1));
173}
174template <>
175EIGEN_STRONG_INLINE Packet2f pbroadcast_second<Packet2f>(const Packet2f& a) {
176 return pset1<Packet2f>(vget_lane_f32(a, 1));
177}
178
179// Return a packet with the third value of the input Packet replicated
180template <>
181EIGEN_STRONG_INLINE Packet4f pbroadcast_third<Packet4f>(const Packet4f& a) {
182 return pset1<Packet4f>(vgetq_lane_f32(a, 2));
183}
184
185// Return a packet with the fourth value of the input Packet replicated
186template <>
187EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth<Packet4f>(const Packet4f& a) {
188 return pset1<Packet4f>(vgetq_lane_f32(a, 3));
189}
190#endif
191
192#if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX)
193// Return a packet with the first value of the input Packet replicated
194template <>
195EIGEN_STRONG_INLINE Packet4f pbroadcast_first<Packet4f>(const Packet4f& a) {
196 return vec_splat(a, 0);
197}
198
199// Return a packet with the second value of the input Packet replicated
200template <>
201EIGEN_STRONG_INLINE Packet4f pbroadcast_second<Packet4f>(const Packet4f& a) {
202 return vec_splat(a, 1);
203}
204
205// Return a packet with the third value of the input Packet replicated
206template <>
207EIGEN_STRONG_INLINE Packet4f pbroadcast_third<Packet4f>(const Packet4f& a) {
208 return vec_splat(a, 2);
209}
210
211// Return a packet with the fourth value of the input Packet replicated
212template <>
213EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth<Packet4f>(const Packet4f& a) {
214 return vec_splat(a, 3);
215}
216#endif
217
218#ifdef EIGEN_VECTORIZE_SSE2
219// For PacketSize of 4 floats the Packet is not modified
220template <>
221EIGEN_STRONG_INLINE Packet4f pinterleave4x64<Packet4f>(const Packet4f& from) {
222 return from;
223}
224
225// Return a Packet with 4 floats loaded from 4 bfloat16 values
226template <>
227EIGEN_STRONG_INLINE Packet4f pload4bf16<Packet4f>(const float* from) {
228 __m128i zero = _mm_setzero_si128();
229 __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
230 return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
231}
232
233// Return a Packet with 2 floats loaded from 2 bfloat16 values
234template <>
235EIGEN_STRONG_INLINE Packet4f pload2bf16<Packet4f>(const float* from) {
236 __m128i zero = _mm_setzero_si128();
237 __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
238 return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
239}
240
241// Return a Packet with 4 floats expanded from 4 bfloat16 values
242// in the lower half of the 128-bit lane
243template <typename Packet>
244EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_l(const Packet4f& from) {
245 __m128i zero = _mm_setzero_si128();
246 __m128i tmp = _mm_castps_si128(from);
247 return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
248}
249
250// Return a Packet with 4 floats expanded from 4 bfloat16 values
251// in the upper half of the 128-bit lane
252template <typename Packet>
253EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_u(const Packet4f& from) {
254 __m128i zero = _mm_setzero_si128();
255 __m128i tmp = _mm_castps_si128(from);
256 return _mm_castsi128_ps(_mm_unpackhi_epi16(zero, tmp));
257}
258
259// Return a packet with the first value of the input Packet replicated
260template <>
261EIGEN_STRONG_INLINE Packet4f pbroadcast_first<Packet4f>(const Packet4f& a) {
262 return _mm_set1_ps(pfirst<Packet4f>(a));
263}
264
265// Return a packet with the second value of the input Packet replicated
266template <>
267EIGEN_STRONG_INLINE Packet4f pbroadcast_second<Packet4f>(const Packet4f& a) {
268 return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 1)));
269}
270
271// Return a packet with the third value of the input Packet replicated
272template <>
273EIGEN_STRONG_INLINE Packet4f pbroadcast_third<Packet4f>(const Packet4f& a) {
274 return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 2)));
275}
276
277// Return a packet with the fourth value of the input Packet replicated
278template <>
279EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth<Packet4f>(const Packet4f& a) {
280 return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 3)));
281}
282
283#endif
284
285#ifdef EIGEN_VECTORIZE_AVX512
286template <>
287EIGEN_STRONG_INLINE Packet16f
288pbroadcast_first<Packet16f>(const Packet16f& a_in) {
289 Packet4f a = _mm512_castps512_ps128(a_in);
290 return _mm512_broadcastss_ps(a);
291}
292template <>
293EIGEN_STRONG_INLINE Packet16f
294pbroadcast_second<Packet16f>(const Packet16f& a_in) {
295 Packet4f a = _mm512_castps512_ps128(a_in);
296 return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(1, 1, 1, 1)));
297}
298template <>
299EIGEN_STRONG_INLINE Packet16f
300pbroadcast_third<Packet16f>(const Packet16f& a_in) {
301 Packet4f a = _mm512_castps512_ps128(a_in);
302 return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(2, 2, 2, 2)));
303}
304template <>
305EIGEN_STRONG_INLINE Packet16f
306pbroadcast_fourth<Packet16f>(const Packet16f& a_in) {
307 Packet4f a = _mm512_castps512_ps128(a_in);
308 return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(3, 3, 3, 3)));
309}
310template <>
311EIGEN_STRONG_INLINE Packet8d pbroadcast_first<Packet8d>(const Packet8d& a_in) {
312 Packet2d a = _mm512_castpd512_pd128(a_in);
313 return _mm512_broadcastsd_pd(a);
314}
315template <>
316EIGEN_STRONG_INLINE Packet8d pbroadcast_second<Packet8d>(const Packet8d& a_in) {
317 Packet2d a = _mm_permute_pd(_mm512_castpd512_pd128(a_in), 3);
318 return _mm512_broadcastsd_pd(a);
319}
320template <>
321EIGEN_STRONG_INLINE Packet8d pbroadcast_third<Packet8d>(const Packet8d& a_in) {
322 Packet2d a = _mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1);
323 return _mm512_broadcastsd_pd(a);
324}
325template <>
326EIGEN_STRONG_INLINE Packet8d pbroadcast_fourth<Packet8d>(const Packet8d& a_in) {
327 Packet2d a =
328 _mm_permute_pd(_mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1), 3);
329 return _mm512_broadcastsd_pd(a);
330}
331template <>
332EIGEN_STRONG_INLINE Packet16i
333pbroadcast_first<Packet16i>(const Packet16i& a_in) {
334 Packet4i a = _mm512_castsi512_si128(a_in);
335 return _mm512_broadcastd_epi32(a);
336}
337template <>
338EIGEN_STRONG_INLINE Packet16i
339pbroadcast_second<Packet16i>(const Packet16i& a_in) {
340 Packet4i a = _mm512_castsi512_si128(a_in);
341 return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(1, 1, 1, 1)));
342}
343template <>
344EIGEN_STRONG_INLINE Packet16i
345pbroadcast_third<Packet16i>(const Packet16i& a_in) {
346 Packet4i a = _mm512_castsi512_si128(a_in);
347 return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(2, 2, 2, 2)));
348}
349template <>
350EIGEN_STRONG_INLINE Packet16i
351pbroadcast_fourth<Packet16i>(const Packet16i& a_in) {
352 Packet4i a = _mm512_castsi512_si128(a_in);
353 return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(3, 3, 3, 3)));
354}
355#endif
356
357#ifdef EIGEN_VECTORIZE_AVX
358// For a Packet of Size 8 floats(256-bits), swap the 2nd and 3rd quadwords
359template <>
360EIGEN_STRONG_INLINE Packet8f pinterleave4x64<Packet8f>(const Packet8f& from) {
361#ifdef EIGEN_VECTORIZE_AVX2
362 return _mm256_castsi256_ps(_mm256_permute4x64_epi64(_mm256_castps_si256(from),
363 _MM_SHUFFLE(3, 1, 2, 0)));
364#else
365 auto tmp1 = _mm256_extract_epi32(_mm256_castps_si256(from), 2);
366 auto tmp2 = _mm256_extract_epi32(_mm256_castps_si256(from), 3);
367 auto tmp3 = _mm256_extract_epi32(_mm256_castps_si256(from), 4);
368 auto tmp4 = _mm256_extract_epi32(_mm256_castps_si256(from), 5);
369 auto tmp5 = _mm256_insert_epi32(_mm256_castps_si256(from), tmp1, 4);
370 tmp5 = _mm256_insert_epi32(tmp5, tmp2, 5);
371 tmp5 = _mm256_insert_epi32(tmp5, tmp3, 2);
372 tmp5 = _mm256_insert_epi32(tmp5, tmp4, 3);
373 return _mm256_castsi256_ps(tmp5);
374#endif
375}
376// Return a Packet with 4 floats loaded from 4 bfloat16 values
377template <>
378EIGEN_STRONG_INLINE Packet8f pload4bf16<Packet8f>(const float* from) {
379 __m128i zero = _mm_setzero_si128();
380 __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
381 return _mm256_castps128_ps256(
382 _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
383}
384// Return a Packet with 2 floats loaded from 2 bfloat16 values
385template <>
386EIGEN_STRONG_INLINE Packet8f pload2bf16<Packet8f>(const float* from) {
387 __m128i zero = _mm_setzero_si128();
388 __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
389 return _mm256_castps128_ps256(
390 _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
391}
392
393#ifdef EIGEN_VECTORIZE_AVX512
394// Return a Packet with 4 floats loaded from 4 bfloat16 values
395template <>
396EIGEN_STRONG_INLINE Packet16f pload4bf16<Packet16f>(const float* from) {
397 __m128i zero = _mm_setzero_si128();
398 __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
399 return _mm512_castps128_ps512(
400 _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
401}
402// Return a Packet with 2 floats loaded from 2 bfloat16 values
403template <>
404EIGEN_STRONG_INLINE Packet16f pload2bf16<Packet16f>(const float* from) {
405 __m128i zero = _mm_setzero_si128();
406 __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
407 return _mm512_castps128_ps512(
408 _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
409}
410#endif
411
412// For each 128-bit lane convert 4 bfloat to 4 float values from the lower half
413// of the 128-bit lane
414template <typename Packet>
415EIGEN_DEVICE_FUNC inline Packet8f pexpand_bf16_l(const Packet8f& from) {
416#ifdef EIGEN_VECTORIZE_AVX2
417 __m256i zero = _mm256_setzero_si256();
418 __m256i tmp = _mm256_castps_si256(from);
419 return _mm256_castsi256_ps(_mm256_unpacklo_epi16(zero, tmp));
420#else
421 __m128i zero = _mm_setzero_si128();
422 __m128i low = _mm_castps_si128(_mm256_extractf128_ps(from, 0));
423 __m128i res_l = _mm_unpacklo_epi16(zero, low);
424 __m128i high = _mm_castps_si128(_mm256_extractf128_ps(from, 1));
425 __m128i res_h = _mm_unpacklo_epi16(zero, high);
426 __m256 res = _mm256_castps128_ps256(_mm_castsi128_ps(res_l));
427 res = _mm256_insertf128_ps(res, _mm_castsi128_ps(res_h), 1);
428 return res;
429#endif
430}
431
432// For each 128-bit lane convert 4 bfloat to 4 float values from the upper half
433// of the 128-bit lane
434template <typename Packet>
435EIGEN_DEVICE_FUNC inline Packet8f pexpand_bf16_u(const Packet8f& from) {
436#ifdef EIGEN_VECTORIZE_AVX2
437 __m256i zero = _mm256_setzero_si256();
438 __m256i tmp = _mm256_castps_si256(from);
439 return _mm256_castsi256_ps(_mm256_unpackhi_epi16(zero, tmp));
440#else
441 __m128i zero = _mm_setzero_si128();
442 __m128i low = _mm_castps_si128(_mm256_extractf128_ps(from, 0));
443 __m128i res_l = _mm_unpackhi_epi16(zero, low);
444 __m128i high = _mm_castps_si128(_mm256_extractf128_ps(from, 1));
445 __m128i res_h = _mm_unpackhi_epi16(zero, high);
446 __m256 res = _mm256_castps128_ps256(_mm_castsi128_ps(res_l));
447 res = _mm256_insertf128_ps(res, _mm_castsi128_ps(res_h), 1);
448 return res;
449#endif
450}
451
452// Return a packet with the first value of the input Packet replicated
453template <>
454EIGEN_STRONG_INLINE Packet8f pbroadcast_first<Packet8f>(const Packet8f& a) {
455 return _mm256_set1_ps(pfirst<Packet8f>(a));
456}
457
458// Return a packet with the second value of the input Packet replicated
459template <>
460EIGEN_STRONG_INLINE Packet8f pbroadcast_second<Packet8f>(const Packet8f& a) {
461 return _mm256_set1_ps(
462 _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 1))));
463}
464
465// Return a packet with the third value of the input Packet replicated
466template <>
467EIGEN_STRONG_INLINE Packet8f pbroadcast_third<Packet8f>(const Packet8f& a) {
468 return _mm256_set1_ps(
469 _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 2))));
470}
471
472// Return a packet with the fourth value of the input Packet replicated
473template <>
474EIGEN_STRONG_INLINE Packet8f pbroadcast_fourth<Packet8f>(const Packet8f& a) {
475 return _mm256_set1_ps(
476 _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 3))));
477}
478
479#endif
480
481#ifdef EIGEN_VECTORIZE_AVX512
482
483template <typename Packet>
484EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_l(const Packet16f& from) {
485 return _mm512_castsi512_ps(_mm512_slli_epi32(
486 _mm512_cvtepu16_epi32(_mm512_castsi512_si256(_mm512_castps_si512(from))),
487 16));
488}
489
490template <typename Packet>
491EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_u(const Packet16f& from) {
492 Packet16i tmp = _mm512_castps_si512(from);
493 Packet16i tmp2 = _mm512_alignr_epi32(tmp, tmp, 8);
494 return _mm512_castsi512_ps(_mm512_slli_epi32(
495 _mm512_cvtepu16_epi32(_mm512_castsi512_si256(tmp2)), 16));
496}
497
498#endif
499} // namespace internal
500} // namespace Eigen
501#endif // TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
502