1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_CPU_PRIMITIVE_HPP |
18 | #define CPU_CPU_PRIMITIVE_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "oneapi/dnnl/dnnl_types.h" |
23 | |
24 | #include "common/c_types_map.hpp" |
25 | #include "common/primitive_attr.hpp" |
26 | #include "common/primitive_exec_types.hpp" |
27 | #include "common/utils.hpp" |
28 | #include "common/z_magic.hpp" |
29 | |
30 | #define DEFINE_SCALES_BUFFER_ATTR_ARG(attr, scales, arg) \ |
31 | alignas(16) float CONCAT2(scales, _buf16)[16] = {0}; \ |
32 | const float *scales {nullptr}; \ |
33 | if ((attr)) { \ |
34 | if ((attr)->output_scales_.has_default_values()) { \ |
35 | utils::array_set(CONCAT2(scales, _buf16), 1.0f, 16); \ |
36 | scales = CONCAT2(scales, _buf16); \ |
37 | } else { \ |
38 | scales = CTX_IN_MEM(const float *, arg); \ |
39 | if (scales == nullptr) return status::invalid_arguments; \ |
40 | const auto scales_d = ctx.memory_mdw(arg); \ |
41 | bool ok = scales_d.data_type() == data_type::f32 \ |
42 | && scales_d.ndims() == 1; \ |
43 | if (!ok) return status::invalid_arguments; \ |
44 | if (scales_d.dims()[0] == 1) { \ |
45 | utils::array_set(CONCAT2(scales, _buf16), scales[0], 16); \ |
46 | scales = CONCAT2(scales, _buf16); \ |
47 | } \ |
48 | } \ |
49 | } \ |
50 | MAYBE_UNUSED(scales); |
51 | |
52 | #define DEFINE_SCALES_BUFFER_ATTR(attr, scales) \ |
53 | DEFINE_SCALES_BUFFER_ATTR_ARG(attr, scales, DNNL_ARG_ATTR_OUTPUT_SCALES); |
54 | |
55 | #define DEFINE_SCALES_BUFFER(scales) \ |
56 | DEFINE_SCALES_BUFFER_ATTR(pd()->attr(), scales) |
57 | |
58 | #define DEFINE_ARG_SCALES_BUFFER_ATTR(attr, scales, arg) \ |
59 | alignas(16) float CONCAT2(scales, _buf16)[16] = {0}; \ |
60 | const float *scales {nullptr}; \ |
61 | if ((attr)) { \ |
62 | if ((attr)->scales_.get(arg).has_default_values()) { \ |
63 | utils::array_set(CONCAT2(scales, _buf16), 1.0f, 16); \ |
64 | scales = CONCAT2(scales, _buf16); \ |
65 | } else { \ |
66 | scales = CTX_IN_MEM(const float *, DNNL_ARG_ATTR_SCALES | arg); \ |
67 | if (scales == nullptr) return status::invalid_arguments; \ |
68 | const auto scales_d = ctx.memory_mdw(DNNL_ARG_ATTR_SCALES | arg); \ |
69 | bool ok = scales_d.data_type() == data_type::f32 \ |
70 | && scales_d.ndims() == 1; \ |
71 | if (!ok) return status::invalid_arguments; \ |
72 | if (scales_d.dims()[0] == 1) { \ |
73 | if (utils::one_of(arg, DNNL_ARG_DST, \ |
74 | DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_DST)) { \ |
75 | utils::array_set( \ |
76 | CONCAT2(scales, _buf16), 1.f / scales[0], 16); \ |
77 | } else { \ |
78 | utils::array_set(CONCAT2(scales, _buf16), scales[0], 16); \ |
79 | } \ |
80 | scales = CONCAT2(scales, _buf16); \ |
81 | } \ |
82 | } \ |
83 | } \ |
84 | MAYBE_UNUSED(scales); |
85 | |
86 | #define DEFINE_ARG_SCALES_BUFFER(scales, arg) \ |
87 | DEFINE_ARG_SCALES_BUFFER_ATTR(pd()->attr(), scales, arg) |
88 | |
89 | #define DEFINE_ZERO_POINTS_BUFFER(zero_points_ptr, mem_arg) \ |
90 | int32_t CONCAT2(default_zero_point_, mem_arg) = 0; \ |
91 | const int32_t *zero_points_ptr \ |
92 | = pd()->attr()->zero_points_.defined(mem_arg) \ |
93 | ? &CONCAT2(default_zero_point_, mem_arg) \ |
94 | : CTX_IN_MEM( \ |
95 | const int32_t *, DNNL_ARG_ATTR_ZERO_POINTS | mem_arg); \ |
96 | if (zero_points_ptr == nullptr) return status::invalid_arguments; \ |
97 | MAYBE_UNUSED(zero_points_ptr); |
98 | |
99 | #define ASSIGN_ARG_SCALE_VALUE(scale, mem_arg) \ |
100 | alignas(16) float CONCAT2(CONCAT2(scales, _buf16), mem_arg)[16] = {0}; \ |
101 | if (pd()->attr()->scales_.get(mem_arg).has_default_values()) { \ |
102 | utils::array_set(CONCAT2(CONCAT2(scales, _buf16), mem_arg), 1.0f, 16); \ |
103 | scale = CONCAT2(CONCAT2(scales, _buf16), mem_arg); \ |
104 | } else { \ |
105 | const auto scale_d = ctx.memory_mdw(DNNL_ARG_ATTR_SCALES | mem_arg); \ |
106 | bool ok = scale_d.data_type() == data_type::f32 \ |
107 | && scale_d.ndims() == 1 && scale_d.dims()[0] == 1; \ |
108 | if (!ok) return status::invalid_arguments; \ |
109 | const float *scale_p \ |
110 | = CTX_IN_MEM(const float *, DNNL_ARG_ATTR_SCALES | mem_arg); \ |
111 | if (scale_p == nullptr) return status::invalid_arguments; \ |
112 | scale = scale_p; \ |
113 | } |
114 | |
115 | #define DEFINE_ZERO_POINT_VALUE_ATTR(attr, zero_point, mem_arg) \ |
116 | int32_t zero_point = 0; \ |
117 | if (!attr->zero_points_.has_default_values(mem_arg)) { \ |
118 | const auto zero_points_d \ |
119 | = ctx.memory_mdw(DNNL_ARG_ATTR_ZERO_POINTS | mem_arg); \ |
120 | bool ok = zero_points_d.data_type() == data_type::s32 \ |
121 | && zero_points_d.ndims() == 1 && zero_points_d.dims()[0] == 1; \ |
122 | if (!ok) return status::invalid_arguments; \ |
123 | const int32_t *zero_points_ptr = CTX_IN_MEM( \ |
124 | const int32_t *, DNNL_ARG_ATTR_ZERO_POINTS | mem_arg); \ |
125 | if (zero_points_ptr == nullptr) return status::invalid_arguments; \ |
126 | zero_point = *zero_points_ptr; \ |
127 | } \ |
128 | MAYBE_UNUSED(zero_point); |
129 | |
130 | #define DEFINE_ZERO_POINT_VALUE(zero_point, mem_arg) \ |
131 | DEFINE_ZERO_POINT_VALUE_ATTR(pd()->attr(), zero_point, mem_arg) |
132 | |
133 | #endif // CPU_CPU_PRIMITIVE_HPP |
134 | |