1#pragma once
2
3#include <ATen/core/jit_type.h>
4#include <ATen/core/stack.h>
5#include <c10/util/hash.h>
6#include <c10/util/irange.h>
7#include <torch/csrc/Export.h>
8#include <torch/csrc/autograd/variable.h>
9#include <torch/csrc/jit/ir/ir.h>
10#include <iostream>
11#include <vector>
12
13C10_CLANG_DIAGNOSTIC_PUSH()
14#if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32")
15C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32")
16#endif
17
18namespace torch {
19namespace jit {
20
21// GraphExecutor creates specializations of Graphs for different
22// dimensionalitities and types of inputs.
23
24struct ArgumentInfo {
25 friend struct ArgumentSpec;
26 using plain_data_type = uint64_t;
27
28 bool defined() const {
29 return defined_;
30 }
31 at::Device device() const {
32 return at::Device(DeviceType(dev_type_), device_);
33 }
34 // XXX: It is guaranteed that this will return false when called on non-tensor
35 // arguments
36 bool requires_grad() const {
37 return requires_grad_;
38 }
39 int dim() const {
40 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
41 return dim_;
42 }
43 at::ScalarType type() const {
44 return at::ScalarType(type_);
45 }
46 TypePtr toType() const {
47 if (!defined())
48 return TensorType::get();
49
50 return TensorType::create(
51 type(), device(), c10::optional<size_t>(dim()), requires_grad());
52 }
53 operator TypePtr() const {
54 return toType();
55 }
56
57 private:
58 unsigned defined_ : 1;
59 unsigned requires_grad_ : 1;
60 unsigned : 5;
61 unsigned dim_ : 8;
62 unsigned device_ : 8;
63 unsigned type_ : 8;
64 unsigned dev_type_ : 16;
65 unsigned : 16;
66};
67
68static_assert(
69 std::is_standard_layout<ArgumentInfo>::value,
70 "ArgumentInfo is to be a POD struct");
71static_assert(
72 sizeof(ArgumentInfo) == sizeof(ArgumentInfo::plain_data_type),
73 "ArgumentInfo is expected to be a 32-bit struct");
74
75struct ArgumentSpec {
76 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
77 ArgumentSpec(size_t num_flat_tensor_inputs, size_t num_flat_optional_inputs) {
78 hash_code =
79 c10::hash_combine(num_flat_tensor_inputs, num_flat_optional_inputs);
80 tensor_args.reserve(num_flat_tensor_inputs);
81 optional_presence.reserve(num_flat_optional_inputs);
82 }
83
84 void addOptional(const IValue& input) {
85 bool is_present = !input.isNone();
86 optional_presence.push_back(is_present);
87 hash_code = c10::hash_combine(hash_code, is_present);
88 }
89
90 void addTensor(const IValue& input, bool with_grad) {
91 AT_ASSERT(input.isTensor(), "Expected Tensor but found ", input.tagKind());
92 tensor_args.emplace_back();
93 auto& arg = tensor_args.back();
94 // Initialize all fields to 0. This is convenient, because e.g.
95 // requires_grad() can be checked even on tensors AND will make
96 // padding bits all 0s.
97 std::memset(&arg, 0, sizeof(ArgumentInfo));
98
99 // [argspec refcounting] reinterpret the IValue to avoid having to refcount
100 // the Tensor microbenchmarks
101 // https://github.com/zdevito/pytorch/commit/21e7200a0a0fc456bea2f10e95b1781f83933d10
102 // show overhead in extra refcounting along this path
103 const at::Tensor* t = reinterpret_cast<const at::Tensor*>(&input);
104 arg.defined_ = t->defined();
105 if (arg.defined_) {
106 arg.requires_grad_ = with_grad && autograd::Variable(*t).requires_grad();
107 arg.dim_ = t->dim();
108 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
109 at::Device device = t->device();
110 arg.dev_type_ =
111 // NOLINTNEXTLINE(bugprone-signed-char-misuse)
112 static_cast<std::underlying_type<DeviceType>::type>(device.type());
113 // NOLINTNEXTLINE(bugprone-signed-char-misuse)
114 arg.device_ = device.index();
115 arg.type_ = static_cast<unsigned>(t->scalar_type());
116 }
117 combineHash(arg);
118 }
119
120 void combineHash(const ArgumentInfo& arg) {
121 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
122 ArgumentInfo::plain_data_type arg_data;
123 std::memcpy(&arg_data, &arg, sizeof(ArgumentInfo));
124 hash_code = c10::hash_combine(hash_code, arg_data);
125 }
126
127 // equality is fast: check ninputs, and then check the raw array data,
128 // there are no size/stride indirections
129 // hopefully std::vector<bool> has fast equality
130 bool operator==(const ArgumentSpec& spec) const {
131 if (optional_presence != spec.optional_presence) {
132 return false;
133 }
134 if (tensor_args.size() != spec.tensor_args.size())
135 return false;
136 // NB: we need to break out early when there are no elements, because
137 // passing a nullptr to memcmp is UB.
138 if (tensor_args.empty())
139 return true;
140 return std::memcmp(
141 tensor_args.data(),
142 spec.tensor_args.data(),
143 tensor_args.size() * sizeof(ArgumentInfo)) == 0;
144 }
145 bool operator!=(const ArgumentSpec& spec) const {
146 return !(*this == spec);
147 }
148 size_t numTensors() const {
149 return tensor_args.size();
150 }
151 const ArgumentInfo& tensorAt(size_t i) const {
152 return tensor_args[i];
153 }
154 size_t numOptionals() const {
155 return optional_presence.size();
156 }
157 bool isPresent(size_t i) const {
158 return optional_presence[i];
159 }
160 size_t hashCode() const {
161 return hash_code;
162 }
163
164 private:
165 size_t hash_code; // precomputed on construction
166 std::vector<ArgumentInfo> tensor_args;
167 std::vector<bool> optional_presence;
168};
169
170namespace {
171static constexpr size_t ARG_SPEC_DEPTH_LIMIT = 128;
172}
173
174// ArgumentSpecCreator takes an initial graph and comes up with a set
175// of simple instructions to compute the ArgumentSpec given a set of
176// input tensors.
177struct TORCH_API ArgumentSpecCreator {
178 // instructs acts on a stack of a list of input IValues
179 // at the beginning the stack contains a single list of the inputs to the
180 // function the ENTER_ instructs descend into subobjects and push new lists
181 // onto the stack
182 enum Inst : char {
183 ENTER_TUPLE, // consume a tuple ivalue from the top-most list, and push the
184 // list of its elements onto the stack as a new list
185 ENTER_OBJECT, // same as ENTER_TUPLE, but the input is a class
186 LEAVE, // pop the top-most list from the stack
187 SKIP, // consume an element from the top-most list, and discard
188 SPECIALIZE_OPTIONAL_TENSOR, // consume a optional tensor for the top-most
189 // list, and add it to the ArgSpec key being
190 // created
191 SPECIALIZE_TENSOR, // consume a tensor for the top-most
192 // list, and add it to the ArgSpec key being created
193 SPECIALIZE_OPTIONAL,
194 // consume a nontensor optional from the top-most list,
195 // and add it to the ArgSpec key being created
196 };
197 ArgumentSpecCreator(Graph& graph);
198 ArgumentSpec create(bool with_grad, const Stack& stack) const;
199 void specializeTypes(Graph& g, const ArgumentSpec& spec) const;
200 void dump() const;
201 using WrittenSlots = std::unordered_set<std::string>;
202
203 private:
204 void scan(
205 const TypePtr& typ,
206 size_t depth,
207 const WrittenSlots& written_slots);
208 size_t num_inputs_;
209 size_t num_tensors_ = 0;
210 size_t num_optionals_ = 0;
211 std::vector<Inst> instructions_;
212};
213
214// CompleteArgumentSpec represents one particular specialization.
215// It is designed so that it can be created, hashed, and compared quickly
216// since it is used along the hot-path of the JIT to check if the code
217// we have created is valid for the given inputs.
218
219// COmpleteArgumentInfoPOD is only used internally in CompleteArgumentSpec
220// API users should use ArgumentInfo
221struct CompleteArgumentInfoPOD {
222 // total size is 64-bit
223 unsigned is_tensor : 8; // all other fields are invalid if this is false
224 unsigned type : 8; // scalar type
225 unsigned defined : 1;
226 unsigned requires_grad : 1;
227 signed device : 14;
228 unsigned dev_type : 16;
229 unsigned
230 total_dims : 16; // all TensorInfoPODs are in CompleteArgumentSpec's
231 // tensor_info() array. total_dims is the total number of
232 // dimensions seen so far in all previous members of
233 // tensor_info(), including this tensor 2*total_dims
234 // becomes the offset into the sizes_strides list for the
235 // _next_ tensor in the tensor_info array for tensor 0,
236 // the offset is always 0
237};
238
239static_assert(
240 sizeof(CompleteArgumentInfoPOD) == sizeof(int64_t),
241 "CompleteArgumentInfoPOD must be 64-bit struct for CompleteArgumentSpec encoding to work");
242
243struct CompleteArgumentInfo;
244
245struct CompleteArgumentSpec {
246 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
247 CompleteArgumentSpec(bool with_grad, at::ArrayRef<IValue> inputs)
248 : hash_code(0), ninputs(inputs.size()) {
249 int32_t all_dims = 0;
250 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
251 const int32_t num_inputs = inputs.size();
252 for (const auto i : c10::irange(num_inputs)) {
253 if (!inputs[i].isTensor())
254 continue;
255 auto& tensor = inputs[i].toTensor();
256 all_dims += tensor.defined() ? tensor.ndimension() : 0;
257 }
258 // allocate enough room for all TensorPODs and dimensions
259 data.resize(ninputs + all_dims * 2);
260
261 // and reinterpret our data array as these structs
262 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
263 auto* pods = reinterpret_cast<CompleteArgumentInfoPOD*>(data.data());
264 int64_t* next_dim = sizes_strides();
265 int32_t total_dims = 0;
266 for (const auto i : c10::irange(num_inputs)) {
267 auto& pod = pods[i];
268 pod.is_tensor = static_cast<uint32_t>(inputs[i].isTensor());
269 if (pod.is_tensor) {
270 at::Tensor t = inputs[i].toTensor();
271 pod.defined = t.defined();
272 if (pod.defined) {
273 pod.type = static_cast<int>(t.scalar_type());
274 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
275 at::Device device = t.device();
276 // NOLINTNEXTLINE(bugprone-signed-char-misuse)
277 pod.dev_type = static_cast<std::underlying_type<DeviceType>::type>(
278 device.type());
279 // NOLINTNEXTLINE(bugprone-signed-char-misuse)
280 pod.device = device.index();
281 pod.requires_grad = with_grad && t.requires_grad();
282 total_dims += t.ndimension();
283 auto sizes = t.sizes();
284 std::copy(sizes.begin(), sizes.end(), next_dim);
285 next_dim += sizes.size();
286 auto strides = t.strides();
287 std::copy(strides.begin(), strides.end(), next_dim);
288 next_dim += strides.size();
289 }
290 }
291 // each POD has a running tally of all dimensions including its own
292 TORCH_CHECK(
293 total_dims < std::numeric_limits<uint16_t>::max(),
294 "The number of dims cannot be packed into CompleteArgumentSpec:",
295 total_dims);
296 pod.total_dims = total_dims;
297 }
298 // we precompute the hash_code to minimize the time inside of hash
299 // table operations where we may need to hold a compiler cache lock.
300 hash_code = c10::hash_combine(0, ninputs);
301 for (auto d : data) {
302 hash_code = c10::hash_combine(hash_code, d);
303 }
304 }
305
306 // equality is fast: check ninputs, and then check the raw array data,
307 // there are no size/stride indirections
308 bool operator==(const CompleteArgumentSpec& spec) const {
309 return ninputs == spec.ninputs && data == spec.data;
310 }
311 bool operator!=(const CompleteArgumentSpec& spec) const {
312 return !(*this == spec);
313 }
314 friend struct CompleteArgumentInfo;
315 CompleteArgumentInfo at(size_t i) const;
316 size_t size() const {
317 return ninputs;
318 }
319 size_t hashCode() const {
320 return hash_code;
321 }
322
323 private:
324 ArrayRef<CompleteArgumentInfoPOD> tensor_info() const {
325 return ArrayRef<CompleteArgumentInfoPOD>(
326 reinterpret_cast<const CompleteArgumentInfoPOD*>(data.data()), ninputs);
327 }
328 // the start of the sizes_strides information, which comes after the
329 // CompleteArgumentInfoPOD list.
330 const int64_t* sizes_strides() const {
331 return data.data() + ninputs;
332 }
333 int64_t* sizes_strides() {
334 return data.data() + ninputs;
335 }
336 size_t hash_code; // precomputed on construction
337 size_t ninputs;
338 // layout is ninputs of TensorPOD (each 64-bit) followed by their size and
339 // stride info for 3 tensors:
340 // [t0POD][t1POD][t2POD]...
341 // [t0 sizes][t0 strides][t1 sizes][t1 strides][t2 sizes][t2 strides]
342 std::vector<int64_t> data;
343};
344
345// public view of compressed CompleteArgumentInfo
346struct CompleteArgumentInfo {
347 CompleteArgumentInfo(const CompleteArgumentSpec& spec, const int i)
348 : spec(spec), i(i) {}
349 bool isTensor() const {
350 return pod(i).is_tensor;
351 }
352 at::ScalarType type() const {
353 return at::ScalarType(pod(i).type);
354 }
355 bool defined() const {
356 return pod(i).defined;
357 }
358 bool requires_grad() const {
359 return pod(i).requires_grad;
360 }
361 at::Device device() const {
362 return at::Device(
363 DeviceType(pod(i).dev_type),
364 static_cast<c10::DeviceIndex>(pod(i).device));
365 }
366 int ndimension() const {
367 // See [valid range], it is always valid to ask for offset for (i + 1)
368 return (sizes_strides_offset(i + 1) - sizes_strides_offset(i)) / 2;
369 }
370 at::IntArrayRef sizes() const {
371 return at::IntArrayRef(
372 spec.sizes_strides() + sizes_strides_offset(i), ndimension());
373 }
374 at::IntArrayRef strides() const {
375 int ndim = ndimension();
376 return at::IntArrayRef(
377 spec.sizes_strides() + sizes_strides_offset(i) + ndim, ndim);
378 }
379 operator TypePtr() const {
380 if (!defined())
381 return TensorType::get();
382 return TensorType::create(
383 type(),
384 device(),
385 c10::VaryingShape<int64_t>{sizes()},
386 c10::VaryingShape<int64_t>{strides()},
387 requires_grad());
388 }
389
390 private:
391 // offsetinto sizes_strides() array where the sizes start for tensor j
392 // [valid range] valid range is [0, ninputs]
393 // (i.e. you can ask for the offset at ninputs, which would be the offset of
394 // the next tensor if it existed)
395 int sizes_strides_offset(int j) const {
396 if (j == 0)
397 return 0;
398 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
399 return 2 * pod(j - 1).total_dims;
400 }
401 const CompleteArgumentInfoPOD& pod(int j) const {
402 return spec.tensor_info().at(j);
403 }
404 const CompleteArgumentSpec& spec;
405 const int i;
406};
407
408inline std::ostream& operator<<(std::ostream& out, const ArgumentInfo& info) {
409 if (!info.defined()) {
410 return out << "<undefined>";
411 }
412 out << "Tensor(device=" << info.device() << ", type=" << toString(info.type())
413 << ", requires_grad=" << info.requires_grad() << ", dims=" << info.dim()
414 << ")";
415 return out;
416}
417
418inline std::ostream& operator<<(std::ostream& out, const ArgumentSpec& spec) {
419 out << "{";
420 for (const auto i : c10::irange(spec.numTensors())) {
421 if (i > 0)
422 out << ", ";
423 out << spec.tensorAt(i);
424 }
425 out << "; ";
426 for (const auto i : c10::irange(spec.numOptionals())) {
427 if (i > 0)
428 out << ", ";
429 out << spec.isPresent(i);
430 }
431 out << "}";
432 return out;
433}
434
435inline std::ostream& operator<<(
436 std::ostream& out,
437 const CompleteArgumentInfo& info) {
438 if (!info.defined()) {
439 return out << "<undefined>";
440 }
441 out << "Tensor(device=" << info.device() << ", type=" << toString(info.type())
442 << ", requires_grad=" << info.requires_grad()
443 << ", sizes=" << info.sizes() << ", strides=" << info.strides() << ")";
444 return out;
445}
446
447inline std::ostream& operator<<(
448 std::ostream& out,
449 const CompleteArgumentSpec& spec) {
450 out << "{";
451 for (const auto i : c10::irange(spec.size())) {
452 if (i > 0)
453 out << ", ";
454 out << spec.at(i);
455 }
456 out << "}";
457 return out;
458}
459
460inline CompleteArgumentInfo CompleteArgumentSpec::at(size_t i) const {
461 return CompleteArgumentInfo(*this, i);
462}
463
464inline c10::optional<int8_t> convertOptional(
465 c10::optional<c10::ScalarType> const& from) {
466 return (from) ? c10::optional<int8_t>(static_cast<int8_t>(*from))
467 : c10::optional<int8_t>{};
468}
469
470} // namespace jit
471} // namespace torch
472
473namespace std {
474
475template <typename T>
476struct hash<c10::VaryingShape<T>> {
477 size_t operator()(const c10::VaryingShape<T>& vs) const {
478 return c10::get_hash(
479 vs.size(),
480 vs.size() ? vs.sizes().value() : std::vector<c10::optional<T>>());
481 }
482};
483
484template <>
485struct hash<c10::TensorType> {
486 size_t operator()(const c10::TensorType& ptt) const {
487 return c10::get_hash<
488 c10::optional<int8_t>,
489 c10::VaryingShape<int64_t>,
490 c10::VaryingShape<int64_t>,
491 c10::optional<bool>>(
492 torch::jit::convertOptional(ptt.scalarType()),
493 ptt.sizes(),
494 ptt.strides(),
495 ptt.requiresGrad());
496 }
497};
498
499template <>
500struct hash<torch::jit::ArgumentSpec> {
501 size_t operator()(const torch::jit::ArgumentSpec& spec) const {
502 return spec.hashCode();
503 }
504};
505template <>
506struct hash<torch::jit::CompleteArgumentSpec> {
507 size_t operator()(const torch::jit::CompleteArgumentSpec& spec) const {
508 return spec.hashCode();
509 }
510};
511} // namespace std
512
513C10_CLANG_DIAGNOSTIC_POP()
514