1#pragma once
2
3#include <ATen/core/symbol.h>
4#include <c10/util/ArrayRef.h>
5#include <c10/util/Optional.h>
6#include <ostream>
7
8namespace at {
9
10enum class NameType: uint8_t { BASIC, WILDCARD };
11
12struct TORCH_API Dimname {
13 static Dimname fromSymbol(Symbol name);
14 static Dimname wildcard();
15 static bool isValidName(const std::string& name);
16
17 NameType type() const { return type_; }
18 Symbol symbol() const { return name_; }
19
20 bool isBasic() const { return type_ == NameType::BASIC; }
21 bool isWildcard() const { return type_ == NameType::WILDCARD; }
22
23 bool matches(Dimname other) const;
24 c10::optional<Dimname> unify(Dimname other) const;
25
26 private:
27 Dimname(Symbol name)
28 : name_(name), type_(NameType::BASIC) {}
29 Dimname(Symbol name, NameType type)
30 : name_(name), type_(type) {}
31
32 Symbol name_;
33 NameType type_;
34};
35
36using DimnameList = c10::ArrayRef<Dimname>;
37
38TORCH_API std::ostream& operator<<(std::ostream& out, const Dimname& dimname);
39
40inline bool operator==(const Dimname& lhs, const Dimname& rhs) {
41 return lhs.symbol() == rhs.symbol();
42}
43
44inline bool operator!=(const Dimname& lhs, const Dimname& rhs) {
45 return !(lhs == rhs);
46}
47
48} // namespace at
49