1/*******************************************************************************
2* Copyright 2020-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_ZERO_POINT_UTILS_HPP
18#define CPU_ZERO_POINT_UTILS_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/primitive_attr.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26/*
27 * Structure describing the size zero point padding compensation buffer.
28 * Size of buffer is h * w * d * output channels. Size h * w * d represents
29 * number of unique application of filter over input spatial where filter
30 * overlapped the padding need for calculating unique zero point padding
31 * compensation of given case. Pad variables represents number of unique
32 * zp padding compensation values resulting from participation of a given type
33 * of padding (example top padding) over a given axis (example h). Mid points
34 * over given axis represents compensation resulting in the absence of padding
35 * over given axis, but where padding over other axis exists. Example: 2D: conv:
36 * mid_point_w = true, where filter overlaps top padding, but not right and left
37 * padding. Spatial filter w size fits in w range of input image.
38 */
39struct zero_point_pad_comp_config_t {
40 zero_point_pad_comp_config_t() = default;
41 zero_point_pad_comp_config_t(const dim_t front_pad, const dim_t back_pad,
42 const dim_t top_pad, const dim_t bottom_pad, const dim_t left_pad,
43 const dim_t right_pad, const dim_t stride_d, const dim_t stride_h,
44 const dim_t stride_w, const dim_t od, const dim_t oh,
45 const dim_t ow);
46
47 dim_t top_pad = 0;
48 dim_t bottom_pad = 0;
49 dim_t left_pad = 0;
50 dim_t right_pad = 0;
51 dim_t front_pad = 0;
52 dim_t back_pad = 0;
53
54 dim_t mid_h = 0;
55 dim_t mid_w = 0;
56 dim_t mid_d = 0;
57
58 dim_t h = 0;
59 dim_t w = 0;
60 dim_t d = 0;
61};
62
63struct zero_point_config_t {
64 zero_point_config_t() = default;
65 zero_point_config_t(const primitive_attr_t &attr);
66
67 bool src_exists = false;
68 bool dst_exists = false;
69 bool src_is_common = false;
70 zero_point_pad_comp_config_t src_pad_comp;
71
72 bool zp_exists() const noexcept;
73};
74
75struct zero_point_call_params_t {
76 zero_point_call_params_t() = default;
77 zero_point_call_params_t(const int32_t *src, const int32_t *dst,
78 const int32_t *src_comp, const int32_t *src_pad_comp);
79
80 const int32_t *src = nullptr;
81 const int32_t *dst = nullptr;
82 const int32_t *src_comp = nullptr;
83 const int32_t *src_pad_comp = nullptr;
84};
85
86bool zero_points_valid(const primitive_attr_t *attr,
87 bool per_oc_bcast_accepted = false) noexcept;
88
89void set_zp_src_comp_flags(memory_desc_t &weights_md, bool with_groups);
90const int32_t *get_src_zp_comp_from_wei(const int8_t *weights,
91 const memory_desc_wrapper &weights_md, bool signed_input, dim_t ngroups,
92 dim_t oc);
93
94} // namespace cpu
95} // namespace impl
96} // namespace dnnl
97
98#endif
99