1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_REF_BATCH_NORMALIZATION_HPP
18#define CPU_REF_BATCH_NORMALIZATION_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/primitive.hpp"
24#include "common/type_helpers.hpp"
25#include "common/utils.hpp"
26
27#include "cpu/platform.hpp"
28
29#include "cpu/cpu_batch_normalization_pd.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34
35template <data_type_t d_type>
36struct ref_batch_normalization_fwd_t : public primitive_t {
37 struct pd_t : public cpu_batch_normalization_fwd_pd_t {
38 pd_t(const batch_normalization_desc_t *adesc,
39 const primitive_attr_t *attr,
40 const batch_normalization_fwd_pd_t *hint_fwd_pd)
41 : cpu_batch_normalization_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
42
43 DECLARE_COMMON_PD_T("ref:any", ref_batch_normalization_fwd_t);
44
45 status_t init(engine_t *engine) {
46 using namespace data_type;
47 bool ok = is_fwd()
48 && utils::everyone_is(
49 d_type, src_md()->data_type, dst_md()->data_type)
50 && platform::has_data_type_support(d_type)
51 && IMPLICATION(is_training(),
52 platform::has_training_support(d_type))
53 && check_scale_shift_data_type()
54 && (attr()->has_default_values()
55 || with_relu_post_op(is_training()))
56 && set_default_formats_common()
57 && memory_desc_wrapper(src_md())
58 == memory_desc_wrapper(dst_md());
59 if (!ok) return status::unimplemented;
60
61 // BN+Add+Relu fusion is not currently implemented
62 if (fuse_norm_add_relu()) return status::unimplemented;
63
64 if (src_md()->data_type == s8 && !stats_is_src())
65 return status::unimplemented;
66
67 if (is_training() && fuse_norm_relu()) init_default_ws(8);
68
69 return status::success;
70 }
71 };
72
73 ref_batch_normalization_fwd_t(const pd_t *apd) : primitive_t(apd) {}
74
75 typedef typename prec_traits<d_type>::type data_t;
76
77 status_t execute(const exec_ctx_t &ctx) const override {
78 return execute_forward(ctx);
79 }
80
81private:
82 status_t execute_forward(const exec_ctx_t &ctx) const;
83 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
84};
85
86template <data_type_t d_type>
87struct ref_batch_normalization_bwd_t : public primitive_t {
88 struct pd_t : public cpu_batch_normalization_bwd_pd_t {
89 pd_t(const batch_normalization_desc_t *adesc,
90 const primitive_attr_t *attr,
91 const batch_normalization_fwd_pd_t *hint_fwd_pd)
92 : cpu_batch_normalization_bwd_pd_t(adesc, attr, hint_fwd_pd) {}
93
94 DECLARE_COMMON_PD_T("ref:any", ref_batch_normalization_bwd_t);
95
96 status_t init(engine_t *engine) {
97 using namespace data_type;
98
99 bool ok = !is_fwd()
100 && utils::everyone_is(d_type, src_md()->data_type,
101 diff_dst_md()->data_type, diff_src_md()->data_type)
102 && platform::has_data_type_support(d_type)
103 && platform::has_training_support(d_type)
104 && check_scale_shift_data_type()
105 && attr()->has_default_values()
106 && set_default_formats_common()
107 && memory_desc_wrapper(diff_src_md())
108 == memory_desc_wrapper(diff_dst_md());
109 if (!ok) return status::unimplemented;
110
111 // BN+Add+Relu fusion is not currently implemented
112 if (fuse_norm_add_relu()) return status::unimplemented;
113
114 if (fuse_norm_relu()) {
115 init_default_ws(8);
116 if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
117 }
118
119 return status::success;
120 }
121 };
122
123 ref_batch_normalization_bwd_t(const pd_t *apd) : primitive_t(apd) {}
124 typedef typename prec_traits<d_type>::type data_t;
125
126 status_t execute(const exec_ctx_t &ctx) const override {
127 return execute_backward(ctx);
128 }
129
130private:
131 status_t execute_backward(const exec_ctx_t &ctx) const;
132 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
133};
134
135} // namespace cpu
136} // namespace impl
137} // namespace dnnl
138
139#endif
140
141// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
142