1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/composite_device.h" |
17 | |
18 | #include "absl/strings/str_join.h" |
19 | #include "tensorflow/core/util/device_name_utils.h" |
20 | |
21 | namespace tensorflow { |
22 | |
23 | const char* const kCompositeDeviceType = "COMPOSITE" ; |
24 | |
25 | std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice( |
26 | const std::vector<string>& underlying_devices, const int unique_device_id, |
27 | const DeviceNameUtils::ParsedName& host_name, Status* status) { |
28 | DeviceNameUtils::ParsedName parsed_name = host_name; |
29 | parsed_name.type = kCompositeDeviceType; |
30 | parsed_name.id = unique_device_id; |
31 | const string device_name = DeviceNameUtils::ParsedNameToString(parsed_name); |
32 | return CompositeDevice::MakeDevice(underlying_devices, device_name, status); |
33 | } |
34 | |
35 | std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice( |
36 | const std::vector<string>& underlying_devices, const string& device_name, |
37 | Status* status) { |
38 | if (underlying_devices.empty()) { |
39 | status->Update( |
40 | errors::InvalidArgument("underlying_devices should not be empty." )); |
41 | return nullptr; |
42 | } |
43 | DeviceNameUtils::ParsedName parsed_name; |
44 | if (!DeviceNameUtils::ParseFullName(underlying_devices.at(0), &parsed_name)) { |
45 | status->Update(tensorflow::errors::InvalidArgument( |
46 | "Cannot parse device name " , underlying_devices.at(0), |
47 | " when creating CompositeDevice." )); |
48 | return nullptr; |
49 | } |
50 | const string& underlying_type = parsed_name.type; |
51 | for (int i = 1; i < underlying_devices.size(); ++i) { |
52 | DeviceNameUtils::ParsedName name; |
53 | if (!DeviceNameUtils::ParseFullName(underlying_devices.at(i), &name)) { |
54 | status->Update(tensorflow::errors::InvalidArgument( |
55 | "Cannot parse device name " , underlying_devices.at(i), |
56 | " when creating CompositeDevice." )); |
57 | return nullptr; |
58 | } |
59 | if (name.type != underlying_type) { |
60 | status->Update(tensorflow::errors::InvalidArgument( |
61 | "Expect device type " , parsed_name.type, "; but got type " , name.type, |
62 | " from device: " , underlying_devices.at(i), |
63 | " when creating CompositeDevice." )); |
64 | return nullptr; |
65 | } |
66 | } |
67 | |
68 | DeviceAttributes device_attributes; |
69 | device_attributes.set_name(device_name); |
70 | device_attributes.set_device_type(kCompositeDeviceType); |
71 | |
72 | return absl::WrapUnique( |
73 | new CompositeDevice(device_attributes, underlying_devices)); |
74 | } |
75 | |
76 | } // namespace tensorflow |
77 | |