1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16#include "cpu/zero_point_utils.hpp"
17#include "common/utils.hpp"
18
19namespace dnnl {
20namespace impl {
21namespace cpu {
22
23static void adjust_zero_pad_comp_dims(const dim_t output_dim_size,
24 dim_t &estimated_dim_size, dim_t &begin_pad, dim_t &mid_pad,
25 dim_t &end_pad) {
26
27 if (output_dim_size < estimated_dim_size) {
28 const auto diff = estimated_dim_size - output_dim_size;
29 estimated_dim_size = output_dim_size;
30
31 end_pad -= diff;
32
33 if (end_pad < 0) {
34 if (mid_pad) {
35 mid_pad = 0;
36 end_pad += 1;
37 }
38 if (end_pad < 0) {
39 begin_pad += end_pad;
40 end_pad = 0;
41 }
42 }
43 }
44}
45
46zero_point_pad_comp_config_t::zero_point_pad_comp_config_t(
47 const dim_t front_pad, const dim_t back_pad, const dim_t top_pad,
48 const dim_t bottom_pad, const dim_t left_pad, const dim_t right_pad,
49 const dim_t stride_d, const dim_t stride_h, const dim_t stride_w,
50 const dim_t od, const dim_t oh, const dim_t ow)
51
52 : top_pad(utils::div_up(top_pad, stride_h))
53 , bottom_pad(utils::div_up(bottom_pad, stride_h))
54 , left_pad(utils::div_up(left_pad, stride_w))
55 , right_pad(utils::div_up(right_pad, stride_w))
56 , front_pad(utils::div_up(front_pad, stride_d))
57 , back_pad(utils::div_up(back_pad, stride_d))
58 , mid_h((oh - this->top_pad - this->bottom_pad > 0)
59 && (this->left_pad > 0 || this->right_pad > 0
60 || this->front_pad > 0 || this->back_pad)
61 ? 1
62 : 0)
63 , mid_w((ow - this->left_pad - this->right_pad > 0)
64 && (this->bottom_pad > 0 || this->top_pad > 0
65 || this->front_pad > 0 || this->back_pad)
66 ? 1
67 : 0)
68 , mid_d((od - this->front_pad - this->back_pad > 0)
69 && (this->top_pad > 0 || this->bottom_pad > 0
70 || this->right_pad > 0 || this->left_pad)
71 ? 1
72 : 0)
73 , h(this->top_pad + this->bottom_pad + this->mid_h)
74 , w(this->left_pad + this->right_pad + this->mid_w)
75 , d(this->front_pad + this->back_pad + this->mid_d) {
76
77 adjust_zero_pad_comp_dims(
78 oh, this->h, this->top_pad, this->mid_h, this->bottom_pad);
79 adjust_zero_pad_comp_dims(
80 ow, this->w, this->left_pad, this->mid_w, this->right_pad);
81 adjust_zero_pad_comp_dims(
82 od, this->d, this->front_pad, this->mid_d, this->back_pad);
83}
84
85zero_point_config_t::zero_point_config_t(const primitive_attr_t &attr)
86 : src_exists(!attr.zero_points_.has_default_values(DNNL_ARG_SRC))
87 , dst_exists(!attr.zero_points_.has_default_values(DNNL_ARG_DST))
88 , src_is_common(attr.zero_points_.common(DNNL_ARG_SRC)) {}
89
90bool zero_point_config_t::zp_exists() const noexcept {
91 return src_exists || dst_exists;
92}
93
94zero_point_call_params_t::zero_point_call_params_t(const int32_t *src,
95 const int32_t *dst, const int32_t *src_comp,
96 const int32_t *src_pad_comp)
97 : src(src), dst(dst), src_comp(src_comp), src_pad_comp(src_pad_comp) {}
98
99bool zero_points_valid(
100 const primitive_attr_t *attr, bool per_oc_bcast_accepted) noexcept {
101
102 int mask_src = -1, mask_dst = -1;
103 static constexpr int common_mask = 0x0,
104 per_oc_mask = 0x2; // mask for common and per_oc_bcast
105
106 attr->zero_points_.get(DNNL_ARG_SRC, &mask_src);
107 attr->zero_points_.get(DNNL_ARG_DST, &mask_dst);
108
109 const bool src_mask_valid = per_oc_bcast_accepted
110 ? utils::one_of(mask_src, common_mask, per_oc_mask)
111 : mask_src == 0;
112 const bool dst_mask_valid = per_oc_bcast_accepted
113 ? utils::one_of(mask_dst, common_mask, per_oc_mask)
114 : mask_dst == 0;
115
116 return attr->zero_points_.has_default_values(DNNL_ARG_WEIGHTS)
117 && src_mask_valid && dst_mask_valid;
118}
119
120void set_zp_src_comp_flags(memory_desc_t &weights_md, bool with_groups) {
121 weights_md.extra.flags
122 |= memory_extra_flags::compensation_conv_asymmetric_src;
123 weights_md.extra.asymm_compensation_mask
124 = (1 << 0) + (with_groups ? (1 << 1) : 0);
125}
126
127const int32_t *get_src_zp_comp_from_wei(const int8_t *weights,
128 const memory_desc_wrapper &weights_md, bool signed_input, dim_t ngroups,
129 dim_t oc) {
130
131 const auto comp_offset
132 = weights_md.size() - weights_md.additional_buffer_size();
133 const auto src_zp_com_offset = signed_input ? ngroups * oc : 0;
134 return reinterpret_cast<const int32_t *>(&weights[comp_offset])
135 + src_zp_com_offset;
136}
137
138} // namespace cpu
139} // namespace impl
140} // namespace dnnl
141