1 | #pragma once |
2 | #include <torch/csrc/Export.h> |
3 | #include <torch/csrc/jit/ir/ir.h> |
4 | #include <torch/csrc/jit/ir/named_value.h> |
5 | |
6 | #include <ATen/core/function_schema.h> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | |
11 | // Try to match a list of inputs and keyword 'attributes' to this |
12 | // schema. Return the flat list of positional inputs to the call or |
13 | // `c10::nullopt` on failure (`failure_messages` contains a good error |
14 | // report in this case) |
15 | |
16 | struct MatchedSchema { |
17 | std::vector<Value*> inputs; |
18 | std::vector<TypePtr> return_types; |
19 | c10::OptNameList return_field_names; |
20 | std::string schema_name; |
21 | }; |
22 | |
23 | TORCH_API bool isBlockListedSchema(const FunctionSchema& schema); |
24 | |
25 | TORCH_API MatchedSchema matchSchema( |
26 | const ::c10::FunctionSchema& schema, |
27 | const SourceRange& loc, |
28 | Graph& graph, |
29 | at::ArrayRef<NamedValue> args, |
30 | at::ArrayRef<NamedValue> kwargs, |
31 | const c10::optional<NamedValue>& self = c10::nullopt); |
32 | |
33 | TORCH_API std::pair<size_t, MatchedSchema> matchSchemas( |
34 | const std::vector<const ::c10::FunctionSchema*>& schemas, |
35 | const SourceRange& loc, |
36 | Graph& graph, |
37 | at::ArrayRef<NamedValue> args, |
38 | at::ArrayRef<NamedValue> kwargs, |
39 | const c10::optional<NamedValue>& self = c10::nullopt, |
40 | bool render_errors = false); |
41 | |
42 | TORCH_API bool convertibleToList( |
43 | const TypePtr& type, |
44 | const TypePtr& list_type_); |
45 | |
46 | TORCH_API std::string getFullSchemaName(const ::c10::FunctionSchema& schema); |
47 | |
48 | TORCH_API Value* emitBuiltinCall( |
49 | const SourceRange& loc, |
50 | Graph& graph, |
51 | Symbol name, |
52 | at::ArrayRef<NamedValue> args, |
53 | at::ArrayRef<NamedValue> kwargs, |
54 | const c10::optional<NamedValue>& self = c10::nullopt); |
55 | |
56 | TORCH_API c10::optional<size_t> findInputWithName( |
57 | const std::string& name, |
58 | at::ArrayRef<NamedValue> kwargs, |
59 | bool is_aten = false); |
60 | |
61 | // applies implicit conversion from value trying to turn it into type |
62 | // concrete_type it succeeds if the return_value->isSubtypeOf(concrete_type) |
63 | TORCH_API Value* tryConvertToType( |
64 | const SourceRange& loc, |
65 | Graph& graph, |
66 | const TypePtr& concrete_type, |
67 | Value* value, |
68 | bool allow_conversions); |
69 | } // namespace jit |
70 | } // namespace torch |
71 | |