1 | #pragma once |
---|---|
2 | |
3 | #include <ATen/core/symbol.h> |
4 | #include <c10/util/ArrayRef.h> |
5 | #include <c10/util/Optional.h> |
6 | #include <ostream> |
7 | |
8 | namespace at { |
9 | |
10 | enum class NameType: uint8_t { BASIC, WILDCARD }; |
11 | |
12 | struct TORCH_API Dimname { |
13 | static Dimname fromSymbol(Symbol name); |
14 | static Dimname wildcard(); |
15 | static bool isValidName(const std::string& name); |
16 | |
17 | NameType type() const { return type_; } |
18 | Symbol symbol() const { return name_; } |
19 | |
20 | bool isBasic() const { return type_ == NameType::BASIC; } |
21 | bool isWildcard() const { return type_ == NameType::WILDCARD; } |
22 | |
23 | bool matches(Dimname other) const; |
24 | c10::optional<Dimname> unify(Dimname other) const; |
25 | |
26 | private: |
27 | Dimname(Symbol name) |
28 | : name_(name), type_(NameType::BASIC) {} |
29 | Dimname(Symbol name, NameType type) |
30 | : name_(name), type_(type) {} |
31 | |
32 | Symbol name_; |
33 | NameType type_; |
34 | }; |
35 | |
36 | using DimnameList = c10::ArrayRef<Dimname>; |
37 | |
38 | TORCH_API std::ostream& operator<<(std::ostream& out, const Dimname& dimname); |
39 | |
40 | inline bool operator==(const Dimname& lhs, const Dimname& rhs) { |
41 | return lhs.symbol() == rhs.symbol(); |
42 | } |
43 | |
44 | inline bool operator!=(const Dimname& lhs, const Dimname& rhs) { |
45 | return !(lhs == rhs); |
46 | } |
47 | |
48 | } // namespace at |
49 |