1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/target/virtual_device.cc |
22 | * \brief A compile time representation for where data is to be stored at runtime, and how to |
23 | * compile code to compute it. |
24 | */ |
25 | #include <tvm/node/reflection.h> |
26 | #include <tvm/runtime/device_api.h> |
27 | #include <tvm/target/virtual_device.h> |
28 | |
29 | namespace tvm { |
30 | |
31 | TVM_REGISTER_NODE_TYPE(VirtualDeviceNode); |
32 | |
33 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
34 | .set_dispatch<VirtualDeviceNode>([](const ObjectRef& ref, ReprPrinter* p) { |
35 | auto* node = ref.as<VirtualDeviceNode>(); |
36 | p->stream << "VirtualDevice(" ; |
37 | if (node->IsFullyUnconstrained()) { |
38 | p->stream << "?" ; |
39 | } else { |
40 | bool need_sep = false; |
41 | if (node->device_type() != kInvalidDeviceType) { |
42 | p->stream << "device_type=" << node->device_type(); |
43 | need_sep = true; |
44 | } |
45 | if (node->virtual_device_id >= 0) { |
46 | if (need_sep) { |
47 | p->stream << ", " ; |
48 | } |
49 | p->stream << "virtual_device_id=" << node->virtual_device_id; |
50 | need_sep = true; |
51 | } |
52 | if (node->target.defined()) { |
53 | if (need_sep) { |
54 | p->stream << ", " ; |
55 | } |
56 | p->stream << "target=" << node->target->ToDebugString(); |
57 | need_sep = true; |
58 | } |
59 | if (!node->memory_scope.empty()) { |
60 | if (need_sep) { |
61 | p->stream << ", " ; |
62 | } |
63 | p->stream << "memory_scope='" << node->memory_scope << "'" ; |
64 | } |
65 | } |
66 | p->stream << ")" ; |
67 | }); |
68 | |
69 | VirtualDevice::VirtualDevice(DLDeviceType device_type, int virtual_device_id, Target target, |
70 | MemoryScope memory_scope) { |
71 | ICHECK(!target.defined() || device_type == target->GetTargetDeviceType()) |
72 | << "target " << target->ToDebugString() << " has device type " |
73 | << target->GetTargetDeviceType() << " but virtual device has device type " << device_type; |
74 | auto node = make_object<VirtualDeviceNode>(); |
75 | node->device_type_int = device_type; |
76 | node->virtual_device_id = virtual_device_id; |
77 | node->target = std::move(target); |
78 | node->memory_scope = std::move(memory_scope); |
79 | data_ = std::move(node); |
80 | } |
81 | |
82 | /* static */ VirtualDevice VirtualDevice::FullyUnconstrained() { |
83 | static const VirtualDevice unconstrained{}; |
84 | return unconstrained; |
85 | } |
86 | |
87 | /* static */ |
88 | Optional<VirtualDevice> VirtualDevice::Join(const VirtualDevice& lhs, const VirtualDevice& rhs) { |
89 | if (lhs == rhs) { |
90 | return lhs; |
91 | } |
92 | DLDeviceType joined_device_type; |
93 | if (lhs->device_type() != kInvalidDeviceType) { |
94 | joined_device_type = lhs->device_type(); |
95 | if (rhs->device_type() != kInvalidDeviceType && lhs->device_type() != rhs->device_type()) { |
96 | return {}; |
97 | } |
98 | } else { |
99 | joined_device_type = rhs->device_type(); |
100 | } |
101 | int joined_virtual_device_id; |
102 | if (lhs->virtual_device_id >= 0) { |
103 | joined_virtual_device_id = lhs->virtual_device_id; |
104 | if (rhs->virtual_device_id >= 0 && lhs->virtual_device_id != rhs->virtual_device_id) { |
105 | return {}; |
106 | } |
107 | } else { |
108 | joined_virtual_device_id = rhs->virtual_device_id; |
109 | } |
110 | Target joined_target; |
111 | if (lhs->target.defined()) { |
112 | joined_target = lhs->target; |
113 | if (rhs->target.defined() && lhs->target != rhs->target) { |
114 | return {}; |
115 | } |
116 | } else { |
117 | joined_target = rhs->target; |
118 | } |
119 | MemoryScope joined_memory_scope; |
120 | if (!lhs->memory_scope.empty()) { |
121 | joined_memory_scope = lhs->memory_scope; |
122 | if (!rhs->memory_scope.empty() && lhs->memory_scope != rhs->memory_scope) { |
123 | return {}; |
124 | } |
125 | } else { |
126 | joined_memory_scope = rhs->memory_scope; |
127 | } |
128 | return VirtualDevice(joined_device_type, joined_virtual_device_id, joined_target, |
129 | joined_memory_scope); |
130 | } |
131 | |
132 | /* static */ |
133 | VirtualDevice VirtualDevice::Default(const VirtualDevice& lhs, const VirtualDevice& rhs) { |
134 | if (lhs == rhs) { |
135 | return lhs; |
136 | } |
137 | DLDeviceType defaulted_device_type; |
138 | if (lhs->device_type() != kInvalidDeviceType) { |
139 | defaulted_device_type = lhs->device_type(); |
140 | } else { |
141 | defaulted_device_type = rhs->device_type(); |
142 | } |
143 | int defaulted_virtual_device_id; |
144 | if (lhs->virtual_device_id >= 0) { |
145 | defaulted_virtual_device_id = lhs->virtual_device_id; |
146 | } else { |
147 | defaulted_virtual_device_id = rhs->virtual_device_id; |
148 | } |
149 | Target defaulted_target; |
150 | if (lhs->target.defined()) { |
151 | defaulted_target = lhs->target; |
152 | } else { |
153 | // We can only default to the rhs's target if it is consistent with the device type |
154 | if (rhs->target.defined() && rhs->target->GetTargetDeviceType() == defaulted_device_type) { |
155 | defaulted_target = rhs->target; |
156 | } |
157 | // else: leave as null |
158 | } |
159 | MemoryScope defaulted_memory_scope; |
160 | if (!lhs->memory_scope.empty()) { |
161 | defaulted_memory_scope = lhs->memory_scope; |
162 | } else { |
163 | defaulted_memory_scope = rhs->memory_scope; |
164 | } |
165 | return VirtualDevice(defaulted_device_type, defaulted_virtual_device_id, defaulted_target, |
166 | defaulted_memory_scope); |
167 | } |
168 | |
169 | VirtualDevice VirtualDeviceCache::Make(DLDeviceType device_type, int virtual_device_id, |
170 | Target target, MemoryScope memory_scope) { |
171 | VirtualDevice prototype(device_type, virtual_device_id, std::move(target), |
172 | std::move(memory_scope)); |
173 | if (prototype->IsFullyUnconstrained()) { |
174 | return VirtualDevice::FullyUnconstrained(); |
175 | } |
176 | auto itr = cache_.find(prototype); |
177 | if (itr == cache_.end()) { |
178 | cache_.emplace(prototype); |
179 | return prototype; |
180 | } else { |
181 | ICHECK_EQ(prototype->target.defined(), (*itr)->target.defined()); |
182 | if (prototype->target.defined()) { |
183 | ICHECK_EQ(prototype->target->host.defined(), (*itr)->target->host.defined()); |
184 | } |
185 | return *itr; |
186 | } |
187 | } |
188 | |
189 | VirtualDevice VirtualDeviceCache::Unique(const VirtualDevice& virtual_device) { |
190 | return Make(virtual_device->device_type(), virtual_device->virtual_device_id, |
191 | virtual_device->target, virtual_device->memory_scope); |
192 | } |
193 | |
194 | TVM_REGISTER_GLOBAL("target.VirtualDevice_ForDeviceTargetAndMemoryScope" ) |
195 | .set_body_typed(VirtualDevice::ForDeviceTargetAndMemoryScope); |
196 | |
197 | } // namespace tvm |
198 | |