1// Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef GEMMLOWP_META_SINGLE_THREAD_GEMM_H_
16#define GEMMLOWP_META_SINGLE_THREAD_GEMM_H_
17
18#include <iostream>
19#include "base.h"
20
21namespace gemmlowp {
22namespace meta {
23
24template <typename Executor, typename Params, int kernel_m, int kernel_n,
25 int kernel_k>
26void Gemm(const Params& params);
27
28class GemmExecutorPackRHS {
29 public:
30 template <typename P>
31 static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n,
32 int kernel_k) {
33 const int lhs_scratch =
34 StreamUtil<typename P::InType, typename P::LeftStream>::Scratch(
35 params.left_stream, kernel_m, kernel_k);
36 const int rhs_chunks = ((params.n + kernel_n - 1) / kernel_n);
37 const int rhs_scratch =
38 rhs_chunks *
39 StreamUtil<typename P::InType, typename P::RightStream>::Scratch(
40 params.right_stream, kernel_n, kernel_k);
41 return AlignTo<64 * 1024>(lhs_scratch + rhs_scratch);
42 }
43
44 template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers,
45 int k_leftovers>
46 static void ExecuteDispatch3D(const P& params) {
47 // Shorthand typedefs for streams and multiply kernels.
48 typedef typename P::InType InType;
49 typedef typename P::OutType OutType;
50
51 typedef Stream<typename P::InType, m, k, k_leftovers,
52 typename P::LeftStream>
53 LeftStreamF;
54 typedef Stream<typename P::InType, m_leftovers, k, k_leftovers,
55 typename P::LeftStream>
56 LeftStreamL;
57
58 typedef Stream<typename P::InType, n, k, k_leftovers,
59 typename P::RightStream>
60 RightStreamF;
61 typedef Stream<typename P::InType, n_leftovers, k, k_leftovers,
62 typename P::RightStream>
63 RightStreamL;
64
65 typedef Stream<typename P::OutType, m, n, 0, typename P::OutputStream>
66 OutputStreamFF;
67 typedef Stream<typename P::OutType, m_leftovers, n, 0,
68 typename P::OutputStream>
69 OutputStreamLF;
70
71 typedef MulKernel<typename P::InType, typename P::OutType,
72 typename P::Kernel, typename P::OutputStream, m, n, k>
73 KernelFF;
74 typedef MulKernel<typename P::InType, typename P::OutType,
75 typename P::Kernel, typename P::OutputStream, m,
76 n_leftovers, k>
77 KernelFL;
78 typedef MulKernel<typename P::InType, typename P::OutType,
79 typename P::Kernel, typename P::OutputStream, m_leftovers,
80 n, k>
81 KernelLF;
82 typedef MulKernel<typename P::InType, typename P::OutType,
83 typename P::Kernel, typename P::OutputStream, m_leftovers,
84 n_leftovers, k>
85 KernelLL;
86
87#ifdef DEBUG
88#ifdef DEBUG_METAGEMM_VERBOSE
89 std::cout << "GemmExecutor(" << typeid(P).name() << "): " << m << "x" << n
90 << "x" << k << " -- " << m_leftovers << "x" << n_leftovers << "x"
91 << k_leftovers << " -- " << params.m << "x" << params.n << "x"
92 << params.k << std::endl;
93 LeftStreamF::Debug(params.left_stream);
94 LeftStreamL::Debug(params.left_stream);
95
96 RightStreamF::Debug(params.right_stream);
97 RightStreamL::Debug(params.right_stream);
98
99 OutputStreamFF::Debug(params.fused_kernel.output_stream);
100 OutputStreamLF::Debug(params.fused_kernel.output_stream);
101
102 KernelFF::Debug(params.fused_kernel);
103 KernelFL::Debug(params.fused_kernel);
104 KernelLF::Debug(params.fused_kernel);
105 KernelLL::Debug(params.fused_kernel);
106#endif
107#endif
108
109 int lhs_chunks = params.m / m;
110 int rhs_chunks = params.n / n;
111
112 // Scratch memory for packed LHS & RHS chunks.
113
114 std::uint8_t* packed_lhs = params.scratch;
115 std::uint8_t* packed_rhs =
116 params.scratch + LeftStreamF::Scratch(params.left_stream);
117
118 // Pack full RHS first.
119
120 std::uint8_t* packed_rhs_chunk = packed_rhs;
121 const int packed_rhs_chunk_size =
122 RightStreamF::PackedStride(params.right_stream);
123
124 {
125 const std::uint8_t* rhs_chunk =
126 reinterpret_cast<const std::uint8_t*>(params.rhs);
127 const int rhs_chunk_size =
128 RightStreamF::UnpackedStride(params.right_stream);
129
130 for (int i = 0; i < rhs_chunks; ++i) {
131 RightStreamF::Pack(reinterpret_cast<const InType*>(rhs_chunk),
132 params.right_stream,
133 reinterpret_cast<InType*>(packed_rhs_chunk));
134
135 rhs_chunk += rhs_chunk_size;
136 packed_rhs_chunk += packed_rhs_chunk_size;
137 }
138
139 RightStreamL::Pack(reinterpret_cast<const InType*>(rhs_chunk),
140 params.right_stream,
141 reinterpret_cast<InType*>(packed_rhs_chunk));
142 }
143
144 // Multiply RHS by LHS one LHS chunk at a time.
145
146 const std::uint8_t* lhs_chunk =
147 reinterpret_cast<const std::uint8_t*>(params.lhs);
148 std::uint8_t* result_strip = reinterpret_cast<std::uint8_t*>(params.result);
149 std::uint8_t* result_chunk = result_strip;
150
151 {
152 const int lhs_chunk_size =
153 LeftStreamF::UnpackedStride(params.left_stream);
154 const int result_strip_size =
155 OutputStreamFF::UnpackedStride(params.fused_kernel.output_stream);
156 const int result_chunk_size =
157 OutputStreamFF::UnpackedAdvance(params.fused_kernel.output_stream);
158
159 for (int i = 0; i < lhs_chunks; ++i) {
160 LeftStreamF::Pack(reinterpret_cast<const InType*>(lhs_chunk),
161 params.left_stream,
162 reinterpret_cast<InType*>(packed_lhs));
163
164 result_chunk = result_strip;
165 packed_rhs_chunk = packed_rhs;
166
167 for (int j = 0; j < rhs_chunks; ++j) {
168 KernelFF::Multiply(reinterpret_cast<const InType*>(packed_lhs),
169 reinterpret_cast<const InType*>(packed_rhs_chunk),
170 params.fused_kernel,
171 reinterpret_cast<OutType*>(result_chunk));
172
173 result_chunk += result_chunk_size;
174 packed_rhs_chunk += packed_rhs_chunk_size;
175 }
176
177 KernelFL::Multiply(reinterpret_cast<const InType*>(packed_lhs),
178 reinterpret_cast<const InType*>(packed_rhs_chunk),
179 params.fused_kernel,
180 reinterpret_cast<OutType*>(result_chunk));
181
182 lhs_chunk += lhs_chunk_size;
183 result_strip += result_strip_size;
184 }
185 }
186
187 // Leftover LHS chunk.
188 if (m_leftovers > 0) { // static if
189 const int result_chunk_size =
190 OutputStreamLF::UnpackedAdvance(params.fused_kernel.output_stream);
191
192 LeftStreamL::Pack(reinterpret_cast<const InType*>(lhs_chunk),
193 params.left_stream,
194 reinterpret_cast<InType*>(packed_lhs));
195
196 result_chunk = result_strip;
197 packed_rhs_chunk = packed_rhs;
198
199 for (int i = 0; i < rhs_chunks; ++i) {
200 KernelLF::Multiply(reinterpret_cast<const InType*>(packed_lhs),
201 reinterpret_cast<const InType*>(packed_rhs_chunk),
202 params.fused_kernel,
203 reinterpret_cast<OutType*>(result_chunk));
204
205 result_chunk += result_chunk_size;
206 packed_rhs_chunk += packed_rhs_chunk_size;
207 }
208
209 KernelLL::Multiply(reinterpret_cast<const InType*>(packed_lhs),
210 reinterpret_cast<const InType*>(packed_rhs_chunk),
211 params.fused_kernel,
212 reinterpret_cast<OutType*>(result_chunk));
213 }
214 }
215};
216
217class GemmExecutorPackLHS {
218 public:
219 template <typename P>
220 static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n,
221 int kernel_k) {
222 const int lhs_chunks = ((params.m + kernel_m - 1) / kernel_m);
223 const int lhs_scratch =
224 lhs_chunks *
225 StreamUtil<typename P::InType, typename P::LeftStream>::Scratch(
226 params.left_stream, kernel_m, kernel_k);
227 const int rhs_scratch =
228 StreamUtil<typename P::InType, typename P::RightStream>::Scratch(
229 params.right_stream, kernel_n, kernel_k);
230 return AlignTo<64 * 1024>(lhs_scratch + rhs_scratch);
231 }
232
233 template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers,
234 int k_leftovers>
235 static void ExecuteDispatch3D(const P& params) {
236 // Shorthand typedefs for streams and multiply kernels.
237 typedef typename P::InType InType;
238 typedef typename P::OutType OutType;
239
240 typedef Stream<typename P::InType, m, k, k_leftovers,
241 typename P::LeftStream>
242 LeftStreamF;
243 typedef Stream<typename P::InType, m_leftovers, k, k_leftovers,
244 typename P::LeftStream>
245 LeftStreamL;
246
247 typedef Stream<typename P::InType, n, k, k_leftovers,
248 typename P::RightStream>
249 RightStreamF;
250 typedef Stream<typename P::InType, n_leftovers, k, k_leftovers,
251 typename P::RightStream>
252 RightStreamL;
253
254 typedef Stream<typename P::OutType, m, n, 0, typename P::OutputStream>
255 OutputStreamFF;
256 typedef Stream<typename P::OutType, m, n_leftovers, 0,
257 typename P::OutputStream>
258 OutputStreamFL;
259
260 typedef MulKernel<typename P::InType, typename P::OutType,
261 typename P::Kernel, typename P::OutputStream, m, n, k>
262 KernelFF;
263 typedef MulKernel<typename P::InType, typename P::OutType,
264 typename P::Kernel, typename P::OutputStream, m,
265 n_leftovers, k>
266 KernelFL;
267 typedef MulKernel<typename P::InType, typename P::OutType,
268 typename P::Kernel, typename P::OutputStream, m_leftovers,
269 n, k>
270 KernelLF;
271 typedef MulKernel<typename P::InType, typename P::OutType,
272 typename P::Kernel, typename P::OutputStream, m_leftovers,
273 n_leftovers, k>
274 KernelLL;
275#ifdef DEBUG
276#ifdef DEBUG_METAGEMM_VERBOSE
277 std::cout << "GemmExecutor(" << typeid(P).name() << "): " << m << "x" << n
278 << "x" << k << " -- " << m_leftovers << "x" << n_leftovers << "x"
279 << k_leftovers << " -- " << params.m << "x" << params.n << "x"
280 << params.k << std::endl;
281 LeftStreamF::Debug(params.left_stream);
282 LeftStreamL::Debug(params.left_stream);
283
284 RightStreamF::Debug(params.right_stream);
285 RightStreamL::Debug(params.right_stream);
286
287 OutputStreamFF::Debug(params.fused_kernel.output_stream);
288 OutputStreamFL::Debug(params.fused_kernel.output_stream);
289
290 KernelFF::Debug(params.fused_kernel);
291 KernelFL::Debug(params.fused_kernel);
292 KernelLF::Debug(params.fused_kernel);
293 KernelLL::Debug(params.fused_kernel);
294#endif
295#endif
296
297 int lhs_chunks = params.m / m;
298 int rhs_chunks = params.n / n;
299
300 // Scratch memory for packed LHS & RHS chunks.
301 std::uint8_t* packed_rhs = params.scratch;
302 std::uint8_t* packed_lhs =
303 params.scratch + RightStreamF::Scratch(params.right_stream);
304
305 // Pack full LHS first.
306
307 std::uint8_t* packed_lhs_chunk = packed_lhs;
308 const int packed_lhs_chunk_size =
309 LeftStreamF::PackedStride(params.left_stream);
310
311 {
312 const std::uint8_t* lhs_chunk =
313 reinterpret_cast<const std::uint8_t*>(params.lhs);
314 const int lhs_chunk_size =
315 LeftStreamF::UnpackedStride(params.left_stream);
316
317 for (int i = 0; i < lhs_chunks; ++i) {
318 LeftStreamF::Pack(reinterpret_cast<const InType*>(lhs_chunk),
319 params.left_stream,
320 reinterpret_cast<InType*>(packed_lhs_chunk));
321
322 lhs_chunk += lhs_chunk_size;
323 packed_lhs_chunk += packed_lhs_chunk_size;
324 }
325
326 LeftStreamL::Pack(reinterpret_cast<const InType*>(lhs_chunk),
327 params.left_stream,
328 reinterpret_cast<InType*>(packed_lhs_chunk));
329 }
330
331 // Multiply RHS by LHS one RHS chunk at a time.
332
333 const std::uint8_t* rhs_chunk =
334 reinterpret_cast<const std::uint8_t*>(params.rhs);
335 std::uint8_t* result_strip = reinterpret_cast<std::uint8_t*>(params.result);
336 std::uint8_t* result_chunk = result_strip;
337
338 {
339 const int rhs_chunk_size =
340 RightStreamF::UnpackedStride(params.right_stream);
341 const int result_strip_size =
342 OutputStreamFF::UnpackedAdvance(params.fused_kernel.output_stream);
343 const int result_chunk_size =
344 OutputStreamFF::UnpackedStride(params.fused_kernel.output_stream);
345
346 for (int i = 0; i < rhs_chunks; ++i) {
347 RightStreamF::Pack(reinterpret_cast<const InType*>(rhs_chunk),
348 params.right_stream,
349 reinterpret_cast<InType*>(packed_rhs));
350
351 result_chunk = result_strip;
352 packed_lhs_chunk = packed_lhs;
353
354 for (int j = 0; j < lhs_chunks; ++j) {
355 KernelFF::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk),
356 reinterpret_cast<const InType*>(packed_rhs),
357 params.fused_kernel,
358 reinterpret_cast<OutType*>(result_chunk));
359
360 result_chunk += result_chunk_size;
361 packed_lhs_chunk += packed_lhs_chunk_size;
362 }
363
364 KernelLF::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk),
365 reinterpret_cast<const InType*>(packed_rhs),
366 params.fused_kernel,
367 reinterpret_cast<OutType*>(result_chunk));
368
369 rhs_chunk += rhs_chunk_size;
370 result_strip += result_strip_size;
371 }
372 }
373
374 // Leftover RHS chunk.
375 if (n_leftovers > 0) { // static if
376 const int result_chunk_size =
377 OutputStreamFL::UnpackedStride(params.fused_kernel.output_stream);
378
379 RightStreamL::Pack(reinterpret_cast<const InType*>(rhs_chunk),
380 params.right_stream,
381 reinterpret_cast<InType*>(packed_rhs));
382
383 result_chunk = result_strip;
384 packed_lhs_chunk = packed_lhs;
385
386 for (int i = 0; i < lhs_chunks; ++i) {
387 KernelFL::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk),
388 reinterpret_cast<const InType*>(packed_rhs),
389 params.fused_kernel,
390 reinterpret_cast<OutType*>(result_chunk));
391
392 result_chunk += result_chunk_size;
393 packed_lhs_chunk += packed_lhs_chunk_size;
394 }
395
396 KernelLL::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk),
397 reinterpret_cast<const InType*>(packed_rhs),
398 params.fused_kernel,
399 reinterpret_cast<OutType*>(result_chunk));
400 }
401 }
402};
403
404namespace internal {
405
406inline int CalculateCacheFriendlyTasksCount(int cache_size, int constant_memory,
407 int per_chunk_memory, int total_dim,
408 int chunk_dim) {
409 assert(constant_memory + per_chunk_memory < cache_size);
410 const int available_cache = cache_size - constant_memory;
411 const int available_chunks = available_cache / per_chunk_memory;
412 const int chunks_count = (total_dim + chunk_dim - 1) / chunk_dim;
413 return (chunks_count + available_chunks - 1) / available_chunks;
414}
415
416template <typename Params>
417inline void UpdateCacheFriendlyTask(int m_offset, int m, int n_offset, int n,
418 const Params& params, Params* task_params) {
419 task_params->m = m;
420 task_params->lhs =
421 StreamUtil<typename Params::InType, typename Params::LeftStream>::Offset(
422 params.left_stream, params.lhs, m_offset, 0);
423
424 task_params->n = n;
425 task_params->rhs =
426 StreamUtil<typename Params::InType, typename Params::RightStream>::Offset(
427 params.right_stream, params.rhs, n_offset, 0);
428
429 task_params->result =
430 StreamUtil<typename Params::OutType, typename Params::OutputStream>::
431 Offset(params.fused_kernel.output_stream, params.result, m_offset,
432 n_offset);
433}
434
435} // namespace internal
436
437template <int cache_size = 256 * 1024>
438class GemmExecutorPackRHSCacheFriendly {
439 public:
440 template <typename P>
441 static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n,
442 int kernel_k) {
443 return cache_size;
444 }
445
446 template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers,
447 int k_leftovers>
448 static void ExecuteDispatch3D(const P& params) {
449 typedef Stream<typename P::InType, m, k, k_leftovers,
450 typename P::LeftStream>
451 LeftStream;
452
453 typedef Stream<typename P::InType, n, k, k_leftovers,
454 typename P::RightStream>
455 RightStream;
456
457 const int lhs_scratch = LeftStream::Scratch(params.left_stream);
458 const int rhs_scratch = RightStream::Scratch(params.right_stream);
459
460 const int cache_friendly_tasks_count =
461 internal::CalculateCacheFriendlyTasksCount(cache_size, lhs_scratch,
462 rhs_scratch, params.n, n);
463
464 if (cache_friendly_tasks_count == 1) {
465 GemmExecutorPackRHS::ExecuteDispatch3D<P, m, n, k, m_leftovers,
466 n_leftovers, k_leftovers>(params);
467 return;
468 }
469
470 const int cache_friendly_dim = params.n / cache_friendly_tasks_count;
471
472 P task_params = params;
473 for (int i = 0; i < cache_friendly_tasks_count - 1; ++i) {
474 internal::UpdateCacheFriendlyTask(0, params.m, i * cache_friendly_dim,
475 cache_friendly_dim, params,
476 &task_params);
477 Gemm<GemmExecutorPackRHS, P, m, n, k>(task_params);
478 }
479 const int dim_sum = (cache_friendly_tasks_count - 1) * cache_friendly_dim;
480 internal::UpdateCacheFriendlyTask(0, params.m, dim_sum, params.n - dim_sum,
481 params, &task_params);
482 Gemm<GemmExecutorPackRHS, P, m, n, k>(task_params);
483 }
484};
485
486template <int cache_size = 256 * 1024>
487class GemmExecutorPackLHSCacheFriendly {
488 public:
489 template <typename P>
490 static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n,
491 int kernel_k) {
492 return cache_size;
493 }
494
495 template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers,
496 int k_leftovers>
497 static void ExecuteDispatch3D(const P& params) {
498 typedef Stream<typename P::InType, m, k, k_leftovers,
499 typename P::LeftStream>
500 LeftStream;
501
502 typedef Stream<typename P::InType, n, k, k_leftovers,
503 typename P::RightStream>
504 RightStream;
505
506 const int lhs_scratch = LeftStream::Scratch(params.left_stream);
507 const int rhs_scratch = RightStream::Scratch(params.right_stream);
508
509 const int cache_friendly_tasks_count =
510 internal::CalculateCacheFriendlyTasksCount(cache_size, rhs_scratch,
511 lhs_scratch, params.m, m);
512
513 if (cache_friendly_tasks_count == 1) {
514 GemmExecutorPackLHS::ExecuteDispatch3D<P, m, n, k, m_leftovers,
515 n_leftovers, k_leftovers>(params);
516 return;
517 }
518
519 const int cache_friendly_dim = params.m / cache_friendly_tasks_count;
520
521 P task_params = params;
522 for (int i = 0; i < cache_friendly_tasks_count - 1; ++i) {
523 internal::UpdateCacheFriendlyTask(i * cache_friendly_dim,
524 cache_friendly_dim, 0, params.n, params,
525 &task_params);
526 Gemm<GemmExecutorPackLHS, P, m, n, k>(task_params);
527 }
528 const int dim_sum = (cache_friendly_tasks_count - 1) * cache_friendly_dim;
529 internal::UpdateCacheFriendlyTask(dim_sum, params.m - dim_sum, 0, params.n,
530 params, &task_params);
531 Gemm<GemmExecutorPackLHS, P, m, n, k>(task_params);
532 }
533};
534
535namespace internal {
536
537// Stage 3.
538
539template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m,
540 int fixed_n, int variable_k>
541struct Dispatch3DStage3 {
542 static void Execute(const P& params, int k) {
543#ifdef DEBUG
544#ifdef DEBUG_METAGEMM_VERBOSE
545 std::cout << "Dispatch(3): " << dim_m << "x" << dim_n << "x" << dim_k
546 << " : " << fixed_m << "x" << fixed_n << "x" << variable_k
547 << std::endl
548 << std::flush;
549#endif
550#endif
551 if (k == variable_k) {
552 E::template ExecuteDispatch3D<P, dim_m, dim_n, dim_k, fixed_m, fixed_n,
553 variable_k>(params);
554 } else {
555 Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, fixed_n,
556 variable_k - 1>::Execute(params, k);
557 }
558 }
559};
560
561template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m,
562 int fixed_n>
563struct Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, fixed_n, 0> {
564 static void Execute(const P& params, int k) {
565#ifdef DEBUG
566#ifdef DEBUG_METAGEMM_VERBOSE
567 std::cout << "Dispatch(3): " << dim_m << "x" << dim_n << "x" << dim_k
568 << " : " << fixed_m << "x" << fixed_n << "x" << 0 << std::endl
569 << std::flush;
570#endif
571#endif
572 if (k == 0) {
573 E::template ExecuteDispatch3D<P, dim_m, dim_n, dim_k, fixed_m, fixed_n,
574 0>(params);
575 } else {
576 std::cerr << "FATAL: dispatch3DStage3 failed: ran out of cases."
577 << std::endl
578 << std::flush;
579 std::exit(1);
580 }
581 }
582};
583
584// Stage 2.
585
586template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m,
587 int variable_n>
588struct Dispatch3DStage2 {
589 static void Execute(const P& params, int n, int k) {
590#ifdef DEBUG
591#ifdef DEBUG_METAGEMM_VERBOSE
592 std::cout << "Dispatch(2): " << dim_m << "x" << dim_n << "x" << dim_k
593 << " : " << fixed_m << "x" << variable_n << std::endl
594 << std::flush;
595#endif
596#endif
597 if (n == variable_n) {
598 Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, variable_n,
599 dim_k - 1>::Execute(params, k);
600 } else {
601 Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, fixed_m,
602 variable_n - 1>::Execute(params, n, k);
603 }
604 }
605};
606
607template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m>
608struct Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, fixed_m, 0> {
609 static void Execute(const P& params, int n, int k) {
610#ifdef DEBUG
611#ifdef DEBUG_METAGEMM_VERBOSE
612 std::cout << "Dispatch(2): " << dim_m << "x" << dim_n << "x" << dim_k
613 << " : " << fixed_m << "x" << 0 << std::endl
614 << std::flush;
615#endif
616#endif
617 if (n == 0) {
618 Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, 0,
619 dim_k - 1>::Execute(params, k);
620 } else {
621 std::cerr << "FATAL: dispatch3DStage2 failed: ran out of cases."
622 << std::endl
623 << std::flush;
624 std::exit(1);
625 }
626 }
627};
628
629// Stage 1.
630
631template <typename E, typename P, int dim_m, int dim_n, int dim_k,
632 int variable_m>
633struct Dispatch3DStage1 {
634 static void Execute(const P& params, int m, int n, int k) {
635#ifdef DEBUG
636#ifdef DEBUG_METAGEMM_VERBOSE
637 std::cout << "Dispatch(1): " << dim_m << "x" << dim_n << "x" << dim_k
638 << " : " << variable_m << std::endl
639 << std::flush;
640#endif
641#endif
642 if (m == variable_m) {
643 Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, variable_m,
644 dim_n - 1>::Execute(params, n, k);
645 } else {
646 Dispatch3DStage1<E, P, dim_m, dim_n, dim_k, variable_m - 1>::Execute(
647 params, m, n, k);
648 }
649 }
650};
651
652template <typename E, typename P, int dim_m, int dim_n, int dim_k>
653struct Dispatch3DStage1<E, P, dim_m, dim_n, dim_k, 0> {
654 static void Execute(const P& params, int m, int n, int k) {
655#ifdef DEBUG
656#ifdef DEBUG_METAGEMM_VERBOSE
657 std::cout << "Dispatch(1): " << dim_m << "x" << dim_n << "x" << dim_k
658 << " : " << 0 << std::endl
659 << std::flush;
660#endif
661#endif
662 if (m == 0) {
663 Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, 0, dim_n - 1>::Execute(params,
664 n, k);
665 } else {
666 std::cerr << "FATAL: dispatch3DStage1 failed: ran out of cases."
667 << std::endl
668 << std::flush;
669 std::exit(1);
670 }
671 }
672};
673
674} // namespace internal
675
676template <typename Executor, typename Params, int kernel_m, int kernel_n,
677 int kernel_k>
678inline void Gemm(const Params& params) {
679 internal::Dispatch3DStage1<Executor, Params, kernel_m, kernel_n, kernel_k,
680 kernel_m - 1>::Execute(params, params.m % kernel_m,
681 params.n % kernel_n,
682 params.k % kernel_k);
683}
684
685} // namespace meta
686} // namespace gemmlowp
687
688#endif // GEMMLOWP_META_SINGLE_THREAD_GEMM_H_
689