1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "cpu/simple_sum.hpp" |
18 | #include "common/bfloat16.hpp" |
19 | #include "common/dnnl_thread.hpp" |
20 | |
21 | namespace dnnl { |
22 | namespace impl { |
23 | namespace cpu { |
24 | |
25 | template <data_type_t src_data_type, data_type_t dst_data_type> |
26 | status_t simple_sum_t<src_data_type, dst_data_type>::execute( |
27 | const exec_ctx_t &ctx) const { |
28 | auto output = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); |
29 | |
30 | const memory_desc_wrapper o_d(pd()->dst_md()); |
31 | output += o_d.blk_off(0); |
32 | const int num_arrs = pd()->n_inputs(); |
33 | const src_data_t *input_ptrs[max_num_arrs]; |
34 | |
35 | for (int a = 0; a < num_arrs; ++a) { |
36 | const memory_desc_wrapper i_d(pd()->src_md(a)); |
37 | input_ptrs[a] |
38 | = CTX_IN_MEM(const src_data_t *, DNNL_ARG_MULTIPLE_SRC + a) |
39 | + i_d.blk_off(0); |
40 | } |
41 | |
42 | const dim_t nelems = pd()->nelems_; |
43 | const dim_t block_size = pd()->block_size_; |
44 | const dim_t blocks_number = pd()->blocks_number_; |
45 | const dim_t tail = pd()->tail_; |
46 | |
47 | const auto scales = pd()->scales(); |
48 | |
49 | auto sum_block_xf16 = [&](dim_t start, dim_t end, int ithr) { |
50 | const bool is_dst_xf16 |
51 | = utils::one_of(dst_data_type, data_type::bf16, data_type::f16); |
52 | const auto xf16_params = pd()->xf16_params_; |
53 | const auto scratchpad = ctx.get_scratchpad_grantor(); |
54 | acc_data_t *wspace = scratchpad.template get<acc_data_t>( |
55 | memory_tracking::names::key_sum_srcs_cvt); |
56 | acc_data_t *my_ws = &wspace[ithr * xf16_params.ws_elements_per_thread_]; |
57 | |
58 | for (dim_t b = start; b < end; b += xf16_params.acc_loop_step_) { |
59 | acc_data_t *my_acc = is_dst_xf16 |
60 | ? &my_ws[xf16_params.ws_cvt_elements_per_thread_] |
61 | : (acc_data_t *)&output[b]; |
62 | dim_t current_block |
63 | = nstl::min(xf16_params.acc_loop_step_, end - b); |
64 | types::cvt_to_float(my_ws, &input_ptrs[0][b], current_block); |
65 | for (dim_t e = 0; e < current_block; e++) |
66 | my_acc[e] = scales[0] * my_ws[e]; |
67 | |
68 | for (int a = 1; a < num_arrs; a++) { |
69 | types::cvt_to_float(my_ws, &input_ptrs[a][b], current_block); |
70 | for (dim_t e = 0; e < current_block; e++) |
71 | my_acc[e] += scales[a] * my_ws[e]; |
72 | } |
73 | |
74 | if (is_dst_xf16) |
75 | types::cvt_from_float(&output[b], my_acc, current_block); |
76 | } |
77 | }; |
78 | |
79 | auto sum_block = [&](dim_t start, dim_t end, int ithr) { |
80 | PRAGMA_OMP_SIMD() |
81 | for (dim_t e = start; e < end; e++) { |
82 | output[e] = dst_data_t(scales[0] * input_ptrs[0][e]); |
83 | } |
84 | for (int a = 1; a < num_arrs; a++) { |
85 | PRAGMA_OMP_SIMD() |
86 | for (dim_t e = start; e < end; e++) { |
87 | output[e] += dst_data_t(scales[a] * input_ptrs[a][e]); |
88 | } |
89 | } |
90 | }; |
91 | |
92 | const int max_nthr = pd()->nthr_; |
93 | parallel(max_nthr, [&](const int ithr, const int nthr) { |
94 | dim_t start {0}, end {0}; |
95 | balance211(blocks_number, nthr, ithr, start, end); |
96 | |
97 | for (dim_t nb = start; nb < end; ++nb) { |
98 | dim_t start_e = nb * block_size; |
99 | dim_t end_e = start_e + block_size; |
100 | if (src_data_type == data_type::f32) |
101 | sum_block(start_e, end_e, ithr); |
102 | else |
103 | sum_block_xf16(start_e, end_e, ithr); |
104 | } |
105 | |
106 | if (tail != 0 && ithr == nthr - 1) { |
107 | dim_t start_e = nelems - tail; |
108 | dim_t end_e = nelems; |
109 | if (src_data_type == data_type::f32) |
110 | sum_block(start_e, end_e, ithr); |
111 | else |
112 | sum_block_xf16(start_e, end_e, ithr); |
113 | } |
114 | }); |
115 | |
116 | return status::success; |
117 | } |
118 | |
119 | template struct simple_sum_t<data_type::f32>; |
120 | template struct simple_sum_t<data_type::bf16>; |
121 | template struct simple_sum_t<data_type::bf16, data_type::f32>; |
122 | template struct simple_sum_t<data_type::f16>; |
123 | template struct simple_sum_t<data_type::f16, data_type::f32>; |
124 | } // namespace cpu |
125 | } // namespace impl |
126 | } // namespace dnnl |
127 | |