1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/target/virtual_device.cc
22 * \brief A compile time representation for where data is to be stored at runtime, and how to
23 * compile code to compute it.
24 */
25#include <tvm/node/reflection.h>
26#include <tvm/runtime/device_api.h>
27#include <tvm/target/virtual_device.h>
28
29namespace tvm {
30
31TVM_REGISTER_NODE_TYPE(VirtualDeviceNode);
32
33TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
34 .set_dispatch<VirtualDeviceNode>([](const ObjectRef& ref, ReprPrinter* p) {
35 auto* node = ref.as<VirtualDeviceNode>();
36 p->stream << "VirtualDevice(";
37 if (node->IsFullyUnconstrained()) {
38 p->stream << "?";
39 } else {
40 bool need_sep = false;
41 if (node->device_type() != kInvalidDeviceType) {
42 p->stream << "device_type=" << node->device_type();
43 need_sep = true;
44 }
45 if (node->virtual_device_id >= 0) {
46 if (need_sep) {
47 p->stream << ", ";
48 }
49 p->stream << "virtual_device_id=" << node->virtual_device_id;
50 need_sep = true;
51 }
52 if (node->target.defined()) {
53 if (need_sep) {
54 p->stream << ", ";
55 }
56 p->stream << "target=" << node->target->ToDebugString();
57 need_sep = true;
58 }
59 if (!node->memory_scope.empty()) {
60 if (need_sep) {
61 p->stream << ", ";
62 }
63 p->stream << "memory_scope='" << node->memory_scope << "'";
64 }
65 }
66 p->stream << ")";
67 });
68
69VirtualDevice::VirtualDevice(DLDeviceType device_type, int virtual_device_id, Target target,
70 MemoryScope memory_scope) {
71 ICHECK(!target.defined() || device_type == target->GetTargetDeviceType())
72 << "target " << target->ToDebugString() << " has device type "
73 << target->GetTargetDeviceType() << " but virtual device has device type " << device_type;
74 auto node = make_object<VirtualDeviceNode>();
75 node->device_type_int = device_type;
76 node->virtual_device_id = virtual_device_id;
77 node->target = std::move(target);
78 node->memory_scope = std::move(memory_scope);
79 data_ = std::move(node);
80}
81
82/* static */ VirtualDevice VirtualDevice::FullyUnconstrained() {
83 static const VirtualDevice unconstrained{};
84 return unconstrained;
85}
86
87/* static */
88Optional<VirtualDevice> VirtualDevice::Join(const VirtualDevice& lhs, const VirtualDevice& rhs) {
89 if (lhs == rhs) {
90 return lhs;
91 }
92 DLDeviceType joined_device_type;
93 if (lhs->device_type() != kInvalidDeviceType) {
94 joined_device_type = lhs->device_type();
95 if (rhs->device_type() != kInvalidDeviceType && lhs->device_type() != rhs->device_type()) {
96 return {};
97 }
98 } else {
99 joined_device_type = rhs->device_type();
100 }
101 int joined_virtual_device_id;
102 if (lhs->virtual_device_id >= 0) {
103 joined_virtual_device_id = lhs->virtual_device_id;
104 if (rhs->virtual_device_id >= 0 && lhs->virtual_device_id != rhs->virtual_device_id) {
105 return {};
106 }
107 } else {
108 joined_virtual_device_id = rhs->virtual_device_id;
109 }
110 Target joined_target;
111 if (lhs->target.defined()) {
112 joined_target = lhs->target;
113 if (rhs->target.defined() && lhs->target != rhs->target) {
114 return {};
115 }
116 } else {
117 joined_target = rhs->target;
118 }
119 MemoryScope joined_memory_scope;
120 if (!lhs->memory_scope.empty()) {
121 joined_memory_scope = lhs->memory_scope;
122 if (!rhs->memory_scope.empty() && lhs->memory_scope != rhs->memory_scope) {
123 return {};
124 }
125 } else {
126 joined_memory_scope = rhs->memory_scope;
127 }
128 return VirtualDevice(joined_device_type, joined_virtual_device_id, joined_target,
129 joined_memory_scope);
130}
131
132/* static */
133VirtualDevice VirtualDevice::Default(const VirtualDevice& lhs, const VirtualDevice& rhs) {
134 if (lhs == rhs) {
135 return lhs;
136 }
137 DLDeviceType defaulted_device_type;
138 if (lhs->device_type() != kInvalidDeviceType) {
139 defaulted_device_type = lhs->device_type();
140 } else {
141 defaulted_device_type = rhs->device_type();
142 }
143 int defaulted_virtual_device_id;
144 if (lhs->virtual_device_id >= 0) {
145 defaulted_virtual_device_id = lhs->virtual_device_id;
146 } else {
147 defaulted_virtual_device_id = rhs->virtual_device_id;
148 }
149 Target defaulted_target;
150 if (lhs->target.defined()) {
151 defaulted_target = lhs->target;
152 } else {
153 // We can only default to the rhs's target if it is consistent with the device type
154 if (rhs->target.defined() && rhs->target->GetTargetDeviceType() == defaulted_device_type) {
155 defaulted_target = rhs->target;
156 }
157 // else: leave as null
158 }
159 MemoryScope defaulted_memory_scope;
160 if (!lhs->memory_scope.empty()) {
161 defaulted_memory_scope = lhs->memory_scope;
162 } else {
163 defaulted_memory_scope = rhs->memory_scope;
164 }
165 return VirtualDevice(defaulted_device_type, defaulted_virtual_device_id, defaulted_target,
166 defaulted_memory_scope);
167}
168
169VirtualDevice VirtualDeviceCache::Make(DLDeviceType device_type, int virtual_device_id,
170 Target target, MemoryScope memory_scope) {
171 VirtualDevice prototype(device_type, virtual_device_id, std::move(target),
172 std::move(memory_scope));
173 if (prototype->IsFullyUnconstrained()) {
174 return VirtualDevice::FullyUnconstrained();
175 }
176 auto itr = cache_.find(prototype);
177 if (itr == cache_.end()) {
178 cache_.emplace(prototype);
179 return prototype;
180 } else {
181 ICHECK_EQ(prototype->target.defined(), (*itr)->target.defined());
182 if (prototype->target.defined()) {
183 ICHECK_EQ(prototype->target->host.defined(), (*itr)->target->host.defined());
184 }
185 return *itr;
186 }
187}
188
189VirtualDevice VirtualDeviceCache::Unique(const VirtualDevice& virtual_device) {
190 return Make(virtual_device->device_type(), virtual_device->virtual_device_id,
191 virtual_device->target, virtual_device->memory_scope);
192}
193
194TVM_REGISTER_GLOBAL("target.VirtualDevice_ForDeviceTargetAndMemoryScope")
195 .set_body_typed(VirtualDevice::ForDeviceTargetAndMemoryScope);
196
197} // namespace tvm
198