1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <float.h>
18#include <math.h>
19#include <stdio.h>
20#include <stdlib.h>
21#include <string.h>
22
23#include "oneapi/dnnl/dnnl.h"
24
25#include "dnnl_common.hpp"
26
27#include "pool/pool.hpp"
28
29namespace pool {
30
31/* cfgs definition
32 * arrays: SRC, UNUSED, UNUSED, DST
33 * params: {data_type, min, max, f_min, f_max, eps}
34 */
35
36// though integers are expected, eps is needed to cover division error
37const dt_conf_t conf_entry_f32
38 = {dnnl_f32, -FLT_MAX, FLT_MAX, -2048, 2048, 5e-7};
39const dt_conf_t conf_entry_s32 = {dnnl_s32, INT_MIN, INT_MAX, -2048, 2048, 0.};
40const dt_conf_t conf_entry_s8
41 = {dnnl_s8, INT8_MIN, INT8_MAX, INT8_MIN, INT8_MAX, 0.};
42const dt_conf_t conf_entry_u8 = {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0.};
43
44const float16_t flt16_max = dnnl::impl::nstl::numeric_limits<float16_t>::max();
45const dt_conf_t conf_entry_f16
46 = {dnnl_f16, -flt16_max, flt16_max, -32, 32, 2e-2};
47
48#define BFLT16_MAX 3.38953138925153547590470800371487866880e+38F
49/* Although integers are expected, eps is needed to cover
50 * for the division error */
51const dt_conf_t conf_entry_bf16
52 = {dnnl_bf16, -BFLT16_MAX, BFLT16_MAX, -32, 32, 5e-2};
53#undef BFLT16_MAX
54
55// Configurations with same SRC and DST datatypes
56const _dt_conf_t conf_f32 = {conf_entry_f32, {}, {}, conf_entry_f32};
57const _dt_conf_t conf_s32 = {conf_entry_s32, {}, {}, conf_entry_s32};
58const _dt_conf_t conf_f16 = {conf_entry_f16, {}, {}, conf_entry_f16};
59const _dt_conf_t conf_bf16 = {conf_entry_bf16, {}, {}, conf_entry_bf16};
60const _dt_conf_t conf_s8 = {conf_entry_s8, {}, {}, conf_entry_s8};
61const _dt_conf_t conf_u8 = {conf_entry_u8, {}, {}, conf_entry_u8};
62
63// Configurations with different SRC and DST datatypes
64const _dt_conf_t conf_s8u8 {conf_entry_s8, {}, {}, conf_entry_u8};
65const _dt_conf_t conf_u8s8 {conf_entry_u8, {}, {}, conf_entry_s8};
66const _dt_conf_t conf_s8f32 {conf_entry_s8, {}, {}, conf_entry_f32};
67const _dt_conf_t conf_f32s8 {conf_entry_f32, {}, {}, conf_entry_s8};
68const _dt_conf_t conf_u8f32 {conf_entry_u8, {}, {}, conf_entry_f32};
69const _dt_conf_t conf_f32u8 {conf_entry_f32, {}, {}, conf_entry_u8};
70const _dt_conf_t conf_s8f16 {conf_entry_s8, {}, {}, conf_entry_f16};
71const _dt_conf_t conf_f16s8 {conf_entry_f16, {}, {}, conf_entry_s8};
72const _dt_conf_t conf_u8f16 {conf_entry_u8, {}, {}, conf_entry_f16};
73const _dt_conf_t conf_f16u8 {conf_entry_f16, {}, {}, conf_entry_u8};
74
75const dt_conf_t *str2cfg(const char *str) {
76#define CASE(cfg) \
77 if (!strcasecmp(STRINGIFY(cfg), str)) return CONCAT2(conf_, cfg)
78 CASE(f32);
79 CASE(s32);
80 CASE(f16);
81 CASE(bf16);
82 CASE(s8);
83 CASE(u8);
84 CASE(s8u8);
85 CASE(u8s8);
86 CASE(s8f32);
87 CASE(f32s8);
88 CASE(u8f32);
89 CASE(f32u8);
90 CASE(s8f16);
91 CASE(f16s8);
92 CASE(u8f16);
93 CASE(f16u8);
94
95#undef CASE
96 []() {
97 SAFE(FAIL, CRIT);
98 return 0;
99 }();
100 return (const dt_conf_t *)1;
101}
102
103std::ostream &operator<<(std::ostream &s, const dt_conf_t *cfg) {
104#define CASE(_cfg) \
105 if (cfg == CONCAT2(conf_, _cfg)) return s << STRINGIFY(_cfg)
106 CASE(f32);
107 CASE(s32);
108 CASE(f16);
109 CASE(bf16);
110 CASE(s8);
111 CASE(u8);
112 CASE(s8u8);
113 CASE(u8s8);
114 CASE(s8f32);
115 CASE(f32s8);
116 CASE(u8f32);
117 CASE(f32u8);
118 CASE(s8f16);
119 CASE(f16s8);
120 CASE(u8f16);
121 CASE(f16u8);
122#undef CASE
123 SAFE_V(FAIL);
124 return s;
125}
126
127} // namespace pool
128