1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/data/split_utils.h"
16
17#include <functional>
18#include <string>
19#include <utility>
20
21#include "tensorflow/core/platform/errors.h"
22
23namespace tensorflow {
24namespace data {
25namespace {
26constexpr char kNumToSkip[] = "num_to_skip";
27constexpr char kSplitProvider[] = "split_provider";
28constexpr char kSlash[] = "/";
29constexpr char kIndex[] = "index";
30} // namespace
31
32IndexSplitProvider::IndexSplitProvider(int64_t n) : i_(0), n_(n) {}
33
34Status IndexSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
35 mutex_lock l(mu_);
36 if (i_ >= n_) {
37 *end_of_splits = true;
38 return OkStatus();
39 }
40 *end_of_splits = false;
41 *split = Tensor(DT_INT64, TensorShape{});
42 split->scalar<int64_t>()() = i_++;
43 return OkStatus();
44}
45
46Status IndexSplitProvider::Reset() {
47 mutex_lock l(mu_);
48 i_ = 0;
49 return OkStatus();
50}
51
52Status IndexSplitProvider::Save(
53 std::function<std::string(std::string)> full_name,
54 IteratorStateWriter* writer) {
55 mutex_lock l(mu_);
56 return writer->WriteScalar(full_name(kIndex), i_);
57}
58
59Status IndexSplitProvider::Restore(
60 std::function<std::string(std::string)> full_name,
61 IteratorStateReader* reader) {
62 mutex_lock l(mu_);
63 return reader->ReadScalar(full_name(kIndex), &i_);
64}
65
66ShardingSplitProvider::ShardingSplitProvider(
67 int64_t num_shards, int64_t shard_index,
68 std::shared_ptr<SplitProvider> split_provider)
69 : num_shards_(num_shards),
70 shard_index_(shard_index),
71 split_provider_(split_provider),
72 num_to_skip_(shard_index_) {}
73
74Status ShardingSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
75 mutex_lock l(mu_);
76 while (num_to_skip_ > 0) {
77 TF_RETURN_IF_ERROR(split_provider_->GetNext(split, end_of_splits));
78 if (*end_of_splits) {
79 return OkStatus();
80 }
81 num_to_skip_--;
82 }
83 num_to_skip_ = num_shards_ - 1;
84 TF_RETURN_IF_ERROR(split_provider_->GetNext(split, end_of_splits));
85 return OkStatus();
86}
87
88Status ShardingSplitProvider::Reset() {
89 mutex_lock l(mu_);
90 TF_RETURN_IF_ERROR(split_provider_->Reset());
91 num_to_skip_ = shard_index_;
92 return OkStatus();
93}
94
95Status ShardingSplitProvider::Save(
96 std::function<std::string(std::string)> full_name,
97 IteratorStateWriter* writer) {
98 mutex_lock l(mu_);
99 TF_RETURN_IF_ERROR(split_provider_->Save(
100 [&](const std::string& key) {
101 return full_name(absl::StrCat(kSplitProvider, kSlash, key));
102 },
103 writer));
104 return writer->WriteScalar(full_name(kNumToSkip), num_to_skip_);
105}
106
107Status ShardingSplitProvider::Restore(
108 std::function<std::string(std::string)> full_name,
109 IteratorStateReader* reader) {
110 mutex_lock l(mu_);
111 TF_RETURN_IF_ERROR(split_provider_->Restore(
112 [&](const std::string& key) {
113 return full_name(absl::StrCat(kSplitProvider, kSlash, key));
114 },
115 reader));
116 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumToSkip), &num_to_skip_));
117 return OkStatus();
118}
119
120StatusOr<std::shared_ptr<SplitProvider>> GetSingleSplitProvider(
121 IteratorContext* ctx, const DatasetBase* dataset) {
122 if (ctx->split_providers().size() != 1) {
123 return errors::FailedPrecondition(
124 "Failed to get single split provider for dataset ",
125 dataset->DebugString(), ". Found ", ctx->split_providers().size(),
126 " split providers");
127 }
128 return ctx->split_providers()[0];
129}
130
131StatusOr<std::vector<std::unique_ptr<SplitProvider>>> GetSplitProviders(
132 const DatasetBase* dataset) {
133 std::vector<std::unique_ptr<SplitProvider>> result;
134 std::vector<const DatasetBase*> inputs;
135 TF_RETURN_IF_ERROR(dataset->InputDatasets(&inputs));
136 for (const auto& input : inputs) {
137 std::vector<std::unique_ptr<SplitProvider>> providers;
138 TF_RETURN_IF_ERROR(input->MakeSplitProviders(&providers));
139 for (auto& provider : providers) {
140 result.push_back(std::move(provider));
141 }
142 }
143 return result;
144}
145
146StatusOr<std::vector<IteratorContext>> CreateInputIteratorContexts(
147 IteratorContext* ctx, const DatasetBase* dataset) {
148 std::vector<const DatasetBase*> inputs;
149 TF_RETURN_IF_ERROR(dataset->InputDatasets(&inputs));
150 std::vector<IteratorContext> result;
151 if (ctx->split_providers().empty()) {
152 for (int i = 0; i < inputs.size(); ++i) {
153 result.emplace_back(ctx);
154 }
155 return result;
156 }
157 int64_t num_sources = 0;
158 for (size_t i = 0; i < inputs.size(); ++i) {
159 if (inputs[i]->num_sources() < 0) {
160 return errors::FailedPrecondition(
161 "Failed to determine the number of sources for dataset of type ",
162 inputs[i]->type_string());
163 }
164 num_sources += inputs[i]->num_sources();
165 }
166 if (num_sources != ctx->split_providers().size()) {
167 return errors::FailedPrecondition(
168 "Attempted to feed ", ctx->split_providers().size(),
169 " split providers into a dataset with ", num_sources, " sources");
170 }
171 int64_t split_provider_index = 0;
172 for (size_t i = 0; i < inputs.size(); ++i) {
173 IteratorContext::Params params(ctx);
174 params.split_providers.clear();
175 for (int j = 0; j < inputs[i]->num_sources(); ++j) {
176 params.split_providers.push_back(
177 ctx->split_providers()[split_provider_index + j]);
178 }
179 split_provider_index += inputs[i]->num_sources();
180 result.emplace_back(std::move(params));
181 }
182 return result;
183}
184
185} // namespace data
186} // namespace tensorflow
187