1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include "oneapi/dnnl/dnnl.h"
19
20#include "c_types_map.hpp"
21#include "opdesc.hpp"
22#include "primitive_desc_iface.hpp"
23#include "type_helpers.hpp"
24#include "utils.hpp"
25
26using namespace dnnl::impl;
27using namespace dnnl::impl::utils;
28using namespace dnnl::impl::status;
29using namespace dnnl::impl::prop_kind;
30using namespace dnnl::impl::alg_kind;
31using namespace dnnl::impl::types;
32
33namespace {
34status_t pooling_desc_init(pooling_desc_t *pool_desc, prop_kind_t prop_kind,
35 alg_kind_t alg_kind, const memory_desc_t *src_desc,
36 const memory_desc_t *dst_desc, const dims_t strides,
37 const dims_t kernel, const dims_t dilation, const dims_t padding_l,
38 const dims_t padding_r) {
39 bool args_ok = !any_null(pool_desc, src_desc, dst_desc, strides, kernel,
40 padding_l)
41 && one_of(alg_kind, pooling_max, pooling_avg_include_padding,
42 pooling_avg_exclude_padding)
43 && IMPLICATION(
44 one_of(prop_kind, forward_training, forward_inference),
45 !memory_desc_wrapper(src_desc).format_any());
46 if (!args_ok) return invalid_arguments;
47
48 if (padding_r == nullptr) padding_r = padding_l;
49
50 auto pd = pooling_desc_t();
51 pd.primitive_kind = primitive_kind::pooling;
52 pd.prop_kind = prop_kind;
53 pd.alg_kind = alg_kind;
54 pd.src_desc.ndims = src_desc->ndims;
55
56 const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
57
58 bool runtime_dims_or_strides
59 = memory_desc_wrapper(src_desc).has_runtime_dims_or_strides()
60 || memory_desc_wrapper(dst_desc).has_runtime_dims_or_strides();
61 if (runtime_dims_or_strides) return unimplemented;
62
63 pd.diff_src_desc = pd.src_desc = zero_md();
64 pd.diff_dst_desc = pd.dst_desc = zero_md();
65
66 (is_fwd ? pd.src_desc : pd.diff_src_desc) = *src_desc;
67 (is_fwd ? pd.dst_desc : pd.diff_dst_desc) = *dst_desc;
68
69 int sp_dims = src_desc->ndims - 2;
70 utils::array_copy(pd.strides, strides, sp_dims);
71 utils::array_copy(pd.kernel, kernel, sp_dims);
72 utils::array_copy(pd.padding[0], padding_l, sp_dims);
73 utils::array_copy(pd.padding[1], padding_r, sp_dims);
74 utils::array_copy(pd.dilation, dilation, sp_dims);
75
76 if (one_of(alg_kind, pooling_max, pooling_avg_include_padding,
77 pooling_avg_exclude_padding)) {
78 pd.accum_data_type = types::default_accum_data_type(
79 src_desc->data_type, dst_desc->data_type, false);
80 if (pd.accum_data_type == data_type::undef) return invalid_arguments;
81 } else {
82 pd.accum_data_type = dst_desc->data_type;
83 }
84
85 if (!utils::one_of(src_desc->ndims, 3, 4, 5)
86 || !utils::one_of(dst_desc->ndims, 3, 4, 5)
87 || src_desc->dims[0] != dst_desc->dims[0]
88 || src_desc->dims[1] != dst_desc->dims[1])
89 return invalid_arguments;
90
91 for (int i = 2; i < src_desc->ndims; ++i) {
92 const int src = src_desc->dims[i];
93 const int dst = dst_desc->dims[i];
94 const int ker = kernel[i - 2];
95 const int dil = dilation ? dilation[i - 2] : 0;
96 const int pad_l = padding_l[i - 2];
97 const int pad_r = padding_r[i - 2];
98 const int str = strides[i - 2];
99 const int ker_range = 1 + (ker - 1) * (dil + 1);
100
101 if (str < 1 || dil < 0 || pad_l < 0 || pad_r + str < 0)
102 return invalid_arguments;
103
104 if ((src - ker_range + pad_l + pad_r) / str + 1 != dst)
105 return invalid_arguments;
106
107 // It's not allowed for pooling window to be totally placed outside
108 // of real source domain for pooling_avg_exclude_padding algorithm
109 // due to 0 / 0 ambiguity
110 if (alg_kind == pooling_avg_exclude_padding
111 && !(pad_l < ker_range && pad_r < ker_range && dil < src))
112 return invalid_arguments;
113 }
114
115 *pool_desc = pd;
116 return success;
117}
118} // namespace
119
120dnnl_status_t dnnl_pooling_forward_primitive_desc_create(
121 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
122 prop_kind_t prop_kind, alg_kind_t alg_kind,
123 const memory_desc_t *src_desc, const memory_desc_t *dst_desc,
124 const dims_t strides, const dims_t kernel, const dims_t dilation,
125 const dims_t padding_l, const dims_t padding_r,
126 const primitive_attr_t *attr) {
127
128 if (!one_of(prop_kind, forward_training, forward_inference))
129 return invalid_arguments;
130
131 auto pool_desc = pooling_desc_t();
132 CHECK(pooling_desc_init(&pool_desc, prop_kind, alg_kind, src_desc, dst_desc,
133 strides, kernel, dilation, padding_l, padding_r));
134 return primitive_desc_create(primitive_desc_iface, engine,
135 (const op_desc_t *)&pool_desc, nullptr, attr);
136}
137
138dnnl_status_t dnnl_pooling_backward_primitive_desc_create(
139 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
140 alg_kind_t alg_kind, const memory_desc_t *diff_src_desc,
141 const memory_desc_t *diff_dst_desc, const dims_t strides,
142 const dims_t kernel, const dims_t dilation, const dims_t padding_l,
143 const dims_t padding_r, const primitive_desc_iface_t *hint_fwd_pd,
144 const primitive_attr_t *attr) {
145
146 auto pool_desc = pooling_desc_t();
147 CHECK(pooling_desc_init(&pool_desc, prop_kind::backward_data, alg_kind,
148 diff_src_desc, diff_dst_desc, strides, kernel, dilation, padding_l,
149 padding_r));
150 return primitive_desc_create(primitive_desc_iface, engine,
151 (const op_desc_t *)&pool_desc, hint_fwd_pd, attr);
152}
153
154// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
155