1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_DW_CONVOLUTION_UTILS_HPP |
18 | #define CPU_DW_CONVOLUTION_UTILS_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/convolution_pd.hpp" |
22 | #include "common/primitive_desc_iterator.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | #include "common/utils.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace cpu { |
29 | |
30 | inline status_t get_depthwise_conv_desc(convolution_desc_t &cd_dw, |
31 | const memory_desc_t &src_dw_md, const primitive_attr_t &attr_1x1, |
32 | primitive_attr_t &attr_dw, int dw_po_index) { |
33 | |
34 | const memory_desc_wrapper src_dw_d(src_dw_md); |
35 | const int ndims = src_dw_d.ndims(); |
36 | if (ndims != 4) return status::unimplemented; |
37 | |
38 | if (dw_po_index == -1 || dw_po_index >= attr_1x1.post_ops_.len() |
39 | || !attr_1x1.post_ops_.entry_[dw_po_index].is_convolution()) |
40 | return status::invalid_arguments; |
41 | |
42 | // Create new attributes with scales from depthwise post-op and copy |
43 | // post-ops after depthwise post-op. |
44 | auto &dw_po = attr_1x1.post_ops_.entry_[dw_po_index].depthwise_conv; |
45 | |
46 | // erase 1x1 conv scales |
47 | for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { |
48 | auto &scale = attr_dw.scales_.get(arg); |
49 | if (!scale.has_default_values()) attr_dw.scales_.reset(arg); |
50 | } |
51 | |
52 | const auto &dw_src_scales = attr_1x1.scales_.get(DNNL_ARG_DST); |
53 | const auto &dw_wei_scales |
54 | = attr_1x1.scales_.get(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS); |
55 | const auto &dw_dst_scales |
56 | = attr_1x1.scales_.get(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_DST); |
57 | if (!dw_src_scales.has_default_values()) |
58 | attr_dw.scales_.set(DNNL_ARG_SRC, dw_src_scales.mask_); |
59 | if (!dw_wei_scales.has_default_values()) |
60 | attr_dw.scales_.set(DNNL_ARG_WEIGHTS, dw_wei_scales.mask_); |
61 | if (!dw_dst_scales.has_default_values()) |
62 | attr_dw.scales_.set(DNNL_ARG_DST, dw_dst_scales.mask_); |
63 | |
64 | auto dw_po_len = attr_1x1.post_ops_.len() - (dw_po_index + 1); |
65 | attr_dw.post_ops_.entry_.resize(dw_po_len); |
66 | for (int i = 0; i < dw_po_len; ++i) { |
67 | CHECK(attr_dw.post_ops_.entry_[i].copy_from( |
68 | attr_1x1.post_ops_.entry_[i + dw_po_index + 1])); |
69 | } |
70 | |
71 | attr_dw.scratchpad_mode_ = attr_1x1.scratchpad_mode_; |
72 | |
73 | const bool with_bias = dw_po.bias_dt != data_type::undef; |
74 | |
75 | const auto n = src_dw_d.dims()[0]; |
76 | const auto oc = src_dw_d.dims()[1]; |
77 | const auto g = src_dw_d.dims()[1]; |
78 | const auto ih = src_dw_d.dims()[ndims - 2]; |
79 | const auto iw = src_dw_d.dims()[ndims - 1]; |
80 | const auto kernel = dw_po.kernel; |
81 | const auto stride = dw_po.stride; |
82 | const auto padding = dw_po.padding; |
83 | |
84 | const dims_t weights_tz = {g, 1, 1, kernel, kernel}; |
85 | |
86 | // Not following standard convolution formula for output shapes since |
87 | // right/top padding might be greated than left/top one. |
88 | const dim_t oh = utils::div_up(ih, stride); |
89 | const dim_t ow = utils::div_up(iw, stride); |
90 | const dims_t dst_tz = {n, oc, oh, ow}; |
91 | |
92 | const dims_t bias_tz = {oc}; |
93 | const dims_t pad_tz = {padding, padding}; |
94 | const dims_t stride_tz = {stride, stride}; |
95 | |
96 | const dim_t pad_h_r = (oh - 1) * stride - ih + kernel - padding; |
97 | const dim_t pad_w_r = (ow - 1) * stride - iw + kernel - padding; |
98 | const dims_t pad_r_tz = {pad_h_r, pad_w_r}; |
99 | |
100 | memory_desc_t src_md, weights_md, bias_md, dst_md; |
101 | |
102 | const auto src_dw_tag = src_dw_d.matches_one_of_tag( |
103 | format_tag::nChw16c, format_tag::nChw8c, format_tag::nhwc); |
104 | const auto data_tag |
105 | = (src_dw_tag == format_tag::undef) ? format_tag::any : src_dw_tag; |
106 | |
107 | memory_desc_init_by_tag( |
108 | src_md, ndims, src_dw_md.dims, src_dw_md.data_type, data_tag); |
109 | |
110 | memory_desc_init_by_tag( |
111 | weights_md, ndims + 1, weights_tz, dw_po.wei_dt, format_tag::any); |
112 | |
113 | if (with_bias) |
114 | memory_desc_init_by_tag( |
115 | bias_md, 1, bias_tz, dw_po.bias_dt, format_tag::a); |
116 | |
117 | memory_desc_init_by_tag(dst_md, ndims, dst_tz, dw_po.dst_dt, data_tag); |
118 | |
119 | CHECK(conv_desc_init(&cd_dw, prop_kind::forward_inference, |
120 | alg_kind::convolution_auto, &src_md, &weights_md, |
121 | with_bias ? &bias_md : nullptr, &dst_md, stride_tz, nullptr, pad_tz, |
122 | pad_r_tz)); |
123 | |
124 | return status::success; |
125 | } |
126 | |
127 | } // namespace cpu |
128 | } // namespace impl |
129 | } // namespace dnnl |
130 | |
131 | #endif |
132 | |