1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_TYPED_QUEUE_H_
17#define TENSORFLOW_CORE_KERNELS_TYPED_QUEUE_H_
18
19#include <deque>
20#include <queue>
21#include <vector>
22
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/kernels/queue_base.h"
25#include "tensorflow/core/platform/mutex.h"
26
27namespace tensorflow {
28
29// TypedQueue builds on QueueBase, with backing class (SubQueue)
30// known and stored within. Shared methods that need to have access
31// to the backed data sit in this class.
32template <typename SubQueue>
33class TypedQueue : public QueueBase {
34 public:
35 TypedQueue(const int32_t capacity, const DataTypeVector& component_dtypes,
36 const std::vector<TensorShape>& component_shapes,
37 const string& name);
38
39 virtual Status Initialize(); // Must be called before any other method.
40
41 int64_t MemoryUsed() const override;
42
43 protected:
44 std::vector<SubQueue> queues_ TF_GUARDED_BY(mu_);
45}; // class TypedQueue
46
47template <typename SubQueue>
48TypedQueue<SubQueue>::TypedQueue(
49 int32_t capacity, const DataTypeVector& component_dtypes,
50 const std::vector<TensorShape>& component_shapes, const string& name)
51 : QueueBase(capacity, component_dtypes, component_shapes, name) {}
52
53template <typename SubQueue>
54Status TypedQueue<SubQueue>::Initialize() {
55 if (component_dtypes_.empty()) {
56 return errors::InvalidArgument("Empty component types for queue ", name_);
57 }
58 if (!component_shapes_.empty() &&
59 component_dtypes_.size() != component_shapes_.size()) {
60 return errors::InvalidArgument(
61 "Different number of component types. ",
62 "Types: ", DataTypeSliceString(component_dtypes_),
63 ", Shapes: ", ShapeListString(component_shapes_));
64 }
65
66 mutex_lock lock(mu_);
67 queues_.reserve(num_components());
68 for (int i = 0; i < num_components(); ++i) {
69 queues_.push_back(SubQueue());
70 }
71 return OkStatus();
72}
73
74template <typename SubQueue>
75inline int64_t SizeOf(const SubQueue& sq) {
76 static_assert(sizeof(SubQueue) != sizeof(SubQueue), "SubQueue size unknown.");
77 return 0;
78}
79
80template <>
81inline int64_t SizeOf(const std::deque<Tensor>& sq) {
82 if (sq.empty()) {
83 return 0;
84 }
85 return sq.size() * sq.front().AllocatedBytes();
86}
87
88template <>
89inline int64_t SizeOf(const std::vector<Tensor>& sq) {
90 if (sq.empty()) {
91 return 0;
92 }
93 return sq.size() * sq.front().AllocatedBytes();
94}
95
96using TensorPair = std::pair<int64_t, Tensor>;
97
98template <typename U, typename V>
99int64_t SizeOf(const std::priority_queue<TensorPair, U, V>& sq) {
100 if (sq.empty()) {
101 return 0;
102 }
103 return sq.size() * (sizeof(TensorPair) + sq.top().second.AllocatedBytes());
104}
105
106template <typename SubQueue>
107inline int64_t TypedQueue<SubQueue>::MemoryUsed() const {
108 int memory_size = 0;
109 mutex_lock l(mu_);
110 for (const auto& sq : queues_) {
111 memory_size += SizeOf(sq);
112 }
113 return memory_size;
114}
115
116} // namespace tensorflow
117
118#endif // TENSORFLOW_CORE_KERNELS_TYPED_QUEUE_H_
119