1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/data/split_utils.h" |
16 | |
17 | #include <functional> |
18 | #include <string> |
19 | #include <utility> |
20 | |
21 | #include "tensorflow/core/platform/errors.h" |
22 | |
23 | namespace tensorflow { |
24 | namespace data { |
25 | namespace { |
26 | constexpr char kNumToSkip[] = "num_to_skip" ; |
27 | constexpr char kSplitProvider[] = "split_provider" ; |
28 | constexpr char kSlash[] = "/" ; |
29 | constexpr char kIndex[] = "index" ; |
30 | } // namespace |
31 | |
32 | IndexSplitProvider::IndexSplitProvider(int64_t n) : i_(0), n_(n) {} |
33 | |
34 | Status IndexSplitProvider::GetNext(Tensor* split, bool* end_of_splits) { |
35 | mutex_lock l(mu_); |
36 | if (i_ >= n_) { |
37 | *end_of_splits = true; |
38 | return OkStatus(); |
39 | } |
40 | *end_of_splits = false; |
41 | *split = Tensor(DT_INT64, TensorShape{}); |
42 | split->scalar<int64_t>()() = i_++; |
43 | return OkStatus(); |
44 | } |
45 | |
46 | Status IndexSplitProvider::Reset() { |
47 | mutex_lock l(mu_); |
48 | i_ = 0; |
49 | return OkStatus(); |
50 | } |
51 | |
52 | Status IndexSplitProvider::Save( |
53 | std::function<std::string(std::string)> full_name, |
54 | IteratorStateWriter* writer) { |
55 | mutex_lock l(mu_); |
56 | return writer->WriteScalar(full_name(kIndex), i_); |
57 | } |
58 | |
59 | Status IndexSplitProvider::Restore( |
60 | std::function<std::string(std::string)> full_name, |
61 | IteratorStateReader* reader) { |
62 | mutex_lock l(mu_); |
63 | return reader->ReadScalar(full_name(kIndex), &i_); |
64 | } |
65 | |
66 | ShardingSplitProvider::ShardingSplitProvider( |
67 | int64_t num_shards, int64_t shard_index, |
68 | std::shared_ptr<SplitProvider> split_provider) |
69 | : num_shards_(num_shards), |
70 | shard_index_(shard_index), |
71 | split_provider_(split_provider), |
72 | num_to_skip_(shard_index_) {} |
73 | |
74 | Status ShardingSplitProvider::GetNext(Tensor* split, bool* end_of_splits) { |
75 | mutex_lock l(mu_); |
76 | while (num_to_skip_ > 0) { |
77 | TF_RETURN_IF_ERROR(split_provider_->GetNext(split, end_of_splits)); |
78 | if (*end_of_splits) { |
79 | return OkStatus(); |
80 | } |
81 | num_to_skip_--; |
82 | } |
83 | num_to_skip_ = num_shards_ - 1; |
84 | TF_RETURN_IF_ERROR(split_provider_->GetNext(split, end_of_splits)); |
85 | return OkStatus(); |
86 | } |
87 | |
88 | Status ShardingSplitProvider::Reset() { |
89 | mutex_lock l(mu_); |
90 | TF_RETURN_IF_ERROR(split_provider_->Reset()); |
91 | num_to_skip_ = shard_index_; |
92 | return OkStatus(); |
93 | } |
94 | |
95 | Status ShardingSplitProvider::Save( |
96 | std::function<std::string(std::string)> full_name, |
97 | IteratorStateWriter* writer) { |
98 | mutex_lock l(mu_); |
99 | TF_RETURN_IF_ERROR(split_provider_->Save( |
100 | [&](const std::string& key) { |
101 | return full_name(absl::StrCat(kSplitProvider, kSlash, key)); |
102 | }, |
103 | writer)); |
104 | return writer->WriteScalar(full_name(kNumToSkip), num_to_skip_); |
105 | } |
106 | |
107 | Status ShardingSplitProvider::Restore( |
108 | std::function<std::string(std::string)> full_name, |
109 | IteratorStateReader* reader) { |
110 | mutex_lock l(mu_); |
111 | TF_RETURN_IF_ERROR(split_provider_->Restore( |
112 | [&](const std::string& key) { |
113 | return full_name(absl::StrCat(kSplitProvider, kSlash, key)); |
114 | }, |
115 | reader)); |
116 | TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumToSkip), &num_to_skip_)); |
117 | return OkStatus(); |
118 | } |
119 | |
120 | StatusOr<std::shared_ptr<SplitProvider>> GetSingleSplitProvider( |
121 | IteratorContext* ctx, const DatasetBase* dataset) { |
122 | if (ctx->split_providers().size() != 1) { |
123 | return errors::FailedPrecondition( |
124 | "Failed to get single split provider for dataset " , |
125 | dataset->DebugString(), ". Found " , ctx->split_providers().size(), |
126 | " split providers" ); |
127 | } |
128 | return ctx->split_providers()[0]; |
129 | } |
130 | |
131 | StatusOr<std::vector<std::unique_ptr<SplitProvider>>> GetSplitProviders( |
132 | const DatasetBase* dataset) { |
133 | std::vector<std::unique_ptr<SplitProvider>> result; |
134 | std::vector<const DatasetBase*> inputs; |
135 | TF_RETURN_IF_ERROR(dataset->InputDatasets(&inputs)); |
136 | for (const auto& input : inputs) { |
137 | std::vector<std::unique_ptr<SplitProvider>> providers; |
138 | TF_RETURN_IF_ERROR(input->MakeSplitProviders(&providers)); |
139 | for (auto& provider : providers) { |
140 | result.push_back(std::move(provider)); |
141 | } |
142 | } |
143 | return result; |
144 | } |
145 | |
146 | StatusOr<std::vector<IteratorContext>> CreateInputIteratorContexts( |
147 | IteratorContext* ctx, const DatasetBase* dataset) { |
148 | std::vector<const DatasetBase*> inputs; |
149 | TF_RETURN_IF_ERROR(dataset->InputDatasets(&inputs)); |
150 | std::vector<IteratorContext> result; |
151 | if (ctx->split_providers().empty()) { |
152 | for (int i = 0; i < inputs.size(); ++i) { |
153 | result.emplace_back(ctx); |
154 | } |
155 | return result; |
156 | } |
157 | int64_t num_sources = 0; |
158 | for (size_t i = 0; i < inputs.size(); ++i) { |
159 | if (inputs[i]->num_sources() < 0) { |
160 | return errors::FailedPrecondition( |
161 | "Failed to determine the number of sources for dataset of type " , |
162 | inputs[i]->type_string()); |
163 | } |
164 | num_sources += inputs[i]->num_sources(); |
165 | } |
166 | if (num_sources != ctx->split_providers().size()) { |
167 | return errors::FailedPrecondition( |
168 | "Attempted to feed " , ctx->split_providers().size(), |
169 | " split providers into a dataset with " , num_sources, " sources" ); |
170 | } |
171 | int64_t split_provider_index = 0; |
172 | for (size_t i = 0; i < inputs.size(); ++i) { |
173 | IteratorContext::Params params(ctx); |
174 | params.split_providers.clear(); |
175 | for (int j = 0; j < inputs[i]->num_sources(); ++j) { |
176 | params.split_providers.push_back( |
177 | ctx->split_providers()[split_provider_index + j]); |
178 | } |
179 | split_provider_index += inputs[i]->num_sources(); |
180 | result.emplace_back(std::move(params)); |
181 | } |
182 | return result; |
183 | } |
184 | |
185 | } // namespace data |
186 | } // namespace tensorflow |
187 | |