1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_DW_CONVOLUTION_UTILS_HPP
18#define CPU_DW_CONVOLUTION_UTILS_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/convolution_pd.hpp"
22#include "common/primitive_desc_iterator.hpp"
23#include "common/type_helpers.hpp"
24#include "common/utils.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace cpu {
29
30inline status_t get_depthwise_conv_desc(convolution_desc_t &cd_dw,
31 const memory_desc_t &src_dw_md, const primitive_attr_t &attr_1x1,
32 primitive_attr_t &attr_dw, int dw_po_index) {
33
34 const memory_desc_wrapper src_dw_d(src_dw_md);
35 const int ndims = src_dw_d.ndims();
36 if (ndims != 4) return status::unimplemented;
37
38 if (dw_po_index == -1 || dw_po_index >= attr_1x1.post_ops_.len()
39 || !attr_1x1.post_ops_.entry_[dw_po_index].is_convolution())
40 return status::invalid_arguments;
41
42 // Create new attributes with scales from depthwise post-op and copy
43 // post-ops after depthwise post-op.
44 auto &dw_po = attr_1x1.post_ops_.entry_[dw_po_index].depthwise_conv;
45
46 // erase 1x1 conv scales
47 for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) {
48 auto &scale = attr_dw.scales_.get(arg);
49 if (!scale.has_default_values()) attr_dw.scales_.reset(arg);
50 }
51
52 const auto &dw_src_scales = attr_1x1.scales_.get(DNNL_ARG_DST);
53 const auto &dw_wei_scales
54 = attr_1x1.scales_.get(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS);
55 const auto &dw_dst_scales
56 = attr_1x1.scales_.get(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_DST);
57 if (!dw_src_scales.has_default_values())
58 attr_dw.scales_.set(DNNL_ARG_SRC, dw_src_scales.mask_);
59 if (!dw_wei_scales.has_default_values())
60 attr_dw.scales_.set(DNNL_ARG_WEIGHTS, dw_wei_scales.mask_);
61 if (!dw_dst_scales.has_default_values())
62 attr_dw.scales_.set(DNNL_ARG_DST, dw_dst_scales.mask_);
63
64 auto dw_po_len = attr_1x1.post_ops_.len() - (dw_po_index + 1);
65 attr_dw.post_ops_.entry_.resize(dw_po_len);
66 for (int i = 0; i < dw_po_len; ++i) {
67 CHECK(attr_dw.post_ops_.entry_[i].copy_from(
68 attr_1x1.post_ops_.entry_[i + dw_po_index + 1]));
69 }
70
71 attr_dw.scratchpad_mode_ = attr_1x1.scratchpad_mode_;
72
73 const bool with_bias = dw_po.bias_dt != data_type::undef;
74
75 const auto n = src_dw_d.dims()[0];
76 const auto oc = src_dw_d.dims()[1];
77 const auto g = src_dw_d.dims()[1];
78 const auto ih = src_dw_d.dims()[ndims - 2];
79 const auto iw = src_dw_d.dims()[ndims - 1];
80 const auto kernel = dw_po.kernel;
81 const auto stride = dw_po.stride;
82 const auto padding = dw_po.padding;
83
84 const dims_t weights_tz = {g, 1, 1, kernel, kernel};
85
86 // Not following standard convolution formula for output shapes since
87 // right/top padding might be greated than left/top one.
88 const dim_t oh = utils::div_up(ih, stride);
89 const dim_t ow = utils::div_up(iw, stride);
90 const dims_t dst_tz = {n, oc, oh, ow};
91
92 const dims_t bias_tz = {oc};
93 const dims_t pad_tz = {padding, padding};
94 const dims_t stride_tz = {stride, stride};
95
96 const dim_t pad_h_r = (oh - 1) * stride - ih + kernel - padding;
97 const dim_t pad_w_r = (ow - 1) * stride - iw + kernel - padding;
98 const dims_t pad_r_tz = {pad_h_r, pad_w_r};
99
100 memory_desc_t src_md, weights_md, bias_md, dst_md;
101
102 const auto src_dw_tag = src_dw_d.matches_one_of_tag(
103 format_tag::nChw16c, format_tag::nChw8c, format_tag::nhwc);
104 const auto data_tag
105 = (src_dw_tag == format_tag::undef) ? format_tag::any : src_dw_tag;
106
107 memory_desc_init_by_tag(
108 src_md, ndims, src_dw_md.dims, src_dw_md.data_type, data_tag);
109
110 memory_desc_init_by_tag(
111 weights_md, ndims + 1, weights_tz, dw_po.wei_dt, format_tag::any);
112
113 if (with_bias)
114 memory_desc_init_by_tag(
115 bias_md, 1, bias_tz, dw_po.bias_dt, format_tag::a);
116
117 memory_desc_init_by_tag(dst_md, ndims, dst_tz, dw_po.dst_dt, data_tag);
118
119 CHECK(conv_desc_init(&cd_dw, prop_kind::forward_inference,
120 alg_kind::convolution_auto, &src_md, &weights_md,
121 with_bias ? &bias_md : nullptr, &dst_md, stride_tz, nullptr, pad_tz,
122 pad_r_tz));
123
124 return status::success;
125}
126
127} // namespace cpu
128} // namespace impl
129} // namespace dnnl
130
131#endif
132