1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file extract_intermediate_expr.cc |
22 | * \brief Used for extracting Relay Expr |
23 | by the expression ID of the main function |
24 | that we can see in `print(mod["main"])`. |
25 | */ |
26 | #include <tvm/node/structural_hash.h> |
27 | #include <tvm/relay/analysis.h> |
28 | #include <tvm/relay/expr.h> |
29 | #include <tvm/relay/expr_functor.h> |
30 | |
31 | namespace tvm { |
32 | namespace relay { |
33 | |
34 | class : private MixedModeVisitor { |
35 | public: |
36 | explicit (const IRModule& mod, const int expr_id) |
37 | : mod_(mod), target_expr_id_(expr_id), counter_(0) {} |
38 | |
39 | IRModule () { |
40 | VisitExpr(this->mod_->Lookup("main" )); |
41 | |
42 | // ensure the target expr_id we want to extract is valid. |
43 | ICHECK(target_expr_id_ >= 0 && target_expr_id_ < counter_); |
44 | |
45 | return IRModule::FromExpr(target_op_, {}); |
46 | } |
47 | |
48 | private: |
49 | using MixedModeVisitor::VisitExpr_; |
50 | |
51 | const IRModule ; |
52 | /*! \brief the expr id that we want to extract. */ |
53 | const int ; |
54 | int ; |
55 | Expr ; |
56 | |
57 | void (const CallNode* n) final { |
58 | CheckCounterAndIncrease(GetRef<Expr>(n)); |
59 | MixedModeVisitor::VisitExpr_(n); |
60 | } |
61 | |
62 | void (const TupleNode* n) final { |
63 | CheckCounterAndIncrease(GetRef<Expr>(n)); |
64 | MixedModeVisitor::VisitExpr_(n); |
65 | } |
66 | |
67 | void (const TupleGetItemNode* n) final { |
68 | CheckCounterAndIncrease(GetRef<Expr>(n)); |
69 | MixedModeVisitor::VisitExpr_(n); |
70 | } |
71 | |
72 | void CheckCounterAndIncrease(const Expr& expr) { |
73 | if (target_expr_id_ == counter_) { |
74 | target_op_ = expr; |
75 | } |
76 | ++counter_; |
77 | } |
78 | }; |
79 | |
80 | IRModule (const IRModule& mod, const int expr_id) { |
81 | return ExtractIntermediateExprWrapper(mod, expr_id).Extract(); |
82 | } |
83 | |
84 | TVM_REGISTER_GLOBAL("relay.analysis.ExtractIntermediateExpr" ) |
85 | .set_body_typed(ExtractIntermediateExprPacked); |
86 | |
87 | } // namespace relay |
88 | } // namespace tvm |
89 | |