1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "cpu/simple_sum.hpp"
18#include "common/bfloat16.hpp"
19#include "common/dnnl_thread.hpp"
20
21namespace dnnl {
22namespace impl {
23namespace cpu {
24
25template <data_type_t src_data_type, data_type_t dst_data_type>
26status_t simple_sum_t<src_data_type, dst_data_type>::execute(
27 const exec_ctx_t &ctx) const {
28 auto output = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
29
30 const memory_desc_wrapper o_d(pd()->dst_md());
31 output += o_d.blk_off(0);
32 const int num_arrs = pd()->n_inputs();
33 const src_data_t *input_ptrs[max_num_arrs];
34
35 for (int a = 0; a < num_arrs; ++a) {
36 const memory_desc_wrapper i_d(pd()->src_md(a));
37 input_ptrs[a]
38 = CTX_IN_MEM(const src_data_t *, DNNL_ARG_MULTIPLE_SRC + a)
39 + i_d.blk_off(0);
40 }
41
42 const dim_t nelems = pd()->nelems_;
43 const dim_t block_size = pd()->block_size_;
44 const dim_t blocks_number = pd()->blocks_number_;
45 const dim_t tail = pd()->tail_;
46
47 const auto scales = pd()->scales();
48
49 auto sum_block_xf16 = [&](dim_t start, dim_t end, int ithr) {
50 const bool is_dst_xf16
51 = utils::one_of(dst_data_type, data_type::bf16, data_type::f16);
52 const auto xf16_params = pd()->xf16_params_;
53 const auto scratchpad = ctx.get_scratchpad_grantor();
54 acc_data_t *wspace = scratchpad.template get<acc_data_t>(
55 memory_tracking::names::key_sum_srcs_cvt);
56 acc_data_t *my_ws = &wspace[ithr * xf16_params.ws_elements_per_thread_];
57
58 for (dim_t b = start; b < end; b += xf16_params.acc_loop_step_) {
59 acc_data_t *my_acc = is_dst_xf16
60 ? &my_ws[xf16_params.ws_cvt_elements_per_thread_]
61 : (acc_data_t *)&output[b];
62 dim_t current_block
63 = nstl::min(xf16_params.acc_loop_step_, end - b);
64 types::cvt_to_float(my_ws, &input_ptrs[0][b], current_block);
65 for (dim_t e = 0; e < current_block; e++)
66 my_acc[e] = scales[0] * my_ws[e];
67
68 for (int a = 1; a < num_arrs; a++) {
69 types::cvt_to_float(my_ws, &input_ptrs[a][b], current_block);
70 for (dim_t e = 0; e < current_block; e++)
71 my_acc[e] += scales[a] * my_ws[e];
72 }
73
74 if (is_dst_xf16)
75 types::cvt_from_float(&output[b], my_acc, current_block);
76 }
77 };
78
79 auto sum_block = [&](dim_t start, dim_t end, int ithr) {
80 PRAGMA_OMP_SIMD()
81 for (dim_t e = start; e < end; e++) {
82 output[e] = dst_data_t(scales[0] * input_ptrs[0][e]);
83 }
84 for (int a = 1; a < num_arrs; a++) {
85 PRAGMA_OMP_SIMD()
86 for (dim_t e = start; e < end; e++) {
87 output[e] += dst_data_t(scales[a] * input_ptrs[a][e]);
88 }
89 }
90 };
91
92 const int max_nthr = pd()->nthr_;
93 parallel(max_nthr, [&](const int ithr, const int nthr) {
94 dim_t start {0}, end {0};
95 balance211(blocks_number, nthr, ithr, start, end);
96
97 for (dim_t nb = start; nb < end; ++nb) {
98 dim_t start_e = nb * block_size;
99 dim_t end_e = start_e + block_size;
100 if (src_data_type == data_type::f32)
101 sum_block(start_e, end_e, ithr);
102 else
103 sum_block_xf16(start_e, end_e, ithr);
104 }
105
106 if (tail != 0 && ithr == nthr - 1) {
107 dim_t start_e = nelems - tail;
108 dim_t end_e = nelems;
109 if (src_data_type == data_type::f32)
110 sum_block(start_e, end_e, ithr);
111 else
112 sum_block_xf16(start_e, end_e, ithr);
113 }
114 });
115
116 return status::success;
117}
118
119template struct simple_sum_t<data_type::f32>;
120template struct simple_sum_t<data_type::bf16>;
121template struct simple_sum_t<data_type::bf16, data_type::f32>;
122template struct simple_sum_t<data_type::f16>;
123template struct simple_sum_t<data_type::f16, data_type::f32>;
124} // namespace cpu
125} // namespace impl
126} // namespace dnnl
127