1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_REF_CONVOLUTION_INT8_HPP
18#define CPU_REF_CONVOLUTION_INT8_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/primitive.hpp"
24#include "common/type_helpers.hpp"
25#include "common/utils.hpp"
26
27#include "cpu/cpu_convolution_pd.hpp"
28#include "cpu/primitive_attr_postops.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33
34struct ref_convolution_int8_fwd_t : public primitive_t {
35 struct pd_t : public cpu_convolution_fwd_pd_t {
36 using cpu_convolution_fwd_pd_t::cpu_convolution_fwd_pd_t;
37
38 DECLARE_COMMON_PD_T("ref:any", ref_convolution_int8_fwd_t);
39
40 status_t init(engine_t *engine) {
41 using namespace data_type;
42 using smask_t = primitive_attr_t::skip_mask_t;
43 const auto src_type = src_md(0)->data_type;
44 const auto wei_type = weights_md(0)->data_type;
45 const auto bia_type = weights_md(1)->data_type;
46 const auto dst_type = dst_md(0)->data_type;
47
48 bool ok = is_fwd()
49 && set_default_alg_kind(alg_kind::convolution_direct)
50 && utils::one_of(src_type, s8, u8) && wei_type == s8
51 && IMPLICATION(with_bias(),
52 utils::one_of(bia_type, f32, bf16, s32, s8, u8))
53 && utils::one_of(dst_type, f32, bf16, s32, s8, u8)
54 && set_default_formats()
55 && attr()->has_default_values(smask_t::scales_runtime
56 | smask_t::zero_points_runtime
57 | smask_t::post_ops | smask_t::sum_dt,
58 dst_type)
59 && attr()->post_ops_.check_sum_consistent_dt(dst_type)
60 && scales_mask_ok() && zero_points_ok() && post_ops_ok()
61 && attr_.set_default_formats(dst_md(0)) == status::success;
62 return ok ? status::success : status::unimplemented;
63 }
64
65 protected:
66 bool set_default_formats() {
67 using namespace format_tag;
68 auto dat_tag = utils::pick(ndims() - 3, nwc, nhwc, ndhwc);
69 auto wei_tag = with_groups()
70 ? utils::pick(ndims() - 3, goiw, goihw, goidhw)
71 : utils::pick(ndims() - 3, oiw, oihw, oidhw);
72 return set_default_formats_common(dat_tag, wei_tag, dat_tag);
73 }
74
75 bool scales_mask_ok() const {
76 using namespace data_type;
77 const std::vector<int> supported_args
78 = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
79 bool ok = attr()->scales_.has_default_values(supported_args);
80 for (int arg : supported_args) {
81 const auto &mask = attr()->scales_.get(arg).mask_;
82 if (arg == DNNL_ARG_WEIGHTS)
83 ok = ok && (mask == 0 || mask == (1 << (int)with_groups()));
84 else
85 ok = ok && (mask == 0);
86 }
87 return ok;
88 }
89
90 bool zero_points_ok() const {
91 int mask_src = 0, mask_dst = 0;
92 attr()->zero_points_.get(DNNL_ARG_SRC, &mask_src);
93 attr()->zero_points_.get(DNNL_ARG_DST, &mask_dst);
94
95 return attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS)
96 && (mask_src == 0 || mask_src == 1 << 1)
97 && (mask_dst == 0 || mask_dst == 1 << 1);
98 }
99
100 bool post_ops_ok() const {
101 return attr()->post_ops_.find(primitive_kind::convolution) == -1;
102 }
103 };
104
105 ref_convolution_int8_fwd_t(const pd_t *apd) : primitive_t(apd) {}
106
107 status_t init(engine_t *engine) override {
108 ref_post_ops
109 = utils::make_unique<ref_post_ops_t>(pd()->attr()->post_ops_);
110 if (!ref_post_ops) return status::out_of_memory;
111 return status::success;
112 }
113
114 status_t execute(const exec_ctx_t &ctx) const override {
115 return execute_forward(ctx);
116 }
117
118private:
119 status_t execute_forward(const exec_ctx_t &ctx) const;
120 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
121 std::unique_ptr<ref_post_ops_t> ref_post_ops;
122};
123
124struct ref_convolution_int8_bwd_data_t : public primitive_t {
125 struct pd_t : public cpu_convolution_bwd_data_pd_t {
126 using cpu_convolution_bwd_data_pd_t::cpu_convolution_bwd_data_pd_t;
127
128 DECLARE_COMMON_PD_T("ref:any", ref_convolution_int8_bwd_data_t);
129
130 status_t init(engine_t *engine) {
131 using namespace data_type;
132 const auto diff_src_type = diff_src_md(0)->data_type;
133 const auto wei_type = weights_md(0)->data_type;
134 const auto diff_dst_type = diff_dst_md(0)->data_type;
135
136 bool ok = desc()->prop_kind == prop_kind::backward_data
137 && set_default_alg_kind(alg_kind::convolution_direct)
138 && utils::one_of(diff_dst_type, s8, u8) && wei_type == s8
139 && utils::one_of(diff_src_type, f32, bf16, s32, s8, u8)
140 && set_default_formats()
141 && attr()->has_default_values(
142 primitive_attr_t::skip_mask_t::scales_runtime)
143 && scales_mask_ok();
144
145 return ok ? status::success : status::unimplemented;
146 }
147
148 protected:
149 bool set_default_formats() {
150 using namespace format_tag;
151 auto dat_tag = utils::pick(ndims() - 3, nwc, nhwc, ndhwc);
152 auto wei_tag = with_groups()
153 ? utils::pick(ndims() - 3, goiw, goihw, goidhw)
154 : utils::pick(ndims() - 3, oiw, oihw, oidhw);
155 return set_default_formats_common(dat_tag, wei_tag, dat_tag);
156 }
157
158 bool scales_mask_ok() const {
159 using namespace data_type;
160 const std::vector<int> supported_args
161 = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
162 bool ok = attr()->scales_.has_default_values(supported_args);
163 for (int arg : supported_args) {
164 const auto &mask = attr()->scales_.get(arg).mask_;
165 if (arg == DNNL_ARG_WEIGHTS)
166 ok = ok && (mask == 0 || mask == (1 << (int)with_groups()));
167 else
168 ok = ok && (mask == 0);
169 }
170 return ok;
171 }
172 };
173
174 ref_convolution_int8_bwd_data_t(const pd_t *apd) : primitive_t(apd) {}
175
176 status_t execute(const exec_ctx_t &ctx) const override {
177 return execute_backward_data(ctx);
178 }
179
180private:
181 status_t execute_backward_data(const exec_ctx_t &ctx) const;
182 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
183};
184
185} // namespace cpu
186} // namespace impl
187} // namespace dnnl
188
189#endif
190
191// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
192