1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_CPU_PRIMITIVE_HPP
18#define CPU_CPU_PRIMITIVE_HPP
19
20#include <assert.h>
21
22#include "oneapi/dnnl/dnnl_types.h"
23
24#include "common/c_types_map.hpp"
25#include "common/primitive_attr.hpp"
26#include "common/primitive_exec_types.hpp"
27#include "common/utils.hpp"
28#include "common/z_magic.hpp"
29
30#define DEFINE_SCALES_BUFFER_ATTR_ARG(attr, scales, arg) \
31 alignas(16) float CONCAT2(scales, _buf16)[16] = {0}; \
32 const float *scales {nullptr}; \
33 if ((attr)) { \
34 if ((attr)->output_scales_.has_default_values()) { \
35 utils::array_set(CONCAT2(scales, _buf16), 1.0f, 16); \
36 scales = CONCAT2(scales, _buf16); \
37 } else { \
38 scales = CTX_IN_MEM(const float *, arg); \
39 if (scales == nullptr) return status::invalid_arguments; \
40 const auto scales_d = ctx.memory_mdw(arg); \
41 bool ok = scales_d.data_type() == data_type::f32 \
42 && scales_d.ndims() == 1; \
43 if (!ok) return status::invalid_arguments; \
44 if (scales_d.dims()[0] == 1) { \
45 utils::array_set(CONCAT2(scales, _buf16), scales[0], 16); \
46 scales = CONCAT2(scales, _buf16); \
47 } \
48 } \
49 } \
50 MAYBE_UNUSED(scales);
51
52#define DEFINE_SCALES_BUFFER_ATTR(attr, scales) \
53 DEFINE_SCALES_BUFFER_ATTR_ARG(attr, scales, DNNL_ARG_ATTR_OUTPUT_SCALES);
54
55#define DEFINE_SCALES_BUFFER(scales) \
56 DEFINE_SCALES_BUFFER_ATTR(pd()->attr(), scales)
57
58#define DEFINE_ARG_SCALES_BUFFER_ATTR(attr, scales, arg) \
59 alignas(16) float CONCAT2(scales, _buf16)[16] = {0}; \
60 const float *scales {nullptr}; \
61 if ((attr)) { \
62 if ((attr)->scales_.get(arg).has_default_values()) { \
63 utils::array_set(CONCAT2(scales, _buf16), 1.0f, 16); \
64 scales = CONCAT2(scales, _buf16); \
65 } else { \
66 scales = CTX_IN_MEM(const float *, DNNL_ARG_ATTR_SCALES | arg); \
67 if (scales == nullptr) return status::invalid_arguments; \
68 const auto scales_d = ctx.memory_mdw(DNNL_ARG_ATTR_SCALES | arg); \
69 bool ok = scales_d.data_type() == data_type::f32 \
70 && scales_d.ndims() == 1; \
71 if (!ok) return status::invalid_arguments; \
72 if (scales_d.dims()[0] == 1) { \
73 if (utils::one_of(arg, DNNL_ARG_DST, \
74 DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_DST)) { \
75 utils::array_set( \
76 CONCAT2(scales, _buf16), 1.f / scales[0], 16); \
77 } else { \
78 utils::array_set(CONCAT2(scales, _buf16), scales[0], 16); \
79 } \
80 scales = CONCAT2(scales, _buf16); \
81 } \
82 } \
83 } \
84 MAYBE_UNUSED(scales);
85
86#define DEFINE_ARG_SCALES_BUFFER(scales, arg) \
87 DEFINE_ARG_SCALES_BUFFER_ATTR(pd()->attr(), scales, arg)
88
89#define DEFINE_ZERO_POINTS_BUFFER(zero_points_ptr, mem_arg) \
90 int32_t CONCAT2(default_zero_point_, mem_arg) = 0; \
91 const int32_t *zero_points_ptr \
92 = pd()->attr()->zero_points_.defined(mem_arg) \
93 ? &CONCAT2(default_zero_point_, mem_arg) \
94 : CTX_IN_MEM( \
95 const int32_t *, DNNL_ARG_ATTR_ZERO_POINTS | mem_arg); \
96 if (zero_points_ptr == nullptr) return status::invalid_arguments; \
97 MAYBE_UNUSED(zero_points_ptr);
98
99#define ASSIGN_ARG_SCALE_VALUE(scale, mem_arg) \
100 alignas(16) float CONCAT2(CONCAT2(scales, _buf16), mem_arg)[16] = {0}; \
101 if (pd()->attr()->scales_.get(mem_arg).has_default_values()) { \
102 utils::array_set(CONCAT2(CONCAT2(scales, _buf16), mem_arg), 1.0f, 16); \
103 scale = CONCAT2(CONCAT2(scales, _buf16), mem_arg); \
104 } else { \
105 const auto scale_d = ctx.memory_mdw(DNNL_ARG_ATTR_SCALES | mem_arg); \
106 bool ok = scale_d.data_type() == data_type::f32 \
107 && scale_d.ndims() == 1 && scale_d.dims()[0] == 1; \
108 if (!ok) return status::invalid_arguments; \
109 const float *scale_p \
110 = CTX_IN_MEM(const float *, DNNL_ARG_ATTR_SCALES | mem_arg); \
111 if (scale_p == nullptr) return status::invalid_arguments; \
112 scale = scale_p; \
113 }
114
115#define DEFINE_ZERO_POINT_VALUE_ATTR(attr, zero_point, mem_arg) \
116 int32_t zero_point = 0; \
117 if (!attr->zero_points_.has_default_values(mem_arg)) { \
118 const auto zero_points_d \
119 = ctx.memory_mdw(DNNL_ARG_ATTR_ZERO_POINTS | mem_arg); \
120 bool ok = zero_points_d.data_type() == data_type::s32 \
121 && zero_points_d.ndims() == 1 && zero_points_d.dims()[0] == 1; \
122 if (!ok) return status::invalid_arguments; \
123 const int32_t *zero_points_ptr = CTX_IN_MEM( \
124 const int32_t *, DNNL_ARG_ATTR_ZERO_POINTS | mem_arg); \
125 if (zero_points_ptr == nullptr) return status::invalid_arguments; \
126 zero_point = *zero_points_ptr; \
127 } \
128 MAYBE_UNUSED(zero_point);
129
130#define DEFINE_ZERO_POINT_VALUE(zero_point, mem_arg) \
131 DEFINE_ZERO_POINT_VALUE_ATTR(pd()->attr(), zero_point, mem_arg)
132
133#endif // CPU_CPU_PRIMITIVE_HPP
134