1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PARTITIONING_UTILS_H_ |
16 | #define TENSORFLOW_CORE_COMMON_RUNTIME_PARTITIONING_UTILS_H_ |
17 | |
18 | #include <unordered_map> |
19 | #include <vector> |
20 | |
21 | #include "tensorflow/core/common_runtime/device_set.h" |
22 | #include "tensorflow/core/framework/function.h" |
23 | #include "tensorflow/core/lib/core/status.h" |
24 | |
25 | namespace tensorflow { |
26 | |
27 | // Given a `device_set` and a `graph`, partitions the `graph` into |
28 | // `subgraphs`. `subgraphs` maps device names to the graph assigned to that |
29 | // device. `graph` must have been placed (e.g. by running Placer), |
30 | // i.e. all nodes must have an assigned_device set. |
31 | // `graph` is non-const because the underlying Partition() function transforms |
32 | // the graph to correctly partition distributed control flow. |
33 | // `get_tensor_name_attr` computes the "tensor_name" attr value of Send/Recv ops |
34 | // inserted during partitioning. Use the default one if not set. It needs to be |
35 | // thread safe if it's shared in multple threads. |
36 | Status PartitionFunctionGraph( |
37 | const DeviceSet& device_set, std::unique_ptr<Graph> graph, |
38 | std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs, |
39 | std::function<string(const Edge*)> get_tensor_name_attr = nullptr); |
40 | |
41 | // Inserts send/recv ops to `graph` if nodes are assigned to multiple devices. |
42 | // Returns the new graph with the added nodes. Moreover, the dependency between |
43 | // a send/recv pair is made explicit by adding a control dependency between |
44 | // them. |
45 | // Note that, the returned graph is intended to be used by TF MLIR importer. |
46 | // The dependencies between send/recv pairs ensure the importer will generate TF |
47 | // MLIR ops in a valid order. |
48 | StatusOr<std::unique_ptr<Graph>> InsertTransferOps( |
49 | const DeviceSet& device_set, std::unique_ptr<Graph> graph); |
50 | |
51 | // This function performs bookkeeping to track which `Arg` and `Retval` nodes |
52 | // were placed on a particular device / graph. |
53 | // |
54 | // More specifically, this function |
55 | // |
56 | // (1) rewrites the indices of the `Arg` and `Retval` nodes in `graph` to be |
57 | // consecutive. |
58 | // |
59 | // These indices might not be consecutive after grappler's pruning |
60 | // optimization (e.g. removing redundant Args), or graph partitioning. In |
61 | // the latter case, the nodes in `graph` are placed on `device_type`, and |
62 | // each such graph partition gets a subset of the arguments and return |
63 | // values. The `index` attributes of these _Arg and _Retval nodes reflect |
64 | // the indices of these parameters in the original function. To convert |
65 | // `subgraph` to a function, we need to replace there original indices with |
66 | // 0, 1, 2, ... . |
67 | // |
68 | // The argument and return value order in `graph` is determined by the |
69 | // argument and return value order in the original function. This stability |
70 | // is important because it enables us to treat a single-partition function |
71 | // as having the same signature as the subgraph. |
72 | // |
73 | // (2) records the subsets of `Arg` and `Retval` nodes assigned to the |
74 | // device in `*_indices`, and |
75 | // (3) records which `Arg` and `Retval` nodes live in host memory in |
76 | // `*_alloc_attrs`. If these vectors are NULL, do nothing here. If |
77 | // `ints_on_device` is false, int32 `Arg` and `Retval` nodes are placed on |
78 | // host else not. This is needed because in certain special cases e.g. |
79 | // when graph is placed on TPU/XLA device or when the `Retval` is an output |
80 | // of an iterator, int32 tensors live on device. |
81 | Status UpdateArgAndRetvalMetadata( |
82 | Graph* graph, std::vector<FunctionArgIndex>* arg_indices, |
83 | std::vector<int>* ret_indices, |
84 | std::vector<AllocatorAttributes>* arg_alloc_attrs, |
85 | std::vector<AllocatorAttributes>* ret_alloc_attrs, bool ints_on_device); |
86 | |
87 | // Utility for generating function names not present in `flib_def`, using |
88 | // given `name` as the base for the name. |
89 | class FunctionNameGenerator { |
90 | public: |
91 | // `flib_def` must outlive this. |
92 | FunctionNameGenerator(const FunctionLibraryDefinition* flib_def, |
93 | const string& name) |
94 | : flib_def_(flib_def), name_(name), counter_(0) {} |
95 | |
96 | // Returns a function name not present in `flib_def` using `name` as |
97 | // the base and appending a numeric suffix. |
98 | string GetName(); |
99 | |
100 | private: |
101 | const FunctionLibraryDefinition* flib_def_; |
102 | const string name_; |
103 | uint32 counter_; |
104 | }; |
105 | |
106 | } // namespace tensorflow |
107 | |
108 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_PARTITIONING_UTILS_H_ |
109 | |