1 | // Copyright 2016 The Gemmlowp Authors. All Rights Reserved. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | #ifndef GEMMLOWP_META_SINGLE_THREAD_GEMM_H_ |
16 | #define GEMMLOWP_META_SINGLE_THREAD_GEMM_H_ |
17 | |
18 | #include <iostream> |
19 | #include "base.h" |
20 | |
21 | namespace gemmlowp { |
22 | namespace meta { |
23 | |
24 | template <typename Executor, typename Params, int kernel_m, int kernel_n, |
25 | int kernel_k> |
26 | void Gemm(const Params& params); |
27 | |
28 | class GemmExecutorPackRHS { |
29 | public: |
30 | template <typename P> |
31 | static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n, |
32 | int kernel_k) { |
33 | const int lhs_scratch = |
34 | StreamUtil<typename P::InType, typename P::LeftStream>::Scratch( |
35 | params.left_stream, kernel_m, kernel_k); |
36 | const int rhs_chunks = ((params.n + kernel_n - 1) / kernel_n); |
37 | const int rhs_scratch = |
38 | rhs_chunks * |
39 | StreamUtil<typename P::InType, typename P::RightStream>::Scratch( |
40 | params.right_stream, kernel_n, kernel_k); |
41 | return AlignTo<64 * 1024>(lhs_scratch + rhs_scratch); |
42 | } |
43 | |
44 | template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers, |
45 | int k_leftovers> |
46 | static void ExecuteDispatch3D(const P& params) { |
47 | // Shorthand typedefs for streams and multiply kernels. |
48 | typedef typename P::InType InType; |
49 | typedef typename P::OutType OutType; |
50 | |
51 | typedef Stream<typename P::InType, m, k, k_leftovers, |
52 | typename P::LeftStream> |
53 | LeftStreamF; |
54 | typedef Stream<typename P::InType, m_leftovers, k, k_leftovers, |
55 | typename P::LeftStream> |
56 | LeftStreamL; |
57 | |
58 | typedef Stream<typename P::InType, n, k, k_leftovers, |
59 | typename P::RightStream> |
60 | RightStreamF; |
61 | typedef Stream<typename P::InType, n_leftovers, k, k_leftovers, |
62 | typename P::RightStream> |
63 | RightStreamL; |
64 | |
65 | typedef Stream<typename P::OutType, m, n, 0, typename P::OutputStream> |
66 | OutputStreamFF; |
67 | typedef Stream<typename P::OutType, m_leftovers, n, 0, |
68 | typename P::OutputStream> |
69 | OutputStreamLF; |
70 | |
71 | typedef MulKernel<typename P::InType, typename P::OutType, |
72 | typename P::Kernel, typename P::OutputStream, m, n, k> |
73 | KernelFF; |
74 | typedef MulKernel<typename P::InType, typename P::OutType, |
75 | typename P::Kernel, typename P::OutputStream, m, |
76 | n_leftovers, k> |
77 | KernelFL; |
78 | typedef MulKernel<typename P::InType, typename P::OutType, |
79 | typename P::Kernel, typename P::OutputStream, m_leftovers, |
80 | n, k> |
81 | KernelLF; |
82 | typedef MulKernel<typename P::InType, typename P::OutType, |
83 | typename P::Kernel, typename P::OutputStream, m_leftovers, |
84 | n_leftovers, k> |
85 | KernelLL; |
86 | |
87 | #ifdef DEBUG |
88 | #ifdef DEBUG_METAGEMM_VERBOSE |
89 | std::cout << "GemmExecutor(" << typeid(P).name() << "): " << m << "x" << n |
90 | << "x" << k << " -- " << m_leftovers << "x" << n_leftovers << "x" |
91 | << k_leftovers << " -- " << params.m << "x" << params.n << "x" |
92 | << params.k << std::endl; |
93 | LeftStreamF::Debug(params.left_stream); |
94 | LeftStreamL::Debug(params.left_stream); |
95 | |
96 | RightStreamF::Debug(params.right_stream); |
97 | RightStreamL::Debug(params.right_stream); |
98 | |
99 | OutputStreamFF::Debug(params.fused_kernel.output_stream); |
100 | OutputStreamLF::Debug(params.fused_kernel.output_stream); |
101 | |
102 | KernelFF::Debug(params.fused_kernel); |
103 | KernelFL::Debug(params.fused_kernel); |
104 | KernelLF::Debug(params.fused_kernel); |
105 | KernelLL::Debug(params.fused_kernel); |
106 | #endif |
107 | #endif |
108 | |
109 | int lhs_chunks = params.m / m; |
110 | int rhs_chunks = params.n / n; |
111 | |
112 | // Scratch memory for packed LHS & RHS chunks. |
113 | |
114 | std::uint8_t* packed_lhs = params.scratch; |
115 | std::uint8_t* packed_rhs = |
116 | params.scratch + LeftStreamF::Scratch(params.left_stream); |
117 | |
118 | // Pack full RHS first. |
119 | |
120 | std::uint8_t* packed_rhs_chunk = packed_rhs; |
121 | const int packed_rhs_chunk_size = |
122 | RightStreamF::PackedStride(params.right_stream); |
123 | |
124 | { |
125 | const std::uint8_t* rhs_chunk = |
126 | reinterpret_cast<const std::uint8_t*>(params.rhs); |
127 | const int rhs_chunk_size = |
128 | RightStreamF::UnpackedStride(params.right_stream); |
129 | |
130 | for (int i = 0; i < rhs_chunks; ++i) { |
131 | RightStreamF::Pack(reinterpret_cast<const InType*>(rhs_chunk), |
132 | params.right_stream, |
133 | reinterpret_cast<InType*>(packed_rhs_chunk)); |
134 | |
135 | rhs_chunk += rhs_chunk_size; |
136 | packed_rhs_chunk += packed_rhs_chunk_size; |
137 | } |
138 | |
139 | RightStreamL::Pack(reinterpret_cast<const InType*>(rhs_chunk), |
140 | params.right_stream, |
141 | reinterpret_cast<InType*>(packed_rhs_chunk)); |
142 | } |
143 | |
144 | // Multiply RHS by LHS one LHS chunk at a time. |
145 | |
146 | const std::uint8_t* lhs_chunk = |
147 | reinterpret_cast<const std::uint8_t*>(params.lhs); |
148 | std::uint8_t* result_strip = reinterpret_cast<std::uint8_t*>(params.result); |
149 | std::uint8_t* result_chunk = result_strip; |
150 | |
151 | { |
152 | const int lhs_chunk_size = |
153 | LeftStreamF::UnpackedStride(params.left_stream); |
154 | const int result_strip_size = |
155 | OutputStreamFF::UnpackedStride(params.fused_kernel.output_stream); |
156 | const int result_chunk_size = |
157 | OutputStreamFF::UnpackedAdvance(params.fused_kernel.output_stream); |
158 | |
159 | for (int i = 0; i < lhs_chunks; ++i) { |
160 | LeftStreamF::Pack(reinterpret_cast<const InType*>(lhs_chunk), |
161 | params.left_stream, |
162 | reinterpret_cast<InType*>(packed_lhs)); |
163 | |
164 | result_chunk = result_strip; |
165 | packed_rhs_chunk = packed_rhs; |
166 | |
167 | for (int j = 0; j < rhs_chunks; ++j) { |
168 | KernelFF::Multiply(reinterpret_cast<const InType*>(packed_lhs), |
169 | reinterpret_cast<const InType*>(packed_rhs_chunk), |
170 | params.fused_kernel, |
171 | reinterpret_cast<OutType*>(result_chunk)); |
172 | |
173 | result_chunk += result_chunk_size; |
174 | packed_rhs_chunk += packed_rhs_chunk_size; |
175 | } |
176 | |
177 | KernelFL::Multiply(reinterpret_cast<const InType*>(packed_lhs), |
178 | reinterpret_cast<const InType*>(packed_rhs_chunk), |
179 | params.fused_kernel, |
180 | reinterpret_cast<OutType*>(result_chunk)); |
181 | |
182 | lhs_chunk += lhs_chunk_size; |
183 | result_strip += result_strip_size; |
184 | } |
185 | } |
186 | |
187 | // Leftover LHS chunk. |
188 | if (m_leftovers > 0) { // static if |
189 | const int result_chunk_size = |
190 | OutputStreamLF::UnpackedAdvance(params.fused_kernel.output_stream); |
191 | |
192 | LeftStreamL::Pack(reinterpret_cast<const InType*>(lhs_chunk), |
193 | params.left_stream, |
194 | reinterpret_cast<InType*>(packed_lhs)); |
195 | |
196 | result_chunk = result_strip; |
197 | packed_rhs_chunk = packed_rhs; |
198 | |
199 | for (int i = 0; i < rhs_chunks; ++i) { |
200 | KernelLF::Multiply(reinterpret_cast<const InType*>(packed_lhs), |
201 | reinterpret_cast<const InType*>(packed_rhs_chunk), |
202 | params.fused_kernel, |
203 | reinterpret_cast<OutType*>(result_chunk)); |
204 | |
205 | result_chunk += result_chunk_size; |
206 | packed_rhs_chunk += packed_rhs_chunk_size; |
207 | } |
208 | |
209 | KernelLL::Multiply(reinterpret_cast<const InType*>(packed_lhs), |
210 | reinterpret_cast<const InType*>(packed_rhs_chunk), |
211 | params.fused_kernel, |
212 | reinterpret_cast<OutType*>(result_chunk)); |
213 | } |
214 | } |
215 | }; |
216 | |
217 | class GemmExecutorPackLHS { |
218 | public: |
219 | template <typename P> |
220 | static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n, |
221 | int kernel_k) { |
222 | const int lhs_chunks = ((params.m + kernel_m - 1) / kernel_m); |
223 | const int lhs_scratch = |
224 | lhs_chunks * |
225 | StreamUtil<typename P::InType, typename P::LeftStream>::Scratch( |
226 | params.left_stream, kernel_m, kernel_k); |
227 | const int rhs_scratch = |
228 | StreamUtil<typename P::InType, typename P::RightStream>::Scratch( |
229 | params.right_stream, kernel_n, kernel_k); |
230 | return AlignTo<64 * 1024>(lhs_scratch + rhs_scratch); |
231 | } |
232 | |
233 | template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers, |
234 | int k_leftovers> |
235 | static void ExecuteDispatch3D(const P& params) { |
236 | // Shorthand typedefs for streams and multiply kernels. |
237 | typedef typename P::InType InType; |
238 | typedef typename P::OutType OutType; |
239 | |
240 | typedef Stream<typename P::InType, m, k, k_leftovers, |
241 | typename P::LeftStream> |
242 | LeftStreamF; |
243 | typedef Stream<typename P::InType, m_leftovers, k, k_leftovers, |
244 | typename P::LeftStream> |
245 | LeftStreamL; |
246 | |
247 | typedef Stream<typename P::InType, n, k, k_leftovers, |
248 | typename P::RightStream> |
249 | RightStreamF; |
250 | typedef Stream<typename P::InType, n_leftovers, k, k_leftovers, |
251 | typename P::RightStream> |
252 | RightStreamL; |
253 | |
254 | typedef Stream<typename P::OutType, m, n, 0, typename P::OutputStream> |
255 | OutputStreamFF; |
256 | typedef Stream<typename P::OutType, m, n_leftovers, 0, |
257 | typename P::OutputStream> |
258 | OutputStreamFL; |
259 | |
260 | typedef MulKernel<typename P::InType, typename P::OutType, |
261 | typename P::Kernel, typename P::OutputStream, m, n, k> |
262 | KernelFF; |
263 | typedef MulKernel<typename P::InType, typename P::OutType, |
264 | typename P::Kernel, typename P::OutputStream, m, |
265 | n_leftovers, k> |
266 | KernelFL; |
267 | typedef MulKernel<typename P::InType, typename P::OutType, |
268 | typename P::Kernel, typename P::OutputStream, m_leftovers, |
269 | n, k> |
270 | KernelLF; |
271 | typedef MulKernel<typename P::InType, typename P::OutType, |
272 | typename P::Kernel, typename P::OutputStream, m_leftovers, |
273 | n_leftovers, k> |
274 | KernelLL; |
275 | #ifdef DEBUG |
276 | #ifdef DEBUG_METAGEMM_VERBOSE |
277 | std::cout << "GemmExecutor(" << typeid(P).name() << "): " << m << "x" << n |
278 | << "x" << k << " -- " << m_leftovers << "x" << n_leftovers << "x" |
279 | << k_leftovers << " -- " << params.m << "x" << params.n << "x" |
280 | << params.k << std::endl; |
281 | LeftStreamF::Debug(params.left_stream); |
282 | LeftStreamL::Debug(params.left_stream); |
283 | |
284 | RightStreamF::Debug(params.right_stream); |
285 | RightStreamL::Debug(params.right_stream); |
286 | |
287 | OutputStreamFF::Debug(params.fused_kernel.output_stream); |
288 | OutputStreamFL::Debug(params.fused_kernel.output_stream); |
289 | |
290 | KernelFF::Debug(params.fused_kernel); |
291 | KernelFL::Debug(params.fused_kernel); |
292 | KernelLF::Debug(params.fused_kernel); |
293 | KernelLL::Debug(params.fused_kernel); |
294 | #endif |
295 | #endif |
296 | |
297 | int lhs_chunks = params.m / m; |
298 | int rhs_chunks = params.n / n; |
299 | |
300 | // Scratch memory for packed LHS & RHS chunks. |
301 | std::uint8_t* packed_rhs = params.scratch; |
302 | std::uint8_t* packed_lhs = |
303 | params.scratch + RightStreamF::Scratch(params.right_stream); |
304 | |
305 | // Pack full LHS first. |
306 | |
307 | std::uint8_t* packed_lhs_chunk = packed_lhs; |
308 | const int packed_lhs_chunk_size = |
309 | LeftStreamF::PackedStride(params.left_stream); |
310 | |
311 | { |
312 | const std::uint8_t* lhs_chunk = |
313 | reinterpret_cast<const std::uint8_t*>(params.lhs); |
314 | const int lhs_chunk_size = |
315 | LeftStreamF::UnpackedStride(params.left_stream); |
316 | |
317 | for (int i = 0; i < lhs_chunks; ++i) { |
318 | LeftStreamF::Pack(reinterpret_cast<const InType*>(lhs_chunk), |
319 | params.left_stream, |
320 | reinterpret_cast<InType*>(packed_lhs_chunk)); |
321 | |
322 | lhs_chunk += lhs_chunk_size; |
323 | packed_lhs_chunk += packed_lhs_chunk_size; |
324 | } |
325 | |
326 | LeftStreamL::Pack(reinterpret_cast<const InType*>(lhs_chunk), |
327 | params.left_stream, |
328 | reinterpret_cast<InType*>(packed_lhs_chunk)); |
329 | } |
330 | |
331 | // Multiply RHS by LHS one RHS chunk at a time. |
332 | |
333 | const std::uint8_t* rhs_chunk = |
334 | reinterpret_cast<const std::uint8_t*>(params.rhs); |
335 | std::uint8_t* result_strip = reinterpret_cast<std::uint8_t*>(params.result); |
336 | std::uint8_t* result_chunk = result_strip; |
337 | |
338 | { |
339 | const int rhs_chunk_size = |
340 | RightStreamF::UnpackedStride(params.right_stream); |
341 | const int result_strip_size = |
342 | OutputStreamFF::UnpackedAdvance(params.fused_kernel.output_stream); |
343 | const int result_chunk_size = |
344 | OutputStreamFF::UnpackedStride(params.fused_kernel.output_stream); |
345 | |
346 | for (int i = 0; i < rhs_chunks; ++i) { |
347 | RightStreamF::Pack(reinterpret_cast<const InType*>(rhs_chunk), |
348 | params.right_stream, |
349 | reinterpret_cast<InType*>(packed_rhs)); |
350 | |
351 | result_chunk = result_strip; |
352 | packed_lhs_chunk = packed_lhs; |
353 | |
354 | for (int j = 0; j < lhs_chunks; ++j) { |
355 | KernelFF::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk), |
356 | reinterpret_cast<const InType*>(packed_rhs), |
357 | params.fused_kernel, |
358 | reinterpret_cast<OutType*>(result_chunk)); |
359 | |
360 | result_chunk += result_chunk_size; |
361 | packed_lhs_chunk += packed_lhs_chunk_size; |
362 | } |
363 | |
364 | KernelLF::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk), |
365 | reinterpret_cast<const InType*>(packed_rhs), |
366 | params.fused_kernel, |
367 | reinterpret_cast<OutType*>(result_chunk)); |
368 | |
369 | rhs_chunk += rhs_chunk_size; |
370 | result_strip += result_strip_size; |
371 | } |
372 | } |
373 | |
374 | // Leftover RHS chunk. |
375 | if (n_leftovers > 0) { // static if |
376 | const int result_chunk_size = |
377 | OutputStreamFL::UnpackedStride(params.fused_kernel.output_stream); |
378 | |
379 | RightStreamL::Pack(reinterpret_cast<const InType*>(rhs_chunk), |
380 | params.right_stream, |
381 | reinterpret_cast<InType*>(packed_rhs)); |
382 | |
383 | result_chunk = result_strip; |
384 | packed_lhs_chunk = packed_lhs; |
385 | |
386 | for (int i = 0; i < lhs_chunks; ++i) { |
387 | KernelFL::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk), |
388 | reinterpret_cast<const InType*>(packed_rhs), |
389 | params.fused_kernel, |
390 | reinterpret_cast<OutType*>(result_chunk)); |
391 | |
392 | result_chunk += result_chunk_size; |
393 | packed_lhs_chunk += packed_lhs_chunk_size; |
394 | } |
395 | |
396 | KernelLL::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk), |
397 | reinterpret_cast<const InType*>(packed_rhs), |
398 | params.fused_kernel, |
399 | reinterpret_cast<OutType*>(result_chunk)); |
400 | } |
401 | } |
402 | }; |
403 | |
404 | namespace internal { |
405 | |
406 | inline int CalculateCacheFriendlyTasksCount(int cache_size, int constant_memory, |
407 | int per_chunk_memory, int total_dim, |
408 | int chunk_dim) { |
409 | assert(constant_memory + per_chunk_memory < cache_size); |
410 | const int available_cache = cache_size - constant_memory; |
411 | const int available_chunks = available_cache / per_chunk_memory; |
412 | const int chunks_count = (total_dim + chunk_dim - 1) / chunk_dim; |
413 | return (chunks_count + available_chunks - 1) / available_chunks; |
414 | } |
415 | |
416 | template <typename Params> |
417 | inline void UpdateCacheFriendlyTask(int m_offset, int m, int n_offset, int n, |
418 | const Params& params, Params* task_params) { |
419 | task_params->m = m; |
420 | task_params->lhs = |
421 | StreamUtil<typename Params::InType, typename Params::LeftStream>::Offset( |
422 | params.left_stream, params.lhs, m_offset, 0); |
423 | |
424 | task_params->n = n; |
425 | task_params->rhs = |
426 | StreamUtil<typename Params::InType, typename Params::RightStream>::Offset( |
427 | params.right_stream, params.rhs, n_offset, 0); |
428 | |
429 | task_params->result = |
430 | StreamUtil<typename Params::OutType, typename Params::OutputStream>:: |
431 | Offset(params.fused_kernel.output_stream, params.result, m_offset, |
432 | n_offset); |
433 | } |
434 | |
435 | } // namespace internal |
436 | |
437 | template <int cache_size = 256 * 1024> |
438 | class GemmExecutorPackRHSCacheFriendly { |
439 | public: |
440 | template <typename P> |
441 | static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n, |
442 | int kernel_k) { |
443 | return cache_size; |
444 | } |
445 | |
446 | template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers, |
447 | int k_leftovers> |
448 | static void ExecuteDispatch3D(const P& params) { |
449 | typedef Stream<typename P::InType, m, k, k_leftovers, |
450 | typename P::LeftStream> |
451 | LeftStream; |
452 | |
453 | typedef Stream<typename P::InType, n, k, k_leftovers, |
454 | typename P::RightStream> |
455 | RightStream; |
456 | |
457 | const int lhs_scratch = LeftStream::Scratch(params.left_stream); |
458 | const int rhs_scratch = RightStream::Scratch(params.right_stream); |
459 | |
460 | const int cache_friendly_tasks_count = |
461 | internal::CalculateCacheFriendlyTasksCount(cache_size, lhs_scratch, |
462 | rhs_scratch, params.n, n); |
463 | |
464 | if (cache_friendly_tasks_count == 1) { |
465 | GemmExecutorPackRHS::ExecuteDispatch3D<P, m, n, k, m_leftovers, |
466 | n_leftovers, k_leftovers>(params); |
467 | return; |
468 | } |
469 | |
470 | const int cache_friendly_dim = params.n / cache_friendly_tasks_count; |
471 | |
472 | P task_params = params; |
473 | for (int i = 0; i < cache_friendly_tasks_count - 1; ++i) { |
474 | internal::UpdateCacheFriendlyTask(0, params.m, i * cache_friendly_dim, |
475 | cache_friendly_dim, params, |
476 | &task_params); |
477 | Gemm<GemmExecutorPackRHS, P, m, n, k>(task_params); |
478 | } |
479 | const int dim_sum = (cache_friendly_tasks_count - 1) * cache_friendly_dim; |
480 | internal::UpdateCacheFriendlyTask(0, params.m, dim_sum, params.n - dim_sum, |
481 | params, &task_params); |
482 | Gemm<GemmExecutorPackRHS, P, m, n, k>(task_params); |
483 | } |
484 | }; |
485 | |
486 | template <int cache_size = 256 * 1024> |
487 | class GemmExecutorPackLHSCacheFriendly { |
488 | public: |
489 | template <typename P> |
490 | static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n, |
491 | int kernel_k) { |
492 | return cache_size; |
493 | } |
494 | |
495 | template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers, |
496 | int k_leftovers> |
497 | static void ExecuteDispatch3D(const P& params) { |
498 | typedef Stream<typename P::InType, m, k, k_leftovers, |
499 | typename P::LeftStream> |
500 | LeftStream; |
501 | |
502 | typedef Stream<typename P::InType, n, k, k_leftovers, |
503 | typename P::RightStream> |
504 | RightStream; |
505 | |
506 | const int lhs_scratch = LeftStream::Scratch(params.left_stream); |
507 | const int rhs_scratch = RightStream::Scratch(params.right_stream); |
508 | |
509 | const int cache_friendly_tasks_count = |
510 | internal::CalculateCacheFriendlyTasksCount(cache_size, rhs_scratch, |
511 | lhs_scratch, params.m, m); |
512 | |
513 | if (cache_friendly_tasks_count == 1) { |
514 | GemmExecutorPackLHS::ExecuteDispatch3D<P, m, n, k, m_leftovers, |
515 | n_leftovers, k_leftovers>(params); |
516 | return; |
517 | } |
518 | |
519 | const int cache_friendly_dim = params.m / cache_friendly_tasks_count; |
520 | |
521 | P task_params = params; |
522 | for (int i = 0; i < cache_friendly_tasks_count - 1; ++i) { |
523 | internal::UpdateCacheFriendlyTask(i * cache_friendly_dim, |
524 | cache_friendly_dim, 0, params.n, params, |
525 | &task_params); |
526 | Gemm<GemmExecutorPackLHS, P, m, n, k>(task_params); |
527 | } |
528 | const int dim_sum = (cache_friendly_tasks_count - 1) * cache_friendly_dim; |
529 | internal::UpdateCacheFriendlyTask(dim_sum, params.m - dim_sum, 0, params.n, |
530 | params, &task_params); |
531 | Gemm<GemmExecutorPackLHS, P, m, n, k>(task_params); |
532 | } |
533 | }; |
534 | |
535 | namespace internal { |
536 | |
537 | // Stage 3. |
538 | |
539 | template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m, |
540 | int fixed_n, int variable_k> |
541 | struct Dispatch3DStage3 { |
542 | static void Execute(const P& params, int k) { |
543 | #ifdef DEBUG |
544 | #ifdef DEBUG_METAGEMM_VERBOSE |
545 | std::cout << "Dispatch(3): " << dim_m << "x" << dim_n << "x" << dim_k |
546 | << " : " << fixed_m << "x" << fixed_n << "x" << variable_k |
547 | << std::endl |
548 | << std::flush; |
549 | #endif |
550 | #endif |
551 | if (k == variable_k) { |
552 | E::template ExecuteDispatch3D<P, dim_m, dim_n, dim_k, fixed_m, fixed_n, |
553 | variable_k>(params); |
554 | } else { |
555 | Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, fixed_n, |
556 | variable_k - 1>::Execute(params, k); |
557 | } |
558 | } |
559 | }; |
560 | |
561 | template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m, |
562 | int fixed_n> |
563 | struct Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, fixed_n, 0> { |
564 | static void Execute(const P& params, int k) { |
565 | #ifdef DEBUG |
566 | #ifdef DEBUG_METAGEMM_VERBOSE |
567 | std::cout << "Dispatch(3): " << dim_m << "x" << dim_n << "x" << dim_k |
568 | << " : " << fixed_m << "x" << fixed_n << "x" << 0 << std::endl |
569 | << std::flush; |
570 | #endif |
571 | #endif |
572 | if (k == 0) { |
573 | E::template ExecuteDispatch3D<P, dim_m, dim_n, dim_k, fixed_m, fixed_n, |
574 | 0>(params); |
575 | } else { |
576 | std::cerr << "FATAL: dispatch3DStage3 failed: ran out of cases." |
577 | << std::endl |
578 | << std::flush; |
579 | std::exit(1); |
580 | } |
581 | } |
582 | }; |
583 | |
584 | // Stage 2. |
585 | |
586 | template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m, |
587 | int variable_n> |
588 | struct Dispatch3DStage2 { |
589 | static void Execute(const P& params, int n, int k) { |
590 | #ifdef DEBUG |
591 | #ifdef DEBUG_METAGEMM_VERBOSE |
592 | std::cout << "Dispatch(2): " << dim_m << "x" << dim_n << "x" << dim_k |
593 | << " : " << fixed_m << "x" << variable_n << std::endl |
594 | << std::flush; |
595 | #endif |
596 | #endif |
597 | if (n == variable_n) { |
598 | Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, variable_n, |
599 | dim_k - 1>::Execute(params, k); |
600 | } else { |
601 | Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, fixed_m, |
602 | variable_n - 1>::Execute(params, n, k); |
603 | } |
604 | } |
605 | }; |
606 | |
607 | template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m> |
608 | struct Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, fixed_m, 0> { |
609 | static void Execute(const P& params, int n, int k) { |
610 | #ifdef DEBUG |
611 | #ifdef DEBUG_METAGEMM_VERBOSE |
612 | std::cout << "Dispatch(2): " << dim_m << "x" << dim_n << "x" << dim_k |
613 | << " : " << fixed_m << "x" << 0 << std::endl |
614 | << std::flush; |
615 | #endif |
616 | #endif |
617 | if (n == 0) { |
618 | Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, 0, |
619 | dim_k - 1>::Execute(params, k); |
620 | } else { |
621 | std::cerr << "FATAL: dispatch3DStage2 failed: ran out of cases." |
622 | << std::endl |
623 | << std::flush; |
624 | std::exit(1); |
625 | } |
626 | } |
627 | }; |
628 | |
629 | // Stage 1. |
630 | |
631 | template <typename E, typename P, int dim_m, int dim_n, int dim_k, |
632 | int variable_m> |
633 | struct Dispatch3DStage1 { |
634 | static void Execute(const P& params, int m, int n, int k) { |
635 | #ifdef DEBUG |
636 | #ifdef DEBUG_METAGEMM_VERBOSE |
637 | std::cout << "Dispatch(1): " << dim_m << "x" << dim_n << "x" << dim_k |
638 | << " : " << variable_m << std::endl |
639 | << std::flush; |
640 | #endif |
641 | #endif |
642 | if (m == variable_m) { |
643 | Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, variable_m, |
644 | dim_n - 1>::Execute(params, n, k); |
645 | } else { |
646 | Dispatch3DStage1<E, P, dim_m, dim_n, dim_k, variable_m - 1>::Execute( |
647 | params, m, n, k); |
648 | } |
649 | } |
650 | }; |
651 | |
652 | template <typename E, typename P, int dim_m, int dim_n, int dim_k> |
653 | struct Dispatch3DStage1<E, P, dim_m, dim_n, dim_k, 0> { |
654 | static void Execute(const P& params, int m, int n, int k) { |
655 | #ifdef DEBUG |
656 | #ifdef DEBUG_METAGEMM_VERBOSE |
657 | std::cout << "Dispatch(1): " << dim_m << "x" << dim_n << "x" << dim_k |
658 | << " : " << 0 << std::endl |
659 | << std::flush; |
660 | #endif |
661 | #endif |
662 | if (m == 0) { |
663 | Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, 0, dim_n - 1>::Execute(params, |
664 | n, k); |
665 | } else { |
666 | std::cerr << "FATAL: dispatch3DStage1 failed: ran out of cases." |
667 | << std::endl |
668 | << std::flush; |
669 | std::exit(1); |
670 | } |
671 | } |
672 | }; |
673 | |
674 | } // namespace internal |
675 | |
676 | template <typename Executor, typename Params, int kernel_m, int kernel_n, |
677 | int kernel_k> |
678 | inline void Gemm(const Params& params) { |
679 | internal::Dispatch3DStage1<Executor, Params, kernel_m, kernel_n, kernel_k, |
680 | kernel_m - 1>::Execute(params, params.m % kernel_m, |
681 | params.n % kernel_n, |
682 | params.k % kernel_k); |
683 | } |
684 | |
685 | } // namespace meta |
686 | } // namespace gemmlowp |
687 | |
688 | #endif // GEMMLOWP_META_SINGLE_THREAD_GEMM_H_ |
689 | |