1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#include <cassert>
8#include <cmath>
9#include "RefImplementations.h"
10#include "fbgemm/FbgemmEmbedding.h"
11
12#include "fbgemm/Types.h"
13
14namespace fbgemm {
15namespace internal {
16
17template <typename InType, typename IndexType, typename OffsetType>
18bool EmbeddingSpMDMBlockSize1_(
19 const std::int64_t output_size,
20 const std::int64_t index_size,
21 const std::int64_t data_size, // the number of rows in input
22 const InType* input,
23 const IndexType* indices,
24 const OffsetType* offsets_or_lengths,
25 const float* weights, // optional, can be null for non-weighted sum
26 bool normalize_by_lengths,
27 float* out,
28 bool is_weight_positional,
29 bool use_offsets,
30 bool is_bf16) {
31 int64_t current = 0;
32 for (int m = 0; m < output_size; ++m) {
33 out[m] = 0;
34 int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
35 : offsets_or_lengths[m];
36 if (current + len > index_size) {
37 return false;
38 }
39 int i = 0;
40
41 // The following code doesn't speedup
42#if 0
43 constexpr int VLEN = std::is_same<IndexType, std::int64_t>::value ? 4 : 8;
44 for (; i < lengths[m] / VLEN * VLEN; i += VLEN) {
45 if (std::is_same<IndexType, std::int64_t>::value) {
46 __m256i idx_v = _mm256_lddqu_si256(
47 reinterpret_cast<const __m256i*>(indices + current));
48 // Should be none true
49 int mask1 = _mm256_movemask_pd(_mm256_castsi256_pd(
50 _mm256_cmpgt_epi64(_mm256_setzero_si256(), idx_v)));
51 // Should be all true
52 int mask2 = _mm256_movemask_pd(_mm256_castsi256_pd(
53 _mm256_cmpgt_epi64(_mm256_set1_epi64x(data_size), idx_v)));
54 if (mask1 || mask2 != 0x0f) {
55 return false;
56 }
57
58 __m128 in_v = _mm256_i64gather_ps(input, idx_v, 4);
59 alignas(64) float in_buf[VLEN];
60 _mm_store_ps(in_buf, in_v);
61 for (int j = 0; j < VLEN; ++j) {
62 if (weights) {
63 out[m] = std::fma(
64 weights[is_weight_positional ? i + j : current + j],
65 in_buf[j],
66 out[m]);
67 } else {
68 out[m] += in_buf[j];
69 }
70 }
71 } else {
72 __m256i idx_v = _mm256_lddqu_si256(
73 reinterpret_cast<const __m256i*>(indices + current));
74 // Should be none true
75 int mask1 = _mm256_movemask_ps(_mm256_castsi256_ps(
76 _mm256_cmpgt_epi32(_mm256_setzero_si256(), idx_v)));
77 // Should be all true
78 int mask2 = _mm256_movemask_ps(_mm256_castsi256_ps(
79 _mm256_cmpgt_epi32(_mm256_set1_epi32(data_size), idx_v)));
80 if (mask1 || mask2 != 0x00ff) {
81 return false;
82 }
83
84 __m256 in_v = _mm256_i32gather_ps(input, idx_v, 4);
85 alignas(64) float in_buf[VLEN];
86 _mm256_store_ps(in_buf, in_v);
87 for (int j = 0; j < VLEN; ++j) {
88 if (weights) {
89 out[m] = std::fma(
90 weights[is_weight_positional ? i + j : current + j],
91 in_buf[j],
92 out[m]);
93 } else {
94 out[m] += in_buf[j];
95 }
96 }
97 }
98
99 current += VLEN;
100 }
101#endif
102
103 float temp = out[m];
104 for (; i < len; ++i) {
105 int64_t idx = indices[current];
106 if (idx < 0 || idx >= data_size) {
107 out[m] = temp;
108 return false;
109 }
110
111 float w = 1.f;
112 if (weights) {
113 w = weights[is_weight_positional ? i : current];
114 }
115
116 const InType* inptr = input + indices[current];
117 temp = std::fma(w, convert_to_float_ref(*inptr, is_bf16), temp);
118
119 ++current;
120 }
121 if (normalize_by_lengths && len) {
122 float scale = 1.f / len;
123 temp *= scale;
124 }
125 out[m] = temp;
126 }
127 return current == index_size;
128}
129
130#define INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE) \
131 template bool EmbeddingSpMDMBlockSize1_( \
132 const std::int64_t output_size, \
133 const std::int64_t index_size, \
134 const std::int64_t data_size, \
135 const IN_TYPE* input, \
136 const INDEX_TYPE* indices, \
137 const OFFSET_TYPE* offsets_or_lengths, \
138 const float* weights, \
139 bool normalize_by_lengths, \
140 float* out, \
141 bool is_weight_positional, \
142 bool use_offsets, \
143 bool is_bf16);
144
145#define INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, INDEX_TYPE) \
146 INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, std::int32_t) \
147 INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, std::int64_t)
148
149#define INSTANTIATE_SPMDM_INDEX_T(IN_TYPE) \
150 INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, std::int32_t) \
151 INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, std::int64_t)
152
153INSTANTIATE_SPMDM_INDEX_T(float)
154INSTANTIATE_SPMDM_INDEX_T(float16)
155INSTANTIATE_SPMDM_INDEX_T(std::uint8_t)
156
157#undef INSTANTIATE_SPMDM_INDEX_T
158#undef INSTANTIATE_SPMDM_OFFSET_T
159#undef INSTANTIATE_SPMDM_BASE
160
161} // namespace internal
162} // namespace fbgemm
163