1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include "oneapi/dnnl/dnnl.h"
19#include "opdesc.hpp"
20#include "primitive_desc_iface.hpp"
21
22#include "c_types_map.hpp"
23#include "type_helpers.hpp"
24#include "utils.hpp"
25
26using namespace dnnl::impl;
27using namespace dnnl::impl::utils;
28using namespace dnnl::impl::status;
29using namespace dnnl::impl::prop_kind;
30using namespace dnnl::impl::alg_kind;
31using namespace dnnl::impl::types;
32
33namespace {
34status_t bnrm_desc_init(batch_normalization_desc_t *bnrm_desc,
35 prop_kind_t prop_kind, const memory_desc_t *src_desc,
36 const memory_desc_t *dst_desc, const memory_desc_t *diff_src_desc,
37 const memory_desc_t *diff_dst_desc, float epsilon, unsigned flags) {
38 const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
39 bool args_ok = !any_null(bnrm_desc, src_desc)
40 && one_of(prop_kind, forward_training, forward_inference,
41 backward_data, backward)
42 && IMPLICATION(is_fwd, dst_desc != nullptr)
43 && IMPLICATION(!is_fwd, !any_null(diff_src_desc, diff_dst_desc))
44 && IMPLICATION(is_fwd, !memory_desc_wrapper(src_desc).format_any());
45 if (!args_ok) return invalid_arguments;
46
47 unsigned bnorm_flags = normalization_flags::use_global_stats
48 | normalization_flags::fuse_norm_relu
49 | normalization_flags::fuse_norm_add_relu
50 | normalization_flags::use_scale | normalization_flags::use_shift;
51 if ((~bnorm_flags & flags) != 0) return invalid_arguments;
52
53 auto bd = batch_normalization_desc_t();
54 bd.primitive_kind = primitive_kind::batch_normalization;
55 bd.prop_kind = prop_kind;
56
57 bool runtime_dims_or_strides
58 = memory_desc_wrapper(src_desc).has_runtime_dims_or_strides();
59 if (is_fwd) {
60 runtime_dims_or_strides = runtime_dims_or_strides
61 || memory_desc_wrapper(dst_desc).has_runtime_dims_or_strides();
62 } else {
63 runtime_dims_or_strides = runtime_dims_or_strides
64 || memory_desc_wrapper(diff_src_desc)
65 .has_runtime_dims_or_strides()
66 || memory_desc_wrapper(diff_dst_desc)
67 .has_runtime_dims_or_strides();
68 }
69 if (runtime_dims_or_strides) return unimplemented;
70
71 bd.src_desc = *src_desc;
72 if (is_fwd) bd.dst_desc = *dst_desc;
73 if (!is_fwd) {
74 bd.diff_src_desc = *diff_src_desc;
75 bd.diff_dst_desc = *diff_dst_desc;
76 }
77
78 const bool has_scale_or_shift = flags
79 & (normalization_flags::use_scale | normalization_flags::use_shift);
80 if (has_scale_or_shift) {
81 dims_t scaleshift_dims = {src_desc->dims[1]};
82 memory_desc_init_by_tag(bd.scaleshift_desc, 1, scaleshift_dims,
83 data_type::f32, format_tag::a);
84 if (!is_fwd) bd.diff_scaleshift_desc = bd.scaleshift_desc;
85 }
86
87 dims_t stats_dims = {src_desc->dims[1]};
88 memory_desc_init_by_tag(
89 bd.stat_desc, 1, stats_dims, data_type::f32, format_tag::a);
90
91 bd.batch_norm_epsilon = epsilon;
92 bd.flags = flags;
93
94 bool consistency = bd.src_desc.ndims >= 2;
95 if (consistency && is_fwd) {
96 consistency = bd.dst_desc.ndims == bd.src_desc.ndims
97 && array_cmp(
98 bd.dst_desc.dims, bd.src_desc.dims, bd.src_desc.ndims);
99 }
100 if (consistency && !is_fwd) {
101 consistency = bd.diff_dst_desc.ndims == bd.src_desc.ndims
102 && bd.diff_dst_desc.ndims == bd.diff_src_desc.ndims
103 && array_cmp(bd.diff_dst_desc.dims, bd.src_desc.dims,
104 bd.src_desc.ndims)
105 && array_cmp(bd.diff_src_desc.dims, bd.diff_dst_desc.dims,
106 bd.diff_dst_desc.ndims);
107 }
108 if (!consistency) return invalid_arguments;
109
110 *bnrm_desc = bd;
111 return success;
112}
113} // namespace
114
115status_t dnnl_batch_normalization_forward_primitive_desc_create(
116 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
117 prop_kind_t prop_kind, const memory_desc_t *src_desc,
118 const memory_desc_t *dst_desc, float epsilon, unsigned flags,
119 const primitive_attr_t *attr) {
120 if (!one_of(prop_kind, forward_training, forward_inference))
121 return invalid_arguments;
122
123 auto bnrm_desc = batch_normalization_desc_t();
124 CHECK(bnrm_desc_init(&bnrm_desc, prop_kind, src_desc, dst_desc, nullptr,
125 nullptr, epsilon, flags));
126 return primitive_desc_create(primitive_desc_iface, engine,
127 (const op_desc_t *)&bnrm_desc, nullptr, attr);
128}
129
130status_t dnnl_batch_normalization_backward_primitive_desc_create(
131 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
132 prop_kind_t prop_kind, const memory_desc_t *diff_src_desc,
133 const memory_desc_t *diff_dst_desc, const memory_desc_t *src_desc,
134 float epsilon, unsigned flags,
135 const primitive_desc_iface_t *hint_fwd_pd,
136 const primitive_attr_t *attr) {
137 if (!one_of(prop_kind, backward, backward_data)) return invalid_arguments;
138
139 auto bnrm_desc = batch_normalization_desc_t();
140 CHECK(bnrm_desc_init(&bnrm_desc, prop_kind, src_desc, nullptr,
141 diff_src_desc, diff_dst_desc, epsilon, flags));
142 return primitive_desc_create(primitive_desc_iface, engine,
143 (const op_desc_t *)&bnrm_desc, hint_fwd_pd, attr);
144}
145
146// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
147