1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
17#define TENSORFLOW_CORE_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
18
19#include <memory>
20
21#include "tensorflow/core/common_runtime/device.h"
22#include "tensorflow/core/graph/graph.h"
23#include "tensorflow/core/lib/core/status.h"
24#include "tensorflow/core/platform/macros.h"
25#include "tensorflow/core/platform/protobuf.h"
26#include "tensorflow/core/protobuf/debug.pb.h"
27
28namespace tensorflow {
29
30// Returns a summary string for the list of debug tensor watches.
31const string SummarizeDebugTensorWatches(
32 const protobuf::RepeatedPtrField<DebugTensorWatch>& watches);
33
34// An abstract interface for storing and retrieving debugging information.
35class DebuggerStateInterface {
36 public:
37 virtual ~DebuggerStateInterface() {}
38
39 // Publish metadata about the debugged Session::Run() call.
40 //
41 // Args:
42 // global_step: A global step count supplied by the caller of
43 // Session::Run().
44 // session_run_index: A chronologically sorted index for calls to the Run()
45 // method of the Session object.
46 // executor_step_index: A chronologically sorted index of invocations of the
47 // executor charged to serve this Session::Run() call.
48 // input_names: Name of the input Tensors (feed keys).
49 // output_names: Names of the fetched Tensors.
50 // target_names: Names of the target nodes.
51 virtual Status PublishDebugMetadata(
52 const int64_t global_step, const int64_t session_run_index,
53 const int64_t executor_step_index, const std::vector<string>& input_names,
54 const std::vector<string>& output_names,
55 const std::vector<string>& target_nodes) = 0;
56};
57
58class DebugGraphDecoratorInterface {
59 public:
60 virtual ~DebugGraphDecoratorInterface() {}
61
62 // Insert special-purpose debug nodes to graph and dump the graph for
63 // record. See the documentation of DebugNodeInserter::InsertNodes() for
64 // details.
65 virtual Status DecorateGraph(Graph* graph, Device* device) = 0;
66
67 // Publish Graph to debug URLs.
68 virtual Status PublishGraph(const Graph& graph,
69 const string& device_name) = 0;
70};
71
72typedef std::function<std::unique_ptr<DebuggerStateInterface>(
73 const DebugOptions& options)>
74 DebuggerStateFactory;
75
76// Contains only static methods for registering DebuggerStateFactory.
77// We don't expect to create any instances of this class.
78// Call DebuggerStateRegistry::RegisterFactory() at initialization time to
79// define a global factory that creates instances of DebuggerState, then call
80// DebuggerStateRegistry::CreateState() to create a single instance.
81class DebuggerStateRegistry {
82 public:
83 // Registers a function that creates a concrete DebuggerStateInterface
84 // implementation based on DebugOptions.
85 static void RegisterFactory(const DebuggerStateFactory& factory);
86
87 // If RegisterFactory() has been called, creates and supplies a concrete
88 // DebuggerStateInterface implementation using the registered factory,
89 // owned by the caller and return an OK Status. Otherwise returns an error
90 // Status.
91 static Status CreateState(const DebugOptions& debug_options,
92 std::unique_ptr<DebuggerStateInterface>* state);
93
94 private:
95 static DebuggerStateFactory* factory_;
96
97 TF_DISALLOW_COPY_AND_ASSIGN(DebuggerStateRegistry);
98};
99
100typedef std::function<std::unique_ptr<DebugGraphDecoratorInterface>(
101 const DebugOptions& options)>
102 DebugGraphDecoratorFactory;
103
104class DebugGraphDecoratorRegistry {
105 public:
106 static void RegisterFactory(const DebugGraphDecoratorFactory& factory);
107
108 static Status CreateDecorator(
109 const DebugOptions& options,
110 std::unique_ptr<DebugGraphDecoratorInterface>* decorator);
111
112 private:
113 static DebugGraphDecoratorFactory* factory_;
114
115 TF_DISALLOW_COPY_AND_ASSIGN(DebugGraphDecoratorRegistry);
116};
117
118} // end namespace tensorflow
119
120#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
121