1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/parser/meta_ref.cc |
22 | * \brief An operator which allows forward referencing a yet-to-be parsed meta table reference. |
23 | */ |
24 | |
25 | #include "./meta_ref.h" |
26 | |
27 | #include <tvm/relay/expr_functor.h> |
28 | #include <tvm/relay/op.h> |
29 | #include <tvm/relay/op_attr_types.h> |
30 | #include <tvm/relay/transform.h> |
31 | |
32 | namespace tvm { |
33 | namespace parser { |
34 | |
35 | using tvm::relay::transform::CreateFunctionPass; |
36 | using tvm::transform::PassContext; |
37 | |
38 | /* Set to arbitrary high number, since we should never schedule in normal pass manager flow. */ |
39 | static int kMetaExpandOptLevel = 1337; |
40 | |
41 | TVM_REGISTER_NODE_TYPE(MetaRefAttrs); |
42 | |
43 | bool MetaRefRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
44 | const TypeReporter& reporter) { |
45 | LOG(FATAL) << "need to expand before type checking" ; |
46 | } |
47 | |
48 | RELAY_REGISTER_OP("parser.MetaRef" ) |
49 | .describe(R"code(A reference into the meta table.)code" TVM_ADD_FILELINE) |
50 | .set_attrs_type<MetaRefAttrs>() |
51 | .set_num_inputs(0) |
52 | .set_support_level(10) |
53 | .add_type_rel("MetaRef" , MetaRefRel) |
54 | .set_attr<TOpIsStateful>("TOpIsStateful" , false) |
55 | .set_attr<TNonComputational>("TNonComputational" , true); |
56 | |
57 | Expr MetaRef(std::string type_key, uint64_t node_index) { |
58 | static const Op& op = Op::Get("parser.MetaRef" ); |
59 | auto attrs = make_object<MetaRefAttrs>(); |
60 | attrs->node_type_key = tvm::String(type_key); |
61 | attrs->node_index = node_index; |
62 | return Call(op, {}, Attrs(attrs), {}); |
63 | } |
64 | |
65 | struct MetaRefExpander : public ExprMutator { |
66 | MetaTable table; |
67 | |
68 | explicit MetaRefExpander(const MetaTable& table) : table(table) {} |
69 | |
70 | Expr VisitExpr_(const CallNode* call) final { |
71 | if (auto op_node = call->op.as<OpNode>()) { |
72 | if (op_node->name == "parser.MetaRef" ) { |
73 | auto meta_attrs = call->attrs.as<MetaRefAttrs>(); |
74 | ICHECK(meta_attrs) << "an internal error has occurred" ; |
75 | auto nodes = table.at(meta_attrs->node_type_key); |
76 | ICHECK_LT(meta_attrs->node_index, nodes.size()); |
77 | return Downcast<Expr>(nodes[meta_attrs->node_index]); |
78 | } |
79 | } |
80 | |
81 | return ExprMutator::VisitExpr_(call); |
82 | } |
83 | }; |
84 | |
85 | Function ExpandMetaRefs(const MetaTable& meta_table, const relay::Function& func) { |
86 | MetaRefExpander expander(meta_table); |
87 | return Downcast<Function>(expander.VisitExpr(func)); |
88 | } |
89 | |
90 | IRModule ExpandMetaRefs(const MetaTable& meta_table, const IRModule& mod) { |
91 | auto pass = CreateFunctionPass([&](Function func, IRModule module, |
92 | PassContext ctx) { return ExpandMetaRefs(meta_table, func); }, |
93 | kMetaExpandOptLevel, "ExpandMetaRefs" , {}); |
94 | |
95 | return pass(mod, PassContext::Create()); |
96 | } |
97 | |
98 | } // namespace parser |
99 | } // namespace tvm |
100 | |