1 | #pragma once |
2 | |
3 | #include <ATen/WrapDimUtils.h> |
4 | |
5 | namespace at { |
6 | namespace namedinference { |
7 | |
8 | // TensorName and TensorNames are wrappers around Dimname and DimnameList |
9 | // that contain helper functions to make writing name inference rules easier. |
10 | // |
11 | // A TensorName represents a Dimname associated with some DimnameList (from a |
12 | // Tensor). This encapsulates all the information that is needed to check if |
13 | // names *match* and to *unify* names. |
14 | // |
15 | // Definition: Two names in two tensors *match* if they are equal, or if at |
16 | // least one of them is a wildcard that can be *refined* to the other name. |
17 | // |
18 | // Definition: unify(name, other) fails if the names do not match. Otherwise, |
19 | // it returns the most refined of name and other. |
20 | // |
21 | // Here is an example of checking if two names match. |
22 | // tensor: Tensor[A, None] |
23 | // other: Tensor[A] |
24 | // |
25 | // Let's say we wish to check if tensor.names[-1] matches other.names[-1]. |
26 | // None (in tensor) cannot match A (in other) because if the None were refined |
27 | // to A, `tensor` would have duplicate names [A, A]. Therefore we need to check |
28 | // tensor.names [A, None] for the existence of A. |
29 | struct TORCH_API TensorName { |
30 | explicit TensorName(ArrayRef<Dimname> origin, int origin_idx) |
31 | : origin_(origin), |
32 | name_(origin[maybe_wrap_dim(origin_idx, origin.size())]), |
33 | origin_idx_(origin_idx) {} |
34 | |
35 | // op_name is only used for error reporting. |
36 | const TensorName& unify(const TensorName& other, const char* op_name) const; |
37 | Dimname toDimname() const; |
38 | |
39 | private: |
40 | ArrayRef<Dimname> origin_; |
41 | Dimname name_; |
42 | int origin_idx_; // A named tensor can have at most 64 dims. |
43 | |
44 | TORCH_API friend std::ostream& operator<<( |
45 | std::ostream& out, |
46 | const TensorName& tensorname); |
47 | }; |
48 | |
49 | using TensorNameVec = SmallVector<TensorName, 10>; |
50 | |
51 | struct TORCH_API TensorNames { |
52 | explicit TensorNames(ArrayRef<Dimname> names); |
53 | |
54 | // Create TensorNames from names[start:end]. Each individual TensorName stores |
55 | // `names`, NOT names[start:end], because the original tensor's names are |
56 | // `names`. |
57 | explicit TensorNames(ArrayRef<Dimname> names, int64_t start, int64_t end); |
58 | |
59 | // op_name is only used for error reporting. |
60 | TensorNames& unifyFromRightInplace( |
61 | const TensorNames& other, |
62 | const char* op_name = "unify" ); |
63 | void checkUnique(const char* op_name) const; |
64 | |
65 | void append(TensorName&& name); |
66 | std::vector<Dimname> toDimnameVec() const; |
67 | |
68 | private: |
69 | explicit TensorNames(TensorNameVec&& names) : names_(names){}; |
70 | |
71 | TensorNameVec names_; |
72 | }; |
73 | |
74 | } // namespace namedinference |
75 | } // namespace at |
76 | |