1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/ir/attrs.h> |
20 | #include <tvm/relay/expr_functor.h> |
21 | #include <tvm/relay/transform.h> |
22 | |
23 | namespace tvm { |
24 | namespace relay { |
25 | namespace transform { |
26 | |
27 | namespace { |
28 | |
29 | /*! \brief Collect all attributes whose name contains "layout". |
30 | */ |
31 | struct CollectAttrs : public AttrVisitor { |
32 | void Visit(const char* key, std::string* value) final { |
33 | if (std::string(key).find("layout" ) != std::string::npos) { |
34 | attrs[key] = String(*value); |
35 | } |
36 | } |
37 | void Visit(const char* key, double* value) final {} |
38 | void Visit(const char* key, uint64_t* value) final {} |
39 | void Visit(const char* key, int* value) final {} |
40 | void Visit(const char* key, int64_t* value) final {} |
41 | void Visit(const char* key, bool* value) final {} |
42 | void Visit(const char* key, runtime::NDArray* value) final {} |
43 | void Visit(const char* key, ObjectRef* value) final { |
44 | if (std::string(key).find("layout" ) != std::string::npos) { |
45 | attrs[key] = *value; |
46 | } |
47 | } |
48 | void Visit(const char* key, DataType* value) final {} |
49 | void Visit(const char* key, void** value) final {} |
50 | std::unordered_map<std::string, ObjectRef> attrs; |
51 | }; |
52 | } // namespace |
53 | |
54 | /*! \brief Visitor to add structural hash and layout information to `Function` |
55 | * nodes. Sets the "hash" field on the attr to the structural hash of the |
56 | * function. Propogates any attributes with "layout" in their name from call |
57 | * nodes in the Function to the Function's attrs. |
58 | */ |
59 | class LabelOpsMutator : public MixedModeMutator { |
60 | private: |
61 | using MixedModeMutator::VisitExpr_; |
62 | std::unordered_map<std::string, ObjectRef> body_attrs; |
63 | Expr VisitExpr_(const FunctionNode* op) final { |
64 | if (op->GetAttr<String>("hash" ).defined()) { |
65 | // Already labelled. |
66 | return ExprMutator::VisitExpr_(op); |
67 | } |
68 | |
69 | // body_attrs collects attrs from Calls in the body of this Function. Reset |
70 | // it so we only get attrs from this Function. |
71 | body_attrs = {}; |
72 | auto updated = ExprMutator::VisitExpr_(op); |
73 | size_t hash = StructuralHash()(updated); |
74 | |
75 | // format hash as fixed length hex string so it is easier to read |
76 | std::stringstream s; |
77 | s << std::setfill('0') << std::setw(sizeof(size_t) * 2) << std::hex << hash; |
78 | |
79 | Function f = WithAttr(Downcast<Function>(updated), "hash" , String(s.str())); |
80 | for (auto p : body_attrs) { |
81 | f = WithAttr(f, p.first, p.second); |
82 | } |
83 | return std::move(f); |
84 | } |
85 | |
86 | Expr VisitExpr_(const LetNode* op) final { |
87 | auto pre_visit = [this](const LetNode* op) { |
88 | this->Mutate(op->var); |
89 | this->Mutate(op->value); |
90 | }; |
91 | auto post_visit = [this](const LetNode* op) { |
92 | Var var = Downcast<Var>(this->Mutate(op->var)); |
93 | auto value = this->Mutate(op->value); |
94 | auto body = this->Mutate(op->body); |
95 | auto expr = GetRef<Expr>(op); |
96 | if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { |
97 | this->memo_[expr] = expr; |
98 | } else { |
99 | this->memo_[expr] = Let(var, value, body); |
100 | } |
101 | }; |
102 | ExpandANormalForm(op, pre_visit, post_visit); |
103 | return memo_[GetRef<Expr>(op)]; |
104 | } |
105 | |
106 | Expr Rewrite_(const CallNode* op, const Expr& post) final { |
107 | auto updated = MixedModeMutator::Rewrite_(op, post); |
108 | if (op->attrs.defined()) { |
109 | CollectAttrs collect; |
110 | const_cast<BaseAttrsNode*>(op->attrs.get())->VisitAttrs(&collect); |
111 | for (auto p : collect.attrs) { |
112 | if (body_attrs.find(p.first) != body_attrs.end() && p.second == body_attrs[p.first]) { |
113 | LOG(WARNING) << "LabelOps found two call sites with different values for " << p.first |
114 | << " (" << p.second << " vs " << body_attrs[p.first] |
115 | << "). Only the first will be recorded." ; |
116 | } |
117 | body_attrs[p.first] = p.second; |
118 | } |
119 | } |
120 | return updated; |
121 | } |
122 | }; |
123 | |
124 | /*! \brief Add structural hash and layout information to Function nodes. This |
125 | * information is used later by the profiler. |
126 | * |
127 | * The hash and layout information is added to the attrs field of the Function. |
128 | * The key "hash" contains the structural hash of the node. Any attributes with |
129 | * "layout" in their name are also added to attrs (for example, |
130 | * `attrs["src_layout"]` contains the `src_layout` attribute of the TVM op |
131 | * corresponding to this function call). |
132 | */ |
133 | Pass LabelOps() { |
134 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
135 | [=](Function f, IRModule m, PassContext pc) { |
136 | return Downcast<Function>(LabelOpsMutator().Mutate(f)); |
137 | }; |
138 | return CreateFunctionPass(pass_func, 1, "LabelOps" , {}); |
139 | } |
140 | |
141 | } // namespace transform |
142 | } // namespace relay |
143 | } // namespace tvm |
144 | |