1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_SIMPLE_SUM_HPP
18#define CPU_SIMPLE_SUM_HPP
19
20#include "common/dnnl_thread.hpp"
21#include "common/primitive.hpp"
22#include "common/type_helpers.hpp"
23
24#include "cpu/cpu_sum_pd.hpp"
25#include "cpu/platform.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30
31struct sum_xf16_params_t {
32 dim_t ws_cvt_elements_per_thread_;
33 dim_t ws_acc_elements_per_thread_;
34 dim_t ws_elements_per_thread_;
35 dim_t acc_loop_step_;
36};
37
38template <data_type_t src_data_type, data_type_t dst_data_type = src_data_type>
39struct simple_sum_t : public primitive_t {
40 struct pd_t : public cpu_sum_pd_t {
41 using cpu_sum_pd_t::cpu_sum_pd_t;
42
43 DECLARE_SUM_PD_T("simple:any", simple_sum_t);
44
45 status_t init(engine_t *engine) {
46 const int n = n_inputs();
47
48 bool ok = platform::has_data_type_support(src_data_type)
49 && platform::has_data_type_support(dst_data_type)
50 && cpu_sum_pd_t::init(engine) == status::success
51 && n <= max_num_arrs;
52 if (!ok) return status::unimplemented;
53
54 const memory_desc_wrapper o_d(dst_md());
55 ok = ok && o_d.data_type() == dst_data_type && o_d.is_dense();
56 if (!ok) return status::unimplemented;
57
58 for (int i = 0; i < n; ++i) {
59 const memory_desc_wrapper i_d(src_md(i));
60 ok = true && utils::everyone_is(src_data_type, i_d.data_type())
61 && o_d.similar_to(i_d, true, false, 0)
62 && i_d.is_dense();
63 if (!ok) return status::unimplemented;
64 }
65 nthr_ = dnnl_get_max_threads();
66 compute_blocking();
67 init_scratchpad();
68 return status::success;
69 }
70 int nthr_ = 1;
71 sum_xf16_params_t xf16_params_;
72 dim_t block_size_ = 0, nelems_ = 0, blocks_number_ = 0, tail_ = 0;
73
74 private:
75 void compute_blocking() {
76 const int block_size_bytes
77 = utils::one_of(
78 src_data_type, data_type::bf16, data_type::f16)
79 ? 16 * platform::get_cache_line_size()
80 : platform::get_per_core_cache_size(1) / 2;
81 block_size_ = block_size_bytes / (int)sizeof(src_data_type);
82 const memory_desc_wrapper o_d(dst_md());
83 nelems_ = o_d.nelems();
84 blocks_number_ = nelems_ / block_size_;
85 tail_ = nelems_ % block_size_;
86 }
87
88 void init_scratchpad() {
89 if (utils::one_of(src_data_type, data_type::bf16, data_type::f16)) {
90 const bool is_dst_xf16 = utils::one_of(
91 dst_data_type, data_type::bf16, data_type::f16);
92 xf16_params_.ws_cvt_elements_per_thread_
93 = platform::get_cache_line_size()
94 / (int)sizeof(acc_data_t);
95
96 xf16_params_.ws_acc_elements_per_thread_ = is_dst_xf16
97 ? xf16_params_.ws_cvt_elements_per_thread_
98 : 0;
99
100 xf16_params_.acc_loop_step_ = is_dst_xf16
101 ? xf16_params_.ws_cvt_elements_per_thread_
102 : 1;
103
104 xf16_params_.ws_elements_per_thread_
105 = xf16_params_.ws_cvt_elements_per_thread_
106 + xf16_params_.ws_acc_elements_per_thread_;
107 const dim_t cvt_buf_sz
108 = xf16_params_.ws_elements_per_thread_ * nthr_;
109 auto scratchpad = scratchpad_registry().registrar();
110 scratchpad.template book<acc_data_t>(
111 memory_tracking::names::key_sum_srcs_cvt, cvt_buf_sz);
112 }
113 }
114 };
115
116 simple_sum_t(const pd_t *apd) : primitive_t(apd) {}
117
118 status_t execute(const exec_ctx_t &ctx) const override;
119
120 enum { max_num_arrs = 16 };
121 typedef typename prec_traits<src_data_type>::type src_data_t;
122 typedef typename prec_traits<dst_data_type>::type dst_data_t;
123 typedef typename prec_traits<data_type::f32>::type acc_data_t;
124
125private:
126 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
127};
128
129} // namespace cpu
130} // namespace impl
131} // namespace dnnl
132
133#endif
134
135// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
136