1 | #include <c10/core/DefaultDtype.h> |
---|---|
2 | #include <c10/util/typeid.h> |
3 | |
4 | namespace c10 { |
5 | static auto default_dtype = caffe2::TypeMeta::Make<float>(); |
6 | static auto default_dtype_as_scalartype = default_dtype.toScalarType(); |
7 | static auto default_complex_dtype = |
8 | caffe2::TypeMeta::Make<c10::complex<float>>(); |
9 | |
10 | void set_default_dtype(caffe2::TypeMeta dtype) { |
11 | default_dtype = dtype; |
12 | default_dtype_as_scalartype = default_dtype.toScalarType(); |
13 | switch (default_dtype_as_scalartype) { |
14 | case ScalarType::Half: |
15 | default_complex_dtype = ScalarType::ComplexHalf; |
16 | break; |
17 | case ScalarType::Double: |
18 | default_complex_dtype = ScalarType::ComplexDouble; |
19 | break; |
20 | default: |
21 | default_complex_dtype = ScalarType::ComplexFloat; |
22 | break; |
23 | } |
24 | } |
25 | |
26 | const caffe2::TypeMeta get_default_dtype() { |
27 | return default_dtype; |
28 | } |
29 | ScalarType get_default_dtype_as_scalartype() { |
30 | return default_dtype_as_scalartype; |
31 | } |
32 | const caffe2::TypeMeta get_default_complex_dtype() { |
33 | return default_complex_dtype; |
34 | } |
35 | } // namespace c10 |
36 |