1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include "oneapi/dnnl/dnnl.h"
19#include "opdesc.hpp"
20#include "primitive_desc_iface.hpp"
21
22#include "c_types_map.hpp"
23#include "math_utils.hpp"
24#include "type_helpers.hpp"
25#include "utils.hpp"
26
27using namespace dnnl::impl;
28using namespace dnnl::impl::utils;
29using namespace dnnl::impl::status;
30using namespace dnnl::impl::prop_kind;
31using namespace dnnl::impl::alg_kind;
32using namespace dnnl::impl::types;
33
34namespace {
35status_t eltwise_desc_init(eltwise_desc_t *eltwise_desc, prop_kind_t prop_kind,
36 alg_kind_t alg_kind, const memory_desc_t *src_desc,
37 const memory_desc_t *dst_desc, const memory_desc_t *diff_src_desc,
38 const memory_desc_t *diff_dst_desc, float alpha, float beta) {
39 const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
40 bool args_ok = !any_null(eltwise_desc, src_desc, dst_desc)
41 && one_of(prop_kind, forward_training, forward_inference,
42 backward_data)
43 && math::is_eltwise_ok(src_desc->data_type, alg_kind, alpha, beta)
44 && IMPLICATION(!is_fwd, !any_null(diff_src_desc, diff_dst_desc))
45 && IMPLICATION(alg_kind == eltwise_round, is_fwd)
46 && IMPLICATION(is_fwd, !memory_desc_wrapper(src_desc).format_any());
47 if (!args_ok) return invalid_arguments;
48
49 bool runtime_dims_or_strides
50 = memory_desc_wrapper(src_desc).has_runtime_dims_or_strides()
51 || memory_desc_wrapper(dst_desc).has_runtime_dims_or_strides();
52 if (!is_fwd)
53 runtime_dims_or_strides = runtime_dims_or_strides
54 || memory_desc_wrapper(diff_src_desc)
55 .has_runtime_dims_or_strides()
56 || memory_desc_wrapper(diff_dst_desc)
57 .has_runtime_dims_or_strides();
58 if (runtime_dims_or_strides) return unimplemented;
59
60 auto ed = eltwise_desc_t();
61 ed.primitive_kind = primitive_kind::eltwise;
62 ed.prop_kind = prop_kind;
63 ed.alg_kind = alg_kind;
64
65 ed.src_desc = *src_desc;
66 ed.dst_desc = *dst_desc;
67 if (!is_fwd) {
68 ed.diff_src_desc = *diff_src_desc;
69 ed.diff_dst_desc = *diff_dst_desc;
70 }
71
72 ed.alpha = alpha;
73 ed.beta = beta;
74
75 bool consistency = true;
76 if (consistency && is_fwd) {
77 consistency = ed.dst_desc.ndims == ed.src_desc.ndims
78 && array_cmp(
79 ed.dst_desc.dims, ed.src_desc.dims, ed.src_desc.ndims);
80 }
81 if (consistency && !is_fwd) {
82 consistency = ed.diff_dst_desc.ndims == ed.src_desc.ndims
83 && ed.diff_dst_desc.ndims == ed.diff_src_desc.ndims
84 && array_cmp(ed.diff_dst_desc.dims, ed.src_desc.dims,
85 ed.src_desc.ndims)
86 && array_cmp(ed.diff_src_desc.dims, ed.diff_dst_desc.dims,
87 ed.diff_dst_desc.ndims);
88 }
89 if (!consistency) return invalid_arguments;
90
91 *eltwise_desc = ed;
92 return success;
93}
94} // namespace
95
96status_t dnnl_eltwise_forward_primitive_desc_create(
97 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
98 prop_kind_t prop_kind, alg_kind_t alg_kind,
99 const memory_desc_t *src_desc, const memory_desc_t *dst_desc,
100 float alpha, float beta, const primitive_attr_t *attr) {
101 if (!one_of(prop_kind, forward_training, forward_inference))
102 return invalid_arguments;
103
104 auto eltwise_desc = eltwise_desc_t();
105 CHECK(eltwise_desc_init(&eltwise_desc, prop_kind, alg_kind, src_desc,
106 dst_desc, nullptr, nullptr, alpha, beta));
107 return primitive_desc_create(primitive_desc_iface, engine,
108 (const op_desc_t *)&eltwise_desc, nullptr, attr);
109}
110
111status_t dnnl_eltwise_backward_primitive_desc_create(
112 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
113 alg_kind_t alg_kind, const memory_desc_t *diff_src_desc,
114 const memory_desc_t *diff_dst_desc, const memory_desc_t *data_desc,
115 float alpha, float beta, const primitive_desc_iface_t *hint_fwd_pd,
116 const primitive_attr_t *attr) {
117
118 auto eltwise_desc = eltwise_desc_t();
119 CHECK(eltwise_desc_init(&eltwise_desc, backward_data, alg_kind, data_desc,
120 data_desc, diff_src_desc, diff_dst_desc, alpha, beta));
121 return primitive_desc_create(primitive_desc_iface, engine,
122 (const op_desc_t *)&eltwise_desc, hint_fwd_pd, attr);
123}
124
125// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
126