1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_TYPED_QUEUE_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_TYPED_QUEUE_H_ |
18 | |
19 | #include <deque> |
20 | #include <queue> |
21 | #include <vector> |
22 | |
23 | #include "tensorflow/core/framework/op_kernel.h" |
24 | #include "tensorflow/core/kernels/queue_base.h" |
25 | #include "tensorflow/core/platform/mutex.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | // TypedQueue builds on QueueBase, with backing class (SubQueue) |
30 | // known and stored within. Shared methods that need to have access |
31 | // to the backed data sit in this class. |
32 | template <typename SubQueue> |
33 | class TypedQueue : public QueueBase { |
34 | public: |
35 | TypedQueue(const int32_t capacity, const DataTypeVector& component_dtypes, |
36 | const std::vector<TensorShape>& component_shapes, |
37 | const string& name); |
38 | |
39 | virtual Status Initialize(); // Must be called before any other method. |
40 | |
41 | int64_t MemoryUsed() const override; |
42 | |
43 | protected: |
44 | std::vector<SubQueue> queues_ TF_GUARDED_BY(mu_); |
45 | }; // class TypedQueue |
46 | |
47 | template <typename SubQueue> |
48 | TypedQueue<SubQueue>::TypedQueue( |
49 | int32_t capacity, const DataTypeVector& component_dtypes, |
50 | const std::vector<TensorShape>& component_shapes, const string& name) |
51 | : QueueBase(capacity, component_dtypes, component_shapes, name) {} |
52 | |
53 | template <typename SubQueue> |
54 | Status TypedQueue<SubQueue>::Initialize() { |
55 | if (component_dtypes_.empty()) { |
56 | return errors::InvalidArgument("Empty component types for queue " , name_); |
57 | } |
58 | if (!component_shapes_.empty() && |
59 | component_dtypes_.size() != component_shapes_.size()) { |
60 | return errors::InvalidArgument( |
61 | "Different number of component types. " , |
62 | "Types: " , DataTypeSliceString(component_dtypes_), |
63 | ", Shapes: " , ShapeListString(component_shapes_)); |
64 | } |
65 | |
66 | mutex_lock lock(mu_); |
67 | queues_.reserve(num_components()); |
68 | for (int i = 0; i < num_components(); ++i) { |
69 | queues_.push_back(SubQueue()); |
70 | } |
71 | return OkStatus(); |
72 | } |
73 | |
74 | template <typename SubQueue> |
75 | inline int64_t SizeOf(const SubQueue& sq) { |
76 | static_assert(sizeof(SubQueue) != sizeof(SubQueue), "SubQueue size unknown." ); |
77 | return 0; |
78 | } |
79 | |
80 | template <> |
81 | inline int64_t SizeOf(const std::deque<Tensor>& sq) { |
82 | if (sq.empty()) { |
83 | return 0; |
84 | } |
85 | return sq.size() * sq.front().AllocatedBytes(); |
86 | } |
87 | |
88 | template <> |
89 | inline int64_t SizeOf(const std::vector<Tensor>& sq) { |
90 | if (sq.empty()) { |
91 | return 0; |
92 | } |
93 | return sq.size() * sq.front().AllocatedBytes(); |
94 | } |
95 | |
96 | using TensorPair = std::pair<int64_t, Tensor>; |
97 | |
98 | template <typename U, typename V> |
99 | int64_t SizeOf(const std::priority_queue<TensorPair, U, V>& sq) { |
100 | if (sq.empty()) { |
101 | return 0; |
102 | } |
103 | return sq.size() * (sizeof(TensorPair) + sq.top().second.AllocatedBytes()); |
104 | } |
105 | |
106 | template <typename SubQueue> |
107 | inline int64_t TypedQueue<SubQueue>::MemoryUsed() const { |
108 | int memory_size = 0; |
109 | mutex_lock l(mu_); |
110 | for (const auto& sq : queues_) { |
111 | memory_size += SizeOf(sq); |
112 | } |
113 | return memory_size; |
114 | } |
115 | |
116 | } // namespace tensorflow |
117 | |
118 | #endif // TENSORFLOW_CORE_KERNELS_TYPED_QUEUE_H_ |
119 | |