1/*******************************************************************************
2* Copyright 2016-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_REF_POOLING_HPP
18#define CPU_REF_POOLING_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/primitive.hpp"
24#include "common/type_helpers.hpp"
25#include "common/utils.hpp"
26
27#include "cpu/platform.hpp"
28#include "cpu/primitive_attr_postops.hpp"
29
30#include "cpu/cpu_pooling_pd.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace cpu {
35
36template <impl::data_type_t data_type, impl::data_type_t acc_type = data_type>
37struct ref_pooling_fwd_t : public primitive_t {
38 struct pd_t : public cpu_pooling_fwd_pd_t {
39 using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t;
40
41 DECLARE_COMMON_PD_T("ref:any", ref_pooling_fwd_t);
42
43 status_t init(engine_t *engine) {
44 using sm = primitive_attr_t::skip_mask_t;
45
46 bool ok = platform::has_data_type_support(data_type)
47 && set_default_params() == status::success && is_fwd()
48 && utils::everyone_is(
49 data_type, src_md()->data_type, dst_md()->data_type)
50 && desc()->accum_data_type == acc_type
51 && attr()->has_default_values(sm::post_ops)
52 && attr_.set_default_formats(dst_md(0)) == status::success;
53 if (!ok) return status::unimplemented;
54
55 bool is_training = desc_.prop_kind == prop_kind::forward_training;
56 if (desc()->alg_kind == alg_kind::pooling_max && is_training)
57 init_default_ws();
58
59 return status::success;
60 }
61 };
62
63 ref_pooling_fwd_t(const pd_t *apd) : primitive_t(apd) {}
64
65 status_t init(engine_t *engine) override {
66 ref_post_ops
67 = utils::make_unique<ref_post_ops_t>(pd()->attr()->post_ops_);
68 if (!ref_post_ops) return status::out_of_memory;
69 return status::success;
70 }
71
72 using data_t = typename prec_traits<data_type>::type;
73 using acc_data_t = typename prec_traits<acc_type>::type;
74
75 status_t execute(const exec_ctx_t &ctx) const override {
76 return execute_forward(ctx);
77 }
78
79private:
80 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
81 status_t execute_forward(const exec_ctx_t &ctx) const;
82 std::unique_ptr<ref_post_ops_t> ref_post_ops;
83};
84
85template <impl::data_type_t data_type>
86struct ref_pooling_bwd_t : public primitive_t {
87 struct pd_t : public cpu_pooling_bwd_pd_t {
88 using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t;
89
90 DECLARE_COMMON_PD_T("ref:any", ref_pooling_bwd_t);
91
92 status_t init(engine_t *engine) {
93 bool ok = platform::has_data_type_support(data_type)
94 && set_default_params() == status::success && !is_fwd()
95 && utils::everyone_is(data_type, diff_dst_md()->data_type,
96 diff_src_md()->data_type)
97 && attr()->has_default_values();
98 if (!ok) return status::unimplemented;
99
100 if (desc()->alg_kind == alg_kind::pooling_max) {
101 init_default_ws();
102 if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
103 }
104
105 return status::success;
106 }
107 };
108
109 ref_pooling_bwd_t(const pd_t *apd) : primitive_t(apd) {}
110 typedef typename prec_traits<data_type>::type data_t;
111
112 status_t execute(const exec_ctx_t &ctx) const override {
113 return execute_backward(ctx);
114 }
115
116private:
117 status_t execute_backward(const exec_ctx_t &ctx) const;
118 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
119};
120
121} // namespace cpu
122} // namespace impl
123} // namespace dnnl
124
125#endif
126
127// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
128