1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/util/device_name_utils.h"
17
18#include <algorithm>
19
20#include "tensorflow/core/lib/core/errors.h"
21#include "tensorflow/core/lib/strings/str_util.h"
22#include "tensorflow/core/lib/strings/strcat.h"
23#include "tensorflow/core/platform/logging.h"
24
25namespace tensorflow {
26
27static bool IsAlpha(char c) {
28 return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z');
29}
30
31static bool IsAlphaNumOrUnderscore(char c) {
32 return IsAlpha(c) || (c >= '0' && c <= '9') || c == '_';
33}
34
35// Returns true iff "in" is a valid job name.
36static bool IsJobName(StringPiece in) {
37 return !in.empty() && IsAlpha(in.front()) &&
38 std::all_of(in.begin(), in.end(), IsAlphaNumOrUnderscore);
39}
40
41static bool ConsumePrefix(StringPiece* in, string* out,
42 StringPiece prefix_terminators) {
43 if (in->empty() || !IsAlpha(in->front())) return false;
44 const auto end_it =
45 std::find_first_of(in->begin(), in->end(), prefix_terminators.begin(),
46 prefix_terminators.end());
47 if (!std::all_of(in->begin(), end_it, IsAlphaNumOrUnderscore)) {
48 return false;
49 }
50 out->assign(in->begin(), end_it);
51 in->remove_prefix(end_it - in->begin());
52 return true;
53}
54
55// Returns true and fills in "*job" iff "*in" starts with a job name.
56static bool ConsumeJobName(StringPiece* in, string* job) {
57 return ConsumePrefix(in, job, "/");
58}
59
60// Returns true and fills in "*device_type" iff "*in" starts with a device type
61// name.
62static bool ConsumeDeviceType(StringPiece* in, string* device_type) {
63 return ConsumePrefix(in, device_type, "/:");
64}
65
66// Returns true and fills in "*val" iff "*in" starts with a decimal
67// number.
68static bool ConsumeNumber(StringPiece* in, int* val) {
69 uint64 tmp;
70 if (str_util::ConsumeLeadingDigits(in, &tmp)) {
71 *val = tmp;
72 return true;
73 } else {
74 return false;
75 }
76}
77
78// Returns a fully qualified device name given the parameters.
79static string DeviceName(const string& job, int replica, int task,
80 const string& device_prefix, const string& device_type,
81 int id) {
82 CHECK(IsJobName(job)) << job;
83 CHECK_LE(0, replica);
84 CHECK_LE(0, task);
85 CHECK(!device_type.empty());
86 CHECK_LE(0, id);
87 return strings::StrCat("/job:", job, "/replica:", replica, "/task:", task,
88 device_prefix, device_type, ":", id);
89}
90
91/* static */
92string DeviceNameUtils::FullName(const string& job, int replica, int task,
93 const string& type, int id) {
94 return DeviceName(job, replica, task, "/device:", type, id);
95}
96
97namespace {
98string LegacyName(const string& job, int replica, int task, const string& type,
99 int id) {
100 return DeviceName(job, replica, task, "/", absl::AsciiStrToLower(type), id);
101}
102} // anonymous namespace
103
104bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
105 p->Clear();
106 if (fullname == "/") {
107 return true;
108 }
109 while (!fullname.empty()) {
110 bool progress = false;
111 if (absl::ConsumePrefix(&fullname, "/job:")) {
112 p->has_job = !absl::ConsumePrefix(&fullname, "*");
113 if (p->has_job && !ConsumeJobName(&fullname, &p->job)) {
114 return false;
115 }
116 progress = true;
117 }
118 if (absl::ConsumePrefix(&fullname, "/replica:")) {
119 p->has_replica = !absl::ConsumePrefix(&fullname, "*");
120 if (p->has_replica && !ConsumeNumber(&fullname, &p->replica)) {
121 return false;
122 }
123 progress = true;
124 }
125 if (absl::ConsumePrefix(&fullname, "/task:")) {
126 p->has_task = !absl::ConsumePrefix(&fullname, "*");
127 if (p->has_task && !ConsumeNumber(&fullname, &p->task)) {
128 return false;
129 }
130 progress = true;
131 }
132 if (absl::ConsumePrefix(&fullname, "/device:")) {
133 p->has_type = !absl::ConsumePrefix(&fullname, "*");
134 if (p->has_type && !ConsumeDeviceType(&fullname, &p->type)) {
135 return false;
136 }
137 if (!absl::ConsumePrefix(&fullname, ":")) {
138 p->has_id = false;
139 } else {
140 p->has_id = !absl::ConsumePrefix(&fullname, "*");
141 if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
142 return false;
143 }
144 }
145 progress = true;
146 }
147
148 // Handle legacy naming convention for cpu and gpu.
149 if (absl::ConsumePrefix(&fullname, "/cpu:") ||
150 absl::ConsumePrefix(&fullname, "/CPU:")) {
151 p->has_type = true;
152 p->type = "CPU"; // Treat '/cpu:..' as uppercase '/device:CPU:...'
153 p->has_id = !absl::ConsumePrefix(&fullname, "*");
154 if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
155 return false;
156 }
157 progress = true;
158 }
159 if (absl::ConsumePrefix(&fullname, "/gpu:") ||
160 absl::ConsumePrefix(&fullname, "/GPU:")) {
161 p->has_type = true;
162 p->type = "GPU"; // Treat '/gpu:..' as uppercase '/device:GPU:...'
163 p->has_id = !absl::ConsumePrefix(&fullname, "*");
164 if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
165 return false;
166 }
167 progress = true;
168 }
169
170 if (!progress) {
171 return false;
172 }
173 }
174 return true;
175}
176
177bool DeviceNameUtils::ParseFullOrLocalName(StringPiece fullname,
178 ParsedName* p) {
179 return ParseFullName(fullname, p) || ParseLocalName(fullname, p);
180}
181
182namespace {
183
184void CompleteName(const DeviceNameUtils::ParsedName& parsed_basename,
185 DeviceNameUtils::ParsedName* parsed_name) {
186 if (!parsed_name->has_job) {
187 parsed_name->job = parsed_basename.job;
188 parsed_name->has_job = true;
189 }
190 if (!parsed_name->has_replica) {
191 parsed_name->replica = parsed_basename.replica;
192 parsed_name->has_replica = true;
193 }
194 if (!parsed_name->has_task) {
195 parsed_name->task = parsed_basename.task;
196 parsed_name->has_task = true;
197 }
198 if (!parsed_name->has_type) {
199 parsed_name->type = parsed_basename.type;
200 parsed_name->has_type = true;
201 }
202 if (!parsed_name->has_id) {
203 parsed_name->id = parsed_basename.id;
204 parsed_name->has_id = true;
205 }
206}
207
208} // namespace
209
210/* static */
211Status DeviceNameUtils::CanonicalizeDeviceName(StringPiece fullname,
212 StringPiece basename,
213 string* canonical_name) {
214 *canonical_name = "";
215 ParsedName parsed_basename;
216 if (!ParseFullName(basename, &parsed_basename)) {
217 return errors::InvalidArgument("Could not parse basename: ", basename,
218 " into a device specification.");
219 }
220 if (!(parsed_basename.has_job && parsed_basename.has_replica &&
221 parsed_basename.has_task && parsed_basename.has_type &&
222 parsed_basename.has_id)) {
223 return errors::InvalidArgument("Basename: ", basename,
224 " should be fully "
225 "specified.");
226 }
227 ParsedName parsed_name;
228 if (ParseLocalName(fullname, &parsed_name)) {
229 CompleteName(parsed_basename, &parsed_name);
230 *canonical_name = ParsedNameToString(parsed_name);
231 return OkStatus();
232 }
233 if (ParseFullName(fullname, &parsed_name)) {
234 CompleteName(parsed_basename, &parsed_name);
235 *canonical_name = ParsedNameToString(parsed_name);
236 return OkStatus();
237 }
238 return errors::InvalidArgument("Could not parse ", fullname,
239 " into a device "
240 "specification.");
241}
242
243/* static */
244string DeviceNameUtils::ParsedNameToString(const ParsedName& pn) {
245 string buf;
246 if (pn.has_job) strings::StrAppend(&buf, "/job:", pn.job);
247 if (pn.has_replica) strings::StrAppend(&buf, "/replica:", pn.replica);
248 if (pn.has_task) strings::StrAppend(&buf, "/task:", pn.task);
249 if (pn.has_type) {
250 strings::StrAppend(&buf, "/device:", pn.type, ":");
251 if (pn.has_id) {
252 strings::StrAppend(&buf, pn.id);
253 } else {
254 strings::StrAppend(&buf, "*");
255 }
256 }
257 return buf;
258}
259
260/* static */
261bool DeviceNameUtils::IsSpecification(const ParsedName& less_specific,
262 const ParsedName& more_specific) {
263 if (less_specific.has_job &&
264 (!more_specific.has_job || (less_specific.job != more_specific.job))) {
265 return false;
266 }
267 if (less_specific.has_replica &&
268 (!more_specific.has_replica ||
269 (less_specific.replica != more_specific.replica))) {
270 return false;
271 }
272 if (less_specific.has_task &&
273 (!more_specific.has_task || (less_specific.task != more_specific.task))) {
274 return false;
275 }
276 if (less_specific.has_type &&
277 (!more_specific.has_type || (less_specific.type != more_specific.type))) {
278 return false;
279 }
280 if (less_specific.has_id &&
281 (!more_specific.has_id || (less_specific.id != more_specific.id))) {
282 return false;
283 }
284 return true;
285}
286
287/* static */
288bool DeviceNameUtils::AreCompatibleDevNames(const ParsedName& a,
289 const ParsedName& b) {
290 if (a.has_job && b.has_job && (a.job != b.job)) {
291 return false;
292 }
293 if (a.has_replica && b.has_replica && (a.replica != b.replica)) {
294 return false;
295 }
296 if (a.has_task && b.has_task && (a.task != b.task)) {
297 return false;
298 }
299 if (a.has_type && b.has_type && (a.type != b.type)) {
300 return false;
301 }
302 if (a.has_id && b.has_id && (a.id != b.id)) {
303 return false;
304 }
305 return true;
306}
307
308void DeviceNameUtils::EnsureSpecification(ParsedName* more_specific,
309 const ParsedName& less_specific) {
310 if (less_specific.has_job) {
311 more_specific->has_job = true;
312 more_specific->job = less_specific.job;
313 }
314 if (less_specific.has_replica) {
315 more_specific->has_replica = true;
316 more_specific->replica = less_specific.replica;
317 }
318 if (less_specific.has_task) {
319 more_specific->has_task = true;
320 more_specific->task = less_specific.task;
321 }
322 if (less_specific.has_type) {
323 more_specific->has_type = true;
324 more_specific->type = less_specific.type;
325 }
326 if (less_specific.has_id) {
327 more_specific->has_id = true;
328 more_specific->id = less_specific.id;
329 }
330}
331
332/* static */
333bool DeviceNameUtils::IsCompleteSpecification(const ParsedName& pattern,
334 const ParsedName& name) {
335 CHECK(name.has_job && name.has_replica && name.has_task && name.has_type &&
336 name.has_id);
337
338 if (pattern.has_job && (pattern.job != name.job)) return false;
339 if (pattern.has_replica && (pattern.replica != name.replica)) return false;
340 if (pattern.has_task && (pattern.task != name.task)) return false;
341 if (pattern.has_type && (pattern.type != name.type)) return false;
342 if (pattern.has_id && (pattern.id != name.id)) return false;
343 return true;
344}
345
346namespace {
347Status MergeDevNamesImpl(DeviceNameUtils::ParsedName* target,
348 const DeviceNameUtils::ParsedName& other,
349 bool allow_soft_placement, bool override_conflicts) {
350 const auto& ParsedNameToString = DeviceNameUtils::ParsedNameToString;
351 if (other.has_job) {
352 if (target->has_job && target->job != other.job) {
353 return errors::InvalidArgument(
354 "Cannot merge devices with incompatible jobs: '",
355 ParsedNameToString(*target), "' and '", ParsedNameToString(other),
356 "'");
357 } else {
358 target->has_job = other.has_job;
359 target->job = other.job;
360 }
361 }
362
363 if (other.has_replica) {
364 if (target->has_replica && target->replica != other.replica) {
365 return errors::InvalidArgument(
366 "Cannot merge devices with incompatible replicas: '",
367 ParsedNameToString(*target), "' and '", ParsedNameToString(other),
368 "'");
369 } else {
370 target->has_replica = other.has_replica;
371 target->replica = other.replica;
372 }
373 }
374
375 if (other.has_task) {
376 if (target->has_task && target->task != other.task) {
377 return errors::InvalidArgument(
378 "Cannot merge devices with incompatible tasks: '",
379 ParsedNameToString(*target), "' and '", ParsedNameToString(other),
380 "'");
381 } else {
382 target->has_task = other.has_task;
383 target->task = other.task;
384 }
385 }
386
387 if (other.has_type) {
388 if (target->has_type && target->type != other.type) {
389 if (!allow_soft_placement) {
390 return errors::InvalidArgument(
391 "Cannot merge devices with incompatible types: '",
392 ParsedNameToString(*target), "' and '", ParsedNameToString(other),
393 "'");
394 } else if (override_conflicts) {
395 target->type = other.type;
396 } else {
397 target->has_id = false;
398 target->has_type = false;
399 return OkStatus();
400 }
401 } else {
402 target->has_type = other.has_type;
403 target->type = other.type;
404 }
405 }
406
407 if (other.has_id) {
408 if (target->has_id && target->id != other.id) {
409 if (!allow_soft_placement) {
410 return errors::InvalidArgument(
411 "Cannot merge devices with incompatible ids: '",
412 ParsedNameToString(*target), "' and '", ParsedNameToString(other),
413 "'");
414 } else if (override_conflicts) {
415 target->id = other.id;
416 } else {
417 target->has_id = false;
418 return OkStatus();
419 }
420 } else {
421 target->has_id = other.has_id;
422 target->id = other.id;
423 }
424 }
425
426 return OkStatus();
427}
428
429} // namespace
430
431/* static */
432Status DeviceNameUtils::MergeDevNames(ParsedName* target,
433 const ParsedName& other,
434 bool allow_soft_placement) {
435 return MergeDevNamesImpl(target, other, allow_soft_placement,
436 /*override_conflicts=*/false);
437}
438
439/* static */
440Status DeviceNameUtils::MergeOverrideDevNames(ParsedName* target,
441 const ParsedName& other) {
442 return MergeDevNamesImpl(target, other, /*allow_soft_placement=*/true,
443 /*override_conflicts=*/true);
444}
445
446/* static */
447void DeviceNameUtils::MergeUnsetDevNames(ParsedName* target,
448 const ParsedName& other) {
449 if (other.has_job && !target->has_job) {
450 target->has_job = other.has_job;
451 target->job = other.job;
452 }
453
454 if (other.has_replica && !target->has_replica) {
455 target->has_replica = other.has_replica;
456 target->replica = other.replica;
457 }
458
459 if (other.has_task && !target->has_task) {
460 target->has_task = other.has_task;
461 target->task = other.task;
462 }
463
464 if (other.has_type && !target->has_type) {
465 target->has_type = other.has_type;
466 target->type = other.type;
467 }
468
469 if (other.has_id && !target->has_id) {
470 target->has_id = other.has_id;
471 target->id = other.id;
472 }
473}
474
475/* static */
476bool DeviceNameUtils::IsSameAddressSpace(const ParsedName& a,
477 const ParsedName& b) {
478 return (a.has_job && b.has_job && (a.job == b.job)) &&
479 (a.has_replica && b.has_replica && (a.replica == b.replica)) &&
480 (a.has_task && b.has_task && (a.task == b.task));
481}
482
483/* static */
484bool DeviceNameUtils::IsSameAddressSpace(StringPiece src, StringPiece dst) {
485 ParsedName x;
486 ParsedName y;
487 return ParseFullName(src, &x) && ParseFullName(dst, &y) &&
488 IsSameAddressSpace(x, y);
489}
490
491/* static */
492bool DeviceNameUtils::IsDifferentAddressSpace(const ParsedName& a,
493 const ParsedName& b) {
494 return (a.has_job && b.has_job && (a.job != b.job)) ||
495 (a.has_replica && b.has_replica && (a.replica != b.replica)) ||
496 (a.has_task && b.has_task && (a.task != b.task));
497}
498
499/* static */
500const DeviceNameUtils::ParsedName DeviceNameUtils::AddressSpace(
501 const ParsedName& name) {
502 ParsedName address_space;
503 address_space.has_job = name.has_job;
504 address_space.has_replica = name.has_replica;
505 address_space.has_task = name.has_task;
506 address_space.job = name.job;
507 address_space.replica = name.replica;
508 address_space.task = name.task;
509 return address_space;
510}
511
512/* static */
513string DeviceNameUtils::LocalName(StringPiece type, int id) {
514 return strings::StrCat("/device:", type, ":", id);
515}
516
517namespace {
518// Returns the legacy local device name given its "type" and "id" (which is
519// '/device:type:id').
520string LegacyLocalName(StringPiece type, int id) {
521 return strings::StrCat(type, ":", id);
522}
523} // anonymous namespace
524
525/* static */
526string DeviceNameUtils::LocalName(StringPiece fullname) {
527 ParsedName x;
528 CHECK(ParseFullName(fullname, &x)) << fullname;
529 return LocalName(x.type, x.id);
530}
531
532/* static */
533bool DeviceNameUtils::ParseLocalName(StringPiece name, ParsedName* p) {
534 if (!ConsumeDeviceType(&name, &p->type)) {
535 return false;
536 }
537 p->has_type = true;
538 if (!absl::ConsumePrefix(&name, ":")) {
539 return false;
540 }
541 if (!ConsumeNumber(&name, &p->id)) {
542 return false;
543 }
544 p->has_id = true;
545 return name.empty();
546}
547
548/* static */
549bool DeviceNameUtils::SplitDeviceName(StringPiece name, string* task,
550 string* device) {
551 ParsedName pn;
552 if (ParseFullName(name, &pn) && pn.has_type && pn.has_id) {
553 task->clear();
554 task->reserve(
555 (pn.has_job ? (5 + pn.job.size()) : 0) +
556 (pn.has_replica ? (9 + 4 /*estimated UB for # replica digits*/) : 0) +
557 (pn.has_task ? (6 + 4 /*estimated UB for # task digits*/) : 0));
558 if (pn.has_job) {
559 strings::StrAppend(task, "/job:", pn.job);
560 }
561 if (pn.has_replica) {
562 strings::StrAppend(task, "/replica:", pn.replica);
563 }
564 if (pn.has_task) {
565 strings::StrAppend(task, "/task:", pn.task);
566 }
567 device->clear();
568 strings::StrAppend(device, pn.type, ":", pn.id);
569 return true;
570 }
571 return false;
572}
573
574/* static */
575bool DeviceNameUtils::GetTaskName(const ParsedName& pn, string* task) {
576 if (pn.has_job && pn.has_replica && pn.has_task) {
577 task->clear();
578 task->reserve((5 + pn.job.size()) +
579 (9 + 4 /*estimated UB for # replica digits*/) +
580 (6 + 4 /*estimated UB for # task digits*/));
581 strings::StrAppend(task, "/job:", pn.job);
582 strings::StrAppend(task, "/replica:", pn.replica);
583 strings::StrAppend(task, "/task:", pn.task);
584 return true;
585 }
586 return false;
587}
588
589std::vector<string> DeviceNameUtils::GetNamesForDeviceMappings(
590 const ParsedName& pn) {
591 if (pn.has_job && pn.has_replica && pn.has_task && pn.has_type && pn.has_id) {
592 return {
593 DeviceNameUtils::FullName(pn.job, pn.replica, pn.task, pn.type, pn.id),
594 LegacyName(pn.job, pn.replica, pn.task, pn.type, pn.id)};
595 } else {
596 return {};
597 }
598}
599
600std::vector<string> DeviceNameUtils::GetLocalNamesForDeviceMappings(
601 const ParsedName& pn) {
602 if (pn.has_type && pn.has_id) {
603 return {DeviceNameUtils::LocalName(pn.type, pn.id),
604 LegacyLocalName(pn.type, pn.id)};
605 } else {
606 return {};
607 }
608}
609
610/*static*/ Status DeviceNameUtils::DeviceNameToCpuDeviceName(
611 const string& device_name, string* host_device_name) {
612 DeviceNameUtils::ParsedName device;
613 if (!DeviceNameUtils::ParseFullName(device_name, &device)) {
614 return errors::Internal("Could not parse device name ", device_name);
615 }
616 device.type = "CPU";
617 device.has_type = true;
618 device.id = 0;
619 device.has_id = true;
620 *host_device_name = DeviceNameUtils::ParsedNameToString(device);
621 return OkStatus();
622}
623
624std::ostream& operator<<(std::ostream& os,
625 const DeviceNameUtils::ParsedName& x) {
626 os << DeviceNameUtils::ParsedNameToString(x);
627 return os;
628}
629
630} // namespace tensorflow
631