1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/node_file_writer.h"
17
18#include "absl/container/flat_hash_map.h"
19#include "absl/strings/str_replace.h"
20#include "tensorflow/core/framework/attr_value.pb.h"
21#include "tensorflow/core/platform/path.h"
22#include "tensorflow/core/platform/random.h"
23#include "tensorflow/core/util/equal_graph_def.h"
24
25namespace {
26
27// Avoiding writing to disk very commonly executed ops that are known to be
28// deterministic. This reduces the filesize.
29const absl::flat_hash_set<std::string>* const kOpsToSkipWriting =
30 new absl::flat_hash_set<std::string>{"Add",
31 "AddV2",
32 "BroadcastTo",
33 "Cast",
34 "ConcatV2",
35 "Const",
36 "_EagerConst",
37 "Enter",
38 "Exit",
39 "Fill",
40 "_HostSend",
41 "Identity",
42 "Less",
43 "MatrixDiagV3",
44 "Merge",
45 "Mul",
46 "NextIteration",
47 "Pack",
48 "RandomStandardNormal",
49 "RandomUniform",
50 "Range",
51 "RealDiv",
52 "Reshape",
53 "_Send",
54 "Shape",
55 "StridedSlice",
56 "Sub",
57 "Switch",
58 "Transpose",
59 "_XlaCompile"};
60
61// If a host int32 input has at most this many elements, the tensor value will
62// be written to the file.
63const int kMaxInt32Elems = 10;
64
65} // namespace
66
67namespace tensorflow {
68
69/*static*/ StatusOr<NodeFileWriter*>
70tensorflow::NodeFileWriter::GetNodeFileWriterIfEnabled(
71 const std::string& device_name, Env* env) {
72 // First get the directory from TF_NODE_FILE_WRITER_DIRECTORY.
73 static const std::string* const node_dir = [] {
74 std::string node_dir;
75 TF_CHECK_OK(
76 ReadStringFromEnvVar("TF_NODE_FILE_WRITER_DIRECTORY", "", &node_dir));
77 if (node_dir == "test_undeclared_outputs_dir") {
78 bool env_set = io::GetTestUndeclaredOutputsDir(&node_dir);
79 if (!env_set || node_dir.empty()) {
80 LOG(WARNING)
81 << "TF_NODE_FILE_WRITER_DIRECTORY was set to "
82 "'test_undeclared_outputs_dir', but the environmental "
83 "variable TEST_UNDECLARED_OUTPUTS_DIR does not exist or "
84 "is empty. NodeDef collection will be skipped.";
85 } else {
86 node_dir = io::JoinPath(node_dir, "node_defs");
87 }
88 }
89 return new std::string{node_dir};
90 }();
91 if (node_dir->empty()) {
92 return nullptr;
93 }
94
95 static mutex mu(LINKER_INITIALIZED);
96 // Cache a NodeFileWriter* for each device name, so that different Sessions
97 // each share the same NodeFileWriters. Sharing NodeFileWriters reduces the
98 // total size of the outputted files, since it means if multiple Sessions run
99 // the same op, the op is only written recorded to disk once.
100 static auto* device_name_to_writer =
101 new absl::flat_hash_map<std::string, NodeFileWriter*>{};
102 mutex_lock l(mu);
103 auto it = device_name_to_writer->find(device_name);
104 if (it == device_name_to_writer->end()) {
105 Status s = env->CreateDir(*node_dir);
106 if (!s.ok() && s.code() != error::ALREADY_EXISTS) {
107 return s;
108 }
109
110 // Put the device name in the filename for debugging purposes. Also append
111 // random number in case multiple processes write out nodes concurrently.
112 std::string filename = strings::StrCat(
113 "node_defs", absl::StrReplaceAll(device_name, {{"/", "_"}, {":", "_"}}),
114 "_", random::New64());
115
116 auto* writer = new NodeFileWriter{io::JoinPath(*node_dir, filename)};
117 s = writer->Init(env);
118 if (!s.ok()) {
119 delete writer;
120 return s;
121 }
122 it = device_name_to_writer->insert({device_name, writer}).first;
123 }
124 return it->second;
125}
126
127Status NodeFileWriter::RecordNodeExecution(OpKernel* op_kernel,
128 OpKernelContext* context) {
129 if (kOpsToSkipWriting->count(op_kernel->type_string())) {
130 return OkStatus();
131 }
132 NodeDef def;
133 def.set_name("NodeFileWriter");
134 def.set_op(op_kernel->def().op());
135 *def.mutable_attr() = op_kernel->def().attr();
136 // The input shapes/dtypes are stored in the 'attr' section of the NodeDef
137 AttrValue& input_shapes = (*def.mutable_attr())["_input_shapes"];
138 AttrValue& input_dtypes = (*def.mutable_attr())["_input_dtypes"];
139 for (int i = 0; i < context->num_inputs(); i++) {
140 if (!context->has_input(i) || context->input_is_ref(i)) {
141 // Calling context->input(i) requires the input to exist and not be a ref,
142 // so return immediately if that is not the case.
143 return OkStatus();
144 }
145 TensorShapeProto* shape_proto = input_shapes.mutable_list()->add_shape();
146 const Tensor& input = context->input(i);
147 input.shape().AsProto(shape_proto);
148 input_dtypes.mutable_list()->add_type(context->input_dtype(i));
149 // Store small int32 host inputs, as they often represent shapes.
150 if (input.NumElements() <= kMaxInt32Elems && input.dtype() == DT_INT32 &&
151 context->input_memory_type(i) == HOST_MEMORY) {
152 AttrValue& input_tensor =
153 (*def.mutable_attr())[strings::StrCat("_input_tensor_", i)];
154 input.AsProtoField(input_tensor.mutable_tensor());
155 } else if (!DataTypeIsFloating(input.dtype())) {
156 // Skip ops with non-floating-point inputs, since these are not useful
157 // when testing determinism.
158 return OkStatus();
159 }
160 }
161 return MaybeWriteNodeDefToFile(def);
162}
163
164Status NodeFileWriter::MaybeWriteNodeDefToFile(const NodeDef& def) {
165 std::string def_str = def.SerializeAsString();
166 uint64 size = def_str.size();
167 std::string size_as_str;
168 // The file consists of a series of records, each consisting of a 64-bit
169 // little endian integer representing the size of the serialized NodeDef,
170 // followed by the serialized NodeDef.
171 for (unsigned int i = 0; i < 8; i++) {
172 size_as_str.push_back((size >> (i * 8)) & 0xff);
173 }
174
175 EqualGraphDefOptions options;
176 options.ignore_internal_attrs = false;
177 uint64 hash = NodeDefHash(def, options);
178
179 mutex_lock l{mu_};
180 if (written_hashes_.count(hash) == 0) {
181 TF_RETURN_IF_ERROR(node_def_file_->Append(size_as_str));
182 TF_RETURN_IF_ERROR(node_def_file_->Append(def_str));
183 written_hashes_.insert(hash);
184 // Flush after each write, since NodeFileWriters are never destructed so the
185 // file is never closed.
186 TF_RETURN_IF_ERROR(node_def_file_->Flush());
187 }
188 return OkStatus();
189}
190
191} // namespace tensorflow
192