1 | // Copyright 2016 The Gemmlowp Authors. All Rights Reserved. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | #ifndef GEMMLOWP_META_MULTI_THREAD_GEMM_H_ |
16 | #define GEMMLOWP_META_MULTI_THREAD_GEMM_H_ |
17 | |
18 | #include "multi_thread_common.h" |
19 | #include "single_thread_gemm.h" |
20 | |
21 | namespace gemmlowp { |
22 | namespace meta { |
23 | namespace internal { |
24 | |
25 | const std::int32_t kMinGemmTaskSize = 16000; |
26 | const std::int32_t kMinGemmTaskDimension = 4; |
27 | |
28 | template <typename Executor, typename Params> |
29 | std::uint8_t* PrepareGemmTask(const Params& params, int kernel_m, int kernel_n, |
30 | int kernel_k, std::uint8_t* scratch, int m_start, |
31 | int m, int n_start, int n, |
32 | std::vector<Params>* tasks) { |
33 | tasks->push_back(params); |
34 | Params& task = tasks->back(); |
35 | task.scratch = scratch; |
36 | |
37 | task.m = m; |
38 | task.lhs = |
39 | StreamUtil<typename Params::InType, typename Params::LeftStream>::Offset( |
40 | params.left_stream, params.lhs, m_start, 0); |
41 | |
42 | task.n = n; |
43 | task.rhs = |
44 | StreamUtil<typename Params::InType, typename Params::RightStream>::Offset( |
45 | params.right_stream, params.rhs, n_start, 0); |
46 | |
47 | task.result = |
48 | StreamUtil<typename Params::OutType, typename Params::OutputStream>:: |
49 | Offset(params.fused_kernel.output_stream, params.result, m_start, |
50 | n_start); |
51 | |
52 | return scratch + Executor::template EstimateScratchSize<Params>( |
53 | task, kernel_m, kernel_n, kernel_k); |
54 | } |
55 | |
56 | template <typename MultiThreadingContext, typename Executor, typename Params> |
57 | bool PrepareGemmTasks(MultiThreadingContext* context, const Params& params, |
58 | int kernel_m, int kernel_n, int kernel_k, |
59 | std::vector<Params>* task_params) { |
60 | const int max_threads = ResolveMaxThreads(context->max_num_threads()); |
61 | const int max_tasks_by_size = |
62 | (params.m * params.n * params.k) / kMinGemmTaskSize; |
63 | const int max_tasks_m = params.m / kMinGemmTaskDimension; |
64 | const int max_tasks_n = params.n / kMinGemmTaskDimension; |
65 | const int max_tasks_dimension = std::max(max_tasks_m, max_tasks_n); |
66 | |
67 | const int real_tasks = std::max( |
68 | 1, |
69 | std::min(max_threads, std::min(max_tasks_by_size, max_tasks_dimension))); |
70 | |
71 | if (real_tasks == 1) { |
72 | return false; |
73 | } |
74 | |
75 | std::uint8_t* scratch = params.scratch; |
76 | |
77 | if (max_tasks_m > max_tasks_n) { |
78 | const int m_chunk = params.m / real_tasks; |
79 | for (int i = 0; i < real_tasks - 1; ++i) { |
80 | scratch = PrepareGemmTask<Executor, Params>( |
81 | params, kernel_m, kernel_n, kernel_k, scratch, i * m_chunk, m_chunk, |
82 | 0, params.n, task_params); |
83 | } |
84 | const int sum_m = (real_tasks - 1) * m_chunk; |
85 | PrepareGemmTask<Executor, Params>(params, kernel_m, kernel_n, kernel_k, |
86 | scratch, sum_m, params.m - sum_m, 0, |
87 | params.n, task_params); |
88 | } else { |
89 | const int n_chunk = params.n / real_tasks; |
90 | for (int i = 0; i < real_tasks - 1; ++i) { |
91 | scratch = PrepareGemmTask<Executor, Params>( |
92 | params, kernel_m, kernel_n, kernel_k, scratch, 0, params.m, |
93 | i * n_chunk, n_chunk, task_params); |
94 | } |
95 | int sum_n = (real_tasks - 1) * n_chunk; |
96 | PrepareGemmTask<Executor, Params>(params, kernel_m, kernel_n, kernel_k, |
97 | scratch, 0, params.m, sum_n, |
98 | params.n - sum_n, task_params); |
99 | } |
100 | |
101 | return true; |
102 | } |
103 | |
104 | template <typename Executor, typename Params, int kernel_m, int kernel_n, |
105 | int kernel_k> |
106 | struct GemmTaskRunner : gemmlowp::Task { |
107 | GemmTaskRunner(const Params& params) : params(params) {} |
108 | |
109 | void Run() override { |
110 | Gemm<Executor, Params, kernel_m, kernel_n, kernel_k>(params); |
111 | } |
112 | |
113 | Params params; |
114 | }; |
115 | |
116 | } // namespace internal |
117 | |
118 | template <typename MultiThreadingContext, typename Executor, typename Params, |
119 | int kernel_m, int kernel_n, int kernel_k> |
120 | inline void MultiThreadGemm(MultiThreadingContext* context, |
121 | const Params& params) { |
122 | typedef internal::GemmTaskRunner<Executor, Params, kernel_m, kernel_n, |
123 | kernel_k> |
124 | TaskRunnerType; |
125 | |
126 | std::vector<Params> task_params; |
127 | if (!internal::PrepareGemmTasks<MultiThreadingContext, Executor, Params>( |
128 | context, params, kernel_m, kernel_n, kernel_k, &task_params)) { |
129 | Gemm<Executor, Params, kernel_m, kernel_n, kernel_k>(params); |
130 | return; |
131 | } |
132 | |
133 | auto workers_pool = context->workers_pool(); |
134 | std::vector<Task*> tasks; |
135 | for (auto& task_param : task_params) { |
136 | tasks.push_back(new TaskRunnerType(task_param)); |
137 | }; |
138 | workers_pool->Execute(tasks); |
139 | } |
140 | |
141 | } // namespace meta |
142 | } // namespace gemmlowp |
143 | |
144 | #endif // GEMMLOWP_META_MULTI_THREAD_GEMM_H_ |
145 | |