1#pragma once
2#ifndef C10_UTIL_CPP17_H_
3#define C10_UTIL_CPP17_H_
4
5#include <c10/macros/Macros.h>
6#include <cstdlib>
7#include <functional>
8#include <memory>
9#include <sstream>
10#include <string>
11#include <type_traits>
12#include <utility>
13
14#if !defined(__clang__) && !defined(_MSC_VER) && defined(__GNUC__) && \
15 __GNUC__ < 5
16#error \
17 "You're trying to build PyTorch with a too old version of GCC. We need GCC 5 or later."
18#endif
19
20#if defined(__clang__) && __clang_major__ < 4
21#error \
22 "You're trying to build PyTorch with a too old version of Clang. We need Clang 4 or later."
23#endif
24
25#if (defined(_MSC_VER) && (!defined(_MSVC_LANG) || _MSVC_LANG < 201402L)) || \
26 (!defined(_MSC_VER) && __cplusplus < 201402L)
27#error You need C++14 to compile PyTorch
28#endif
29
30#if defined(_WIN32) && (defined(min) || defined(max))
31#error Macro clash with min and max -- define NOMINMAX when compiling your program on Windows
32#endif
33
34/*
35 * This header adds some polyfills with C++17 functionality
36 */
37
38namespace c10 {
39
40// in c++17 std::result_of has been superceded by std::invoke_result. Since
41// c++20, std::result_of is removed.
42template <typename F, typename... args>
43#if defined(__cpp_lib_is_invocable) && __cpp_lib_is_invocable >= 201703L
44using invoke_result = typename std::invoke_result<F, args...>;
45#else
46using invoke_result = typename std::result_of<F && (args && ...)>;
47#endif
48
49template <typename F, typename... args>
50using invoke_result_t = typename invoke_result<F, args...>::type;
51
52// std::is_pod is deprecated in C++20, std::is_standard_layout and
53// std::is_trivial are introduced in C++11, std::conjunction has been introduced
54// in C++17.
55template <typename T>
56#if defined(__cpp_lib_logical_traits) && __cpp_lib_logical_traits >= 201510L
57using is_pod = std::conjunction<std::is_standard_layout<T>, std::is_trivial<T>>;
58#else
59using is_pod = std::is_pod<T>;
60#endif
61
62template <typename T>
63constexpr bool is_pod_v = is_pod<T>::value;
64
65namespace guts {
66
67template <typename Base, typename Child, typename... Args>
68typename std::enable_if<
69 !std::is_array<Base>::value && !std::is_array<Child>::value &&
70 std::is_base_of<Base, Child>::value,
71 std::unique_ptr<Base>>::type
72make_unique_base(Args&&... args) {
73 return std::unique_ptr<Base>(new Child(std::forward<Args>(args)...));
74}
75
76#if defined(__cpp_lib_logical_traits) && !(defined(_MSC_VER) && _MSC_VER < 1920)
77
78template <class... B>
79using conjunction = std::conjunction<B...>;
80template <class... B>
81using disjunction = std::disjunction<B...>;
82template <bool B>
83using bool_constant = std::bool_constant<B>;
84template <class B>
85using negation = std::negation<B>;
86
87#else
88
89// Implementation taken from http://en.cppreference.com/w/cpp/types/conjunction
90template <class...>
91struct conjunction : std::true_type {};
92template <class B1>
93struct conjunction<B1> : B1 {};
94template <class B1, class... Bn>
95struct conjunction<B1, Bn...>
96 : std::conditional_t<bool(B1::value), conjunction<Bn...>, B1> {};
97
98// Implementation taken from http://en.cppreference.com/w/cpp/types/disjunction
99template <class...>
100struct disjunction : std::false_type {};
101template <class B1>
102struct disjunction<B1> : B1 {};
103template <class B1, class... Bn>
104struct disjunction<B1, Bn...>
105 : std::conditional_t<bool(B1::value), B1, disjunction<Bn...>> {};
106
107// Implementation taken from
108// http://en.cppreference.com/w/cpp/types/integral_constant
109template <bool B>
110using bool_constant = std::integral_constant<bool, B>;
111
112// Implementation taken from http://en.cppreference.com/w/cpp/types/negation
113template <class B>
114struct negation : bool_constant<!bool(B::value)> {};
115
116#endif
117
118#ifdef __cpp_lib_void_t
119
120template <class T>
121using void_t = std::void_t<T>;
122
123#else
124
125// Implementation taken from http://en.cppreference.com/w/cpp/types/void_t
126// (it takes CWG1558 into account and also works for older compilers)
127template <typename... Ts>
128struct make_void {
129 typedef void type;
130};
131template <typename... Ts>
132using void_t = typename make_void<Ts...>::type;
133
134#endif
135
136#if defined(USE_ROCM)
137// rocm doesn't like the C10_HOST_DEVICE
138#define CUDA_HOST_DEVICE
139#else
140#define CUDA_HOST_DEVICE C10_HOST_DEVICE
141#endif
142
143#if defined(__cpp_lib_apply) && !defined(__CUDA_ARCH__)
144
145template <class F, class Tuple>
146CUDA_HOST_DEVICE inline constexpr decltype(auto) apply(F&& f, Tuple&& t) {
147 return std::apply(std::forward<F>(f), std::forward<Tuple>(t));
148}
149
150#else
151
152// Implementation from http://en.cppreference.com/w/cpp/utility/apply (but
153// modified)
154// TODO This is an incomplete implementation of std::apply, not working for
155// member functions.
156namespace detail {
157template <class F, class Tuple, std::size_t... INDEX>
158#if defined(_MSC_VER)
159// MSVC has a problem with the decltype() return type, but it also doesn't need
160// it
161C10_HOST_DEVICE constexpr auto apply_impl(
162 F&& f,
163 Tuple&& t,
164 std::index_sequence<INDEX...>)
165#else
166// GCC/Clang need the decltype() return type
167CUDA_HOST_DEVICE constexpr decltype(auto) apply_impl(
168 F&& f,
169 Tuple&& t,
170 std::index_sequence<INDEX...>)
171#endif
172{
173 return std::forward<F>(f)(std::get<INDEX>(std::forward<Tuple>(t))...);
174}
175} // namespace detail
176
177template <class F, class Tuple>
178CUDA_HOST_DEVICE constexpr decltype(auto) apply(F&& f, Tuple&& t) {
179 return detail::apply_impl(
180 std::forward<F>(f),
181 std::forward<Tuple>(t),
182 std::make_index_sequence<
183 std::tuple_size<std::remove_reference_t<Tuple>>::value>{});
184}
185
186#endif
187
188#undef CUDA_HOST_DEVICE
189
190template <typename Functor, typename... Args>
191typename std::enable_if<
192 std::is_member_pointer<typename std::decay<Functor>::type>::value,
193 typename c10::invoke_result_t<Functor, Args...>>::type
194invoke(Functor&& f, Args&&... args) {
195 return std::mem_fn(std::forward<Functor>(f))(std::forward<Args>(args)...);
196}
197
198template <typename Functor, typename... Args>
199typename std::enable_if<
200 !std::is_member_pointer<typename std::decay<Functor>::type>::value,
201 typename c10::invoke_result_t<Functor, Args...>>::type
202invoke(Functor&& f, Args&&... args) {
203 return std::forward<Functor>(f)(std::forward<Args>(args)...);
204}
205
206namespace detail {
207struct _identity final {
208 template <class T>
209 using type_identity = T;
210
211 template <class T>
212 decltype(auto) operator()(T&& arg) {
213 return std::forward<T>(arg);
214 }
215};
216
217template <class Func, class Enable = void>
218struct function_takes_identity_argument : std::false_type {};
219#if defined(_MSC_VER)
220// For some weird reason, MSVC shows a compiler error when using guts::void_t
221// instead of std::void_t. But we're only building on MSVC versions that have
222// std::void_t, so let's just use that one.
223template <class Func>
224struct function_takes_identity_argument<
225 Func,
226 std::void_t<decltype(std::declval<Func>()(_identity()))>> : std::true_type {
227};
228#else
229template <class Func>
230struct function_takes_identity_argument<
231 Func,
232 void_t<decltype(std::declval<Func>()(_identity()))>> : std::true_type {};
233#endif
234
235template <bool Condition>
236struct _if_constexpr;
237
238template <>
239struct _if_constexpr<true> final {
240 template <
241 class ThenCallback,
242 class ElseCallback,
243 std::enable_if_t<
244 function_takes_identity_argument<ThenCallback>::value,
245 void*> = nullptr>
246 static decltype(auto) call(
247 ThenCallback&& thenCallback,
248 ElseCallback&& /* elseCallback */) {
249 // The _identity instance passed in can be used to delay evaluation of an
250 // expression, because the compiler can't know that it's just the identity
251 // we're passing in.
252 return thenCallback(_identity());
253 }
254
255 template <
256 class ThenCallback,
257 class ElseCallback,
258 std::enable_if_t<
259 !function_takes_identity_argument<ThenCallback>::value,
260 void*> = nullptr>
261 static decltype(auto) call(
262 ThenCallback&& thenCallback,
263 ElseCallback&& /* elseCallback */) {
264 return thenCallback();
265 }
266};
267
268template <>
269struct _if_constexpr<false> final {
270 template <
271 class ThenCallback,
272 class ElseCallback,
273 std::enable_if_t<
274 function_takes_identity_argument<ElseCallback>::value,
275 void*> = nullptr>
276 static decltype(auto) call(
277 ThenCallback&& /* thenCallback */,
278 ElseCallback&& elseCallback) {
279 // The _identity instance passed in can be used to delay evaluation of an
280 // expression, because the compiler can't know that it's just the identity
281 // we're passing in.
282 return elseCallback(_identity());
283 }
284
285 template <
286 class ThenCallback,
287 class ElseCallback,
288 std::enable_if_t<
289 !function_takes_identity_argument<ElseCallback>::value,
290 void*> = nullptr>
291 static decltype(auto) call(
292 ThenCallback&& /* thenCallback */,
293 ElseCallback&& elseCallback) {
294 return elseCallback();
295 }
296};
297} // namespace detail
298
299/*
300 * Get something like C++17 if constexpr in C++14.
301 *
302 * Example 1: simple constexpr if/then/else
303 * template<int arg> int increment_absolute_value() {
304 * int result = arg;
305 * if_constexpr<(arg > 0)>(
306 * [&] { ++result; } // then-case
307 * [&] { --result; } // else-case
308 * );
309 * return result;
310 * }
311 *
312 * Example 2: without else case (i.e. conditionally prune code from assembly)
313 * template<int arg> int decrement_if_positive() {
314 * int result = arg;
315 * if_constexpr<(arg > 0)>(
316 * // This decrement operation is only present in the assembly for
317 * // template instances with arg > 0.
318 * [&] { --result; }
319 * );
320 * return result;
321 * }
322 *
323 * Example 3: branch based on type (i.e. replacement for SFINAE)
324 * struct MyClass1 {int value;};
325 * struct MyClass2 {int val};
326 * template <class T>
327 * int func(T t) {
328 * return if_constexpr<std::is_same<T, MyClass1>::value>(
329 * [&](auto _) { return _(t).value; }, // this code is invalid for T ==
330 * MyClass2, so a regular non-constexpr if statement wouldn't compile
331 * [&](auto _) { return _(t).val; } // this code is invalid for T ==
332 * MyClass1
333 * );
334 * }
335 *
336 * Note: The _ argument passed in Example 3 is the identity function, i.e. it
337 * does nothing. It is used to force the compiler to delay type checking,
338 * because the compiler doesn't know what kind of _ is passed in. Without it,
339 * the compiler would fail when you try to access t.value but the member doesn't
340 * exist.
341 *
342 * Note: In Example 3, both branches return int, so func() returns int. This is
343 * not necessary. If func() had a return type of "auto", then both branches
344 * could return different types, say func<MyClass1>() could return int and
345 * func<MyClass2>() could return string.
346 *
347 * Note: if_constexpr<cond, t, f> is *eager* w.r.t. template expansion - meaning
348 * this polyfill does not behave like a true "if statement at compilation time".
349 * The `_` trick above only defers typechecking, which happens after
350 * templates have been expanded. (Of course this is all that's necessary for
351 * many use cases).
352 */
353template <bool Condition, class ThenCallback, class ElseCallback>
354decltype(auto) if_constexpr(
355 ThenCallback&& thenCallback,
356 ElseCallback&& elseCallback) {
357#if defined(__cpp_if_constexpr)
358 // If we have C++17, just use it's "if constexpr" feature instead of wrapping
359 // it. This will give us better error messages.
360 if constexpr (Condition) {
361 if constexpr (detail::function_takes_identity_argument<
362 ThenCallback>::value) {
363 // Note that we use static_cast<T&&>(t) instead of std::forward (or
364 // ::std::forward) because using the latter produces some compilation
365 // errors about ambiguous `std` on MSVC when using C++17. This static_cast
366 // is just what std::forward is doing under the hood, and is equivalent.
367 return static_cast<ThenCallback&&>(thenCallback)(detail::_identity());
368 } else {
369 return static_cast<ThenCallback&&>(thenCallback)();
370 }
371 } else {
372 if constexpr (detail::function_takes_identity_argument<
373 ElseCallback>::value) {
374 return static_cast<ElseCallback&&>(elseCallback)(detail::_identity());
375 } else {
376 return static_cast<ElseCallback&&>(elseCallback)();
377 }
378 }
379#else
380 // C++14 implementation of if constexpr
381 return detail::_if_constexpr<Condition>::call(
382 static_cast<ThenCallback&&>(thenCallback),
383 static_cast<ElseCallback&&>(elseCallback));
384#endif
385}
386
387template <bool Condition, class ThenCallback>
388decltype(auto) if_constexpr(ThenCallback&& thenCallback) {
389#if defined(__cpp_if_constexpr)
390 // If we have C++17, just use it's "if constexpr" feature instead of wrapping
391 // it. This will give us better error messages.
392 if constexpr (Condition) {
393 if constexpr (detail::function_takes_identity_argument<
394 ThenCallback>::value) {
395 // Note that we use static_cast<T&&>(t) instead of std::forward (or
396 // ::std::forward) because using the latter produces some compilation
397 // errors about ambiguous `std` on MSVC when using C++17. This static_cast
398 // is just what std::forward is doing under the hood, and is equivalent.
399 return static_cast<ThenCallback&&>(thenCallback)(detail::_identity());
400 } else {
401 return static_cast<ThenCallback&&>(thenCallback)();
402 }
403 }
404#else
405 // C++14 implementation of if constexpr
406 return if_constexpr<Condition>(
407 static_cast<ThenCallback&&>(thenCallback), [](auto) {});
408#endif
409}
410
411// GCC 4.8 doesn't define std::to_string, even though that's in C++11. Let's
412// define it.
413namespace detail {
414class DummyClassForToString final {};
415} // namespace detail
416} // namespace guts
417} // namespace c10
418namespace std {
419// We use SFINAE to detect if std::to_string exists for a type, but that only
420// works if the function name is defined. So let's define a std::to_string for a
421// dummy type. If you're getting an error here saying that this overload doesn't
422// match your std::to_string() call, then you're calling std::to_string() but
423// should be calling c10::guts::to_string().
424inline std::string to_string(c10::guts::detail::DummyClassForToString) {
425 return "";
426}
427
428} // namespace std
429namespace c10 {
430namespace guts {
431namespace detail {
432
433template <class T, class Enable = void>
434struct to_string_ final {
435 static std::string call(T value) {
436 std::ostringstream str;
437 str << value;
438 return str.str();
439 }
440};
441// If a std::to_string exists, use that instead
442template <class T>
443struct to_string_<T, void_t<decltype(std::to_string(std::declval<T>()))>>
444 final {
445 static std::string call(T value) {
446 return std::to_string(value);
447 }
448};
449} // namespace detail
450template <class T>
451inline std::string to_string(T value) {
452 return detail::to_string_<T>::call(value);
453}
454
455template <class T>
456constexpr const T& min(const T& a, const T& b) {
457 return (b < a) ? b : a;
458}
459
460template <class T>
461constexpr const T& max(const T& a, const T& b) {
462 return (a < b) ? b : a;
463}
464
465} // namespace guts
466} // namespace c10
467
468#endif // C10_UTIL_CPP17_H_
469