1#pragma once
2
3#include <c10/core/Backend.h>
4#include <c10/core/DefaultDtype.h>
5#include <c10/core/Device.h>
6#include <c10/core/DispatchKeySet.h>
7#include <c10/core/Layout.h>
8#include <c10/core/MemoryFormat.h>
9#include <c10/core/ScalarType.h>
10#include <c10/core/ScalarTypeToTypeMeta.h>
11
12#include <c10/macros/Macros.h>
13#include <c10/util/C++17.h>
14#include <c10/util/Optional.h>
15
16#include <cstddef>
17#include <iosfwd>
18#include <utility>
19
20namespace c10 {
21
22DispatchKey computeDispatchKey(
23 c10::optional<ScalarType> dtype,
24 c10::optional<Layout> layout,
25 c10::optional<Device> device);
26
27inline ScalarType dtype_or_default(c10::optional<ScalarType> dtype) {
28 return value_or_else(dtype, [] { return get_default_dtype_as_scalartype(); });
29}
30
31inline caffe2::TypeMeta dtype_or_default(
32 c10::optional<caffe2::TypeMeta> dtype) {
33 return value_or_else(dtype, [] { return get_default_dtype(); });
34}
35
36inline Layout layout_or_default(c10::optional<Layout> layout) {
37 return layout.value_or(kStrided);
38}
39
40inline Device device_or_default(c10::optional<Device> device) {
41 return value_or_else(device, [] { return Device(kCPU); });
42}
43
44inline bool pinned_memory_or_default(c10::optional<bool> pinned_memory) {
45 return pinned_memory.value_or(false);
46}
47
48/// A class to encapsulate construction axes of an Tensor. TensorOptions was
49/// designed to support the Python style API for specifying construction options
50/// on factory functions, e.g.,
51///
52/// torch.zeros(2, 3, dtype=torch.int32)
53///
54/// Because C++ doesn't natively support keyword arguments, there must be
55/// another way of specifying keyword-like arguments. TensorOptions is a
56/// builder class which can be used to construct this "dictionary" of keyword
57/// arguments: functions which support TensorOptions conventionally take this
58/// argument optionally as their last argument.
59///
60/// WARNING: In PyTorch, there are `torch::` variants of factory functions,
61/// e.g., torch::zeros for at::zeros. These return Variables (while the
62/// stock ATen functions return plain Tensors). If you mix these functions
63/// up, you WILL BE SAD.
64///
65/// Rather than use the constructor of this class directly, you should prefer to
66/// use the constructor functions, and then chain setter methods on top of them.
67///
68/// at::device(at::kCUDA).dtype(kInt)
69/// at::dtype(at::kInt)
70///
71/// Additionally, anywhere a TensorOptions is expected, you can directly
72/// pass at::kCUDA / at::kInt, and it will implicitly convert to a
73/// TensorOptions.
74///
75/// Here are some recommended ways to create a 2x2 tensor of zeros
76/// with certain properties. These all *implicitly* make use of
77/// TensorOptions, even if they don't mention the class explicitly:
78///
79/// at::zeros({2,2}, at::kCUDA);
80/// at::zeros({2,2}, at::kLong);
81/// at::zeros({2,2}, at::device(at::kCUDA).dtype(at::kLong()));
82/// at::zeros({2,2}, at::device({at::kCUDA, 1})); // place on device 1
83/// at::zeros({2,2}, at::requires_grad());
84///
85
86/// NOTE [ TensorOptions Constructors ]
87///
88/// TensorOptions is like a dictionary with entries from the set:
89/// {requires_grad, device, dtype, layout}, where each entry may be
90/// unspecified (i.e., is optional). It is used to specify the properties of
91/// tensors in many places both in C++ internal and API, e.g., tensor factory
92/// methods like `at::empty({10}, options)`, tensor conversions like
93/// `tensor.to(...)`, etc.
94///
95/// To provide a simple API that is consistent with Python, where one can do
96/// `torch.empty(sizes, X)` with `X` being a `torch.device`, `torch.dtype`, or a
97/// `torch.layout`, we want TensorOptions to be implicitly convertible from
98/// `ScalarType dtype`, `Layout layout` and `Device device`. Therefore, we have
99/// three implicit constructors from each of these three types.
100///
101/// This is sufficient for `ScalarType` and `Layout` as they are simple Enum
102/// classes. However, `Device` is an ordinary class with implicit constructors
103/// `Device(DeviceType, DeviceIndex = -1)` and `Device(std::string)` to be
104/// consistent with Python API, where strings are treated as equivalent with a
105/// `torch.device` object (e.g., "cuda:1" can be passed to everywhere a
106/// `torch.device("cuda:1")` is accepted). To support the syntax
107/// `at::empty({10}, {kCUDA, 1})` and `tensor.to(kCUDA)`, we need to make sure
108/// that `TensorOptions` is implicitly constructible with any argments that a
109/// `Device` can constructed from. So we have,
110///
111/// /* implicit */ TensorOptions(T&& device) : TensorOptions() {
112/// this->set_device(device);
113/// }
114///
115/// template <typename... Args,
116/// typename = std::enable_if_t<std::is_constructible<Device,
117/// Args&&...>::value>>
118/// /* implicit */ TensorOptions(Args&&... args)
119/// : TensorOptions(Device(std::forward<Args>(args)...)) {}
120///
121///
122/// But this will be problematic. Consider this: `TensorOptions({kCUDA, 1})`.
123/// Compiler will compain about ambiguity between the copy constructor and the
124/// `Device` constructor because `{kCUDA, 1}` can be converted to both a
125/// `TensorOption` and a `Device`.
126///
127/// To get around this, we templatize the `Device` constructor. Since overload
128/// resolution is done before template resolution, our problem is solved.
129
130DispatchKey computeDispatchKey(
131 optional<ScalarType> dtype,
132 optional<Layout> layout,
133 optional<Device> device);
134
135struct C10_API TensorOptions {
136 TensorOptions()
137 : requires_grad_(false),
138 pinned_memory_(false),
139 has_device_(false),
140 has_dtype_(false),
141 has_layout_(false),
142 has_requires_grad_(false),
143 has_pinned_memory_(false),
144 has_memory_format_(false) {}
145
146 /// Constructs a `TensorOptions` object with the given layout.
147 /* implicit */ TensorOptions(Layout layout) : TensorOptions() {
148 this->set_layout(layout);
149 }
150
151 /// Constructs a `TensorOptions` object with the given device.
152 /// See NOTE [ TensorOptions Constructors ] on why this is templatized.
153 template <
154 typename T,
155 typename = std::enable_if_t<std::is_same<std::decay_t<T>, Device>::value>>
156 /* implicit */ TensorOptions(T&& device) : TensorOptions() {
157 this->set_device(std::forward<T>(device));
158 }
159
160 /// Constructs a `TensorOptions` object from arguments allowed in `Device`
161 /// constructors.
162 ///
163 /// See NOTE [ TensorOptions Constructors ].
164 ///
165 /// NB: Ideally we only allow implicit constructors here. But there is no easy
166 /// way to detect them. So we have this one that allows explicit
167 /// constructors too.
168 template <
169 typename... Args,
170 typename =
171 std::enable_if_t<std::is_constructible<Device, Args&&...>::value>>
172 /* implicit */ TensorOptions(Args&&... args)
173 : TensorOptions(Device(std::forward<Args>(args)...)) {}
174
175 /// Constructs a `TensorOptions` object with the given dtype.
176 /* implicit */ TensorOptions(caffe2::TypeMeta dtype) : TensorOptions() {
177 this->set_dtype(dtype);
178 }
179
180 /// legacy constructor to support ScalarType
181 /* implicit */ TensorOptions(ScalarType dtype) : TensorOptions() {
182 this->set_dtype(dtype);
183 }
184
185 /// Constructs a `TensorOptions` object with the given memory format.
186 /* implicit */ TensorOptions(MemoryFormat memory_format) : TensorOptions() {
187 set_memory_format(memory_format);
188 }
189
190 /// Return a copy of `TensorOptions` with `device` set to the given one, or
191 /// cleared if `device` is `nullopt`.
192 C10_NODISCARD TensorOptions
193 device(c10::optional<Device> device) const noexcept {
194 TensorOptions r = *this;
195 r.set_device(device);
196 return r;
197 }
198
199 /// Return a copy of `TensorOptions` with `device` set to the given one.
200 /// (This overload ensures that variadic template c10::optional constructor
201 /// for Device work correctly.)
202 template <typename... Args>
203 C10_NODISCARD TensorOptions device(Args&&... args) const noexcept {
204 return device(
205 c10::optional<Device>(c10::in_place, std::forward<Args>(args)...));
206 }
207
208 /// Return a copy of `TensorOptions`, but with device set to CUDA, and the
209 /// device index set to the given one.
210 ///
211 /// TODO: This function encourages bad behavior (assuming CUDA is
212 /// the only device that matters). Get rid of it / rename it.
213 C10_NODISCARD TensorOptions
214 device_index(c10::DeviceIndex device_index) const noexcept {
215 return device(Device::Type::CUDA, device_index);
216 }
217
218 /// Return a copy of `TensorOptions` with `dtype` set to the given one.
219 C10_NODISCARD TensorOptions
220 dtype(c10::optional<caffe2::TypeMeta> dtype) const noexcept {
221 TensorOptions r = *this;
222 r.set_dtype(dtype);
223 return r;
224 }
225
226 // legacy function to support ScalarType
227 C10_NODISCARD TensorOptions
228 dtype(c10::optional<ScalarType> dtype) const noexcept {
229 TensorOptions r = *this;
230 r.set_dtype(dtype);
231 return r;
232 }
233
234 // Since dtype is taken...
235 template <typename T>
236 TensorOptions& dtype() {
237 dtype_ = caffe2::TypeMeta::Make<T>();
238 has_dtype_ = true;
239 return *this;
240 }
241
242 /// Sets the layout of the `TensorOptions`.
243 C10_NODISCARD TensorOptions
244 layout(c10::optional<Layout> layout) const noexcept {
245 TensorOptions r = *this;
246 r.set_layout(layout);
247 return r;
248 }
249
250 /// Sets the `requires_grad` property of the `TensorOptions`.
251 C10_NODISCARD TensorOptions
252 requires_grad(c10::optional<bool> requires_grad) const noexcept {
253 TensorOptions r = *this;
254 r.set_requires_grad(requires_grad);
255 return r;
256 }
257
258 /// Sets the `pinned_memory` property on the `TensorOptions`.
259 C10_NODISCARD TensorOptions
260 pinned_memory(c10::optional<bool> pinned_memory) const noexcept {
261 TensorOptions r = *this;
262 r.set_pinned_memory(pinned_memory);
263 return r;
264 }
265
266 /// Sets the `memory_format` property on `TensorOptions`.
267 C10_NODISCARD TensorOptions
268 memory_format(c10::optional<MemoryFormat> memory_format) const noexcept {
269 TensorOptions r = *this;
270 r.set_memory_format(memory_format);
271 return r;
272 }
273
274 /// Returns the device of the `TensorOptions`.
275 Device device() const noexcept {
276 return device_or_default(device_opt());
277 }
278
279 /// Returns whether the device is specified.
280 bool has_device() const noexcept {
281 return has_device_;
282 }
283
284 /// Returns the device of the `TensorOptions`, or `c10::nullopt` if
285 /// device is not specified.
286 c10::optional<Device> device_opt() const noexcept {
287 return has_device_ ? c10::make_optional(device_) : c10::nullopt;
288 }
289
290 /// Returns the device index of the `TensorOptions`.
291 int32_t device_index() const noexcept {
292 return device().index();
293 }
294
295 /// Returns the dtype of the `TensorOptions`.
296 caffe2::TypeMeta dtype() const noexcept {
297 return dtype_or_default(dtype_opt());
298 }
299
300 /// Returns whether the dtype is specified.
301 bool has_dtype() const noexcept {
302 return has_dtype_;
303 }
304
305 /// Returns the dtype of the `TensorOptions`, or `c10::nullopt` if
306 /// device is not specified.
307 c10::optional<caffe2::TypeMeta> dtype_opt() const noexcept {
308 return has_dtype_ ? c10::make_optional(dtype_) : c10::nullopt;
309 }
310
311 /// Returns the layout of the `TensorOptions`.
312 Layout layout() const noexcept {
313 return layout_or_default(layout_opt());
314 }
315
316 /// Returns whether the layout is specified.
317 bool has_layout() const noexcept {
318 return has_layout_;
319 }
320
321 /// Returns the layout of the `TensorOptions`, or `c10::nullopt` if
322 /// layout is not specified.
323 c10::optional<Layout> layout_opt() const noexcept {
324 return has_layout_ ? c10::make_optional(layout_) : c10::nullopt;
325 }
326
327 /// Returns the `requires_grad` property of the `TensorOptions`.
328 bool requires_grad() const noexcept {
329 return has_requires_grad_ ? requires_grad_ : false;
330 }
331
332 /// Returns whether the `requires_grad` is specified.
333 bool has_requires_grad() const noexcept {
334 return has_requires_grad_;
335 }
336
337 /// Returns the `requires_grad` property of the `TensorOptions`, or
338 /// `c10::nullopt` if `requires_grad` is not specified.
339 c10::optional<bool> requires_grad_opt() const noexcept {
340 return has_requires_grad_ ? c10::make_optional(requires_grad_)
341 : c10::nullopt;
342 }
343
344 /// Returns the `pinned_memory` property of the `TensorOptions`.
345 bool pinned_memory() const noexcept {
346 return pinned_memory_or_default(pinned_memory_opt());
347 }
348
349 /// Returns whether the `pinned_memory` is specified.
350 bool has_pinned_memory() const noexcept {
351 return has_pinned_memory_;
352 }
353
354 /// Returns if the layout is sparse
355 bool is_sparse() const {
356 return layout_ == c10::Layout::Sparse;
357 }
358
359 bool is_sparse_csr() const {
360 return layout_ == c10::Layout::SparseCsr;
361 }
362
363 // For compatibility with legacy tensor.type() comparisons
364 bool type_equal(const TensorOptions& other) const {
365 return computeDispatchKey() == other.computeDispatchKey() &&
366 typeMetaToScalarType(dtype_) == typeMetaToScalarType(other.dtype());
367 }
368
369 /// Returns the `pinned_memory` property of the `TensorOptions`, or
370 /// `c10::nullopt` if `pinned_memory` is not specified.
371 c10::optional<bool> pinned_memory_opt() const noexcept {
372 return has_pinned_memory_ ? c10::make_optional(pinned_memory_)
373 : c10::nullopt;
374 }
375
376 /// Returns whether the `memory_layout` is specified
377 bool has_memory_format() const noexcept {
378 return has_memory_format_;
379 }
380
381 // NB: memory_format() getter is PURPOSELY not defined, as the default
382 // behavior of memory_format varies from function to function.
383
384 /// Returns the `memory_layout` property of `TensorOptions, or
385 /// `c10::nullopt` if `memory_format` is not specified.
386 c10::optional<MemoryFormat> memory_format_opt() const noexcept {
387 return has_memory_format_ ? c10::make_optional(memory_format_)
388 : c10::nullopt;
389 }
390
391 // Resolves the ATen backend specified by the current construction axes.
392 // TODO: Deprecate this
393 Backend backend() const {
394 return at::dispatchKeyToBackend(computeDispatchKey());
395 }
396
397 /// Return the right-biased merge of two TensorOptions. This has the
398 /// effect of overwriting settings from self with specified options
399 /// of options.
400 ///
401 /// NB: This merging operation does NOT respect device merges.
402 /// For example, if you device({kCUDA, 1}).merge_in(kCUDA)
403 /// you will get kCUDA in the end! Functions like Tensor.new_empty
404 /// ensure the right device is selected anyway by way of a
405 /// device guard.
406 ///
407 TensorOptions merge_in(TensorOptions options) const noexcept {
408 TensorOptions merged = *this;
409 if (options.has_device())
410 merged.set_device(options.device_opt());
411 if (options.has_dtype())
412 merged.set_dtype(options.dtype_opt());
413 if (options.has_layout())
414 merged.set_layout(options.layout_opt());
415 // NB: requires grad is right biased; not a logical AND/OR!
416 if (options.has_requires_grad())
417 merged.set_requires_grad(options.requires_grad_opt());
418 if (options.has_pinned_memory())
419 merged.set_pinned_memory(options.pinned_memory_opt());
420 if (options.has_memory_format())
421 merged.set_memory_format(options.memory_format_opt());
422 return merged;
423 }
424
425 // TODO remove after TensorOptions rationalization
426 TensorOptions merge_memory_format(
427 c10::optional<MemoryFormat> optional_memory_format) const noexcept {
428 TensorOptions merged = *this;
429 if (optional_memory_format.has_value()) {
430 merged.set_memory_format(*optional_memory_format);
431 }
432 return merged;
433 }
434
435 // INVARIANT: computeDispatchKey returns only the subset of dispatch keys for
436 // which dispatchKeyToBackend is injective, if it is defined at all (for
437 // the most part, this just means that this function never returns an
438 // Autograd key)
439 DispatchKey computeDispatchKey() const {
440 return c10::computeDispatchKey(
441 optTypeMetaToScalarType(dtype_opt()), layout_opt(), device_opt());
442 }
443
444 private:
445 // These methods are currently private because I'm not sure if it's wise
446 // to actually publish them. They are methods because I need them in
447 // the constructor and the functional API implementation.
448 //
449 // If you really, really need it, you can make these public, but check if you
450 // couldn't just do what you need with the functional API. Similarly, these
451 // methods are not chainable, because if you wanted chaining, you probably
452 // want to use the functional API instead. (It's probably OK to make
453 // these chainable, because these functions are all explicitly annotated
454 // with a ref-qualifier, the trailing &, that makes them illegal to call
455 // on temporaries.)
456
457 /// Mutably set the device of `TensorOptions`.
458 void set_device(c10::optional<Device> device) & noexcept {
459 if (device) {
460 device_ = *device;
461 has_device_ = true;
462 } else {
463 has_device_ = false;
464 }
465 }
466
467 /// Mutably set the dtype of `TensorOptions`.
468 void set_dtype(c10::optional<caffe2::TypeMeta> dtype) & noexcept {
469 if (dtype) {
470 dtype_ = *dtype;
471 has_dtype_ = true;
472 } else {
473 has_dtype_ = false;
474 }
475 }
476
477 // legacy function to support ScalarType
478 void set_dtype(c10::optional<ScalarType> dtype) & noexcept {
479 if (dtype) {
480 dtype_ = scalarTypeToTypeMeta(*dtype);
481 has_dtype_ = true;
482 } else {
483 has_dtype_ = false;
484 }
485 }
486
487 /// Mutably set the layout of `TensorOptions`.
488 void set_layout(c10::optional<Layout> layout) & noexcept {
489 if (layout) {
490 layout_ = *layout;
491 has_layout_ = true;
492 } else {
493 has_layout_ = false;
494 }
495 }
496
497 /// Mutably set the `requires_grad` property of `TensorOptions`.
498 void set_requires_grad(c10::optional<bool> requires_grad) & noexcept {
499 if (requires_grad) {
500 requires_grad_ = *requires_grad;
501 has_requires_grad_ = true;
502 } else {
503 has_requires_grad_ = false;
504 }
505 }
506
507 /// Mutably set the `pinned_memory` property of `TensorOptions`.
508 void set_pinned_memory(c10::optional<bool> pinned_memory) & noexcept {
509 if (pinned_memory) {
510 pinned_memory_ = *pinned_memory;
511 has_pinned_memory_ = true;
512 } else {
513 has_pinned_memory_ = false;
514 }
515 }
516
517 /// Mutably set the `memory_Format` property of `TensorOptions`.
518 void set_memory_format(c10::optional<MemoryFormat> memory_format) & noexcept {
519 if (memory_format) {
520 memory_format_ = *memory_format;
521 has_memory_format_ = true;
522 } else {
523 has_memory_format_ = false;
524 }
525 }
526
527 // WARNING: If you edit TensorOptions to add more options, you
528 // may need to adjust the implementation of Tensor::options.
529 // The criteria for whether or not Tensor::options must be adjusted
530 // is whether or not the new option you added should preserved
531 // by functions such as empty_like(); if it should be preserved,
532 // you must adjust options().
533 //
534 // TODO: MemoryFormat is not implemented in this way
535
536 // NB: We didn't use c10::optional here, because then we can't pack
537 // the has_***_ boolean fields.
538
539 Device device_ = at::kCPU; // 16-bit
540 caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make<float>(); // 16-bit
541 Layout layout_ = at::kStrided; // 8-bit
542 MemoryFormat memory_format_ = MemoryFormat::Contiguous; // 8-bit
543
544 // Bitmask required here to get this to fit inside 32 bits (or even 64 bits,
545 // for that matter)
546
547 bool requires_grad_ : 1;
548 bool pinned_memory_ : 1;
549
550 bool has_device_ : 1;
551 bool has_dtype_ : 1;
552 bool has_layout_ : 1;
553 bool has_requires_grad_ : 1;
554 bool has_pinned_memory_ : 1;
555 bool has_memory_format_ : 1;
556};
557
558// We should aspire to fit in one machine-size word; but a size greater than two
559// words is too much. (We are doing terribly on 32-bit archs, where we require
560// three machine size words to store tensor options. Eek!)
561static_assert(
562 sizeof(TensorOptions) <= sizeof(int64_t) * 2,
563 "TensorOptions must fit in 128-bits");
564
565/// Convenience function that returns a `TensorOptions` object with the `dtype`
566/// set to the given one.
567inline TensorOptions dtype(caffe2::TypeMeta dtype) {
568 return TensorOptions().dtype(dtype);
569}
570
571// legacy function to support ScalarType
572inline TensorOptions dtype(ScalarType dtype) {
573 return TensorOptions().dtype(scalarTypeToTypeMeta(dtype));
574}
575
576/// Convenience function that returns a `TensorOptions` object with the `layout`
577/// set to the given one.
578inline TensorOptions layout(Layout layout) {
579 return TensorOptions().layout(layout);
580}
581
582/// Convenience function that returns a `TensorOptions` object with the `device`
583/// set to the given one.
584inline TensorOptions device(Device device) {
585 return TensorOptions().device(device);
586}
587
588/// Convenience function that returns a `TensorOptions` object with the
589/// `device` set to CUDA and the `device_index` set to the given one.
590inline TensorOptions device_index(int16_t device_index) {
591 return TensorOptions().device_index(
592 static_cast<c10::DeviceIndex>(device_index));
593}
594
595/// Convenience function that returns a `TensorOptions` object with the
596/// `requires_grad` set to the given one.
597inline TensorOptions requires_grad(bool requires_grad = true) {
598 return TensorOptions().requires_grad(requires_grad);
599}
600
601/// Convenience function that returns a `TensorOptions` object with the
602/// `memory_format` set to the given one.
603inline TensorOptions memory_format(MemoryFormat memory_format) {
604 return TensorOptions().memory_format(memory_format);
605}
606
607C10_API std::ostream& operator<<(
608 std::ostream& stream,
609 const TensorOptions& options);
610
611template <typename T>
612inline TensorOptions dtype() {
613 return dtype(caffe2::TypeMeta::Make<T>());
614}
615
616inline std::string toString(const TensorOptions options) {
617 std::ostringstream stream;
618 stream << options;
619 return stream.str();
620}
621
622// This is intended to be a centralized location by which we can determine
623// what an appropriate DispatchKey for a tensor is.
624inline DispatchKey computeDispatchKey(
625 c10::optional<ScalarType> dtype,
626 c10::optional<Layout> layout,
627 c10::optional<Device> device) {
628 const auto layout_ = layout_or_default(layout);
629 const auto device_ = device_or_default(device);
630 switch (layout_) {
631 case Layout::Strided: {
632 const auto dtype_ = dtype_or_default(dtype);
633 switch (device_.type()) {
634#define DO_CASE(device, _) \
635 case DeviceType::device: { \
636 if (isQIntType(dtype_)) { \
637 return DispatchKey::Quantized##device; \
638 } \
639 return DispatchKey::device; \
640 }
641 C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused)
642#undef DO_CASE
643 case DeviceType::FPGA:
644 return DispatchKey::FPGA;
645 case DeviceType::ORT:
646 return DispatchKey::ORT;
647 case DeviceType::Vulkan:
648 return DispatchKey::Vulkan;
649 case DeviceType::Metal:
650 return DispatchKey::Metal;
651 case DeviceType::MKLDNN:
652 case DeviceType::OPENGL:
653 case DeviceType::OPENCL:
654 case DeviceType::IDEEP:
655 TORCH_INTERNAL_ASSERT(
656 0,
657 "This is a grandfathered Caffe2 device type ",
658 device_.type(),
659 ", it shouldn't ever convert to a DispatchKey. File a bug describing what you were doing if you think this is in error.");
660 default:
661 TORCH_CHECK_NOT_IMPLEMENTED(
662 false,
663 "Unsupported device type for dense layout: ",
664 device_.type());
665 }
666 }
667 case Layout::Sparse:
668 switch (device_.type()) {
669#define DO_CASE(device, _) \
670 case DeviceType::device: { \
671 return DispatchKey::Sparse##device; \
672 }
673 C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused)
674#undef DO_CASE
675 default:
676 TORCH_CHECK_NOT_IMPLEMENTED(
677 false,
678 "Unsupported device type for sparse layout: ",
679 device_.type());
680 }
681 case Layout::Mkldnn:
682 switch (device_.type()) {
683 case DeviceType::CPU:
684 return DispatchKey::MkldnnCPU;
685 default:
686 TORCH_CHECK_NOT_IMPLEMENTED(
687 false,
688 "Unsupported device type for mkldnn layout: ",
689 device_.type());
690 }
691 case Layout::SparseCsr:
692 case Layout::SparseCsc:
693 case Layout::SparseBsr:
694 case Layout::SparseBsc:
695 switch (device_.type()) {
696 case DeviceType::CPU:
697 return DispatchKey::SparseCsrCPU;
698 case DeviceType::CUDA:
699 return DispatchKey::SparseCsrCUDA;
700 default:
701 AT_ERROR(
702 "Unsupported device type for ",
703 layout_,
704 " layout: ",
705 device_.type());
706 }
707 default:
708 TORCH_CHECK(false, "Unsupported layout: ", layout_);
709 }
710}
711
712inline Layout dispatchKeyToLayout(DispatchKey dispatch_key) {
713 switch (dispatch_key) {
714#define DO_CASE(bc, _) case DispatchKey::Sparse##bc:
715 C10_FORALL_BACKEND_COMPONENTS(DO_CASE, unused)
716#undef DO_CASE
717 return Layout::Sparse;
718 case DispatchKey::SparseCsrCPU:
719 case DispatchKey::SparseCsrCUDA:
720 TORCH_CHECK(
721 false,
722 "Cannot map DispatchKey ",
723 dispatch_key,
724 " to a unique layout.");
725 case DispatchKey::MkldnnCPU:
726 return Layout::Mkldnn;
727 default:
728 return Layout::Strided;
729 }
730}
731
732inline DeviceType dispatchKeyToDeviceType(DispatchKey dispatch_key) {
733 switch (dispatch_key) {
734 // stuff that's real
735#define DO_CASE(suffix, prefix) \
736 case DispatchKey::prefix##suffix: \
737 return DeviceType::suffix;
738#define DO_CASES(_, prefix) C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, prefix)
739 C10_FORALL_FUNCTIONALITY_KEYS(DO_CASES)
740#undef DO_CASES
741#undef DO_CASE
742
743 case DispatchKey::MkldnnCPU:
744 return DeviceType::CPU;
745 case DispatchKey::Vulkan:
746 return DeviceType::Vulkan;
747
748 case DispatchKey::ORT:
749 return DeviceType::ORT;
750 default:
751 TORCH_CHECK(
752 false,
753 "DispatchKey ",
754 dispatch_key,
755 " doesn't correspond to a device");
756 }
757}
758
759inline TensorOptions dispatchKeyToTensorOptions(DispatchKey dispatch_key) {
760 return TensorOptions()
761 .layout(dispatchKeyToLayout(dispatch_key))
762 .device(dispatchKeyToDeviceType(dispatch_key));
763}
764
765namespace detail {
766inline bool backend_supports_empty_operator(const TensorOptions options) {
767 // Quantized backends don't support at::empty().
768 // They have separate operators like at::empty_quantized() that take in
769 // extra information about how to quantize the tensor.
770 return !isQIntType(typeMetaToScalarType(options.dtype()));
771}
772
773} // namespace detail
774
775} // namespace c10
776