1
2/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15==============================================================================*/
16
17#ifndef TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_
18#define TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_
19#ifdef INTEL_MKL
20
21#include <list>
22#include <memory>
23#include <string>
24#include <unordered_map>
25#include <utility>
26#include <vector>
27
28#include "dnnl_threadpool.hpp"
29#include "dnnl.hpp"
30#include "tensorflow/core/framework/op_kernel.h"
31#include "tensorflow/core/platform/threadpool.h"
32#define EIGEN_USE_THREADS
33
34namespace tensorflow {
35
36#ifndef ENABLE_ONEDNN_OPENMP
37using dnnl::threadpool_interop::threadpool_iface;
38
39// Divide 'n' units of work equally among 'teams' threads. If 'n' is not
40// divisible by 'teams' and has a remainder 'r', the first 'r' teams have one
41// unit of work more than the rest. Returns the range of work that belongs to
42// the team 'tid'.
43// Parameters
44// n Total number of jobs.
45// team Number of workers.
46// tid Current thread_id.
47// n_start start of range operated by the thread.
48// n_end end of the range operated by the thread.
49
50template <typename T, typename U>
51inline void balance211(T n, U team, U tid, T* n_start, T* n_end) {
52 if (team <= 1 || n == 0) {
53 *n_start = 0;
54 *n_end = n;
55 return;
56 }
57 T min_per_team = n / team;
58 T remainder = n - min_per_team * team; // i.e., n % teams.
59 *n_start = tid * min_per_team + std::min(tid, remainder);
60 *n_end = *n_start + min_per_team + (tid < remainder);
61}
62
63struct MklDnnThreadPool : public threadpool_iface {
64 MklDnnThreadPool() = default;
65
66 MklDnnThreadPool(OpKernelContext* ctx, int num_threads = -1) {
67 eigen_interface_ = ctx->device()
68 ->tensorflow_cpu_worker_threads()
69 ->workers->AsEigenThreadPool();
70 num_threads_ =
71 (num_threads == -1) ? eigen_interface_->NumThreads() : num_threads;
72 }
73 virtual int get_num_threads() const override { return num_threads_; }
74 virtual bool get_in_parallel() const override {
75 return (eigen_interface_->CurrentThreadId() != -1) ? true : false;
76 }
77 virtual uint64_t get_flags() const override { return ASYNCHRONOUS; }
78 virtual void parallel_for(int n,
79 const std::function<void(int, int)>& fn) override {
80 // Should never happen (handled by DNNL)
81 if (n == 0) return;
82
83 // Should never happen (handled by DNNL)
84 if (n == 1) {
85 fn(0, 1);
86 return;
87 }
88
89 int nthr = get_num_threads();
90 int njobs = std::min(n, nthr);
91 bool balance = (nthr < n);
92 for (int i = 0; i < njobs; i++) {
93 eigen_interface_->ScheduleWithHint(
94 [balance, i, n, njobs, fn]() {
95 if (balance) {
96 int start, end;
97 balance211(n, njobs, i, &start, &end);
98 for (int j = start; j < end; j++) fn(j, n);
99 } else {
100 fn(i, n);
101 }
102 },
103 i, i + 1);
104 }
105 }
106 ~MklDnnThreadPool() {}
107
108 private:
109 Eigen::ThreadPoolInterface* eigen_interface_ = nullptr;
110 int num_threads_ = 1; // Execute in caller thread.
111};
112
113#else
114
115// This struct was just added to enable successful OMP-based build.
116struct MklDnnThreadPool {
117 MklDnnThreadPool() = default;
118 MklDnnThreadPool(OpKernelContext* ctx) {}
119 MklDnnThreadPool(OpKernelContext* ctx, int num_threads) {}
120};
121
122#endif // !ENABLE_ONEDNN_OPENMP
123
124} // namespace tensorflow
125
126#endif // INTEL_MKL
127#endif // TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_
128