1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_ |
17 | #define TENSORFLOW_CORE_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_ |
18 | |
19 | #include <memory> |
20 | |
21 | #include "tensorflow/core/common_runtime/device.h" |
22 | #include "tensorflow/core/graph/graph.h" |
23 | #include "tensorflow/core/lib/core/status.h" |
24 | #include "tensorflow/core/platform/macros.h" |
25 | #include "tensorflow/core/platform/protobuf.h" |
26 | #include "tensorflow/core/protobuf/debug.pb.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | // Returns a summary string for the list of debug tensor watches. |
31 | const string SummarizeDebugTensorWatches( |
32 | const protobuf::RepeatedPtrField<DebugTensorWatch>& watches); |
33 | |
34 | // An abstract interface for storing and retrieving debugging information. |
35 | class DebuggerStateInterface { |
36 | public: |
37 | virtual ~DebuggerStateInterface() {} |
38 | |
39 | // Publish metadata about the debugged Session::Run() call. |
40 | // |
41 | // Args: |
42 | // global_step: A global step count supplied by the caller of |
43 | // Session::Run(). |
44 | // session_run_index: A chronologically sorted index for calls to the Run() |
45 | // method of the Session object. |
46 | // executor_step_index: A chronologically sorted index of invocations of the |
47 | // executor charged to serve this Session::Run() call. |
48 | // input_names: Name of the input Tensors (feed keys). |
49 | // output_names: Names of the fetched Tensors. |
50 | // target_names: Names of the target nodes. |
51 | virtual Status PublishDebugMetadata( |
52 | const int64_t global_step, const int64_t session_run_index, |
53 | const int64_t executor_step_index, const std::vector<string>& input_names, |
54 | const std::vector<string>& output_names, |
55 | const std::vector<string>& target_nodes) = 0; |
56 | }; |
57 | |
58 | class DebugGraphDecoratorInterface { |
59 | public: |
60 | virtual ~DebugGraphDecoratorInterface() {} |
61 | |
62 | // Insert special-purpose debug nodes to graph and dump the graph for |
63 | // record. See the documentation of DebugNodeInserter::InsertNodes() for |
64 | // details. |
65 | virtual Status DecorateGraph(Graph* graph, Device* device) = 0; |
66 | |
67 | // Publish Graph to debug URLs. |
68 | virtual Status PublishGraph(const Graph& graph, |
69 | const string& device_name) = 0; |
70 | }; |
71 | |
72 | typedef std::function<std::unique_ptr<DebuggerStateInterface>( |
73 | const DebugOptions& options)> |
74 | DebuggerStateFactory; |
75 | |
76 | // Contains only static methods for registering DebuggerStateFactory. |
77 | // We don't expect to create any instances of this class. |
78 | // Call DebuggerStateRegistry::RegisterFactory() at initialization time to |
79 | // define a global factory that creates instances of DebuggerState, then call |
80 | // DebuggerStateRegistry::CreateState() to create a single instance. |
81 | class DebuggerStateRegistry { |
82 | public: |
83 | // Registers a function that creates a concrete DebuggerStateInterface |
84 | // implementation based on DebugOptions. |
85 | static void RegisterFactory(const DebuggerStateFactory& factory); |
86 | |
87 | // If RegisterFactory() has been called, creates and supplies a concrete |
88 | // DebuggerStateInterface implementation using the registered factory, |
89 | // owned by the caller and return an OK Status. Otherwise returns an error |
90 | // Status. |
91 | static Status CreateState(const DebugOptions& debug_options, |
92 | std::unique_ptr<DebuggerStateInterface>* state); |
93 | |
94 | private: |
95 | static DebuggerStateFactory* factory_; |
96 | |
97 | TF_DISALLOW_COPY_AND_ASSIGN(DebuggerStateRegistry); |
98 | }; |
99 | |
100 | typedef std::function<std::unique_ptr<DebugGraphDecoratorInterface>( |
101 | const DebugOptions& options)> |
102 | DebugGraphDecoratorFactory; |
103 | |
104 | class DebugGraphDecoratorRegistry { |
105 | public: |
106 | static void RegisterFactory(const DebugGraphDecoratorFactory& factory); |
107 | |
108 | static Status CreateDecorator( |
109 | const DebugOptions& options, |
110 | std::unique_ptr<DebugGraphDecoratorInterface>* decorator); |
111 | |
112 | private: |
113 | static DebugGraphDecoratorFactory* factory_; |
114 | |
115 | TF_DISALLOW_COPY_AND_ASSIGN(DebugGraphDecoratorRegistry); |
116 | }; |
117 | |
118 | } // end namespace tensorflow |
119 | |
120 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_ |
121 | |