1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_INSPECTION_REQUIRED_OPS_UTILS_H_
17#define TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_INSPECTION_REQUIRED_OPS_UTILS_H_
18
19// Operations calling functions are becoming ubiquitous in TF 2.0.
20// Examples include PartitionedCallOp, functional If/While, and Dataset ops.
21// Such operations might require deep inspection - looking at the body of the
22// called function - to place them and surrounding ops correctly.
23
24// This file contains some utilities for placer to correctly place such ops
25// including:
26// - PlacerInspectionRequiredOpChecker: A simple class with a single
27// IsPlacerInspectionRequired method.
28// - IsolatePlacerInspectionRequiredOps: This function adds Identity ops for
29// each input/output of ops requiring placer inspection. It greatly simplifies
30// the implementation of placing such ops.
31
32#include <vector>
33
34#include "absl/types/optional.h"
35#include "tensorflow/core/framework/function.h"
36#include "tensorflow/core/graph/graph.h"
37#include "tensorflow/core/lib/core/status.h"
38
39namespace tensorflow {
40
41// PlacerInspectionRequiredOpChecker allows one to check if Placer needs to
42// look deeply into the op to place ops consuming the outputs correctly.
43//
44// It is a class instead of a standalone method because checking whether
45// a function returns a resource takes non-trivial time and we cache the
46// results.
47class PlacerInspectionRequiredOpChecker {
48 public:
49 // Constructs a PlacerInspectionRequiredOpChecker for nodes of `graph`.
50 // The functions referenced by nodes in `graph` will be looked up in
51 // `flib_def`
52 PlacerInspectionRequiredOpChecker(const Graph* graph,
53 const FunctionLibraryDefinition* flib_def);
54
55 // If `node` is considered a deep op, sets `*is_deep` to true and returns
56 // OkStatus(). If an error occurs, returns that error, and the value of
57 // `*is_deep` is undefined.
58 // Currently, an op is considered deep, if it is a calling a function
59 // returning a resource. This definition is driven by Placer's need to
60 // look inside the op.
61 // REQUIRES: `node` is part of `graph` passed into constructor.
62 Status IsPlacerInspectionRequired(const Node& node, bool* is_deep);
63
64 private:
65 const Graph& graph_;
66 const FunctionLibraryDefinition& flib_def_;
67 // Indexed by the node id.
68 // If cache_[node_id] is empty, the deepness of the node with id `node_id` has
69 // not been computed yet. Else, it contains the value already computed.
70 std::vector<absl::optional<bool>> cache_;
71};
72
73// Extracts `fdef` and `func` from `flib_def` for the function identified
74// in "f" attribute of `node`.
75Status GetFunctionDefAndAttrs(const FunctionLibraryDefinition& flib_def,
76 const Node& node, const FunctionDef** fdef,
77 NameAttrList* func);
78
79// The "call" stack of functions.
80// Useful for better error messages as well as for detecting recursion.
81// Stores references to graph nodes. These references must outlive this.
82class FunctionStack {
83 public:
84 explicit FunctionStack(const string& function_name);
85
86 // `node_in_current_function` must outlive this.
87 FunctionStack Push(const Node* node_in_current_function,
88 const string& new_current_function) const;
89
90 // Returns true iff this stack already includes `function_name`.
91 bool HasFunction(const string& function_name) const;
92
93 const string& current_function_name() const { return current_function_name_; }
94
95 // Format's this suitable for error interpolation that retrieves
96 // Python files and line numbers.
97 string FormatForError() const;
98
99 private:
100 struct Frame {
101 Frame(const string& function, const Node* node)
102 : function_name(function), node(node) {}
103
104 string function_name;
105 const Node* node;
106 };
107
108 // The function at the top of the stack. In other words, the function
109 // that is currently being inspected for placement.
110 string current_function_name_;
111
112 // The stack of frames that got the placement to the current_function_name_.
113 // frames_[0].function_name is the top function that Placer was constructed
114 // with. frames_[0].function_name can be empty if placer was constructed with
115 // a nameless graph, not a function. frames_[0].node_name is a name of a node
116 // in frames_[0].function_name that required deep inspection (e.g. a
117 // PartitionedCallOp). The function that this node invoked is
118 // frames_[1].function_name, if frames_.size() > 1. Else, the function that
119 // this node invoked is current_function_name_.
120 std::vector<Frame> frames_;
121};
122
123// Adds Identities for each input and output of function-calling ops in `graph`
124//
125// For example, the following graph calling a function on inputs `a` and `b`
126// and producing output `y` will be rewritten to include identities on all
127// edges:
128//
129// a b
130// | |
131// v v
132// f (PartitionedCallOp)
133// |
134// v
135// y
136//
137// is transformed to
138//
139// a b
140// | |
141// a_f (Identity) b_f (Identity)
142// | |
143// v v
144// f (PartitionedCallOp)
145// |
146// f_y (Identity)
147// |
148// v
149// y
150//
151Status IsolatePlacerInspectionRequiredOps(
152 const FunctionLibraryDefinition& flib_def, Graph* graph);
153
154} // namespace tensorflow
155
156#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_INSPECTION_REQUIRED_OPS_UTILS_H_
157