1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_INSPECTING_PLACER_H_
17#define TENSORFLOW_CORE_COMMON_RUNTIME_INSPECTING_PLACER_H_
18
19#include <vector>
20
21#include "absl/strings/str_join.h"
22#include "tensorflow/core/common_runtime/device.h"
23#include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h"
24#include "tensorflow/core/framework/function.h"
25#include "tensorflow/core/framework/types.h"
26#include "tensorflow/core/lib/core/stringpiece.h"
27#include "tensorflow/core/util/device_name_utils.h"
28#include "tensorflow/core/util/port.h"
29
30namespace tensorflow {
31
32// TODO(iga): Convert this struct into a class to ensure invariants between
33// device names, i.e.
34// DeviceNameUtils::IsSpecification(resource_device_name,
35// requested_device_name)
36// PossibleDevices does not contain assigned_device_name because we don't
37// assign devices to nested functions.
38struct PossibleDevices {
39 // The same as Member::requested_device_name_ in colocation_graph.cc.
40 DeviceNameUtils::ParsedName requested_device_name;
41
42 // The same as Member::resource_device_name_ in colocation_graph.cc.
43 DeviceNameUtils::ParsedName resource_device_name;
44
45 // A device type outside of this set will not be supported by some
46 // internal op.
47 PrioritizedDeviceTypeVector device_types;
48};
49
50// A struct for communicating constraints on devices that can
51// be chosen for inputs and outputs of an op requiring deep placer inspection.
52struct IOColocationGroups {
53 // input_groups[i] contains the group id that i'th input belongs to.
54 // List inputs are not supported.
55 std::vector<int> input_groups;
56 // output_groups[i] contains the group id that i'th output belongs to.
57 // List inputs are not supported.
58 std::vector<int> output_groups;
59 // group_devices[i] contains possible devices for group with id i.
60 std::vector<PossibleDevices> group_devices;
61
62 string DebugString() const;
63};
64
65class InspectingPlacer {
66 public:
67 // graph and device_set must not be null and must outlive this
68 // InspectingPlacer. default_device can be null. If not, must outlive this.
69 // TODO(iga): Add a "stack trace" to detect recursion and improve log
70 // messages. Currently, we will enter an infinite loop for recursive
71 // functions.
72 InspectingPlacer(const FunctionStack& stack,
73 const FunctionLibraryDefinition* flib_def,
74 const DeviceSet* device_set, const Device* default_device,
75 bool allow_soft_placement, bool log_device_placement);
76
77 // `node` must be
78 // PlacerInspectionRequiredOpsChecker::IsPlacerInspectionRequired.
79 Status ComputeIOColocationGroups(const Node& node,
80 IOColocationGroups* groups);
81
82 private:
83 const FunctionStack stack_;
84 const FunctionLibraryDefinition& flib_def_;
85 const DeviceSet& device_set_;
86 const Device* default_device_;
87 const bool allow_soft_placement_;
88 const bool log_device_placement_;
89
90 TF_DISALLOW_COPY_AND_ASSIGN(InspectingPlacer);
91};
92
93} // namespace tensorflow
94
95#endif // TENSORFLOW_CORE_COMMON_RUNTIME_INSPECTING_PLACER_H_
96