1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/utils.hpp"
20
21#include "cpu/platform.hpp"
22
23#include "cpu/cpu_batch_normalization_utils.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace bnorm_utils {
29
30using namespace dnnl::impl::utils;
31
32void cache_balance(size_t working_set_size, dim_t C_blks, dim_t N, int nthr,
33 dim_t &C_blks_per_iter, int64_t &iters) {
34 int l3_size = platform::get_per_core_cache_size(3) * nthr / 2;
35 C_blks_per_iter = saturate<dim_t>(1, C_blks, l3_size / working_set_size);
36
37 // Align C_blks_per_iter with nthr for better balancing implying a
38 // threading approach realized in thread_balance function below.
39 //
40 // TODO: update batchnorm blocking: all blocking stuff should be in one
41 // place
42 int C_nthr = nthr;
43 if (C_blks_per_iter < nthr) {
44 const int N_nthr = (int)nstl::min<dim_t>(N, nthr);
45 C_nthr = (int)nstl::min<dim_t>(C_blks, nthr / N_nthr);
46 }
47
48 if (C_blks_per_iter > C_nthr)
49 C_blks_per_iter = rnd_dn(C_blks_per_iter, C_nthr);
50 else
51 C_blks_per_iter = div_up(C_nthr, div_up(C_nthr, C_blks_per_iter));
52
53 iters = div_up(C_blks, C_blks_per_iter);
54}
55
56bool thread_balance(bool do_blocking, bool spatial_thr_allowed, bool is_nspc,
57 int ithr, int nthr, dim_t N, dim_t C_blks, dim_t SP, int &C_ithr,
58 int &C_nthr, dim_t &C_blk_s, dim_t &C_blk_e, int &N_ithr, int &N_nthr,
59 dim_t &N_s, dim_t &N_e, int &S_ithr, int &S_nthr, dim_t &S_s,
60 dim_t &S_e) {
61 if (((nthr <= C_blks) && IMPLICATION(is_nspc, N == 1))
62 || !dnnl_thr_syncable()) {
63 C_ithr = ithr;
64 C_nthr = nthr;
65 N_ithr = 0;
66 N_nthr = 1;
67 S_ithr = 0;
68 S_nthr = 1;
69 N_s = 0;
70 N_e = N;
71 S_s = 0;
72 S_e = SP;
73 balance211(C_blks, C_nthr, C_ithr, C_blk_s, C_blk_e);
74 } else {
75 if (is_nspc) {
76 if (C_blks <= 8)
77 C_nthr = 1;
78 else if (nthr >= 8 && C_blks <= 32)
79 C_nthr = 8;
80 else {
81 C_nthr = (int)math::gcd((dim_t)nthr, C_blks);
82 // Unroll by channels in JIT kernel
83 if ((C_nthr == C_blks) || (C_nthr == nthr)) C_nthr = 1;
84 }
85 N_nthr = (int)nstl::min<dim_t>(N, nthr / C_nthr);
86 S_nthr = (int)nstl::min<dim_t>(SP, nthr / (C_nthr * N_nthr));
87 } else {
88 if (do_blocking) {
89 N_nthr = (int)nstl::min<dim_t>(N, nthr);
90 C_nthr = (int)nstl::min<dim_t>(C_blks, nthr / N_nthr);
91 S_nthr = (int)nstl::min<dim_t>(SP, nthr / (C_nthr * N_nthr));
92 } else {
93 C_nthr = (int)math::gcd((dim_t)nthr, C_blks);
94 N_nthr = (int)nstl::min<dim_t>(N, nthr / C_nthr);
95 S_nthr = (int)nstl::min<dim_t>(SP, nthr / (C_nthr * N_nthr));
96 }
97 }
98
99 if (!spatial_thr_allowed) S_nthr = 1;
100
101 if (S_nthr < 1) S_nthr = 1;
102 if (ithr < C_nthr * N_nthr * S_nthr) {
103 N_ithr = (ithr / S_nthr) % N_nthr;
104 C_ithr = ithr / (N_nthr * S_nthr);
105 S_ithr = ithr % S_nthr;
106 balance211(C_blks, C_nthr, C_ithr, C_blk_s, C_blk_e);
107 balance211(N, N_nthr, N_ithr, N_s, N_e);
108 balance211(SP, S_nthr, S_ithr, S_s, S_e);
109 } else {
110 S_ithr = N_ithr = C_ithr = -ithr;
111 S_s = S_e = N_s = N_e = C_blk_s = C_blk_e = -1;
112 }
113 }
114
115 // spatial_thr_allowed is meant to help maintain
116 // consistent decisions about spatial threading
117 // between mutiple invocations of this routine.
118 // It is caller's responsibility to check the
119 // return value and pass it as a flag to the
120 // next call if needed.
121 if (S_nthr == 1) spatial_thr_allowed = false;
122
123 return spatial_thr_allowed;
124}
125
126bool is_spatial_thr(const batch_normalization_pd_t *bdesc, bool is_nspc,
127 int simd_w, int data_size) {
128 if (!dnnl_thr_syncable()) return false;
129
130 dim_t nthr = dnnl_get_max_threads();
131 dim_t SP = bdesc->W() * bdesc->D() * bdesc->H();
132 dim_t C_PADDED = memory_desc_wrapper(bdesc->src_md()).padded_dims()[1];
133 assert(C_PADDED % simd_w == 0);
134
135 dim_t C_blks = C_PADDED / simd_w;
136 dim_t N = bdesc->MB();
137 dim_t S_nthr {1};
138
139 if (is_nspc) {
140 if (nthr <= C_blks && N == 1) return false;
141
142 dim_t C_nthr;
143
144 if ((nthr <= C_blks && nthr == 1) || C_blks <= 8)
145 C_nthr = 1;
146 else if (nthr >= 8 && C_blks <= 32)
147 C_nthr = 8;
148 else {
149 C_nthr = math::gcd((dim_t)nthr, C_blks);
150 if ((C_nthr == C_blks) || (C_nthr == nthr)) C_nthr = 1;
151 }
152
153 dim_t N_nthr = nstl::min<dim_t>(N, nthr / C_nthr);
154 S_nthr = nstl::min<dim_t>(SP, nthr / (C_nthr * N_nthr));
155 } else {
156 size_t data = N * C_PADDED * SP * data_size;
157 size_t l3_size_ = platform::get_per_core_cache_size(3)
158 * dnnl_get_max_threads() / 2;
159 bool do_blocking = (data >= l3_size_ / 2 && l3_size_ > 0);
160 dim_t C_blks_per_iter {1}, iters {1};
161
162 if (do_blocking) {
163 int num_tensors = bdesc->is_fwd() ? 1 : 2;
164 size_t working_set_size
165 = (N * SP * simd_w * data_size) * num_tensors;
166 cache_balance(
167 working_set_size, C_blks, N, nthr, C_blks_per_iter, iters);
168 }
169
170 // Spatial threading decision made in this function shall be consistent
171 // with thread_balance() behavior.
172 C_blks = do_blocking ? C_blks_per_iter : C_blks;
173
174 if (nthr <= C_blks) return false;
175
176 if (do_blocking) {
177 dim_t N_nthr = nstl::min(N, nthr);
178 dim_t C_nthr = nstl::min(C_blks, nthr / N_nthr);
179 S_nthr = nstl::min(SP, nthr / (C_nthr * N_nthr));
180 } else {
181 dim_t C_nthr = math::gcd(nthr, C_blks);
182 dim_t N_nthr = nstl::min(N, nthr / C_nthr);
183 S_nthr = nstl::min(SP, nthr / (C_nthr * N_nthr));
184 }
185 }
186
187 return S_nthr > 1;
188}
189
190} // namespace bnorm_utils
191} // namespace cpu
192} // namespace impl
193} // namespace dnnl
194