1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_REF_REDUCTION_HPP
18#define CPU_REF_REDUCTION_HPP
19
20#include "common/primitive.hpp"
21#include "common/type_helpers.hpp"
22
23#include "cpu/cpu_reduction_pd.hpp"
24#include "cpu/platform.hpp"
25#include "cpu/primitive_attr_postops.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30
31template <data_type_t src_type, data_type_t dst_type, data_type_t acc_type>
32struct ref_reduction_t : public primitive_t {
33 struct pd_t : public cpu_reduction_pd_t {
34 using cpu_reduction_pd_t::cpu_reduction_pd_t;
35
36 DECLARE_COMMON_PD_T("ref:any", ref_reduction_t);
37
38 status_t init(engine_t *engine) {
39 using sm = primitive_attr_t::skip_mask_t;
40
41 bool ok = src_type == src_md()->data_type
42 && dst_type == dst_md()->data_type
43 && acc_type
44 == types::default_accum_data_type(
45 src_type, dst_type)
46 && platform::has_data_type_support(src_type)
47 && platform::has_data_type_support(dst_type)
48 && set_default_params() == status::success
49 && attr()->has_default_values(sm::post_ops)
50 && attr_.set_default_formats(dst_md(0)) == status::success;
51 if (!ok) return status::unimplemented;
52
53 return status::success;
54 }
55 };
56
57 ref_reduction_t(const pd_t *apd) : primitive_t(apd) {}
58
59 status_t init(engine_t *engine) override {
60 ref_post_ops
61 = utils::make_unique<ref_post_ops_t>(pd()->attr()->post_ops_);
62 if (!ref_post_ops) return status::out_of_memory;
63 return status::success;
64 }
65
66 using src_t = typename prec_traits<src_type>::type;
67 using acc_t = typename prec_traits<acc_type>::type;
68 using dst_t = typename prec_traits<dst_type>::type;
69
70 status_t execute(const exec_ctx_t &ctx) const override {
71 return execute_ref(ctx);
72 }
73
74private:
75 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
76 status_t execute_ref(const exec_ctx_t &ctx) const;
77 std::unique_ptr<ref_post_ops_t> ref_post_ops;
78
79 void accumulate(
80 acc_t &acc, const src_t &src, alg_kind_t alg_kind, float p) const;
81 void finalize(
82 float &acc_f32, alg_kind_t alg, float p, float eps, dim_t n) const;
83 void init_acc(acc_t &acc, alg_kind_t alg) const;
84};
85
86} // namespace cpu
87} // namespace impl
88} // namespace dnnl
89
90#endif
91