1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <float.h> |
18 | #include <math.h> |
19 | #include <stdio.h> |
20 | #include <stdlib.h> |
21 | #include <string.h> |
22 | |
23 | #include "oneapi/dnnl/dnnl.h" |
24 | |
25 | #include "dnnl_common.hpp" |
26 | |
27 | #include "pool/pool.hpp" |
28 | |
29 | namespace pool { |
30 | |
31 | /* cfgs definition |
32 | * arrays: SRC, UNUSED, UNUSED, DST |
33 | * params: {data_type, min, max, f_min, f_max, eps} |
34 | */ |
35 | |
36 | // though integers are expected, eps is needed to cover division error |
37 | const dt_conf_t conf_entry_f32 |
38 | = {dnnl_f32, -FLT_MAX, FLT_MAX, -2048, 2048, 5e-7}; |
39 | const dt_conf_t conf_entry_s32 = {dnnl_s32, INT_MIN, INT_MAX, -2048, 2048, 0.}; |
40 | const dt_conf_t conf_entry_s8 |
41 | = {dnnl_s8, INT8_MIN, INT8_MAX, INT8_MIN, INT8_MAX, 0.}; |
42 | const dt_conf_t conf_entry_u8 = {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0.}; |
43 | |
44 | const float16_t flt16_max = dnnl::impl::nstl::numeric_limits<float16_t>::max(); |
45 | const dt_conf_t conf_entry_f16 |
46 | = {dnnl_f16, -flt16_max, flt16_max, -32, 32, 2e-2}; |
47 | |
48 | #define BFLT16_MAX 3.38953138925153547590470800371487866880e+38F |
49 | /* Although integers are expected, eps is needed to cover |
50 | * for the division error */ |
51 | const dt_conf_t conf_entry_bf16 |
52 | = {dnnl_bf16, -BFLT16_MAX, BFLT16_MAX, -32, 32, 5e-2}; |
53 | #undef BFLT16_MAX |
54 | |
55 | // Configurations with same SRC and DST datatypes |
56 | const _dt_conf_t conf_f32 = {conf_entry_f32, {}, {}, conf_entry_f32}; |
57 | const _dt_conf_t conf_s32 = {conf_entry_s32, {}, {}, conf_entry_s32}; |
58 | const _dt_conf_t conf_f16 = {conf_entry_f16, {}, {}, conf_entry_f16}; |
59 | const _dt_conf_t conf_bf16 = {conf_entry_bf16, {}, {}, conf_entry_bf16}; |
60 | const _dt_conf_t conf_s8 = {conf_entry_s8, {}, {}, conf_entry_s8}; |
61 | const _dt_conf_t conf_u8 = {conf_entry_u8, {}, {}, conf_entry_u8}; |
62 | |
63 | // Configurations with different SRC and DST datatypes |
64 | const _dt_conf_t conf_s8u8 {conf_entry_s8, {}, {}, conf_entry_u8}; |
65 | const _dt_conf_t conf_u8s8 {conf_entry_u8, {}, {}, conf_entry_s8}; |
66 | const _dt_conf_t conf_s8f32 {conf_entry_s8, {}, {}, conf_entry_f32}; |
67 | const _dt_conf_t conf_f32s8 {conf_entry_f32, {}, {}, conf_entry_s8}; |
68 | const _dt_conf_t conf_u8f32 {conf_entry_u8, {}, {}, conf_entry_f32}; |
69 | const _dt_conf_t conf_f32u8 {conf_entry_f32, {}, {}, conf_entry_u8}; |
70 | const _dt_conf_t conf_s8f16 {conf_entry_s8, {}, {}, conf_entry_f16}; |
71 | const _dt_conf_t conf_f16s8 {conf_entry_f16, {}, {}, conf_entry_s8}; |
72 | const _dt_conf_t conf_u8f16 {conf_entry_u8, {}, {}, conf_entry_f16}; |
73 | const _dt_conf_t conf_f16u8 {conf_entry_f16, {}, {}, conf_entry_u8}; |
74 | |
75 | const dt_conf_t *str2cfg(const char *str) { |
76 | #define CASE(cfg) \ |
77 | if (!strcasecmp(STRINGIFY(cfg), str)) return CONCAT2(conf_, cfg) |
78 | CASE(f32); |
79 | CASE(s32); |
80 | CASE(f16); |
81 | CASE(bf16); |
82 | CASE(s8); |
83 | CASE(u8); |
84 | CASE(s8u8); |
85 | CASE(u8s8); |
86 | CASE(s8f32); |
87 | CASE(f32s8); |
88 | CASE(u8f32); |
89 | CASE(f32u8); |
90 | CASE(s8f16); |
91 | CASE(f16s8); |
92 | CASE(u8f16); |
93 | CASE(f16u8); |
94 | |
95 | #undef CASE |
96 | []() { |
97 | SAFE(FAIL, CRIT); |
98 | return 0; |
99 | }(); |
100 | return (const dt_conf_t *)1; |
101 | } |
102 | |
103 | std::ostream &operator<<(std::ostream &s, const dt_conf_t *cfg) { |
104 | #define CASE(_cfg) \ |
105 | if (cfg == CONCAT2(conf_, _cfg)) return s << STRINGIFY(_cfg) |
106 | CASE(f32); |
107 | CASE(s32); |
108 | CASE(f16); |
109 | CASE(bf16); |
110 | CASE(s8); |
111 | CASE(u8); |
112 | CASE(s8u8); |
113 | CASE(u8s8); |
114 | CASE(s8f32); |
115 | CASE(f32s8); |
116 | CASE(u8f32); |
117 | CASE(f32u8); |
118 | CASE(s8f16); |
119 | CASE(f16s8); |
120 | CASE(u8f16); |
121 | CASE(f16u8); |
122 | #undef CASE |
123 | SAFE_V(FAIL); |
124 | return s; |
125 | } |
126 | |
127 | } // namespace pool |
128 | |