1 | #pragma once |
2 | |
3 | #include <ATen/core/jit_type.h> |
4 | #include <ATen/core/stack.h> |
5 | #include <c10/util/hash.h> |
6 | #include <c10/util/irange.h> |
7 | #include <torch/csrc/Export.h> |
8 | #include <torch/csrc/autograd/variable.h> |
9 | #include <torch/csrc/jit/ir/ir.h> |
10 | #include <iostream> |
11 | #include <vector> |
12 | |
13 | C10_CLANG_DIAGNOSTIC_PUSH() |
14 | #if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32") |
15 | C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32" ) |
16 | #endif |
17 | |
18 | namespace torch { |
19 | namespace jit { |
20 | |
21 | // GraphExecutor creates specializations of Graphs for different |
22 | // dimensionalitities and types of inputs. |
23 | |
24 | struct ArgumentInfo { |
25 | friend struct ArgumentSpec; |
26 | using plain_data_type = uint64_t; |
27 | |
28 | bool defined() const { |
29 | return defined_; |
30 | } |
31 | at::Device device() const { |
32 | return at::Device(DeviceType(dev_type_), device_); |
33 | } |
34 | // XXX: It is guaranteed that this will return false when called on non-tensor |
35 | // arguments |
36 | bool requires_grad() const { |
37 | return requires_grad_; |
38 | } |
39 | int dim() const { |
40 | // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) |
41 | return dim_; |
42 | } |
43 | at::ScalarType type() const { |
44 | return at::ScalarType(type_); |
45 | } |
46 | TypePtr toType() const { |
47 | if (!defined()) |
48 | return TensorType::get(); |
49 | |
50 | return TensorType::create( |
51 | type(), device(), c10::optional<size_t>(dim()), requires_grad()); |
52 | } |
53 | operator TypePtr() const { |
54 | return toType(); |
55 | } |
56 | |
57 | private: |
58 | unsigned defined_ : 1; |
59 | unsigned requires_grad_ : 1; |
60 | unsigned : 5; |
61 | unsigned dim_ : 8; |
62 | unsigned device_ : 8; |
63 | unsigned type_ : 8; |
64 | unsigned dev_type_ : 16; |
65 | unsigned : 16; |
66 | }; |
67 | |
68 | static_assert( |
69 | std::is_standard_layout<ArgumentInfo>::value, |
70 | "ArgumentInfo is to be a POD struct" ); |
71 | static_assert( |
72 | sizeof(ArgumentInfo) == sizeof(ArgumentInfo::plain_data_type), |
73 | "ArgumentInfo is expected to be a 32-bit struct" ); |
74 | |
75 | struct ArgumentSpec { |
76 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
77 | ArgumentSpec(size_t num_flat_tensor_inputs, size_t num_flat_optional_inputs) { |
78 | hash_code = |
79 | c10::hash_combine(num_flat_tensor_inputs, num_flat_optional_inputs); |
80 | tensor_args.reserve(num_flat_tensor_inputs); |
81 | optional_presence.reserve(num_flat_optional_inputs); |
82 | } |
83 | |
84 | void addOptional(const IValue& input) { |
85 | bool is_present = !input.isNone(); |
86 | optional_presence.push_back(is_present); |
87 | hash_code = c10::hash_combine(hash_code, is_present); |
88 | } |
89 | |
90 | void addTensor(const IValue& input, bool with_grad) { |
91 | AT_ASSERT(input.isTensor(), "Expected Tensor but found " , input.tagKind()); |
92 | tensor_args.emplace_back(); |
93 | auto& arg = tensor_args.back(); |
94 | // Initialize all fields to 0. This is convenient, because e.g. |
95 | // requires_grad() can be checked even on tensors AND will make |
96 | // padding bits all 0s. |
97 | std::memset(&arg, 0, sizeof(ArgumentInfo)); |
98 | |
99 | // [argspec refcounting] reinterpret the IValue to avoid having to refcount |
100 | // the Tensor microbenchmarks |
101 | // https://github.com/zdevito/pytorch/commit/21e7200a0a0fc456bea2f10e95b1781f83933d10 |
102 | // show overhead in extra refcounting along this path |
103 | const at::Tensor* t = reinterpret_cast<const at::Tensor*>(&input); |
104 | arg.defined_ = t->defined(); |
105 | if (arg.defined_) { |
106 | arg.requires_grad_ = with_grad && autograd::Variable(*t).requires_grad(); |
107 | arg.dim_ = t->dim(); |
108 | // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) |
109 | at::Device device = t->device(); |
110 | arg.dev_type_ = |
111 | // NOLINTNEXTLINE(bugprone-signed-char-misuse) |
112 | static_cast<std::underlying_type<DeviceType>::type>(device.type()); |
113 | // NOLINTNEXTLINE(bugprone-signed-char-misuse) |
114 | arg.device_ = device.index(); |
115 | arg.type_ = static_cast<unsigned>(t->scalar_type()); |
116 | } |
117 | combineHash(arg); |
118 | } |
119 | |
120 | void combineHash(const ArgumentInfo& arg) { |
121 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
122 | ArgumentInfo::plain_data_type arg_data; |
123 | std::memcpy(&arg_data, &arg, sizeof(ArgumentInfo)); |
124 | hash_code = c10::hash_combine(hash_code, arg_data); |
125 | } |
126 | |
127 | // equality is fast: check ninputs, and then check the raw array data, |
128 | // there are no size/stride indirections |
129 | // hopefully std::vector<bool> has fast equality |
130 | bool operator==(const ArgumentSpec& spec) const { |
131 | if (optional_presence != spec.optional_presence) { |
132 | return false; |
133 | } |
134 | if (tensor_args.size() != spec.tensor_args.size()) |
135 | return false; |
136 | // NB: we need to break out early when there are no elements, because |
137 | // passing a nullptr to memcmp is UB. |
138 | if (tensor_args.empty()) |
139 | return true; |
140 | return std::memcmp( |
141 | tensor_args.data(), |
142 | spec.tensor_args.data(), |
143 | tensor_args.size() * sizeof(ArgumentInfo)) == 0; |
144 | } |
145 | bool operator!=(const ArgumentSpec& spec) const { |
146 | return !(*this == spec); |
147 | } |
148 | size_t numTensors() const { |
149 | return tensor_args.size(); |
150 | } |
151 | const ArgumentInfo& tensorAt(size_t i) const { |
152 | return tensor_args[i]; |
153 | } |
154 | size_t numOptionals() const { |
155 | return optional_presence.size(); |
156 | } |
157 | bool isPresent(size_t i) const { |
158 | return optional_presence[i]; |
159 | } |
160 | size_t hashCode() const { |
161 | return hash_code; |
162 | } |
163 | |
164 | private: |
165 | size_t hash_code; // precomputed on construction |
166 | std::vector<ArgumentInfo> tensor_args; |
167 | std::vector<bool> optional_presence; |
168 | }; |
169 | |
170 | namespace { |
171 | static constexpr size_t ARG_SPEC_DEPTH_LIMIT = 128; |
172 | } |
173 | |
174 | // ArgumentSpecCreator takes an initial graph and comes up with a set |
175 | // of simple instructions to compute the ArgumentSpec given a set of |
176 | // input tensors. |
177 | struct TORCH_API ArgumentSpecCreator { |
178 | // instructs acts on a stack of a list of input IValues |
179 | // at the beginning the stack contains a single list of the inputs to the |
180 | // function the ENTER_ instructs descend into subobjects and push new lists |
181 | // onto the stack |
182 | enum Inst : char { |
183 | ENTER_TUPLE, // consume a tuple ivalue from the top-most list, and push the |
184 | // list of its elements onto the stack as a new list |
185 | ENTER_OBJECT, // same as ENTER_TUPLE, but the input is a class |
186 | LEAVE, // pop the top-most list from the stack |
187 | SKIP, // consume an element from the top-most list, and discard |
188 | SPECIALIZE_OPTIONAL_TENSOR, // consume a optional tensor for the top-most |
189 | // list, and add it to the ArgSpec key being |
190 | // created |
191 | SPECIALIZE_TENSOR, // consume a tensor for the top-most |
192 | // list, and add it to the ArgSpec key being created |
193 | SPECIALIZE_OPTIONAL, |
194 | // consume a nontensor optional from the top-most list, |
195 | // and add it to the ArgSpec key being created |
196 | }; |
197 | ArgumentSpecCreator(Graph& graph); |
198 | ArgumentSpec create(bool with_grad, const Stack& stack) const; |
199 | void specializeTypes(Graph& g, const ArgumentSpec& spec) const; |
200 | void dump() const; |
201 | using WrittenSlots = std::unordered_set<std::string>; |
202 | |
203 | private: |
204 | void scan( |
205 | const TypePtr& typ, |
206 | size_t depth, |
207 | const WrittenSlots& written_slots); |
208 | size_t num_inputs_; |
209 | size_t num_tensors_ = 0; |
210 | size_t num_optionals_ = 0; |
211 | std::vector<Inst> instructions_; |
212 | }; |
213 | |
214 | // CompleteArgumentSpec represents one particular specialization. |
215 | // It is designed so that it can be created, hashed, and compared quickly |
216 | // since it is used along the hot-path of the JIT to check if the code |
217 | // we have created is valid for the given inputs. |
218 | |
219 | // COmpleteArgumentInfoPOD is only used internally in CompleteArgumentSpec |
220 | // API users should use ArgumentInfo |
221 | struct CompleteArgumentInfoPOD { |
222 | // total size is 64-bit |
223 | unsigned is_tensor : 8; // all other fields are invalid if this is false |
224 | unsigned type : 8; // scalar type |
225 | unsigned defined : 1; |
226 | unsigned requires_grad : 1; |
227 | signed device : 14; |
228 | unsigned dev_type : 16; |
229 | unsigned |
230 | total_dims : 16; // all TensorInfoPODs are in CompleteArgumentSpec's |
231 | // tensor_info() array. total_dims is the total number of |
232 | // dimensions seen so far in all previous members of |
233 | // tensor_info(), including this tensor 2*total_dims |
234 | // becomes the offset into the sizes_strides list for the |
235 | // _next_ tensor in the tensor_info array for tensor 0, |
236 | // the offset is always 0 |
237 | }; |
238 | |
239 | static_assert( |
240 | sizeof(CompleteArgumentInfoPOD) == sizeof(int64_t), |
241 | "CompleteArgumentInfoPOD must be 64-bit struct for CompleteArgumentSpec encoding to work" ); |
242 | |
243 | struct CompleteArgumentInfo; |
244 | |
245 | struct CompleteArgumentSpec { |
246 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
247 | CompleteArgumentSpec(bool with_grad, at::ArrayRef<IValue> inputs) |
248 | : hash_code(0), ninputs(inputs.size()) { |
249 | int32_t all_dims = 0; |
250 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
251 | const int32_t num_inputs = inputs.size(); |
252 | for (const auto i : c10::irange(num_inputs)) { |
253 | if (!inputs[i].isTensor()) |
254 | continue; |
255 | auto& tensor = inputs[i].toTensor(); |
256 | all_dims += tensor.defined() ? tensor.ndimension() : 0; |
257 | } |
258 | // allocate enough room for all TensorPODs and dimensions |
259 | data.resize(ninputs + all_dims * 2); |
260 | |
261 | // and reinterpret our data array as these structs |
262 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
263 | auto* pods = reinterpret_cast<CompleteArgumentInfoPOD*>(data.data()); |
264 | int64_t* next_dim = sizes_strides(); |
265 | int32_t total_dims = 0; |
266 | for (const auto i : c10::irange(num_inputs)) { |
267 | auto& pod = pods[i]; |
268 | pod.is_tensor = static_cast<uint32_t>(inputs[i].isTensor()); |
269 | if (pod.is_tensor) { |
270 | at::Tensor t = inputs[i].toTensor(); |
271 | pod.defined = t.defined(); |
272 | if (pod.defined) { |
273 | pod.type = static_cast<int>(t.scalar_type()); |
274 | // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) |
275 | at::Device device = t.device(); |
276 | // NOLINTNEXTLINE(bugprone-signed-char-misuse) |
277 | pod.dev_type = static_cast<std::underlying_type<DeviceType>::type>( |
278 | device.type()); |
279 | // NOLINTNEXTLINE(bugprone-signed-char-misuse) |
280 | pod.device = device.index(); |
281 | pod.requires_grad = with_grad && t.requires_grad(); |
282 | total_dims += t.ndimension(); |
283 | auto sizes = t.sizes(); |
284 | std::copy(sizes.begin(), sizes.end(), next_dim); |
285 | next_dim += sizes.size(); |
286 | auto strides = t.strides(); |
287 | std::copy(strides.begin(), strides.end(), next_dim); |
288 | next_dim += strides.size(); |
289 | } |
290 | } |
291 | // each POD has a running tally of all dimensions including its own |
292 | TORCH_CHECK( |
293 | total_dims < std::numeric_limits<uint16_t>::max(), |
294 | "The number of dims cannot be packed into CompleteArgumentSpec:" , |
295 | total_dims); |
296 | pod.total_dims = total_dims; |
297 | } |
298 | // we precompute the hash_code to minimize the time inside of hash |
299 | // table operations where we may need to hold a compiler cache lock. |
300 | hash_code = c10::hash_combine(0, ninputs); |
301 | for (auto d : data) { |
302 | hash_code = c10::hash_combine(hash_code, d); |
303 | } |
304 | } |
305 | |
306 | // equality is fast: check ninputs, and then check the raw array data, |
307 | // there are no size/stride indirections |
308 | bool operator==(const CompleteArgumentSpec& spec) const { |
309 | return ninputs == spec.ninputs && data == spec.data; |
310 | } |
311 | bool operator!=(const CompleteArgumentSpec& spec) const { |
312 | return !(*this == spec); |
313 | } |
314 | friend struct CompleteArgumentInfo; |
315 | CompleteArgumentInfo at(size_t i) const; |
316 | size_t size() const { |
317 | return ninputs; |
318 | } |
319 | size_t hashCode() const { |
320 | return hash_code; |
321 | } |
322 | |
323 | private: |
324 | ArrayRef<CompleteArgumentInfoPOD> tensor_info() const { |
325 | return ArrayRef<CompleteArgumentInfoPOD>( |
326 | reinterpret_cast<const CompleteArgumentInfoPOD*>(data.data()), ninputs); |
327 | } |
328 | // the start of the sizes_strides information, which comes after the |
329 | // CompleteArgumentInfoPOD list. |
330 | const int64_t* sizes_strides() const { |
331 | return data.data() + ninputs; |
332 | } |
333 | int64_t* sizes_strides() { |
334 | return data.data() + ninputs; |
335 | } |
336 | size_t hash_code; // precomputed on construction |
337 | size_t ninputs; |
338 | // layout is ninputs of TensorPOD (each 64-bit) followed by their size and |
339 | // stride info for 3 tensors: |
340 | // [t0POD][t1POD][t2POD]... |
341 | // [t0 sizes][t0 strides][t1 sizes][t1 strides][t2 sizes][t2 strides] |
342 | std::vector<int64_t> data; |
343 | }; |
344 | |
345 | // public view of compressed CompleteArgumentInfo |
346 | struct CompleteArgumentInfo { |
347 | CompleteArgumentInfo(const CompleteArgumentSpec& spec, const int i) |
348 | : spec(spec), i(i) {} |
349 | bool isTensor() const { |
350 | return pod(i).is_tensor; |
351 | } |
352 | at::ScalarType type() const { |
353 | return at::ScalarType(pod(i).type); |
354 | } |
355 | bool defined() const { |
356 | return pod(i).defined; |
357 | } |
358 | bool requires_grad() const { |
359 | return pod(i).requires_grad; |
360 | } |
361 | at::Device device() const { |
362 | return at::Device( |
363 | DeviceType(pod(i).dev_type), |
364 | static_cast<c10::DeviceIndex>(pod(i).device)); |
365 | } |
366 | int ndimension() const { |
367 | // See [valid range], it is always valid to ask for offset for (i + 1) |
368 | return (sizes_strides_offset(i + 1) - sizes_strides_offset(i)) / 2; |
369 | } |
370 | at::IntArrayRef sizes() const { |
371 | return at::IntArrayRef( |
372 | spec.sizes_strides() + sizes_strides_offset(i), ndimension()); |
373 | } |
374 | at::IntArrayRef strides() const { |
375 | int ndim = ndimension(); |
376 | return at::IntArrayRef( |
377 | spec.sizes_strides() + sizes_strides_offset(i) + ndim, ndim); |
378 | } |
379 | operator TypePtr() const { |
380 | if (!defined()) |
381 | return TensorType::get(); |
382 | return TensorType::create( |
383 | type(), |
384 | device(), |
385 | c10::VaryingShape<int64_t>{sizes()}, |
386 | c10::VaryingShape<int64_t>{strides()}, |
387 | requires_grad()); |
388 | } |
389 | |
390 | private: |
391 | // offsetinto sizes_strides() array where the sizes start for tensor j |
392 | // [valid range] valid range is [0, ninputs] |
393 | // (i.e. you can ask for the offset at ninputs, which would be the offset of |
394 | // the next tensor if it existed) |
395 | int sizes_strides_offset(int j) const { |
396 | if (j == 0) |
397 | return 0; |
398 | // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) |
399 | return 2 * pod(j - 1).total_dims; |
400 | } |
401 | const CompleteArgumentInfoPOD& pod(int j) const { |
402 | return spec.tensor_info().at(j); |
403 | } |
404 | const CompleteArgumentSpec& spec; |
405 | const int i; |
406 | }; |
407 | |
408 | inline std::ostream& operator<<(std::ostream& out, const ArgumentInfo& info) { |
409 | if (!info.defined()) { |
410 | return out << "<undefined>" ; |
411 | } |
412 | out << "Tensor(device=" << info.device() << ", type=" << toString(info.type()) |
413 | << ", requires_grad=" << info.requires_grad() << ", dims=" << info.dim() |
414 | << ")" ; |
415 | return out; |
416 | } |
417 | |
418 | inline std::ostream& operator<<(std::ostream& out, const ArgumentSpec& spec) { |
419 | out << "{" ; |
420 | for (const auto i : c10::irange(spec.numTensors())) { |
421 | if (i > 0) |
422 | out << ", " ; |
423 | out << spec.tensorAt(i); |
424 | } |
425 | out << "; " ; |
426 | for (const auto i : c10::irange(spec.numOptionals())) { |
427 | if (i > 0) |
428 | out << ", " ; |
429 | out << spec.isPresent(i); |
430 | } |
431 | out << "}" ; |
432 | return out; |
433 | } |
434 | |
435 | inline std::ostream& operator<<( |
436 | std::ostream& out, |
437 | const CompleteArgumentInfo& info) { |
438 | if (!info.defined()) { |
439 | return out << "<undefined>" ; |
440 | } |
441 | out << "Tensor(device=" << info.device() << ", type=" << toString(info.type()) |
442 | << ", requires_grad=" << info.requires_grad() |
443 | << ", sizes=" << info.sizes() << ", strides=" << info.strides() << ")" ; |
444 | return out; |
445 | } |
446 | |
447 | inline std::ostream& operator<<( |
448 | std::ostream& out, |
449 | const CompleteArgumentSpec& spec) { |
450 | out << "{" ; |
451 | for (const auto i : c10::irange(spec.size())) { |
452 | if (i > 0) |
453 | out << ", " ; |
454 | out << spec.at(i); |
455 | } |
456 | out << "}" ; |
457 | return out; |
458 | } |
459 | |
460 | inline CompleteArgumentInfo CompleteArgumentSpec::at(size_t i) const { |
461 | return CompleteArgumentInfo(*this, i); |
462 | } |
463 | |
464 | inline c10::optional<int8_t> convertOptional( |
465 | c10::optional<c10::ScalarType> const& from) { |
466 | return (from) ? c10::optional<int8_t>(static_cast<int8_t>(*from)) |
467 | : c10::optional<int8_t>{}; |
468 | } |
469 | |
470 | } // namespace jit |
471 | } // namespace torch |
472 | |
473 | namespace std { |
474 | |
475 | template <typename T> |
476 | struct hash<c10::VaryingShape<T>> { |
477 | size_t operator()(const c10::VaryingShape<T>& vs) const { |
478 | return c10::get_hash( |
479 | vs.size(), |
480 | vs.size() ? vs.sizes().value() : std::vector<c10::optional<T>>()); |
481 | } |
482 | }; |
483 | |
484 | template <> |
485 | struct hash<c10::TensorType> { |
486 | size_t operator()(const c10::TensorType& ptt) const { |
487 | return c10::get_hash< |
488 | c10::optional<int8_t>, |
489 | c10::VaryingShape<int64_t>, |
490 | c10::VaryingShape<int64_t>, |
491 | c10::optional<bool>>( |
492 | torch::jit::convertOptional(ptt.scalarType()), |
493 | ptt.sizes(), |
494 | ptt.strides(), |
495 | ptt.requiresGrad()); |
496 | } |
497 | }; |
498 | |
499 | template <> |
500 | struct hash<torch::jit::ArgumentSpec> { |
501 | size_t operator()(const torch::jit::ArgumentSpec& spec) const { |
502 | return spec.hashCode(); |
503 | } |
504 | }; |
505 | template <> |
506 | struct hash<torch::jit::CompleteArgumentSpec> { |
507 | size_t operator()(const torch::jit::CompleteArgumentSpec& spec) const { |
508 | return spec.hashCode(); |
509 | } |
510 | }; |
511 | } // namespace std |
512 | |
513 | C10_CLANG_DIAGNOSTIC_POP() |
514 | |