1#include <c10/core/DefaultDtype.h>
2#include <c10/util/typeid.h>
3
4namespace c10 {
5static auto default_dtype = caffe2::TypeMeta::Make<float>();
6static auto default_dtype_as_scalartype = default_dtype.toScalarType();
7static auto default_complex_dtype =
8 caffe2::TypeMeta::Make<c10::complex<float>>();
9
10void set_default_dtype(caffe2::TypeMeta dtype) {
11 default_dtype = dtype;
12 default_dtype_as_scalartype = default_dtype.toScalarType();
13 switch (default_dtype_as_scalartype) {
14 case ScalarType::Half:
15 default_complex_dtype = ScalarType::ComplexHalf;
16 break;
17 case ScalarType::Double:
18 default_complex_dtype = ScalarType::ComplexDouble;
19 break;
20 default:
21 default_complex_dtype = ScalarType::ComplexFloat;
22 break;
23 }
24}
25
26const caffe2::TypeMeta get_default_dtype() {
27 return default_dtype;
28}
29ScalarType get_default_dtype_as_scalartype() {
30 return default_dtype_as_scalartype;
31}
32const caffe2::TypeMeta get_default_complex_dtype() {
33 return default_complex_dtype;
34}
35} // namespace c10
36