1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_CONTROL_FLOW_DEPS_TO_CHAINS_H_
17#define TENSORFLOW_CORE_COMMON_RUNTIME_CONTROL_FLOW_DEPS_TO_CHAINS_H_
18
19#include "tensorflow/core/common_runtime/optimization_registry.h"
20
21namespace tensorflow {
22
23// Move control flow dependencies in functional control flow to chains.
24// Chains are extra loop variables that serve as tokens for wiring control
25// dependencies across loop iterations at a finer granularity, compared to just
26// a single barrier at the end of each iteration. This enables the
27// parallel_iterations feature for tf.while_loop.
28//
29// One separate chain is added for each of the body function's `control_ret`.
30//
31// For example:
32//
33// while i > 0:
34// r = v.read_value()
35// s += expensive_operation(r)
36// assign = v.assign_add(1) # control: r
37// i += 1
38//
39// The loop above can safely compute `r` and `assign` ahead of `s`, by the
40// as-if rule. The separate switch/merge nodes that the loop lowers into support
41// that.
42// This transformation enables that to happen by rewriting the loop as follows:
43//
44// chain = 0.0
45// while i > 0:
46// r = v.read_value() # control: chain
47// s += expensive_operation(r)
48// assign = v.assign_add(1) # control: r
49// i += 1
50// chain = identity(chain) # control: assign
51//
52// This only rewires dependencies which need to cross scope boundaries, as the
53// switch/merge lowering process has no other way of dealing correctly with
54// those.
55//
56// This pass is best-effort and conservative, requiring attributes set by
57// tf.while_loop and automatic_control_dependencies. When the required
58// attributes are missing for a particular While node, no change is made to
59// that node. Other While nodes are still processed if they do have the needed
60// annotations.
61// The pass can also be toggled by omitting the `_stateful_parallelism=True`
62// attribute on the While node.
63// When the pass returns with error, the graph is left in an invalid state.
64// If successful, this pass also clears the body function's control_ret,
65// which in effect removes the hard barrier that gates each loop iteration.
66//
67//
68// TODO(mdan): Can we define that more formally?
69class ControlFlowDepsToChainsPass : public GraphOptimizationPass {
70 public:
71 Status Run(const GraphOptimizationPassOptions& options) override;
72};
73
74} // namespace tensorflow
75
76#endif // TENSORFLOW_CORE_COMMON_RUNTIME_CONTROL_FLOW_DEPS_TO_CHAINS_H_
77