1 | |
2 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
3 | |
4 | Licensed under the Apache License, Version 2.0 (the "License"); |
5 | you may not use this file except in compliance with the License. |
6 | You may obtain a copy of the License at |
7 | |
8 | http://www.apache.org/licenses/LICENSE-2.0 |
9 | |
10 | Unless required by applicable law or agreed to in writing, software |
11 | distributed under the License is distributed on an "AS IS" BASIS, |
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | See the License for the specific language governing permissions and |
14 | limitations under the License. |
15 | ==============================================================================*/ |
16 | |
17 | #ifndef TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_ |
18 | #define TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_ |
19 | #ifdef INTEL_MKL |
20 | |
21 | #include <list> |
22 | #include <memory> |
23 | #include <string> |
24 | #include <unordered_map> |
25 | #include <utility> |
26 | #include <vector> |
27 | |
28 | #include "dnnl_threadpool.hpp" |
29 | #include "dnnl.hpp" |
30 | #include "tensorflow/core/framework/op_kernel.h" |
31 | #include "tensorflow/core/platform/threadpool.h" |
32 | #define EIGEN_USE_THREADS |
33 | |
34 | namespace tensorflow { |
35 | |
36 | #ifndef ENABLE_ONEDNN_OPENMP |
37 | using dnnl::threadpool_interop::threadpool_iface; |
38 | |
39 | // Divide 'n' units of work equally among 'teams' threads. If 'n' is not |
40 | // divisible by 'teams' and has a remainder 'r', the first 'r' teams have one |
41 | // unit of work more than the rest. Returns the range of work that belongs to |
42 | // the team 'tid'. |
43 | // Parameters |
44 | // n Total number of jobs. |
45 | // team Number of workers. |
46 | // tid Current thread_id. |
47 | // n_start start of range operated by the thread. |
48 | // n_end end of the range operated by the thread. |
49 | |
50 | template <typename T, typename U> |
51 | inline void balance211(T n, U team, U tid, T* n_start, T* n_end) { |
52 | if (team <= 1 || n == 0) { |
53 | *n_start = 0; |
54 | *n_end = n; |
55 | return; |
56 | } |
57 | T min_per_team = n / team; |
58 | T remainder = n - min_per_team * team; // i.e., n % teams. |
59 | *n_start = tid * min_per_team + std::min(tid, remainder); |
60 | *n_end = *n_start + min_per_team + (tid < remainder); |
61 | } |
62 | |
63 | struct MklDnnThreadPool : public threadpool_iface { |
64 | MklDnnThreadPool() = default; |
65 | |
66 | MklDnnThreadPool(OpKernelContext* ctx, int num_threads = -1) { |
67 | eigen_interface_ = ctx->device() |
68 | ->tensorflow_cpu_worker_threads() |
69 | ->workers->AsEigenThreadPool(); |
70 | num_threads_ = |
71 | (num_threads == -1) ? eigen_interface_->NumThreads() : num_threads; |
72 | } |
73 | virtual int get_num_threads() const override { return num_threads_; } |
74 | virtual bool get_in_parallel() const override { |
75 | return (eigen_interface_->CurrentThreadId() != -1) ? true : false; |
76 | } |
77 | virtual uint64_t get_flags() const override { return ASYNCHRONOUS; } |
78 | virtual void parallel_for(int n, |
79 | const std::function<void(int, int)>& fn) override { |
80 | // Should never happen (handled by DNNL) |
81 | if (n == 0) return; |
82 | |
83 | // Should never happen (handled by DNNL) |
84 | if (n == 1) { |
85 | fn(0, 1); |
86 | return; |
87 | } |
88 | |
89 | int nthr = get_num_threads(); |
90 | int njobs = std::min(n, nthr); |
91 | bool balance = (nthr < n); |
92 | for (int i = 0; i < njobs; i++) { |
93 | eigen_interface_->ScheduleWithHint( |
94 | [balance, i, n, njobs, fn]() { |
95 | if (balance) { |
96 | int start, end; |
97 | balance211(n, njobs, i, &start, &end); |
98 | for (int j = start; j < end; j++) fn(j, n); |
99 | } else { |
100 | fn(i, n); |
101 | } |
102 | }, |
103 | i, i + 1); |
104 | } |
105 | } |
106 | ~MklDnnThreadPool() {} |
107 | |
108 | private: |
109 | Eigen::ThreadPoolInterface* eigen_interface_ = nullptr; |
110 | int num_threads_ = 1; // Execute in caller thread. |
111 | }; |
112 | |
113 | #else |
114 | |
115 | // This struct was just added to enable successful OMP-based build. |
116 | struct MklDnnThreadPool { |
117 | MklDnnThreadPool() = default; |
118 | MklDnnThreadPool(OpKernelContext* ctx) {} |
119 | MklDnnThreadPool(OpKernelContext* ctx, int num_threads) {} |
120 | }; |
121 | |
122 | #endif // !ENABLE_ONEDNN_OPENMP |
123 | |
124 | } // namespace tensorflow |
125 | |
126 | #endif // INTEL_MKL |
127 | #endif // TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_ |
128 | |