1 | #pragma once |
2 | |
3 | #include <ATen/core/jit_type.h> |
4 | #include <ATen/core/qualified_name.h> |
5 | #include <torch/csrc/jit/frontend/sugared_value.h> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | |
10 | struct Resolver; |
11 | using ResolverPtr = std::shared_ptr<Resolver>; |
12 | |
13 | /** |
14 | * class Resolver |
15 | * |
16 | * Represents an "outer environment" in which we an look up names and return |
17 | * a corresponding SugaredValue. This is used during compilation to resolve |
18 | * references to names which are not defined internal to the graph. |
19 | * |
20 | * Example: PythonResolver looks at the enclosing Python scope for `name`. |
21 | * |
22 | * NOTE: When adding methods, keep this an abstract class (i.e. all new methods |
23 | * should be purely virtual). Resist the urge to provide a default |
24 | * implementation; you should explicitly think about how each resolver would |
25 | * handle the method. |
26 | */ |
27 | struct Resolver { |
28 | virtual ~Resolver() = default; |
29 | |
30 | // Resolve a given name to a SugaredValue. This takes the method `m` that the |
31 | // caller is currently constructing, since we may need to insert nodes into |
32 | // the graph to create a value. |
33 | virtual std::shared_ptr<SugaredValue> resolveValue( |
34 | const std::string& name, |
35 | GraphFunction& m, |
36 | const SourceRange& loc) { |
37 | return nullptr; |
38 | } |
39 | |
40 | // Resolve `name` to a TypePtr. |
41 | virtual TypePtr resolveType(const std::string& name, const SourceRange& loc) { |
42 | return nullptr; |
43 | } |
44 | }; |
45 | |
46 | // A resolver that only understands "torch.foo()" lookups. |
47 | struct NativeResolver : public Resolver { |
48 | std::shared_ptr<SugaredValue> resolveValue( |
49 | const std::string& name, |
50 | GraphFunction& m, |
51 | const SourceRange& loc) override { |
52 | if (name == "torch" ) { |
53 | return std::make_shared<BuiltinModule>("aten" ); |
54 | } |
55 | return nullptr; |
56 | } |
57 | |
58 | TypePtr resolveType(const std::string& name, const SourceRange& loc) |
59 | override { |
60 | return nullptr; |
61 | } |
62 | }; |
63 | |
64 | inline std::shared_ptr<NativeResolver> nativeResolver() { |
65 | return std::make_shared<NativeResolver>(); |
66 | } |
67 | } // namespace jit |
68 | } // namespace torch |
69 | |