1/* Copyright 2019 The TensorFlow Authors. Al Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_INPUT_COLOCATION_EXEMPTION_REGISTRY_H_
16#define TENSORFLOW_CORE_COMMON_RUNTIME_INPUT_COLOCATION_EXEMPTION_REGISTRY_H_
17
18#include <string>
19
20#include "tensorflow/core/lib/gtl/flatset.h"
21#include "tensorflow/core/platform/types.h"
22
23namespace tensorflow {
24
25// TensorFlow runtime (both eager and graph) will aim to colocate ops with
26// their resource inputs so that the ops can access the resource state. In some
27// cases, such as tf.data ops, this is not desirable as the ops themselves might
28// not have a kernel registered for the device on which the resource is placed
29// and instead use a mechanism, such as a multi-device function, to access the
30// resource state.
31//
32// This registry can be used to register and list ops that should be exempt from
33// the input colocation described above.
34//
35// Example usage:
36// REGISTER_INPUT_COLOCATION_EXEMPTION("MapDataset");
37class InputColocationExemptionRegistry {
38 public:
39 // Returns a pointer to a global InputColocationExemptionRegistry object.
40 static InputColocationExemptionRegistry* Global();
41
42 // Returns the set of ops exempt from the input colocation constraints.
43 const gtl::FlatSet<string>& Get() { return ops_; }
44
45 // Registers an op to be excluded from the input colocation constraints.
46 void Register(const string& op);
47
48 private:
49 gtl::FlatSet<string> ops_;
50};
51
52namespace input_colocation_exemption_registration {
53
54class InputColocationExemptionRegistration {
55 public:
56 explicit InputColocationExemptionRegistration(const string& op) {
57 InputColocationExemptionRegistry::Global()->Register(op);
58 }
59};
60
61} // namespace input_colocation_exemption_registration
62
63#define REGISTER_INPUT_COLOCATION_EXEMPTION(op) \
64 REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ_HELPER(__COUNTER__, op)
65
66#define REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ_HELPER(ctr, op) \
67 REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ(ctr, op)
68
69#define REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ(ctr, op) \
70 static input_colocation_exemption_registration:: \
71 InputColocationExemptionRegistration \
72 input_colocation_exemption_registration_fn_##ctr(op)
73
74} // namespace tensorflow
75
76#endif // TENSORFLOW_CORE_COMMON_RUNTIME_INPUT_COLOCATION_EXEMPTION_REGISTRY_H_
77