1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_ALLOWLIST_H_
17#define TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_ALLOWLIST_H_
18
19#include <unordered_set>
20#include "tensorflow/core/lib/core/status.h"
21
22namespace tensorflow {
23namespace data {
24// Registry for stateful ops that need to be used in dataset functions.
25// See below macro for usage details.
26class AllowlistedStatefulOpRegistry {
27 public:
28 Status Add(string op_name) {
29 op_names_.insert(std::move(op_name));
30 return OkStatus();
31 }
32
33 Status Remove(string op_name) {
34 op_names_.erase(op_name);
35 return OkStatus();
36 }
37
38 bool Contains(const string& op_name) { return op_names_.count(op_name); }
39
40 static AllowlistedStatefulOpRegistry* Global() {
41 static auto* reg = new AllowlistedStatefulOpRegistry;
42 return reg;
43 }
44
45 private:
46 AllowlistedStatefulOpRegistry() = default;
47 AllowlistedStatefulOpRegistry(AllowlistedStatefulOpRegistry const& copy) =
48 delete;
49 AllowlistedStatefulOpRegistry operator=(
50 AllowlistedStatefulOpRegistry const& copy) = delete;
51
52 std::unordered_set<string> op_names_;
53};
54
55} // namespace data
56
57// Use this macro to allowlist an op that is marked stateful but needs to be
58// used inside a map_fn in an input pipeline. This is only needed if you wish
59// to be able to checkpoint the state of the input pipeline. We currently
60// do not allow stateful ops to be defined inside of map_fns since it is not
61// possible to save their state.
62// Note that the state of the allowlisted ops inside functions will not be
63// saved during checkpointing, hence this should only be used if the op is
64// marked stateful for reasons like to avoid constant folding during graph
65// optimization but is not stateful.
66// If possible, try to remove the stateful flag on the op first.
67// Example usage:
68//
69// ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS("LegacyStatefulReader");
70//
71#define ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS(name) \
72 ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(__COUNTER__, name)
73#define ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(ctr, name) \
74 ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name)
75#define ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \
76 static ::tensorflow::Status allowlist_op##ctr TF_ATTRIBUTE_UNUSED = \
77 ::tensorflow::data::AllowlistedStatefulOpRegistry::Global()->Add(name)
78
79} // namespace tensorflow
80
81#endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_ALLOWLIST_H_
82