1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_H_ |
17 | #define TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_H_ |
18 | |
19 | #include <string> |
20 | |
21 | #include "tensorflow/core/common_runtime/device_set.h" |
22 | #include "tensorflow/core/framework/function.h" |
23 | #include "tensorflow/core/graph/graph.h" |
24 | #include "tensorflow/core/lib/core/status.h" |
25 | #include "tensorflow/core/platform/macros.h" |
26 | #include "tensorflow/core/platform/types.h" |
27 | #include "tensorflow/core/public/session_options.h" |
28 | |
29 | namespace tensorflow { |
30 | |
31 | // A placement algorithm that assigns the nodes of the given Graph to |
32 | // devices the given DeviceSet, respecting the following constraints: |
33 | // |
34 | // 1. Existing device assignments remain unchanged. |
35 | // 2. Requested (partial or complete) device specifications given by device name |
36 | // for each node are granted. |
37 | // 3. Nodes connected by edges of a reference type are colocated on |
38 | // the same device. |
39 | // 4. Given nodes "A" and "B", if node "B" has a colocation group |
40 | // "@loc:A", nodes "A" and "B" will be colocated on the same device. |
41 | // |
42 | // The implementation builds a constraint graph with the same set of |
43 | // nodes, and edges that represent colocation constraints between |
44 | // nodes. Each connected component in the resulting constraint graph |
45 | // is then assigned to a set of valid devices. |
46 | // |
47 | // Run() will finally assign the device to each node given the list of |
48 | // possible devices. |
49 | // |
50 | // TODO(mrry): "Soft" constraints, such as "place node 'x' as close as |
51 | // possible to node 'y' while respecting the other constraints"? |
52 | // TODO(mrry): Create a common interface for this and the other |
53 | // placement algorithms so that they may be injected into the graph |
54 | // builder. |
55 | class Placer { |
56 | public: |
57 | // Creates an instance of the Placer algorithm for the given |
58 | // Graph "graph" (nodes in which may or may not be assigned) on the |
59 | // given DeviceSet "devices". |
60 | // "function_name" should be set to the name of the function whose body is |
61 | // represented by "graph". If "graph" is not representing a function body, |
62 | // "function_name" should be empty. |
63 | // |
64 | // If non-null, default_local_device is used where possible as a placement for |
65 | // nodes which do not have a device specified, ahead of other devices which |
66 | // would otherwise be higher priority. default_local_device should be on the |
67 | // local host so that its FLR is directly accessible by the current process. |
68 | // |
69 | // The "graph", "devices", and "default_local_device" pointer arguments are |
70 | // borrowed by this Placer, and must outlive it. |
71 | Placer(Graph* graph, const string& function_name, |
72 | const FunctionLibraryDefinition* flib_def, const DeviceSet* devices, |
73 | const Device* default_local_device, bool allow_soft_placement, |
74 | bool log_device_placement); |
75 | Placer(Graph* graph, const string& function_name, |
76 | const FunctionLibraryDefinition* flib_def, const DeviceSet* devices); |
77 | Placer(Graph* graph, const string& function_name, |
78 | const FunctionLibraryDefinition* flib_def, const DeviceSet* devices, |
79 | const Device* default_local_device); |
80 | |
81 | ~Placer(); |
82 | |
83 | // Assigns each node in this Placer's graph to a device in its |
84 | // set of devices. |
85 | // |
86 | // This method is not thread-safe. |
87 | // Run() may be invoked at most once. |
88 | Status Run(); |
89 | |
90 | private: |
91 | // Returns true if the device type of 'candidate_device_name' is |
92 | // found in 'devices'. |
93 | bool CanAssignToDevice(const string& candidate_device_name, |
94 | const std::vector<Device*>& devices) const; |
95 | |
96 | Graph* const graph_; // Not owned. |
97 | const string function_name_; |
98 | const FunctionLibraryDefinition* const flib_def_; // Not owned. |
99 | const DeviceSet* const devices_; // Not owned. |
100 | const Device* default_local_device_; // Not owned. |
101 | const bool allow_soft_placement_; |
102 | const bool log_device_placement_; |
103 | |
104 | TF_DISALLOW_COPY_AND_ASSIGN(Placer); |
105 | }; |
106 | |
107 | } // namespace tensorflow |
108 | |
109 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_H_ |
110 | |