1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/single_threaded_cpu_device.h" |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
21 | #include "tensorflow/core/common_runtime/device.h" |
22 | #include "tensorflow/core/framework/allocator.h" |
23 | #include "tensorflow/core/framework/tensor.pb.h" |
24 | #include "tensorflow/core/framework/tensor_util.h" |
25 | #include "tensorflow/core/lib/core/threadpool.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | namespace { |
30 | |
31 | static constexpr int kNumThreads = 1; |
32 | |
33 | thread::ThreadPool* GraphRunnerThreadPool() { |
34 | static thread::ThreadPool* thread_pool = |
35 | new thread::ThreadPool(Env::Default(), "graph_runner" , kNumThreads); |
36 | return thread_pool; |
37 | } |
38 | |
39 | // A simple single-threaded CPU device. This can be used to run inexpensive |
40 | // computations. In particular, using this avoids initializing the global thread |
41 | // pools in LocalDevice. |
42 | class SingleThreadedCpuDevice : public Device { |
43 | public: |
44 | explicit SingleThreadedCpuDevice(Env* env) |
45 | : Device(env, Device::BuildDeviceAttributes("/device:CPU:0" , DEVICE_CPU, |
46 | Bytes(256 << 20), |
47 | DeviceLocality())) { |
48 | eigen_worker_threads_.num_threads = kNumThreads; |
49 | eigen_worker_threads_.workers = GraphRunnerThreadPool(); |
50 | eigen_device_.reset(new Eigen::ThreadPoolDevice( |
51 | eigen_worker_threads_.workers->AsEigenThreadPool(), |
52 | eigen_worker_threads_.num_threads)); |
53 | set_tensorflow_cpu_worker_threads(&eigen_worker_threads_); |
54 | set_eigen_cpu_device(eigen_device_.get()); |
55 | } |
56 | |
57 | ~SingleThreadedCpuDevice() override { eigen_device_.reset(); } |
58 | |
59 | Status Sync() override { return OkStatus(); } |
60 | |
61 | Status MakeTensorFromProto(const TensorProto& tensor_proto, |
62 | const AllocatorAttributes alloc_attrs, |
63 | Tensor* tensor) override { |
64 | Tensor parsed(tensor_proto.dtype()); |
65 | if (!parsed.FromProto(cpu_allocator(), tensor_proto)) { |
66 | return errors::InvalidArgument("Cannot parse tensor from tensor_proto." ); |
67 | } |
68 | *tensor = parsed; |
69 | return OkStatus(); |
70 | } |
71 | |
72 | void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor, |
73 | const DeviceContext*, |
74 | StatusCallback done) override { |
75 | if (input_tensor->NumElements() != output_tensor->NumElements()) { |
76 | done(errors::Internal( |
77 | "SingleThreadedCPU->SingleThreadedCPU copy shape mismatch: input=" , |
78 | input_tensor->shape(), ", output=" , output_tensor->shape())); |
79 | return; |
80 | } |
81 | tensor::DeepCopy(*input_tensor, output_tensor); |
82 | done(OkStatus()); |
83 | } |
84 | |
85 | Allocator* GetAllocator(AllocatorAttributes attr) override { |
86 | return cpu_allocator(); |
87 | } |
88 | |
89 | private: |
90 | DeviceBase::CpuWorkerThreads eigen_worker_threads_; |
91 | std::unique_ptr<Eigen::ThreadPoolDevice> eigen_device_; |
92 | }; |
93 | |
94 | } // namespace |
95 | |
96 | Device* NewSingleThreadedCpuDevice(Env* env) { |
97 | return new SingleThreadedCpuDevice(env); |
98 | } |
99 | |
100 | } // namespace tensorflow |
101 | |