1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | #include "cpu/zero_point_utils.hpp" |
17 | #include "common/utils.hpp" |
18 | |
19 | namespace dnnl { |
20 | namespace impl { |
21 | namespace cpu { |
22 | |
23 | static void adjust_zero_pad_comp_dims(const dim_t output_dim_size, |
24 | dim_t &estimated_dim_size, dim_t &begin_pad, dim_t &mid_pad, |
25 | dim_t &end_pad) { |
26 | |
27 | if (output_dim_size < estimated_dim_size) { |
28 | const auto diff = estimated_dim_size - output_dim_size; |
29 | estimated_dim_size = output_dim_size; |
30 | |
31 | end_pad -= diff; |
32 | |
33 | if (end_pad < 0) { |
34 | if (mid_pad) { |
35 | mid_pad = 0; |
36 | end_pad += 1; |
37 | } |
38 | if (end_pad < 0) { |
39 | begin_pad += end_pad; |
40 | end_pad = 0; |
41 | } |
42 | } |
43 | } |
44 | } |
45 | |
46 | zero_point_pad_comp_config_t::zero_point_pad_comp_config_t( |
47 | const dim_t front_pad, const dim_t back_pad, const dim_t top_pad, |
48 | const dim_t bottom_pad, const dim_t left_pad, const dim_t right_pad, |
49 | const dim_t stride_d, const dim_t stride_h, const dim_t stride_w, |
50 | const dim_t od, const dim_t oh, const dim_t ow) |
51 | |
52 | : top_pad(utils::div_up(top_pad, stride_h)) |
53 | , bottom_pad(utils::div_up(bottom_pad, stride_h)) |
54 | , left_pad(utils::div_up(left_pad, stride_w)) |
55 | , right_pad(utils::div_up(right_pad, stride_w)) |
56 | , front_pad(utils::div_up(front_pad, stride_d)) |
57 | , back_pad(utils::div_up(back_pad, stride_d)) |
58 | , mid_h((oh - this->top_pad - this->bottom_pad > 0) |
59 | && (this->left_pad > 0 || this->right_pad > 0 |
60 | || this->front_pad > 0 || this->back_pad) |
61 | ? 1 |
62 | : 0) |
63 | , mid_w((ow - this->left_pad - this->right_pad > 0) |
64 | && (this->bottom_pad > 0 || this->top_pad > 0 |
65 | || this->front_pad > 0 || this->back_pad) |
66 | ? 1 |
67 | : 0) |
68 | , mid_d((od - this->front_pad - this->back_pad > 0) |
69 | && (this->top_pad > 0 || this->bottom_pad > 0 |
70 | || this->right_pad > 0 || this->left_pad) |
71 | ? 1 |
72 | : 0) |
73 | , h(this->top_pad + this->bottom_pad + this->mid_h) |
74 | , w(this->left_pad + this->right_pad + this->mid_w) |
75 | , d(this->front_pad + this->back_pad + this->mid_d) { |
76 | |
77 | adjust_zero_pad_comp_dims( |
78 | oh, this->h, this->top_pad, this->mid_h, this->bottom_pad); |
79 | adjust_zero_pad_comp_dims( |
80 | ow, this->w, this->left_pad, this->mid_w, this->right_pad); |
81 | adjust_zero_pad_comp_dims( |
82 | od, this->d, this->front_pad, this->mid_d, this->back_pad); |
83 | } |
84 | |
85 | zero_point_config_t::zero_point_config_t(const primitive_attr_t &attr) |
86 | : src_exists(!attr.zero_points_.has_default_values(DNNL_ARG_SRC)) |
87 | , dst_exists(!attr.zero_points_.has_default_values(DNNL_ARG_DST)) |
88 | , src_is_common(attr.zero_points_.common(DNNL_ARG_SRC)) {} |
89 | |
90 | bool zero_point_config_t::zp_exists() const noexcept { |
91 | return src_exists || dst_exists; |
92 | } |
93 | |
94 | zero_point_call_params_t::zero_point_call_params_t(const int32_t *src, |
95 | const int32_t *dst, const int32_t *src_comp, |
96 | const int32_t *src_pad_comp) |
97 | : src(src), dst(dst), src_comp(src_comp), src_pad_comp(src_pad_comp) {} |
98 | |
99 | bool zero_points_valid( |
100 | const primitive_attr_t *attr, bool per_oc_bcast_accepted) noexcept { |
101 | |
102 | int mask_src = -1, mask_dst = -1; |
103 | static constexpr int common_mask = 0x0, |
104 | per_oc_mask = 0x2; // mask for common and per_oc_bcast |
105 | |
106 | attr->zero_points_.get(DNNL_ARG_SRC, &mask_src); |
107 | attr->zero_points_.get(DNNL_ARG_DST, &mask_dst); |
108 | |
109 | const bool src_mask_valid = per_oc_bcast_accepted |
110 | ? utils::one_of(mask_src, common_mask, per_oc_mask) |
111 | : mask_src == 0; |
112 | const bool dst_mask_valid = per_oc_bcast_accepted |
113 | ? utils::one_of(mask_dst, common_mask, per_oc_mask) |
114 | : mask_dst == 0; |
115 | |
116 | return attr->zero_points_.has_default_values(DNNL_ARG_WEIGHTS) |
117 | && src_mask_valid && dst_mask_valid; |
118 | } |
119 | |
120 | void set_zp_src_comp_flags(memory_desc_t &weights_md, bool with_groups) { |
121 | weights_md.extra.flags |
122 | |= memory_extra_flags::compensation_conv_asymmetric_src; |
123 | weights_md.extra.asymm_compensation_mask |
124 | = (1 << 0) + (with_groups ? (1 << 1) : 0); |
125 | } |
126 | |
127 | const int32_t *get_src_zp_comp_from_wei(const int8_t *weights, |
128 | const memory_desc_wrapper &weights_md, bool signed_input, dim_t ngroups, |
129 | dim_t oc) { |
130 | |
131 | const auto comp_offset |
132 | = weights_md.size() - weights_md.additional_buffer_size(); |
133 | const auto src_zp_com_offset = signed_input ? ngroups * oc : 0; |
134 | return reinterpret_cast<const int32_t *>(&weights[comp_offset]) |
135 | + src_zp_com_offset; |
136 | } |
137 | |
138 | } // namespace cpu |
139 | } // namespace impl |
140 | } // namespace dnnl |
141 | |