1 | /******************************************************************************* |
2 | * Copyright 2020-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_ZERO_POINT_UTILS_HPP |
18 | #define CPU_ZERO_POINT_UTILS_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/primitive_attr.hpp" |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace cpu { |
26 | /* |
27 | * Structure describing the size zero point padding compensation buffer. |
28 | * Size of buffer is h * w * d * output channels. Size h * w * d represents |
29 | * number of unique application of filter over input spatial where filter |
30 | * overlapped the padding need for calculating unique zero point padding |
31 | * compensation of given case. Pad variables represents number of unique |
32 | * zp padding compensation values resulting from participation of a given type |
33 | * of padding (example top padding) over a given axis (example h). Mid points |
34 | * over given axis represents compensation resulting in the absence of padding |
35 | * over given axis, but where padding over other axis exists. Example: 2D: conv: |
36 | * mid_point_w = true, where filter overlaps top padding, but not right and left |
37 | * padding. Spatial filter w size fits in w range of input image. |
38 | */ |
39 | struct zero_point_pad_comp_config_t { |
40 | zero_point_pad_comp_config_t() = default; |
41 | zero_point_pad_comp_config_t(const dim_t front_pad, const dim_t back_pad, |
42 | const dim_t top_pad, const dim_t bottom_pad, const dim_t left_pad, |
43 | const dim_t right_pad, const dim_t stride_d, const dim_t stride_h, |
44 | const dim_t stride_w, const dim_t od, const dim_t oh, |
45 | const dim_t ow); |
46 | |
47 | dim_t top_pad = 0; |
48 | dim_t bottom_pad = 0; |
49 | dim_t left_pad = 0; |
50 | dim_t right_pad = 0; |
51 | dim_t front_pad = 0; |
52 | dim_t back_pad = 0; |
53 | |
54 | dim_t mid_h = 0; |
55 | dim_t mid_w = 0; |
56 | dim_t mid_d = 0; |
57 | |
58 | dim_t h = 0; |
59 | dim_t w = 0; |
60 | dim_t d = 0; |
61 | }; |
62 | |
63 | struct zero_point_config_t { |
64 | zero_point_config_t() = default; |
65 | zero_point_config_t(const primitive_attr_t &attr); |
66 | |
67 | bool src_exists = false; |
68 | bool dst_exists = false; |
69 | bool src_is_common = false; |
70 | zero_point_pad_comp_config_t src_pad_comp; |
71 | |
72 | bool zp_exists() const noexcept; |
73 | }; |
74 | |
75 | struct zero_point_call_params_t { |
76 | zero_point_call_params_t() = default; |
77 | zero_point_call_params_t(const int32_t *src, const int32_t *dst, |
78 | const int32_t *src_comp, const int32_t *src_pad_comp); |
79 | |
80 | const int32_t *src = nullptr; |
81 | const int32_t *dst = nullptr; |
82 | const int32_t *src_comp = nullptr; |
83 | const int32_t *src_pad_comp = nullptr; |
84 | }; |
85 | |
86 | bool zero_points_valid(const primitive_attr_t *attr, |
87 | bool per_oc_bcast_accepted = false) noexcept; |
88 | |
89 | void set_zp_src_comp_flags(memory_desc_t &weights_md, bool with_groups); |
90 | const int32_t *get_src_zp_comp_from_wei(const int8_t *weights, |
91 | const memory_desc_wrapper &weights_md, bool signed_input, dim_t ngroups, |
92 | dim_t oc); |
93 | |
94 | } // namespace cpu |
95 | } // namespace impl |
96 | } // namespace dnnl |
97 | |
98 | #endif |
99 | |