1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include "dnnl.h"
19
20#include "common/c_types_map.hpp"
21#include "opdesc.hpp"
22#include "primitive_desc_iface.hpp"
23#include "type_helpers.hpp"
24#include "utils.hpp"
25
26#include "common/broadcast_strategy.hpp"
27
28using namespace dnnl::impl;
29using namespace dnnl::impl::utils;
30using namespace dnnl::impl::status;
31using namespace dnnl::impl::prop_kind;
32using namespace dnnl::impl::types;
33
34namespace {
35status_t prelu_desc_init(prelu_desc_t *prelu_desc, prop_kind_t prop_kind,
36 const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
37 const memory_desc_t *dst_desc, const memory_desc_t *diff_src_desc,
38 const memory_desc_t *diff_weights_desc,
39 const memory_desc_t *diff_dst_desc) {
40 const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
41 bool args_ok = !any_null(prelu_desc, src_desc, weights_desc)
42 && one_of(prop_kind, forward_training, forward_inference, backward)
43 && IMPLICATION(is_fwd, dst_desc != nullptr)
44 && IMPLICATION(!is_fwd,
45 !any_null(diff_src_desc, diff_weights_desc, diff_dst_desc))
46 && IMPLICATION(is_fwd, !memory_desc_wrapper(src_desc).format_any());
47 if (!args_ok) return invalid_arguments;
48
49 if (memory_desc_wrapper(src_desc).has_runtime_dims_or_strides()
50 || memory_desc_wrapper(weights_desc).has_runtime_dims_or_strides())
51 return unimplemented;
52 if (prop_kind == backward
53 && (memory_desc_wrapper(diff_src_desc).has_runtime_dims_or_strides()
54 || memory_desc_wrapper(diff_weights_desc)
55 .has_runtime_dims_or_strides()))
56 return unimplemented;
57
58 auto pd = prelu_desc_t();
59 pd.primitive_kind = primitive_kind::prelu;
60 pd.prop_kind = prop_kind;
61 pd.src_desc = *src_desc;
62 pd.weights_desc = *weights_desc;
63 if (is_fwd) {
64 pd.dst_desc = *dst_desc;
65 } else {
66 pd.diff_src_desc = *diff_src_desc;
67 pd.diff_weights_desc = *diff_weights_desc;
68 pd.diff_dst_desc = *diff_dst_desc;
69 }
70
71 const memory_desc_wrapper src_mdw(*src_desc);
72 if (get_rhs_arg_broadcasting_strategy(pd.weights_desc, src_mdw)
73 == broadcasting_strategy_t::unsupported)
74 return invalid_arguments;
75
76 static constexpr int max_supported_ndims = 5;
77 bool consistency = src_desc->ndims <= max_supported_ndims
78 && src_desc->ndims == weights_desc->ndims;
79 if (consistency && is_fwd) {
80 consistency = pd.dst_desc.ndims == pd.src_desc.ndims
81 && array_cmp(
82 pd.dst_desc.dims, pd.src_desc.dims, pd.src_desc.ndims);
83 }
84 if (consistency && !is_fwd) {
85 consistency = pd.diff_dst_desc.ndims == pd.src_desc.ndims
86 && pd.diff_dst_desc.ndims == pd.diff_src_desc.ndims
87 && array_cmp(pd.diff_dst_desc.dims, pd.src_desc.dims,
88 pd.src_desc.ndims)
89 && array_cmp(pd.diff_src_desc.dims, pd.diff_dst_desc.dims,
90 pd.diff_dst_desc.ndims);
91 }
92 if (!consistency) return invalid_arguments;
93
94 *prelu_desc = pd;
95 return success;
96}
97} // namespace
98
99status_t dnnl_prelu_forward_primitive_desc_create(
100 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
101 prop_kind_t prop_kind, const memory_desc_t *src_desc,
102 const memory_desc_t *weights_desc, const memory_desc_t *dst_desc,
103 const primitive_attr_t *attr) {
104
105 if (!one_of(prop_kind, forward_training, forward_inference))
106 return invalid_arguments;
107
108 auto prelu_desc = prelu_desc_t();
109 CHECK(prelu_desc_init(&prelu_desc, prop_kind, src_desc, weights_desc,
110 dst_desc, nullptr, nullptr, nullptr));
111
112 return primitive_desc_create(primitive_desc_iface, engine,
113 (const op_desc_t *)&prelu_desc, nullptr, attr);
114}
115
116status_t dnnl_prelu_backward_primitive_desc_create(
117 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
118 const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
119 const memory_desc_t *diff_src_desc,
120 const memory_desc_t *diff_weights_desc,
121 const memory_desc_t *diff_dst_desc,
122 const primitive_desc_iface_t *hint_fwd_pd,
123 const primitive_attr_t *attr) {
124
125 auto prelu_desc = prelu_desc_t();
126 CHECK(prelu_desc_init(&prelu_desc, backward, src_desc, weights_desc,
127 nullptr, diff_src_desc, diff_weights_desc, diff_dst_desc));
128
129 return primitive_desc_create(primitive_desc_iface, engine,
130 (const op_desc_t *)&prelu_desc, hint_fwd_pd, attr);
131}
132
133// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
134