1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/ir/attrs.h>
20#include <tvm/relay/expr_functor.h>
21#include <tvm/relay/transform.h>
22
23namespace tvm {
24namespace relay {
25namespace transform {
26
27namespace {
28
29/*! \brief Collect all attributes whose name contains "layout".
30 */
31struct CollectAttrs : public AttrVisitor {
32 void Visit(const char* key, std::string* value) final {
33 if (std::string(key).find("layout") != std::string::npos) {
34 attrs[key] = String(*value);
35 }
36 }
37 void Visit(const char* key, double* value) final {}
38 void Visit(const char* key, uint64_t* value) final {}
39 void Visit(const char* key, int* value) final {}
40 void Visit(const char* key, int64_t* value) final {}
41 void Visit(const char* key, bool* value) final {}
42 void Visit(const char* key, runtime::NDArray* value) final {}
43 void Visit(const char* key, ObjectRef* value) final {
44 if (std::string(key).find("layout") != std::string::npos) {
45 attrs[key] = *value;
46 }
47 }
48 void Visit(const char* key, DataType* value) final {}
49 void Visit(const char* key, void** value) final {}
50 std::unordered_map<std::string, ObjectRef> attrs;
51};
52} // namespace
53
54/*! \brief Visitor to add structural hash and layout information to `Function`
55 * nodes. Sets the "hash" field on the attr to the structural hash of the
56 * function. Propogates any attributes with "layout" in their name from call
57 * nodes in the Function to the Function's attrs.
58 */
59class LabelOpsMutator : public MixedModeMutator {
60 private:
61 using MixedModeMutator::VisitExpr_;
62 std::unordered_map<std::string, ObjectRef> body_attrs;
63 Expr VisitExpr_(const FunctionNode* op) final {
64 if (op->GetAttr<String>("hash").defined()) {
65 // Already labelled.
66 return ExprMutator::VisitExpr_(op);
67 }
68
69 // body_attrs collects attrs from Calls in the body of this Function. Reset
70 // it so we only get attrs from this Function.
71 body_attrs = {};
72 auto updated = ExprMutator::VisitExpr_(op);
73 size_t hash = StructuralHash()(updated);
74
75 // format hash as fixed length hex string so it is easier to read
76 std::stringstream s;
77 s << std::setfill('0') << std::setw(sizeof(size_t) * 2) << std::hex << hash;
78
79 Function f = WithAttr(Downcast<Function>(updated), "hash", String(s.str()));
80 for (auto p : body_attrs) {
81 f = WithAttr(f, p.first, p.second);
82 }
83 return std::move(f);
84 }
85
86 Expr VisitExpr_(const LetNode* op) final {
87 auto pre_visit = [this](const LetNode* op) {
88 this->Mutate(op->var);
89 this->Mutate(op->value);
90 };
91 auto post_visit = [this](const LetNode* op) {
92 Var var = Downcast<Var>(this->Mutate(op->var));
93 auto value = this->Mutate(op->value);
94 auto body = this->Mutate(op->body);
95 auto expr = GetRef<Expr>(op);
96 if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) {
97 this->memo_[expr] = expr;
98 } else {
99 this->memo_[expr] = Let(var, value, body);
100 }
101 };
102 ExpandANormalForm(op, pre_visit, post_visit);
103 return memo_[GetRef<Expr>(op)];
104 }
105
106 Expr Rewrite_(const CallNode* op, const Expr& post) final {
107 auto updated = MixedModeMutator::Rewrite_(op, post);
108 if (op->attrs.defined()) {
109 CollectAttrs collect;
110 const_cast<BaseAttrsNode*>(op->attrs.get())->VisitAttrs(&collect);
111 for (auto p : collect.attrs) {
112 if (body_attrs.find(p.first) != body_attrs.end() && p.second == body_attrs[p.first]) {
113 LOG(WARNING) << "LabelOps found two call sites with different values for " << p.first
114 << " (" << p.second << " vs " << body_attrs[p.first]
115 << "). Only the first will be recorded.";
116 }
117 body_attrs[p.first] = p.second;
118 }
119 }
120 return updated;
121 }
122};
123
124/*! \brief Add structural hash and layout information to Function nodes. This
125 * information is used later by the profiler.
126 *
127 * The hash and layout information is added to the attrs field of the Function.
128 * The key "hash" contains the structural hash of the node. Any attributes with
129 * "layout" in their name are also added to attrs (for example,
130 * `attrs["src_layout"]` contains the `src_layout` attribute of the TVM op
131 * corresponding to this function call).
132 */
133Pass LabelOps() {
134 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
135 [=](Function f, IRModule m, PassContext pc) {
136 return Downcast<Function>(LabelOpsMutator().Mutate(f));
137 };
138 return CreateFunctionPass(pass_func, 1, "LabelOps", {});
139}
140
141} // namespace transform
142} // namespace relay
143} // namespace tvm
144