1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #include <cassert> |
8 | #include <cmath> |
9 | #include "RefImplementations.h" |
10 | #include "fbgemm/FbgemmEmbedding.h" |
11 | |
12 | #include "fbgemm/Types.h" |
13 | |
14 | namespace fbgemm { |
15 | namespace internal { |
16 | |
17 | template <typename InType, typename IndexType, typename OffsetType> |
18 | bool EmbeddingSpMDMBlockSize1_( |
19 | const std::int64_t output_size, |
20 | const std::int64_t index_size, |
21 | const std::int64_t data_size, // the number of rows in input |
22 | const InType* input, |
23 | const IndexType* indices, |
24 | const OffsetType* offsets_or_lengths, |
25 | const float* weights, // optional, can be null for non-weighted sum |
26 | bool normalize_by_lengths, |
27 | float* out, |
28 | bool is_weight_positional, |
29 | bool use_offsets, |
30 | bool is_bf16) { |
31 | int64_t current = 0; |
32 | for (int m = 0; m < output_size; ++m) { |
33 | out[m] = 0; |
34 | int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m] |
35 | : offsets_or_lengths[m]; |
36 | if (current + len > index_size) { |
37 | return false; |
38 | } |
39 | int i = 0; |
40 | |
41 | // The following code doesn't speedup |
42 | #if 0 |
43 | constexpr int VLEN = std::is_same<IndexType, std::int64_t>::value ? 4 : 8; |
44 | for (; i < lengths[m] / VLEN * VLEN; i += VLEN) { |
45 | if (std::is_same<IndexType, std::int64_t>::value) { |
46 | __m256i idx_v = _mm256_lddqu_si256( |
47 | reinterpret_cast<const __m256i*>(indices + current)); |
48 | // Should be none true |
49 | int mask1 = _mm256_movemask_pd(_mm256_castsi256_pd( |
50 | _mm256_cmpgt_epi64(_mm256_setzero_si256(), idx_v))); |
51 | // Should be all true |
52 | int mask2 = _mm256_movemask_pd(_mm256_castsi256_pd( |
53 | _mm256_cmpgt_epi64(_mm256_set1_epi64x(data_size), idx_v))); |
54 | if (mask1 || mask2 != 0x0f) { |
55 | return false; |
56 | } |
57 | |
58 | __m128 in_v = _mm256_i64gather_ps(input, idx_v, 4); |
59 | alignas(64) float in_buf[VLEN]; |
60 | _mm_store_ps(in_buf, in_v); |
61 | for (int j = 0; j < VLEN; ++j) { |
62 | if (weights) { |
63 | out[m] = std::fma( |
64 | weights[is_weight_positional ? i + j : current + j], |
65 | in_buf[j], |
66 | out[m]); |
67 | } else { |
68 | out[m] += in_buf[j]; |
69 | } |
70 | } |
71 | } else { |
72 | __m256i idx_v = _mm256_lddqu_si256( |
73 | reinterpret_cast<const __m256i*>(indices + current)); |
74 | // Should be none true |
75 | int mask1 = _mm256_movemask_ps(_mm256_castsi256_ps( |
76 | _mm256_cmpgt_epi32(_mm256_setzero_si256(), idx_v))); |
77 | // Should be all true |
78 | int mask2 = _mm256_movemask_ps(_mm256_castsi256_ps( |
79 | _mm256_cmpgt_epi32(_mm256_set1_epi32(data_size), idx_v))); |
80 | if (mask1 || mask2 != 0x00ff) { |
81 | return false; |
82 | } |
83 | |
84 | __m256 in_v = _mm256_i32gather_ps(input, idx_v, 4); |
85 | alignas(64) float in_buf[VLEN]; |
86 | _mm256_store_ps(in_buf, in_v); |
87 | for (int j = 0; j < VLEN; ++j) { |
88 | if (weights) { |
89 | out[m] = std::fma( |
90 | weights[is_weight_positional ? i + j : current + j], |
91 | in_buf[j], |
92 | out[m]); |
93 | } else { |
94 | out[m] += in_buf[j]; |
95 | } |
96 | } |
97 | } |
98 | |
99 | current += VLEN; |
100 | } |
101 | #endif |
102 | |
103 | float temp = out[m]; |
104 | for (; i < len; ++i) { |
105 | int64_t idx = indices[current]; |
106 | if (idx < 0 || idx >= data_size) { |
107 | out[m] = temp; |
108 | return false; |
109 | } |
110 | |
111 | float w = 1.f; |
112 | if (weights) { |
113 | w = weights[is_weight_positional ? i : current]; |
114 | } |
115 | |
116 | const InType* inptr = input + indices[current]; |
117 | temp = std::fma(w, convert_to_float_ref(*inptr, is_bf16), temp); |
118 | |
119 | ++current; |
120 | } |
121 | if (normalize_by_lengths && len) { |
122 | float scale = 1.f / len; |
123 | temp *= scale; |
124 | } |
125 | out[m] = temp; |
126 | } |
127 | return current == index_size; |
128 | } |
129 | |
130 | #define INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE) \ |
131 | template bool EmbeddingSpMDMBlockSize1_( \ |
132 | const std::int64_t output_size, \ |
133 | const std::int64_t index_size, \ |
134 | const std::int64_t data_size, \ |
135 | const IN_TYPE* input, \ |
136 | const INDEX_TYPE* indices, \ |
137 | const OFFSET_TYPE* offsets_or_lengths, \ |
138 | const float* weights, \ |
139 | bool normalize_by_lengths, \ |
140 | float* out, \ |
141 | bool is_weight_positional, \ |
142 | bool use_offsets, \ |
143 | bool is_bf16); |
144 | |
145 | #define INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, INDEX_TYPE) \ |
146 | INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, std::int32_t) \ |
147 | INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, std::int64_t) |
148 | |
149 | #define INSTANTIATE_SPMDM_INDEX_T(IN_TYPE) \ |
150 | INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, std::int32_t) \ |
151 | INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, std::int64_t) |
152 | |
153 | INSTANTIATE_SPMDM_INDEX_T(float) |
154 | INSTANTIATE_SPMDM_INDEX_T(float16) |
155 | INSTANTIATE_SPMDM_INDEX_T(std::uint8_t) |
156 | |
157 | #undef INSTANTIATE_SPMDM_INDEX_T |
158 | #undef INSTANTIATE_SPMDM_OFFSET_T |
159 | #undef INSTANTIATE_SPMDM_BASE |
160 | |
161 | } // namespace internal |
162 | } // namespace fbgemm |
163 | |