1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PARTITIONING_UTILS_H_
16#define TENSORFLOW_CORE_COMMON_RUNTIME_PARTITIONING_UTILS_H_
17
18#include <unordered_map>
19#include <vector>
20
21#include "tensorflow/core/common_runtime/device_set.h"
22#include "tensorflow/core/framework/function.h"
23#include "tensorflow/core/lib/core/status.h"
24
25namespace tensorflow {
26
27// Given a `device_set` and a `graph`, partitions the `graph` into
28// `subgraphs`. `subgraphs` maps device names to the graph assigned to that
29// device. `graph` must have been placed (e.g. by running Placer),
30// i.e. all nodes must have an assigned_device set.
31// `graph` is non-const because the underlying Partition() function transforms
32// the graph to correctly partition distributed control flow.
33// `get_tensor_name_attr` computes the "tensor_name" attr value of Send/Recv ops
34// inserted during partitioning. Use the default one if not set. It needs to be
35// thread safe if it's shared in multple threads.
36Status PartitionFunctionGraph(
37 const DeviceSet& device_set, std::unique_ptr<Graph> graph,
38 std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs,
39 std::function<string(const Edge*)> get_tensor_name_attr = nullptr);
40
41// Inserts send/recv ops to `graph` if nodes are assigned to multiple devices.
42// Returns the new graph with the added nodes. Moreover, the dependency between
43// a send/recv pair is made explicit by adding a control dependency between
44// them.
45// Note that, the returned graph is intended to be used by TF MLIR importer.
46// The dependencies between send/recv pairs ensure the importer will generate TF
47// MLIR ops in a valid order.
48StatusOr<std::unique_ptr<Graph>> InsertTransferOps(
49 const DeviceSet& device_set, std::unique_ptr<Graph> graph);
50
51// This function performs bookkeeping to track which `Arg` and `Retval` nodes
52// were placed on a particular device / graph.
53//
54// More specifically, this function
55//
56// (1) rewrites the indices of the `Arg` and `Retval` nodes in `graph` to be
57// consecutive.
58//
59// These indices might not be consecutive after grappler's pruning
60// optimization (e.g. removing redundant Args), or graph partitioning. In
61// the latter case, the nodes in `graph` are placed on `device_type`, and
62// each such graph partition gets a subset of the arguments and return
63// values. The `index` attributes of these _Arg and _Retval nodes reflect
64// the indices of these parameters in the original function. To convert
65// `subgraph` to a function, we need to replace there original indices with
66// 0, 1, 2, ... .
67//
68// The argument and return value order in `graph` is determined by the
69// argument and return value order in the original function. This stability
70// is important because it enables us to treat a single-partition function
71// as having the same signature as the subgraph.
72//
73// (2) records the subsets of `Arg` and `Retval` nodes assigned to the
74// device in `*_indices`, and
75// (3) records which `Arg` and `Retval` nodes live in host memory in
76// `*_alloc_attrs`. If these vectors are NULL, do nothing here. If
77// `ints_on_device` is false, int32 `Arg` and `Retval` nodes are placed on
78// host else not. This is needed because in certain special cases e.g.
79// when graph is placed on TPU/XLA device or when the `Retval` is an output
80// of an iterator, int32 tensors live on device.
81Status UpdateArgAndRetvalMetadata(
82 Graph* graph, std::vector<FunctionArgIndex>* arg_indices,
83 std::vector<int>* ret_indices,
84 std::vector<AllocatorAttributes>* arg_alloc_attrs,
85 std::vector<AllocatorAttributes>* ret_alloc_attrs, bool ints_on_device);
86
87// Utility for generating function names not present in `flib_def`, using
88// given `name` as the base for the name.
89class FunctionNameGenerator {
90 public:
91 // `flib_def` must outlive this.
92 FunctionNameGenerator(const FunctionLibraryDefinition* flib_def,
93 const string& name)
94 : flib_def_(flib_def), name_(name), counter_(0) {}
95
96 // Returns a function name not present in `flib_def` using `name` as
97 // the base and appending a numeric suffix.
98 string GetName();
99
100 private:
101 const FunctionLibraryDefinition* flib_def_;
102 const string name_;
103 uint32 counter_;
104};
105
106} // namespace tensorflow
107
108#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PARTITIONING_UTILS_H_
109