1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// The utility to read checkpoints for google brain tensor ops and v3
17// checkpoints for dist_belief.
18
19#ifndef TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_CACHE_H_
20#define TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_CACHE_H_
21
22#include <unordered_map>
23
24#include "tensorflow/core/lib/core/status.h"
25#include "tensorflow/core/platform/mutex.h"
26#include "tensorflow/core/platform/types.h"
27#include "tensorflow/core/util/tensor_slice_reader.h"
28
29namespace tensorflow {
30
31namespace checkpoint {
32
33class TensorSliceReaderCache;
34
35// Wrapper to a lazily allocated TensorSliceReaderCache.
36class TensorSliceReaderCacheWrapper {
37 public:
38 TensorSliceReaderCacheWrapper();
39 ~TensorSliceReaderCacheWrapper();
40
41 // Same as TensorSliceReaderCache::GetReader().
42 const TensorSliceReader* GetReader(
43 const string& filepattern,
44 TensorSliceReader::OpenTableFunction open_function,
45 int preferred_shard) const;
46
47 private:
48 mutable mutex mu_;
49 mutable TensorSliceReaderCache* cache_ = nullptr;
50};
51
52// A cache of TensorSliceReaders.
53class TensorSliceReaderCache {
54 public:
55 TensorSliceReaderCache();
56 ~TensorSliceReaderCache();
57
58 // Returns the TensorSliceReader corresponding to 'filepattern' and the
59 // open_function. May return nullptr if we can not create a new
60 // TensorSliceReader for the filepattern/open_function combination.
61 const TensorSliceReader* GetReader(
62 const string& filepattern,
63 TensorSliceReader::OpenTableFunction open_function, int preferred_shard);
64
65 private:
66 // Need to use a regular function type in the key map as std::function does
67 // not support ==.
68 typedef Status (*OpenFuncType)(const string&, TensorSliceReader::Table**);
69
70 // Protects attributes below.
71 mutex mu_;
72
73 // Maps of opened readers.
74 std::unordered_map<string, std::pair<OpenFuncType, TensorSliceReader*>>
75 readers_;
76
77 // Set of keys that a previous GetReader() call is still trying to populate.
78 std::set<string> still_opening_;
79
80 // Condition variable to notify when a reader has been created.
81 condition_variable cv_;
82};
83
84} // namespace checkpoint
85
86} // namespace tensorflow
87
88#endif // TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_CACHE_H_
89