1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/parser/meta_ref.cc
22 * \brief An operator which allows forward referencing a yet-to-be parsed meta table reference.
23 */
24
25#include "./meta_ref.h"
26
27#include <tvm/relay/expr_functor.h>
28#include <tvm/relay/op.h>
29#include <tvm/relay/op_attr_types.h>
30#include <tvm/relay/transform.h>
31
32namespace tvm {
33namespace parser {
34
35using tvm::relay::transform::CreateFunctionPass;
36using tvm::transform::PassContext;
37
38/* Set to arbitrary high number, since we should never schedule in normal pass manager flow. */
39static int kMetaExpandOptLevel = 1337;
40
41TVM_REGISTER_NODE_TYPE(MetaRefAttrs);
42
43bool MetaRefRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
44 const TypeReporter& reporter) {
45 LOG(FATAL) << "need to expand before type checking";
46}
47
48RELAY_REGISTER_OP("parser.MetaRef")
49 .describe(R"code(A reference into the meta table.)code" TVM_ADD_FILELINE)
50 .set_attrs_type<MetaRefAttrs>()
51 .set_num_inputs(0)
52 .set_support_level(10)
53 .add_type_rel("MetaRef", MetaRefRel)
54 .set_attr<TOpIsStateful>("TOpIsStateful", false)
55 .set_attr<TNonComputational>("TNonComputational", true);
56
57Expr MetaRef(std::string type_key, uint64_t node_index) {
58 static const Op& op = Op::Get("parser.MetaRef");
59 auto attrs = make_object<MetaRefAttrs>();
60 attrs->node_type_key = tvm::String(type_key);
61 attrs->node_index = node_index;
62 return Call(op, {}, Attrs(attrs), {});
63}
64
65struct MetaRefExpander : public ExprMutator {
66 MetaTable table;
67
68 explicit MetaRefExpander(const MetaTable& table) : table(table) {}
69
70 Expr VisitExpr_(const CallNode* call) final {
71 if (auto op_node = call->op.as<OpNode>()) {
72 if (op_node->name == "parser.MetaRef") {
73 auto meta_attrs = call->attrs.as<MetaRefAttrs>();
74 ICHECK(meta_attrs) << "an internal error has occurred";
75 auto nodes = table.at(meta_attrs->node_type_key);
76 ICHECK_LT(meta_attrs->node_index, nodes.size());
77 return Downcast<Expr>(nodes[meta_attrs->node_index]);
78 }
79 }
80
81 return ExprMutator::VisitExpr_(call);
82 }
83};
84
85Function ExpandMetaRefs(const MetaTable& meta_table, const relay::Function& func) {
86 MetaRefExpander expander(meta_table);
87 return Downcast<Function>(expander.VisitExpr(func));
88}
89
90IRModule ExpandMetaRefs(const MetaTable& meta_table, const IRModule& mod) {
91 auto pass = CreateFunctionPass([&](Function func, IRModule module,
92 PassContext ctx) { return ExpandMetaRefs(meta_table, func); },
93 kMetaExpandOptLevel, "ExpandMetaRefs", {});
94
95 return pass(mod, PassContext::Create());
96}
97
98} // namespace parser
99} // namespace tvm
100