1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_SIMPLE_SUM_HPP |
18 | #define CPU_SIMPLE_SUM_HPP |
19 | |
20 | #include "common/dnnl_thread.hpp" |
21 | #include "common/primitive.hpp" |
22 | #include "common/type_helpers.hpp" |
23 | |
24 | #include "cpu/cpu_sum_pd.hpp" |
25 | #include "cpu/platform.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | |
31 | struct sum_xf16_params_t { |
32 | dim_t ws_cvt_elements_per_thread_; |
33 | dim_t ws_acc_elements_per_thread_; |
34 | dim_t ws_elements_per_thread_; |
35 | dim_t acc_loop_step_; |
36 | }; |
37 | |
38 | template <data_type_t src_data_type, data_type_t dst_data_type = src_data_type> |
39 | struct simple_sum_t : public primitive_t { |
40 | struct pd_t : public cpu_sum_pd_t { |
41 | using cpu_sum_pd_t::cpu_sum_pd_t; |
42 | |
43 | DECLARE_SUM_PD_T("simple:any" , simple_sum_t); |
44 | |
45 | status_t init(engine_t *engine) { |
46 | const int n = n_inputs(); |
47 | |
48 | bool ok = platform::has_data_type_support(src_data_type) |
49 | && platform::has_data_type_support(dst_data_type) |
50 | && cpu_sum_pd_t::init(engine) == status::success |
51 | && n <= max_num_arrs; |
52 | if (!ok) return status::unimplemented; |
53 | |
54 | const memory_desc_wrapper o_d(dst_md()); |
55 | ok = ok && o_d.data_type() == dst_data_type && o_d.is_dense(); |
56 | if (!ok) return status::unimplemented; |
57 | |
58 | for (int i = 0; i < n; ++i) { |
59 | const memory_desc_wrapper i_d(src_md(i)); |
60 | ok = true && utils::everyone_is(src_data_type, i_d.data_type()) |
61 | && o_d.similar_to(i_d, true, false, 0) |
62 | && i_d.is_dense(); |
63 | if (!ok) return status::unimplemented; |
64 | } |
65 | nthr_ = dnnl_get_max_threads(); |
66 | compute_blocking(); |
67 | init_scratchpad(); |
68 | return status::success; |
69 | } |
70 | int nthr_ = 1; |
71 | sum_xf16_params_t xf16_params_; |
72 | dim_t block_size_ = 0, nelems_ = 0, blocks_number_ = 0, tail_ = 0; |
73 | |
74 | private: |
75 | void compute_blocking() { |
76 | const int block_size_bytes |
77 | = utils::one_of( |
78 | src_data_type, data_type::bf16, data_type::f16) |
79 | ? 16 * platform::get_cache_line_size() |
80 | : platform::get_per_core_cache_size(1) / 2; |
81 | block_size_ = block_size_bytes / (int)sizeof(src_data_type); |
82 | const memory_desc_wrapper o_d(dst_md()); |
83 | nelems_ = o_d.nelems(); |
84 | blocks_number_ = nelems_ / block_size_; |
85 | tail_ = nelems_ % block_size_; |
86 | } |
87 | |
88 | void init_scratchpad() { |
89 | if (utils::one_of(src_data_type, data_type::bf16, data_type::f16)) { |
90 | const bool is_dst_xf16 = utils::one_of( |
91 | dst_data_type, data_type::bf16, data_type::f16); |
92 | xf16_params_.ws_cvt_elements_per_thread_ |
93 | = platform::get_cache_line_size() |
94 | / (int)sizeof(acc_data_t); |
95 | |
96 | xf16_params_.ws_acc_elements_per_thread_ = is_dst_xf16 |
97 | ? xf16_params_.ws_cvt_elements_per_thread_ |
98 | : 0; |
99 | |
100 | xf16_params_.acc_loop_step_ = is_dst_xf16 |
101 | ? xf16_params_.ws_cvt_elements_per_thread_ |
102 | : 1; |
103 | |
104 | xf16_params_.ws_elements_per_thread_ |
105 | = xf16_params_.ws_cvt_elements_per_thread_ |
106 | + xf16_params_.ws_acc_elements_per_thread_; |
107 | const dim_t cvt_buf_sz |
108 | = xf16_params_.ws_elements_per_thread_ * nthr_; |
109 | auto scratchpad = scratchpad_registry().registrar(); |
110 | scratchpad.template book<acc_data_t>( |
111 | memory_tracking::names::key_sum_srcs_cvt, cvt_buf_sz); |
112 | } |
113 | } |
114 | }; |
115 | |
116 | simple_sum_t(const pd_t *apd) : primitive_t(apd) {} |
117 | |
118 | status_t execute(const exec_ctx_t &ctx) const override; |
119 | |
120 | enum { max_num_arrs = 16 }; |
121 | typedef typename prec_traits<src_data_type>::type src_data_t; |
122 | typedef typename prec_traits<dst_data_type>::type dst_data_t; |
123 | typedef typename prec_traits<data_type::f32>::type acc_data_t; |
124 | |
125 | private: |
126 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
127 | }; |
128 | |
129 | } // namespace cpu |
130 | } // namespace impl |
131 | } // namespace dnnl |
132 | |
133 | #endif |
134 | |
135 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
136 | |