1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_DNNL_TRAITS_HPP
18#define COMMON_DNNL_TRAITS_HPP
19
20#include <assert.h>
21#include <stdint.h>
22
23#include "oneapi/dnnl/dnnl.h"
24
25#include "bfloat16.hpp"
26#include "c_types_map.hpp"
27#include "float16.hpp"
28#include "nstl.hpp"
29#include "opdesc.hpp"
30#include "utils.hpp"
31#include "z_magic.hpp"
32
33namespace dnnl {
34namespace impl {
35
36template <data_type_t>
37struct prec_traits {}; /* ::type -> float */
38template <typename>
39struct data_traits {}; /* ::data_type -> f32 */
40template <int>
41struct typesize_traits {}; /* ::data_type_size -> f32 */
42template <primitive_kind_t>
43struct pkind_traits {}; /* ::desc_type, ::query_d */
44
45template <>
46struct prec_traits<data_type::f16> {
47 typedef float16_t type;
48};
49template <>
50struct prec_traits<data_type::bf16> {
51 typedef bfloat16_t type;
52};
53template <>
54struct prec_traits<data_type::f32> {
55 typedef float type;
56};
57template <>
58struct prec_traits<data_type::f64> {
59 typedef double type;
60};
61template <>
62struct prec_traits<data_type::s32> {
63 typedef int32_t type;
64};
65template <>
66struct prec_traits<data_type::s8> {
67 typedef int8_t type;
68};
69template <>
70struct prec_traits<data_type::u8> {
71 typedef uint8_t type;
72};
73
74template <>
75struct data_traits<float16_t> {
76 static constexpr data_type_t data_type = data_type::f16;
77};
78template <>
79struct data_traits<bfloat16_t> {
80 static constexpr data_type_t data_type = data_type::bf16;
81};
82template <>
83struct data_traits<float> {
84 static constexpr data_type_t data_type = data_type::f32;
85};
86template <>
87struct data_traits<int32_t> {
88 static constexpr data_type_t data_type = data_type::s32;
89};
90template <>
91struct data_traits<int8_t> {
92 static constexpr data_type_t data_type = data_type::s8;
93};
94template <>
95struct data_traits<uint8_t> {
96 static constexpr data_type_t data_type = data_type::u8;
97};
98
99template <>
100struct typesize_traits<4> {
101 typedef float type;
102};
103template <>
104struct typesize_traits<2> {
105 typedef int16_t type;
106};
107template <>
108struct typesize_traits<1> {
109 typedef uint8_t type;
110};
111
112#define PKIND_TRAITS_INST(op) \
113 template <> \
114 struct pkind_traits<primitive_kind::op> { \
115 typedef CONCAT2(op, _desc_t) desc_type; \
116 }
117PKIND_TRAITS_INST(convolution);
118PKIND_TRAITS_INST(deconvolution);
119PKIND_TRAITS_INST(shuffle);
120PKIND_TRAITS_INST(eltwise);
121PKIND_TRAITS_INST(softmax);
122PKIND_TRAITS_INST(pooling);
123PKIND_TRAITS_INST(prelu);
124PKIND_TRAITS_INST(lrn);
125PKIND_TRAITS_INST(batch_normalization);
126PKIND_TRAITS_INST(layer_normalization);
127PKIND_TRAITS_INST(inner_product);
128PKIND_TRAITS_INST(rnn);
129PKIND_TRAITS_INST(gemm);
130PKIND_TRAITS_INST(zero_pad);
131PKIND_TRAITS_INST(binary);
132PKIND_TRAITS_INST(matmul);
133PKIND_TRAITS_INST(resampling);
134PKIND_TRAITS_INST(reduction);
135#undef PKIND_TRAITS_INST
136
137} // namespace impl
138} // namespace dnnl
139
140#endif
141
142// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
143