1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_DEBUG_DEBUG_CALLBACK_REGISTRY_H_ |
17 | #define TENSORFLOW_CORE_DEBUG_DEBUG_CALLBACK_REGISTRY_H_ |
18 | |
19 | #include <functional> |
20 | #include <map> |
21 | #include <string> |
22 | #include <vector> |
23 | |
24 | #include "tensorflow/core/debug/debug_node_key.h" |
25 | #include "tensorflow/core/framework/tensor.h" |
26 | #include "tensorflow/core/platform/mutex.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | // Supports exporting observed debug events to clients using registered |
31 | // callbacks. Users can register a callback for each debug_url stored using |
32 | // DebugTensorWatch. The callback key be equivalent to what follows |
33 | // "memcbk:///". |
34 | // |
35 | // All events generated for a watched node will be sent to the call back in the |
36 | // order that they are observed. |
37 | // |
38 | // This callback router should not be used in production or training steps. It |
39 | // is optimized for deep inspection of graph state rather than performance. |
40 | class DebugCallbackRegistry { |
41 | public: |
42 | using EventCallback = std::function<void(const DebugNodeKey&, const Tensor&)>; |
43 | |
44 | // Provides singleton access to the in memory event store. |
45 | static DebugCallbackRegistry* singleton(); |
46 | |
47 | // Returns the registered callback, or nullptr, for key. |
48 | EventCallback* GetCallback(const string& key); |
49 | |
50 | // Associates callback with key. This must be called by clients observing |
51 | // nodes to be exported by this callback router before running a session. |
52 | void RegisterCallback(const string& key, EventCallback callback); |
53 | |
54 | // Removes the callback associated with key. |
55 | void UnregisterCallback(const string& key); |
56 | |
57 | private: |
58 | DebugCallbackRegistry(); |
59 | |
60 | // Mutex to ensure that keyed events are never updated in parallel. |
61 | mutex mu_; |
62 | |
63 | // Maps debug_url keys to callbacks for routing observed tensors. |
64 | std::map<string, EventCallback> keyed_callback_ TF_GUARDED_BY(mu_); |
65 | |
66 | static DebugCallbackRegistry* instance_; |
67 | }; |
68 | |
69 | } // namespace tensorflow |
70 | |
71 | #endif // TENSORFLOW_CORE_DEBUG_DEBUG_CALLBACK_REGISTRY_H_ |
72 | |