1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_GPU_CONVOLUTION_PD_HPP
18#define GPU_GPU_CONVOLUTION_PD_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/convolution_pd.hpp"
24#include "common/type_helpers.hpp"
25#include "common/utils.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace gpu {
30
31struct gpu_convolution_fwd_pd_t : public convolution_fwd_pd_t {
32 using convolution_fwd_pd_t::convolution_fwd_pd_t;
33
34protected:
35 bool arg_scales_ok() const {
36 std::vector<int> supported_args
37 = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
38 if (!attr()->scales_.has_default_values(supported_args)) return false;
39 for (int arg : supported_args) {
40 auto &scales = attr()->scales_.get(arg);
41 if (scales.has_default_values()) continue;
42 int mask = scales.mask_;
43 if (arg == DNNL_ARG_WEIGHTS) {
44 if (!utils::one_of(mask, 0, 1 << (int)with_groups()))
45 return false;
46 } else {
47 if (mask != 0) return false;
48 }
49 }
50 return true;
51 }
52
53 // TODO: consider either moving this method to primitive_conf.hpp or making
54 // it static, or removing the 'attr' argument accessible via attr()
55 bool zero_points_ok(const primitive_attr_t *attr) const {
56 using namespace data_type;
57 const auto src_type = invariant_src_md()->data_type;
58 int mask_src = 0, mask_dst = 0;
59 attr->zero_points_.get(DNNL_ARG_SRC, &mask_src);
60 attr->zero_points_.get(DNNL_ARG_DST, &mask_dst);
61
62 return IMPLICATION(!utils::one_of(src_type, s8, u8),
63 attr->zero_points_.has_default_values())
64 && attr->zero_points_.has_default_values(DNNL_ARG_WEIGHTS)
65 && (mask_src == 0 || mask_src == 1 << 1)
66 && (mask_dst == 0 || mask_dst == 1 << 1);
67 }
68};
69
70struct gpu_convolution_bwd_data_pd_t : public convolution_bwd_data_pd_t {
71 using convolution_bwd_data_pd_t::convolution_bwd_data_pd_t;
72
73protected:
74 bool arg_scales_ok() const {
75 std::vector<int> supported_args
76 = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
77 if (!attr()->scales_.has_default_values(supported_args)) return false;
78 for (int arg : supported_args) {
79 auto &scales = attr()->scales_.get(arg);
80 if (scales.has_default_values()) continue;
81 int mask = scales.mask_;
82 if (arg == DNNL_ARG_WEIGHTS) {
83 // XXX: per_oc for BWD_D is treated as per_ic assuming it's
84 // called from deconvolution.
85 if (!utils::one_of(mask, 0, 1 << (int)with_groups()))
86 return false;
87 } else {
88 if (mask != 0) return false;
89 }
90 }
91 return true;
92 }
93
94 // TODO: consider either moving this method to primitive_conf.hpp or making
95 // it static, or removing the 'attr' argument accessible via attr()
96 bool zero_points_ok(const primitive_attr_t *attr) const {
97 using namespace data_type;
98 const auto dst_type = invariant_dst_md()->data_type;
99 int mask_src = 0, mask_dst = 0;
100 attr->zero_points_.get(DNNL_ARG_SRC, &mask_src);
101 attr->zero_points_.get(DNNL_ARG_DST, &mask_dst);
102
103 return IMPLICATION(!utils::one_of(dst_type, s8, u8),
104 attr->zero_points_.has_default_values())
105 && attr->zero_points_.has_default_values(DNNL_ARG_WEIGHTS)
106 && (mask_src == 0 || mask_src == 1 << 1)
107 && (mask_dst == 0 || mask_dst == 1 << 1);
108 }
109
110 // TODO: consider either moving this method to primitive_conf.hpp or making
111 // it static, or removing the 'attr' argument accessible via attr()
112 bool post_ops_ok(const primitive_attr_t *attr) const {
113 const auto &p = attr->post_ops_;
114
115 auto is_eltwise
116 = [&](int idx) { return p.entry_[idx].is_eltwise(false); };
117 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(false); };
118
119 switch (p.len()) {
120 case 0: return true; // no post_ops
121 case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
122 case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
123 default: return false;
124 }
125
126 return false;
127 }
128};
129
130struct gpu_convolution_bwd_weights_pd_t : public convolution_bwd_weights_pd_t {
131 using convolution_bwd_weights_pd_t::convolution_bwd_weights_pd_t;
132};
133
134} // namespace gpu
135} // namespace impl
136} // namespace dnnl
137
138#endif
139
140// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
141