1 | #pragma once |
---|---|
2 | |
3 | #include <c10/core/ScalarType.h> |
4 | #include <c10/util/Optional.h> |
5 | #include <c10/util/typeid.h> |
6 | |
7 | // these just expose TypeMeta/ScalarType bridge functions in c10 |
8 | // TODO move to typeid.h (or codemod away) when TypeMeta et al |
9 | // are moved from caffe2 to c10 (see note at top of typeid.h) |
10 | |
11 | namespace c10 { |
12 | |
13 | /** |
14 | * convert ScalarType enum values to TypeMeta handles |
15 | */ |
16 | static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) { |
17 | return caffe2::TypeMeta::fromScalarType(scalar_type); |
18 | } |
19 | |
20 | /** |
21 | * convert TypeMeta handles to ScalarType enum values |
22 | */ |
23 | static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) { |
24 | return dtype.toScalarType(); |
25 | } |
26 | |
27 | /** |
28 | * typeMetaToScalarType(), lifted to optional |
29 | */ |
30 | static inline optional<at::ScalarType> optTypeMetaToScalarType( |
31 | optional<caffe2::TypeMeta> type_meta) { |
32 | if (!type_meta.has_value()) { |
33 | return c10::nullopt; |
34 | } |
35 | return type_meta->toScalarType(); |
36 | } |
37 | |
38 | /** |
39 | * convenience: equality across TypeMeta/ScalarType conversion |
40 | */ |
41 | static inline bool operator==(ScalarType t, caffe2::TypeMeta m) { |
42 | return m.isScalarType(t); |
43 | } |
44 | |
45 | static inline bool operator==(caffe2::TypeMeta m, ScalarType t) { |
46 | return t == m; |
47 | } |
48 | |
49 | static inline bool operator!=(ScalarType t, caffe2::TypeMeta m) { |
50 | return !(t == m); |
51 | } |
52 | |
53 | static inline bool operator!=(caffe2::TypeMeta m, ScalarType t) { |
54 | return !(t == m); |
55 | } |
56 | |
57 | } // namespace c10 |
58 |