1/**
2 * Copyright (c) 2017-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#pragma once
10
11#include "gloo/types.h"
12
13namespace gloo {
14
15template <typename T>
16void sum(void* c_, const void* a_, const void* b_, size_t n) {
17 T* c = static_cast<T*>(c_);
18 const T* a = static_cast<const T*>(a_);
19 const T* b = static_cast<const T*>(b_);
20 for (auto i = 0; i < n; i++) {
21 c[i] = a[i] + b[i];
22 }
23}
24
25template <typename T>
26void sum(T* a, const T* b, size_t n) {
27 sum<T>(a, a, b, n);
28}
29
30template <typename T>
31void product(void* c_, const void* a_, const void* b_, size_t n) {
32 T* c = static_cast<T*>(c_);
33 const T* a = static_cast<const T*>(a_);
34 const T* b = static_cast<const T*>(b_);
35 for (auto i = 0; i < n; i++) {
36 c[i] = a[i] * b[i];
37 }
38}
39
40template <typename T>
41void product(T* a, const T* b, size_t n) {
42 product<T>(a, a, b, n);
43}
44
45template <typename T>
46void max(void* c_, const void* a_, const void* b_, size_t n) {
47 T* c = static_cast<T*>(c_);
48 const T* a = static_cast<const T*>(a_);
49 const T* b = static_cast<const T*>(b_);
50 for (auto i = 0; i < n; i++) {
51 c[i] = std::max(a[i], b[i]);
52 }
53}
54
55template <typename T>
56void max(T* a, const T* b, size_t n) {
57 max<T>(a, a, b, n);
58}
59
60template <typename T>
61void min(void* c_, const void* a_, const void* b_, size_t n) {
62 T* c = static_cast<T*>(c_);
63 const T* a = static_cast<const T*>(a_);
64 const T* b = static_cast<const T*>(b_);
65 for (auto i = 0; i < n; i++) {
66 c[i] = std::min(a[i], b[i]);
67 }
68}
69
70template <typename T>
71void min(T* a, const T* b, size_t n) {
72 min<T>(a, a, b, n);
73}
74
75template <typename T>
76T roundUp(T value, T multiple) {
77 T remainder = value % multiple;
78 if (remainder == 0) {
79 return value;
80 }
81 return value + multiple - remainder;
82}
83
84inline uint32_t log2ceil(uint32_t value) {
85 uint32_t dim = 0;
86#if defined(__GNUC__)
87 if (value <= 1)
88 return 0;
89 dim = 32 - __builtin_clz(value - 1);
90#else
91 for (uint32_t size = 1; size < value; ++dim, size <<= 1) /* empty */;
92#endif // defined(__GNUC__)
93 return dim;
94}
95
96#if GLOO_USE_AVX
97
98template <>
99void sum<float16>(void* c, const void* a, const void* b, size_t n);
100extern template void
101sum<float16>(void* c, const void* a, const void* b, size_t n);
102
103template <>
104void product<float16>(void* c, const void* a, const void* b, size_t n);
105extern template void
106product<float16>(void* c, const void* a, const void* b, size_t n);
107
108template <>
109void max<float16>(void* c, const void* a, const void* b, size_t n);
110extern template void
111max<float16>(void* c, const void* a, const void* b, size_t n);
112
113template <>
114void min<float16>(void* c, const void* a, const void* b, size_t n);
115extern template void
116min<float16>(void* c, const void* a, const void* b, size_t n);
117
118#endif
119
120} // namespace gloo
121