1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include "oneapi/dnnl/dnnl.h"
19#include "opdesc.hpp"
20#include "primitive_desc_iface.hpp"
21
22#include "c_types_map.hpp"
23#include "type_helpers.hpp"
24#include "utils.hpp"
25
26using namespace dnnl::impl;
27using namespace dnnl::impl::utils;
28using namespace dnnl::impl::status;
29using namespace dnnl::impl::prop_kind;
30using namespace dnnl::impl::alg_kind;
31using namespace dnnl::impl::types;
32
33namespace {
34status_t lrn_desc_init(lrn_desc_t *lrn_desc, prop_kind_t prop_kind,
35 alg_kind_t alg_kind, const memory_desc_t *src_desc,
36 const memory_desc_t *dst_desc, const memory_desc_t *diff_src_desc,
37 const memory_desc_t *diff_dst_desc, dim_t local_size, float alpha,
38 float beta, float k) {
39 const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
40 bool args_ok = !any_null(lrn_desc, src_desc)
41 && one_of(alg_kind, lrn_within_channel, lrn_across_channels)
42 && one_of(prop_kind, forward_training, forward_inference,
43 backward_data)
44 && local_size >= 0 && IMPLICATION(is_fwd, dst_desc != nullptr)
45 && IMPLICATION(!is_fwd, !any_null(diff_src_desc, diff_dst_desc))
46 && IMPLICATION(is_fwd, !memory_desc_wrapper(src_desc).format_any());
47 if (!args_ok) return invalid_arguments;
48
49 auto ld = lrn_desc_t();
50 ld.primitive_kind = primitive_kind::lrn;
51 ld.prop_kind = prop_kind;
52 ld.alg_kind = alg_kind;
53
54 bool runtime_dims_or_strides
55 = memory_desc_wrapper(src_desc).has_runtime_dims_or_strides();
56 if (is_fwd) {
57 runtime_dims_or_strides = runtime_dims_or_strides
58 || memory_desc_wrapper(dst_desc).has_runtime_dims_or_strides();
59 } else {
60 runtime_dims_or_strides = runtime_dims_or_strides
61 || memory_desc_wrapper(diff_src_desc)
62 .has_runtime_dims_or_strides()
63 || memory_desc_wrapper(diff_dst_desc)
64 .has_runtime_dims_or_strides();
65 }
66 if (runtime_dims_or_strides) return unimplemented;
67
68 ld.src_desc = *src_desc;
69 if (is_fwd) ld.dst_desc = *dst_desc;
70 if (!is_fwd) {
71 ld.diff_src_desc = *diff_src_desc;
72 ld.diff_dst_desc = *diff_dst_desc;
73 }
74
75 ld.local_size = local_size;
76 ld.lrn_alpha = alpha;
77 ld.lrn_beta = beta;
78 ld.lrn_k = k;
79
80 bool consistency = ld.src_desc.ndims >= 2;
81 if (consistency && is_fwd) {
82 consistency = ld.dst_desc.ndims == ld.src_desc.ndims
83 && array_cmp(
84 ld.dst_desc.dims, ld.src_desc.dims, ld.src_desc.ndims);
85 }
86 if (consistency && !is_fwd) {
87 consistency = ld.diff_dst_desc.ndims == ld.src_desc.ndims
88 && ld.diff_dst_desc.ndims == ld.diff_src_desc.ndims
89 && array_cmp(ld.diff_dst_desc.dims, ld.src_desc.dims,
90 ld.src_desc.ndims)
91 && array_cmp(ld.diff_src_desc.dims, ld.diff_dst_desc.dims,
92 ld.diff_dst_desc.ndims);
93 }
94 if (!consistency) return invalid_arguments;
95
96 *lrn_desc = ld;
97 return success;
98}
99} // namespace
100
101status_t dnnl_lrn_forward_primitive_desc_create(
102 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
103 prop_kind_t prop_kind, alg_kind_t alg_kind,
104 const memory_desc_t *src_desc, const memory_desc_t *dst_desc,
105 dim_t local_size, float alpha, float beta, float k,
106 const primitive_attr_t *attr) {
107 if (!one_of(prop_kind, forward_training, forward_inference))
108 return invalid_arguments;
109
110 auto lrn_desc = lrn_desc_t();
111 CHECK(lrn_desc_init(&lrn_desc, prop_kind, alg_kind, src_desc, dst_desc,
112 nullptr, nullptr, local_size, alpha, beta, k));
113 return primitive_desc_create(primitive_desc_iface, engine,
114 (const op_desc_t *)&lrn_desc, nullptr, attr);
115}
116
117status_t dnnl_lrn_backward_primitive_desc_create(
118 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
119 alg_kind_t alg_kind, const memory_desc_t *diff_src_desc,
120 const memory_desc_t *diff_dst_desc, const memory_desc_t *src_desc,
121 dim_t local_size, float alpha, float beta, float k,
122 const primitive_desc_iface_t *hint_fwd_pd,
123 const primitive_attr_t *attr) {
124
125 auto lrn_desc = lrn_desc_t();
126 CHECK(lrn_desc_init(&lrn_desc, backward_data, alg_kind, src_desc, nullptr,
127 diff_src_desc, diff_dst_desc, local_size, alpha, beta, k));
128 return primitive_desc_create(primitive_desc_iface, engine,
129 (const op_desc_t *)&lrn_desc, hint_fwd_pd, attr);
130}
131
132// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
133