1 | /******************************************************************************* |
---|---|
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_DNNL_TRAITS_HPP |
18 | #define COMMON_DNNL_TRAITS_HPP |
19 | |
20 | #include <assert.h> |
21 | #include <stdint.h> |
22 | |
23 | #include "oneapi/dnnl/dnnl.h" |
24 | |
25 | #include "bfloat16.hpp" |
26 | #include "c_types_map.hpp" |
27 | #include "float16.hpp" |
28 | #include "nstl.hpp" |
29 | #include "opdesc.hpp" |
30 | #include "utils.hpp" |
31 | #include "z_magic.hpp" |
32 | |
33 | namespace dnnl { |
34 | namespace impl { |
35 | |
36 | template <data_type_t> |
37 | struct prec_traits {}; /* ::type -> float */ |
38 | template <typename> |
39 | struct data_traits {}; /* ::data_type -> f32 */ |
40 | template <int> |
41 | struct typesize_traits {}; /* ::data_type_size -> f32 */ |
42 | template <primitive_kind_t> |
43 | struct pkind_traits {}; /* ::desc_type, ::query_d */ |
44 | |
45 | template <> |
46 | struct prec_traits<data_type::f16> { |
47 | typedef float16_t type; |
48 | }; |
49 | template <> |
50 | struct prec_traits<data_type::bf16> { |
51 | typedef bfloat16_t type; |
52 | }; |
53 | template <> |
54 | struct prec_traits<data_type::f32> { |
55 | typedef float type; |
56 | }; |
57 | template <> |
58 | struct prec_traits<data_type::f64> { |
59 | typedef double type; |
60 | }; |
61 | template <> |
62 | struct prec_traits<data_type::s32> { |
63 | typedef int32_t type; |
64 | }; |
65 | template <> |
66 | struct prec_traits<data_type::s8> { |
67 | typedef int8_t type; |
68 | }; |
69 | template <> |
70 | struct prec_traits<data_type::u8> { |
71 | typedef uint8_t type; |
72 | }; |
73 | |
74 | template <> |
75 | struct data_traits<float16_t> { |
76 | static constexpr data_type_t data_type = data_type::f16; |
77 | }; |
78 | template <> |
79 | struct data_traits<bfloat16_t> { |
80 | static constexpr data_type_t data_type = data_type::bf16; |
81 | }; |
82 | template <> |
83 | struct data_traits<float> { |
84 | static constexpr data_type_t data_type = data_type::f32; |
85 | }; |
86 | template <> |
87 | struct data_traits<int32_t> { |
88 | static constexpr data_type_t data_type = data_type::s32; |
89 | }; |
90 | template <> |
91 | struct data_traits<int8_t> { |
92 | static constexpr data_type_t data_type = data_type::s8; |
93 | }; |
94 | template <> |
95 | struct data_traits<uint8_t> { |
96 | static constexpr data_type_t data_type = data_type::u8; |
97 | }; |
98 | |
99 | template <> |
100 | struct typesize_traits<4> { |
101 | typedef float type; |
102 | }; |
103 | template <> |
104 | struct typesize_traits<2> { |
105 | typedef int16_t type; |
106 | }; |
107 | template <> |
108 | struct typesize_traits<1> { |
109 | typedef uint8_t type; |
110 | }; |
111 | |
112 | #define PKIND_TRAITS_INST(op) \ |
113 | template <> \ |
114 | struct pkind_traits<primitive_kind::op> { \ |
115 | typedef CONCAT2(op, _desc_t) desc_type; \ |
116 | } |
117 | PKIND_TRAITS_INST(convolution); |
118 | PKIND_TRAITS_INST(deconvolution); |
119 | PKIND_TRAITS_INST(shuffle); |
120 | PKIND_TRAITS_INST(eltwise); |
121 | PKIND_TRAITS_INST(softmax); |
122 | PKIND_TRAITS_INST(pooling); |
123 | PKIND_TRAITS_INST(prelu); |
124 | PKIND_TRAITS_INST(lrn); |
125 | PKIND_TRAITS_INST(batch_normalization); |
126 | PKIND_TRAITS_INST(layer_normalization); |
127 | PKIND_TRAITS_INST(inner_product); |
128 | PKIND_TRAITS_INST(rnn); |
129 | PKIND_TRAITS_INST(gemm); |
130 | PKIND_TRAITS_INST(zero_pad); |
131 | PKIND_TRAITS_INST(binary); |
132 | PKIND_TRAITS_INST(matmul); |
133 | PKIND_TRAITS_INST(resampling); |
134 | PKIND_TRAITS_INST(reduction); |
135 | #undef PKIND_TRAITS_INST |
136 | |
137 | } // namespace impl |
138 | } // namespace dnnl |
139 | |
140 | #endif |
141 | |
142 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
143 |