1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file extract_intermediate_expr.cc
22 * \brief Used for extracting Relay Expr
23 by the expression ID of the main function
24 that we can see in `print(mod["main"])`.
25 */
26#include <tvm/node/structural_hash.h>
27#include <tvm/relay/analysis.h>
28#include <tvm/relay/expr.h>
29#include <tvm/relay/expr_functor.h>
30
31namespace tvm {
32namespace relay {
33
34class ExtractIntermediateExprWrapper : private MixedModeVisitor {
35 public:
36 explicit ExtractIntermediateExprWrapper(const IRModule& mod, const int expr_id)
37 : mod_(mod), target_expr_id_(expr_id), counter_(0) {}
38
39 IRModule Extract() {
40 VisitExpr(this->mod_->Lookup("main"));
41
42 // ensure the target expr_id we want to extract is valid.
43 ICHECK(target_expr_id_ >= 0 && target_expr_id_ < counter_);
44
45 return IRModule::FromExpr(target_op_, {});
46 }
47
48 private:
49 using MixedModeVisitor::VisitExpr_;
50
51 const IRModule mod_;
52 /*! \brief the expr id that we want to extract. */
53 const int target_expr_id_;
54 int counter_;
55 Expr target_op_;
56
57 void VisitExpr_(const CallNode* n) final {
58 CheckCounterAndIncrease(GetRef<Expr>(n));
59 MixedModeVisitor::VisitExpr_(n);
60 }
61
62 void VisitExpr_(const TupleNode* n) final {
63 CheckCounterAndIncrease(GetRef<Expr>(n));
64 MixedModeVisitor::VisitExpr_(n);
65 }
66
67 void VisitExpr_(const TupleGetItemNode* n) final {
68 CheckCounterAndIncrease(GetRef<Expr>(n));
69 MixedModeVisitor::VisitExpr_(n);
70 }
71
72 void CheckCounterAndIncrease(const Expr& expr) {
73 if (target_expr_id_ == counter_) {
74 target_op_ = expr;
75 }
76 ++counter_;
77 }
78};
79
80IRModule ExtractIntermediateExprPacked(const IRModule& mod, const int expr_id) {
81 return ExtractIntermediateExprWrapper(mod, expr_id).Extract();
82}
83
84TVM_REGISTER_GLOBAL("relay.analysis.ExtractIntermediateExpr")
85 .set_body_typed(ExtractIntermediateExprPacked);
86
87} // namespace relay
88} // namespace tvm
89