1 | #pragma once |
2 | |
3 | #include <c10/core/Backend.h> |
4 | #include <c10/core/DefaultDtype.h> |
5 | #include <c10/core/Device.h> |
6 | #include <c10/core/DispatchKeySet.h> |
7 | #include <c10/core/Layout.h> |
8 | #include <c10/core/MemoryFormat.h> |
9 | #include <c10/core/ScalarType.h> |
10 | #include <c10/core/ScalarTypeToTypeMeta.h> |
11 | |
12 | #include <c10/macros/Macros.h> |
13 | #include <c10/util/C++17.h> |
14 | #include <c10/util/Optional.h> |
15 | |
16 | #include <cstddef> |
17 | #include <iosfwd> |
18 | #include <utility> |
19 | |
20 | namespace c10 { |
21 | |
22 | DispatchKey computeDispatchKey( |
23 | c10::optional<ScalarType> dtype, |
24 | c10::optional<Layout> layout, |
25 | c10::optional<Device> device); |
26 | |
27 | inline ScalarType dtype_or_default(c10::optional<ScalarType> dtype) { |
28 | return value_or_else(dtype, [] { return get_default_dtype_as_scalartype(); }); |
29 | } |
30 | |
31 | inline caffe2::TypeMeta dtype_or_default( |
32 | c10::optional<caffe2::TypeMeta> dtype) { |
33 | return value_or_else(dtype, [] { return get_default_dtype(); }); |
34 | } |
35 | |
36 | inline Layout layout_or_default(c10::optional<Layout> layout) { |
37 | return layout.value_or(kStrided); |
38 | } |
39 | |
40 | inline Device device_or_default(c10::optional<Device> device) { |
41 | return value_or_else(device, [] { return Device(kCPU); }); |
42 | } |
43 | |
44 | inline bool pinned_memory_or_default(c10::optional<bool> pinned_memory) { |
45 | return pinned_memory.value_or(false); |
46 | } |
47 | |
48 | /// A class to encapsulate construction axes of an Tensor. TensorOptions was |
49 | /// designed to support the Python style API for specifying construction options |
50 | /// on factory functions, e.g., |
51 | /// |
52 | /// torch.zeros(2, 3, dtype=torch.int32) |
53 | /// |
54 | /// Because C++ doesn't natively support keyword arguments, there must be |
55 | /// another way of specifying keyword-like arguments. TensorOptions is a |
56 | /// builder class which can be used to construct this "dictionary" of keyword |
57 | /// arguments: functions which support TensorOptions conventionally take this |
58 | /// argument optionally as their last argument. |
59 | /// |
60 | /// WARNING: In PyTorch, there are `torch::` variants of factory functions, |
61 | /// e.g., torch::zeros for at::zeros. These return Variables (while the |
62 | /// stock ATen functions return plain Tensors). If you mix these functions |
63 | /// up, you WILL BE SAD. |
64 | /// |
65 | /// Rather than use the constructor of this class directly, you should prefer to |
66 | /// use the constructor functions, and then chain setter methods on top of them. |
67 | /// |
68 | /// at::device(at::kCUDA).dtype(kInt) |
69 | /// at::dtype(at::kInt) |
70 | /// |
71 | /// Additionally, anywhere a TensorOptions is expected, you can directly |
72 | /// pass at::kCUDA / at::kInt, and it will implicitly convert to a |
73 | /// TensorOptions. |
74 | /// |
75 | /// Here are some recommended ways to create a 2x2 tensor of zeros |
76 | /// with certain properties. These all *implicitly* make use of |
77 | /// TensorOptions, even if they don't mention the class explicitly: |
78 | /// |
79 | /// at::zeros({2,2}, at::kCUDA); |
80 | /// at::zeros({2,2}, at::kLong); |
81 | /// at::zeros({2,2}, at::device(at::kCUDA).dtype(at::kLong())); |
82 | /// at::zeros({2,2}, at::device({at::kCUDA, 1})); // place on device 1 |
83 | /// at::zeros({2,2}, at::requires_grad()); |
84 | /// |
85 | |
86 | /// NOTE [ TensorOptions Constructors ] |
87 | /// |
88 | /// TensorOptions is like a dictionary with entries from the set: |
89 | /// {requires_grad, device, dtype, layout}, where each entry may be |
90 | /// unspecified (i.e., is optional). It is used to specify the properties of |
91 | /// tensors in many places both in C++ internal and API, e.g., tensor factory |
92 | /// methods like `at::empty({10}, options)`, tensor conversions like |
93 | /// `tensor.to(...)`, etc. |
94 | /// |
95 | /// To provide a simple API that is consistent with Python, where one can do |
96 | /// `torch.empty(sizes, X)` with `X` being a `torch.device`, `torch.dtype`, or a |
97 | /// `torch.layout`, we want TensorOptions to be implicitly convertible from |
98 | /// `ScalarType dtype`, `Layout layout` and `Device device`. Therefore, we have |
99 | /// three implicit constructors from each of these three types. |
100 | /// |
101 | /// This is sufficient for `ScalarType` and `Layout` as they are simple Enum |
102 | /// classes. However, `Device` is an ordinary class with implicit constructors |
103 | /// `Device(DeviceType, DeviceIndex = -1)` and `Device(std::string)` to be |
104 | /// consistent with Python API, where strings are treated as equivalent with a |
105 | /// `torch.device` object (e.g., "cuda:1" can be passed to everywhere a |
106 | /// `torch.device("cuda:1")` is accepted). To support the syntax |
107 | /// `at::empty({10}, {kCUDA, 1})` and `tensor.to(kCUDA)`, we need to make sure |
108 | /// that `TensorOptions` is implicitly constructible with any argments that a |
109 | /// `Device` can constructed from. So we have, |
110 | /// |
111 | /// /* implicit */ TensorOptions(T&& device) : TensorOptions() { |
112 | /// this->set_device(device); |
113 | /// } |
114 | /// |
115 | /// template <typename... Args, |
116 | /// typename = std::enable_if_t<std::is_constructible<Device, |
117 | /// Args&&...>::value>> |
118 | /// /* implicit */ TensorOptions(Args&&... args) |
119 | /// : TensorOptions(Device(std::forward<Args>(args)...)) {} |
120 | /// |
121 | /// |
122 | /// But this will be problematic. Consider this: `TensorOptions({kCUDA, 1})`. |
123 | /// Compiler will compain about ambiguity between the copy constructor and the |
124 | /// `Device` constructor because `{kCUDA, 1}` can be converted to both a |
125 | /// `TensorOption` and a `Device`. |
126 | /// |
127 | /// To get around this, we templatize the `Device` constructor. Since overload |
128 | /// resolution is done before template resolution, our problem is solved. |
129 | |
130 | DispatchKey computeDispatchKey( |
131 | optional<ScalarType> dtype, |
132 | optional<Layout> layout, |
133 | optional<Device> device); |
134 | |
135 | struct C10_API TensorOptions { |
136 | TensorOptions() |
137 | : requires_grad_(false), |
138 | pinned_memory_(false), |
139 | has_device_(false), |
140 | has_dtype_(false), |
141 | has_layout_(false), |
142 | has_requires_grad_(false), |
143 | has_pinned_memory_(false), |
144 | has_memory_format_(false) {} |
145 | |
146 | /// Constructs a `TensorOptions` object with the given layout. |
147 | /* implicit */ TensorOptions(Layout layout) : TensorOptions() { |
148 | this->set_layout(layout); |
149 | } |
150 | |
151 | /// Constructs a `TensorOptions` object with the given device. |
152 | /// See NOTE [ TensorOptions Constructors ] on why this is templatized. |
153 | template < |
154 | typename T, |
155 | typename = std::enable_if_t<std::is_same<std::decay_t<T>, Device>::value>> |
156 | /* implicit */ TensorOptions(T&& device) : TensorOptions() { |
157 | this->set_device(std::forward<T>(device)); |
158 | } |
159 | |
160 | /// Constructs a `TensorOptions` object from arguments allowed in `Device` |
161 | /// constructors. |
162 | /// |
163 | /// See NOTE [ TensorOptions Constructors ]. |
164 | /// |
165 | /// NB: Ideally we only allow implicit constructors here. But there is no easy |
166 | /// way to detect them. So we have this one that allows explicit |
167 | /// constructors too. |
168 | template < |
169 | typename... Args, |
170 | typename = |
171 | std::enable_if_t<std::is_constructible<Device, Args&&...>::value>> |
172 | /* implicit */ TensorOptions(Args&&... args) |
173 | : TensorOptions(Device(std::forward<Args>(args)...)) {} |
174 | |
175 | /// Constructs a `TensorOptions` object with the given dtype. |
176 | /* implicit */ TensorOptions(caffe2::TypeMeta dtype) : TensorOptions() { |
177 | this->set_dtype(dtype); |
178 | } |
179 | |
180 | /// legacy constructor to support ScalarType |
181 | /* implicit */ TensorOptions(ScalarType dtype) : TensorOptions() { |
182 | this->set_dtype(dtype); |
183 | } |
184 | |
185 | /// Constructs a `TensorOptions` object with the given memory format. |
186 | /* implicit */ TensorOptions(MemoryFormat memory_format) : TensorOptions() { |
187 | set_memory_format(memory_format); |
188 | } |
189 | |
190 | /// Return a copy of `TensorOptions` with `device` set to the given one, or |
191 | /// cleared if `device` is `nullopt`. |
192 | C10_NODISCARD TensorOptions |
193 | device(c10::optional<Device> device) const noexcept { |
194 | TensorOptions r = *this; |
195 | r.set_device(device); |
196 | return r; |
197 | } |
198 | |
199 | /// Return a copy of `TensorOptions` with `device` set to the given one. |
200 | /// (This overload ensures that variadic template c10::optional constructor |
201 | /// for Device work correctly.) |
202 | template <typename... Args> |
203 | C10_NODISCARD TensorOptions device(Args&&... args) const noexcept { |
204 | return device( |
205 | c10::optional<Device>(c10::in_place, std::forward<Args>(args)...)); |
206 | } |
207 | |
208 | /// Return a copy of `TensorOptions`, but with device set to CUDA, and the |
209 | /// device index set to the given one. |
210 | /// |
211 | /// TODO: This function encourages bad behavior (assuming CUDA is |
212 | /// the only device that matters). Get rid of it / rename it. |
213 | C10_NODISCARD TensorOptions |
214 | device_index(c10::DeviceIndex device_index) const noexcept { |
215 | return device(Device::Type::CUDA, device_index); |
216 | } |
217 | |
218 | /// Return a copy of `TensorOptions` with `dtype` set to the given one. |
219 | C10_NODISCARD TensorOptions |
220 | dtype(c10::optional<caffe2::TypeMeta> dtype) const noexcept { |
221 | TensorOptions r = *this; |
222 | r.set_dtype(dtype); |
223 | return r; |
224 | } |
225 | |
226 | // legacy function to support ScalarType |
227 | C10_NODISCARD TensorOptions |
228 | dtype(c10::optional<ScalarType> dtype) const noexcept { |
229 | TensorOptions r = *this; |
230 | r.set_dtype(dtype); |
231 | return r; |
232 | } |
233 | |
234 | // Since dtype is taken... |
235 | template <typename T> |
236 | TensorOptions& dtype() { |
237 | dtype_ = caffe2::TypeMeta::Make<T>(); |
238 | has_dtype_ = true; |
239 | return *this; |
240 | } |
241 | |
242 | /// Sets the layout of the `TensorOptions`. |
243 | C10_NODISCARD TensorOptions |
244 | layout(c10::optional<Layout> layout) const noexcept { |
245 | TensorOptions r = *this; |
246 | r.set_layout(layout); |
247 | return r; |
248 | } |
249 | |
250 | /// Sets the `requires_grad` property of the `TensorOptions`. |
251 | C10_NODISCARD TensorOptions |
252 | requires_grad(c10::optional<bool> requires_grad) const noexcept { |
253 | TensorOptions r = *this; |
254 | r.set_requires_grad(requires_grad); |
255 | return r; |
256 | } |
257 | |
258 | /// Sets the `pinned_memory` property on the `TensorOptions`. |
259 | C10_NODISCARD TensorOptions |
260 | pinned_memory(c10::optional<bool> pinned_memory) const noexcept { |
261 | TensorOptions r = *this; |
262 | r.set_pinned_memory(pinned_memory); |
263 | return r; |
264 | } |
265 | |
266 | /// Sets the `memory_format` property on `TensorOptions`. |
267 | C10_NODISCARD TensorOptions |
268 | memory_format(c10::optional<MemoryFormat> memory_format) const noexcept { |
269 | TensorOptions r = *this; |
270 | r.set_memory_format(memory_format); |
271 | return r; |
272 | } |
273 | |
274 | /// Returns the device of the `TensorOptions`. |
275 | Device device() const noexcept { |
276 | return device_or_default(device_opt()); |
277 | } |
278 | |
279 | /// Returns whether the device is specified. |
280 | bool has_device() const noexcept { |
281 | return has_device_; |
282 | } |
283 | |
284 | /// Returns the device of the `TensorOptions`, or `c10::nullopt` if |
285 | /// device is not specified. |
286 | c10::optional<Device> device_opt() const noexcept { |
287 | return has_device_ ? c10::make_optional(device_) : c10::nullopt; |
288 | } |
289 | |
290 | /// Returns the device index of the `TensorOptions`. |
291 | int32_t device_index() const noexcept { |
292 | return device().index(); |
293 | } |
294 | |
295 | /// Returns the dtype of the `TensorOptions`. |
296 | caffe2::TypeMeta dtype() const noexcept { |
297 | return dtype_or_default(dtype_opt()); |
298 | } |
299 | |
300 | /// Returns whether the dtype is specified. |
301 | bool has_dtype() const noexcept { |
302 | return has_dtype_; |
303 | } |
304 | |
305 | /// Returns the dtype of the `TensorOptions`, or `c10::nullopt` if |
306 | /// device is not specified. |
307 | c10::optional<caffe2::TypeMeta> dtype_opt() const noexcept { |
308 | return has_dtype_ ? c10::make_optional(dtype_) : c10::nullopt; |
309 | } |
310 | |
311 | /// Returns the layout of the `TensorOptions`. |
312 | Layout layout() const noexcept { |
313 | return layout_or_default(layout_opt()); |
314 | } |
315 | |
316 | /// Returns whether the layout is specified. |
317 | bool has_layout() const noexcept { |
318 | return has_layout_; |
319 | } |
320 | |
321 | /// Returns the layout of the `TensorOptions`, or `c10::nullopt` if |
322 | /// layout is not specified. |
323 | c10::optional<Layout> layout_opt() const noexcept { |
324 | return has_layout_ ? c10::make_optional(layout_) : c10::nullopt; |
325 | } |
326 | |
327 | /// Returns the `requires_grad` property of the `TensorOptions`. |
328 | bool requires_grad() const noexcept { |
329 | return has_requires_grad_ ? requires_grad_ : false; |
330 | } |
331 | |
332 | /// Returns whether the `requires_grad` is specified. |
333 | bool has_requires_grad() const noexcept { |
334 | return has_requires_grad_; |
335 | } |
336 | |
337 | /// Returns the `requires_grad` property of the `TensorOptions`, or |
338 | /// `c10::nullopt` if `requires_grad` is not specified. |
339 | c10::optional<bool> requires_grad_opt() const noexcept { |
340 | return has_requires_grad_ ? c10::make_optional(requires_grad_) |
341 | : c10::nullopt; |
342 | } |
343 | |
344 | /// Returns the `pinned_memory` property of the `TensorOptions`. |
345 | bool pinned_memory() const noexcept { |
346 | return pinned_memory_or_default(pinned_memory_opt()); |
347 | } |
348 | |
349 | /// Returns whether the `pinned_memory` is specified. |
350 | bool has_pinned_memory() const noexcept { |
351 | return has_pinned_memory_; |
352 | } |
353 | |
354 | /// Returns if the layout is sparse |
355 | bool is_sparse() const { |
356 | return layout_ == c10::Layout::Sparse; |
357 | } |
358 | |
359 | bool is_sparse_csr() const { |
360 | return layout_ == c10::Layout::SparseCsr; |
361 | } |
362 | |
363 | // For compatibility with legacy tensor.type() comparisons |
364 | bool type_equal(const TensorOptions& other) const { |
365 | return computeDispatchKey() == other.computeDispatchKey() && |
366 | typeMetaToScalarType(dtype_) == typeMetaToScalarType(other.dtype()); |
367 | } |
368 | |
369 | /// Returns the `pinned_memory` property of the `TensorOptions`, or |
370 | /// `c10::nullopt` if `pinned_memory` is not specified. |
371 | c10::optional<bool> pinned_memory_opt() const noexcept { |
372 | return has_pinned_memory_ ? c10::make_optional(pinned_memory_) |
373 | : c10::nullopt; |
374 | } |
375 | |
376 | /// Returns whether the `memory_layout` is specified |
377 | bool has_memory_format() const noexcept { |
378 | return has_memory_format_; |
379 | } |
380 | |
381 | // NB: memory_format() getter is PURPOSELY not defined, as the default |
382 | // behavior of memory_format varies from function to function. |
383 | |
384 | /// Returns the `memory_layout` property of `TensorOptions, or |
385 | /// `c10::nullopt` if `memory_format` is not specified. |
386 | c10::optional<MemoryFormat> memory_format_opt() const noexcept { |
387 | return has_memory_format_ ? c10::make_optional(memory_format_) |
388 | : c10::nullopt; |
389 | } |
390 | |
391 | // Resolves the ATen backend specified by the current construction axes. |
392 | // TODO: Deprecate this |
393 | Backend backend() const { |
394 | return at::dispatchKeyToBackend(computeDispatchKey()); |
395 | } |
396 | |
397 | /// Return the right-biased merge of two TensorOptions. This has the |
398 | /// effect of overwriting settings from self with specified options |
399 | /// of options. |
400 | /// |
401 | /// NB: This merging operation does NOT respect device merges. |
402 | /// For example, if you device({kCUDA, 1}).merge_in(kCUDA) |
403 | /// you will get kCUDA in the end! Functions like Tensor.new_empty |
404 | /// ensure the right device is selected anyway by way of a |
405 | /// device guard. |
406 | /// |
407 | TensorOptions merge_in(TensorOptions options) const noexcept { |
408 | TensorOptions merged = *this; |
409 | if (options.has_device()) |
410 | merged.set_device(options.device_opt()); |
411 | if (options.has_dtype()) |
412 | merged.set_dtype(options.dtype_opt()); |
413 | if (options.has_layout()) |
414 | merged.set_layout(options.layout_opt()); |
415 | // NB: requires grad is right biased; not a logical AND/OR! |
416 | if (options.has_requires_grad()) |
417 | merged.set_requires_grad(options.requires_grad_opt()); |
418 | if (options.has_pinned_memory()) |
419 | merged.set_pinned_memory(options.pinned_memory_opt()); |
420 | if (options.has_memory_format()) |
421 | merged.set_memory_format(options.memory_format_opt()); |
422 | return merged; |
423 | } |
424 | |
425 | // TODO remove after TensorOptions rationalization |
426 | TensorOptions merge_memory_format( |
427 | c10::optional<MemoryFormat> optional_memory_format) const noexcept { |
428 | TensorOptions merged = *this; |
429 | if (optional_memory_format.has_value()) { |
430 | merged.set_memory_format(*optional_memory_format); |
431 | } |
432 | return merged; |
433 | } |
434 | |
435 | // INVARIANT: computeDispatchKey returns only the subset of dispatch keys for |
436 | // which dispatchKeyToBackend is injective, if it is defined at all (for |
437 | // the most part, this just means that this function never returns an |
438 | // Autograd key) |
439 | DispatchKey computeDispatchKey() const { |
440 | return c10::computeDispatchKey( |
441 | optTypeMetaToScalarType(dtype_opt()), layout_opt(), device_opt()); |
442 | } |
443 | |
444 | private: |
445 | // These methods are currently private because I'm not sure if it's wise |
446 | // to actually publish them. They are methods because I need them in |
447 | // the constructor and the functional API implementation. |
448 | // |
449 | // If you really, really need it, you can make these public, but check if you |
450 | // couldn't just do what you need with the functional API. Similarly, these |
451 | // methods are not chainable, because if you wanted chaining, you probably |
452 | // want to use the functional API instead. (It's probably OK to make |
453 | // these chainable, because these functions are all explicitly annotated |
454 | // with a ref-qualifier, the trailing &, that makes them illegal to call |
455 | // on temporaries.) |
456 | |
457 | /// Mutably set the device of `TensorOptions`. |
458 | void set_device(c10::optional<Device> device) & noexcept { |
459 | if (device) { |
460 | device_ = *device; |
461 | has_device_ = true; |
462 | } else { |
463 | has_device_ = false; |
464 | } |
465 | } |
466 | |
467 | /// Mutably set the dtype of `TensorOptions`. |
468 | void set_dtype(c10::optional<caffe2::TypeMeta> dtype) & noexcept { |
469 | if (dtype) { |
470 | dtype_ = *dtype; |
471 | has_dtype_ = true; |
472 | } else { |
473 | has_dtype_ = false; |
474 | } |
475 | } |
476 | |
477 | // legacy function to support ScalarType |
478 | void set_dtype(c10::optional<ScalarType> dtype) & noexcept { |
479 | if (dtype) { |
480 | dtype_ = scalarTypeToTypeMeta(*dtype); |
481 | has_dtype_ = true; |
482 | } else { |
483 | has_dtype_ = false; |
484 | } |
485 | } |
486 | |
487 | /// Mutably set the layout of `TensorOptions`. |
488 | void set_layout(c10::optional<Layout> layout) & noexcept { |
489 | if (layout) { |
490 | layout_ = *layout; |
491 | has_layout_ = true; |
492 | } else { |
493 | has_layout_ = false; |
494 | } |
495 | } |
496 | |
497 | /// Mutably set the `requires_grad` property of `TensorOptions`. |
498 | void set_requires_grad(c10::optional<bool> requires_grad) & noexcept { |
499 | if (requires_grad) { |
500 | requires_grad_ = *requires_grad; |
501 | has_requires_grad_ = true; |
502 | } else { |
503 | has_requires_grad_ = false; |
504 | } |
505 | } |
506 | |
507 | /// Mutably set the `pinned_memory` property of `TensorOptions`. |
508 | void set_pinned_memory(c10::optional<bool> pinned_memory) & noexcept { |
509 | if (pinned_memory) { |
510 | pinned_memory_ = *pinned_memory; |
511 | has_pinned_memory_ = true; |
512 | } else { |
513 | has_pinned_memory_ = false; |
514 | } |
515 | } |
516 | |
517 | /// Mutably set the `memory_Format` property of `TensorOptions`. |
518 | void set_memory_format(c10::optional<MemoryFormat> memory_format) & noexcept { |
519 | if (memory_format) { |
520 | memory_format_ = *memory_format; |
521 | has_memory_format_ = true; |
522 | } else { |
523 | has_memory_format_ = false; |
524 | } |
525 | } |
526 | |
527 | // WARNING: If you edit TensorOptions to add more options, you |
528 | // may need to adjust the implementation of Tensor::options. |
529 | // The criteria for whether or not Tensor::options must be adjusted |
530 | // is whether or not the new option you added should preserved |
531 | // by functions such as empty_like(); if it should be preserved, |
532 | // you must adjust options(). |
533 | // |
534 | // TODO: MemoryFormat is not implemented in this way |
535 | |
536 | // NB: We didn't use c10::optional here, because then we can't pack |
537 | // the has_***_ boolean fields. |
538 | |
539 | Device device_ = at::kCPU; // 16-bit |
540 | caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make<float>(); // 16-bit |
541 | Layout layout_ = at::kStrided; // 8-bit |
542 | MemoryFormat memory_format_ = MemoryFormat::Contiguous; // 8-bit |
543 | |
544 | // Bitmask required here to get this to fit inside 32 bits (or even 64 bits, |
545 | // for that matter) |
546 | |
547 | bool requires_grad_ : 1; |
548 | bool pinned_memory_ : 1; |
549 | |
550 | bool has_device_ : 1; |
551 | bool has_dtype_ : 1; |
552 | bool has_layout_ : 1; |
553 | bool has_requires_grad_ : 1; |
554 | bool has_pinned_memory_ : 1; |
555 | bool has_memory_format_ : 1; |
556 | }; |
557 | |
558 | // We should aspire to fit in one machine-size word; but a size greater than two |
559 | // words is too much. (We are doing terribly on 32-bit archs, where we require |
560 | // three machine size words to store tensor options. Eek!) |
561 | static_assert( |
562 | sizeof(TensorOptions) <= sizeof(int64_t) * 2, |
563 | "TensorOptions must fit in 128-bits" ); |
564 | |
565 | /// Convenience function that returns a `TensorOptions` object with the `dtype` |
566 | /// set to the given one. |
567 | inline TensorOptions dtype(caffe2::TypeMeta dtype) { |
568 | return TensorOptions().dtype(dtype); |
569 | } |
570 | |
571 | // legacy function to support ScalarType |
572 | inline TensorOptions dtype(ScalarType dtype) { |
573 | return TensorOptions().dtype(scalarTypeToTypeMeta(dtype)); |
574 | } |
575 | |
576 | /// Convenience function that returns a `TensorOptions` object with the `layout` |
577 | /// set to the given one. |
578 | inline TensorOptions layout(Layout layout) { |
579 | return TensorOptions().layout(layout); |
580 | } |
581 | |
582 | /// Convenience function that returns a `TensorOptions` object with the `device` |
583 | /// set to the given one. |
584 | inline TensorOptions device(Device device) { |
585 | return TensorOptions().device(device); |
586 | } |
587 | |
588 | /// Convenience function that returns a `TensorOptions` object with the |
589 | /// `device` set to CUDA and the `device_index` set to the given one. |
590 | inline TensorOptions device_index(int16_t device_index) { |
591 | return TensorOptions().device_index( |
592 | static_cast<c10::DeviceIndex>(device_index)); |
593 | } |
594 | |
595 | /// Convenience function that returns a `TensorOptions` object with the |
596 | /// `requires_grad` set to the given one. |
597 | inline TensorOptions requires_grad(bool requires_grad = true) { |
598 | return TensorOptions().requires_grad(requires_grad); |
599 | } |
600 | |
601 | /// Convenience function that returns a `TensorOptions` object with the |
602 | /// `memory_format` set to the given one. |
603 | inline TensorOptions memory_format(MemoryFormat memory_format) { |
604 | return TensorOptions().memory_format(memory_format); |
605 | } |
606 | |
607 | C10_API std::ostream& operator<<( |
608 | std::ostream& stream, |
609 | const TensorOptions& options); |
610 | |
611 | template <typename T> |
612 | inline TensorOptions dtype() { |
613 | return dtype(caffe2::TypeMeta::Make<T>()); |
614 | } |
615 | |
616 | inline std::string toString(const TensorOptions options) { |
617 | std::ostringstream stream; |
618 | stream << options; |
619 | return stream.str(); |
620 | } |
621 | |
622 | // This is intended to be a centralized location by which we can determine |
623 | // what an appropriate DispatchKey for a tensor is. |
624 | inline DispatchKey computeDispatchKey( |
625 | c10::optional<ScalarType> dtype, |
626 | c10::optional<Layout> layout, |
627 | c10::optional<Device> device) { |
628 | const auto layout_ = layout_or_default(layout); |
629 | const auto device_ = device_or_default(device); |
630 | switch (layout_) { |
631 | case Layout::Strided: { |
632 | const auto dtype_ = dtype_or_default(dtype); |
633 | switch (device_.type()) { |
634 | #define DO_CASE(device, _) \ |
635 | case DeviceType::device: { \ |
636 | if (isQIntType(dtype_)) { \ |
637 | return DispatchKey::Quantized##device; \ |
638 | } \ |
639 | return DispatchKey::device; \ |
640 | } |
641 | C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused) |
642 | #undef DO_CASE |
643 | case DeviceType::FPGA: |
644 | return DispatchKey::FPGA; |
645 | case DeviceType::ORT: |
646 | return DispatchKey::ORT; |
647 | case DeviceType::Vulkan: |
648 | return DispatchKey::Vulkan; |
649 | case DeviceType::Metal: |
650 | return DispatchKey::Metal; |
651 | case DeviceType::MKLDNN: |
652 | case DeviceType::OPENGL: |
653 | case DeviceType::OPENCL: |
654 | case DeviceType::IDEEP: |
655 | TORCH_INTERNAL_ASSERT( |
656 | 0, |
657 | "This is a grandfathered Caffe2 device type " , |
658 | device_.type(), |
659 | ", it shouldn't ever convert to a DispatchKey. File a bug describing what you were doing if you think this is in error." ); |
660 | default: |
661 | TORCH_CHECK_NOT_IMPLEMENTED( |
662 | false, |
663 | "Unsupported device type for dense layout: " , |
664 | device_.type()); |
665 | } |
666 | } |
667 | case Layout::Sparse: |
668 | switch (device_.type()) { |
669 | #define DO_CASE(device, _) \ |
670 | case DeviceType::device: { \ |
671 | return DispatchKey::Sparse##device; \ |
672 | } |
673 | C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused) |
674 | #undef DO_CASE |
675 | default: |
676 | TORCH_CHECK_NOT_IMPLEMENTED( |
677 | false, |
678 | "Unsupported device type for sparse layout: " , |
679 | device_.type()); |
680 | } |
681 | case Layout::Mkldnn: |
682 | switch (device_.type()) { |
683 | case DeviceType::CPU: |
684 | return DispatchKey::MkldnnCPU; |
685 | default: |
686 | TORCH_CHECK_NOT_IMPLEMENTED( |
687 | false, |
688 | "Unsupported device type for mkldnn layout: " , |
689 | device_.type()); |
690 | } |
691 | case Layout::SparseCsr: |
692 | case Layout::SparseCsc: |
693 | case Layout::SparseBsr: |
694 | case Layout::SparseBsc: |
695 | switch (device_.type()) { |
696 | case DeviceType::CPU: |
697 | return DispatchKey::SparseCsrCPU; |
698 | case DeviceType::CUDA: |
699 | return DispatchKey::SparseCsrCUDA; |
700 | default: |
701 | AT_ERROR( |
702 | "Unsupported device type for " , |
703 | layout_, |
704 | " layout: " , |
705 | device_.type()); |
706 | } |
707 | default: |
708 | TORCH_CHECK(false, "Unsupported layout: " , layout_); |
709 | } |
710 | } |
711 | |
712 | inline Layout dispatchKeyToLayout(DispatchKey dispatch_key) { |
713 | switch (dispatch_key) { |
714 | #define DO_CASE(bc, _) case DispatchKey::Sparse##bc: |
715 | C10_FORALL_BACKEND_COMPONENTS(DO_CASE, unused) |
716 | #undef DO_CASE |
717 | return Layout::Sparse; |
718 | case DispatchKey::SparseCsrCPU: |
719 | case DispatchKey::SparseCsrCUDA: |
720 | TORCH_CHECK( |
721 | false, |
722 | "Cannot map DispatchKey " , |
723 | dispatch_key, |
724 | " to a unique layout." ); |
725 | case DispatchKey::MkldnnCPU: |
726 | return Layout::Mkldnn; |
727 | default: |
728 | return Layout::Strided; |
729 | } |
730 | } |
731 | |
732 | inline DeviceType dispatchKeyToDeviceType(DispatchKey dispatch_key) { |
733 | switch (dispatch_key) { |
734 | // stuff that's real |
735 | #define DO_CASE(suffix, prefix) \ |
736 | case DispatchKey::prefix##suffix: \ |
737 | return DeviceType::suffix; |
738 | #define DO_CASES(_, prefix) C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, prefix) |
739 | C10_FORALL_FUNCTIONALITY_KEYS(DO_CASES) |
740 | #undef DO_CASES |
741 | #undef DO_CASE |
742 | |
743 | case DispatchKey::MkldnnCPU: |
744 | return DeviceType::CPU; |
745 | case DispatchKey::Vulkan: |
746 | return DeviceType::Vulkan; |
747 | |
748 | case DispatchKey::ORT: |
749 | return DeviceType::ORT; |
750 | default: |
751 | TORCH_CHECK( |
752 | false, |
753 | "DispatchKey " , |
754 | dispatch_key, |
755 | " doesn't correspond to a device" ); |
756 | } |
757 | } |
758 | |
759 | inline TensorOptions dispatchKeyToTensorOptions(DispatchKey dispatch_key) { |
760 | return TensorOptions() |
761 | .layout(dispatchKeyToLayout(dispatch_key)) |
762 | .device(dispatchKeyToDeviceType(dispatch_key)); |
763 | } |
764 | |
765 | namespace detail { |
766 | inline bool backend_supports_empty_operator(const TensorOptions options) { |
767 | // Quantized backends don't support at::empty(). |
768 | // They have separate operators like at::empty_quantized() that take in |
769 | // extra information about how to quantize the tensor. |
770 | return !isQIntType(typeMetaToScalarType(options.dtype())); |
771 | } |
772 | |
773 | } // namespace detail |
774 | |
775 | } // namespace c10 |
776 | |