1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NODE_FILE_WRITER_H_
16#define TENSORFLOW_CORE_COMMON_RUNTIME_NODE_FILE_WRITER_H_
17
18#include <string>
19#include <unordered_map>
20
21#include "absl/container/flat_hash_set.h"
22#include "tensorflow/core/framework/node_def.pb.h"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/platform/env.h"
25#include "tensorflow/core/platform/mutex.h"
26#include "tensorflow/core/util/env_var.h"
27
28namespace tensorflow {
29
30// Writes out the NodeDef and the input shapes/dtypes for an executed node to a
31// file. This allows the set of executed nodes for a model or test to be
32// examined and processed. Currently this is used by an internal tool which
33// checks that ops executed by tests are deterministic.
34class NodeFileWriter {
35 public:
36 // Creates or reuses a NodeFileWriter if environmental variable
37 // TF_NODE_FILE_WRITER_DIRECTORY is set, which specifies the directory where
38 // the node file will be created in. Otherwise, returns nullptr. When called
39 // with the same device_name, the same NodeFileWriter will be returned.
40 static StatusOr<NodeFileWriter*> GetNodeFileWriterIfEnabled(
41 const std::string& device_name, Env* env);
42
43 // Records the execution of a node, if eligible, by writing the node to the
44 // file. Only writes the node if the exact node with the given input
45 // shapes/dtypes hasn't already been written. Should be called once every time
46 // a node is run.
47 Status RecordNodeExecution(OpKernel* op_kernel, OpKernelContext* context);
48
49 const std::string& filename() { return filename_; }
50
51 private:
52 explicit NodeFileWriter(std::string filename)
53 : filename_{std::move(filename)} {}
54
55 Status Init(Env* env) {
56 return env->NewWritableFile(filename_, &node_def_file_);
57 }
58
59 // Writes the NodeDef to a file, if it hasn't already been written yet.
60 Status MaybeWriteNodeDefToFile(const NodeDef& def);
61
62 const std::string filename_;
63 mutex mu_;
64 // Hashes of the NodeDefs already written to the file
65 absl::flat_hash_set<uint64> written_hashes_ TF_GUARDED_BY(mu_);
66
67 std::unique_ptr<WritableFile> node_def_file_ TF_PT_GUARDED_BY(mu_);
68};
69
70} // namespace tensorflow
71#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NODE_FILE_WRITER_H_
72