1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/collage/utils.cc |
22 | * \brief Misc helpers. |
23 | */ |
24 | |
25 | #include "./utils.h" |
26 | |
27 | #include "../../support/scalars.h" |
28 | #include "../op/memory/device_copy.h" |
29 | |
30 | namespace tvm { |
31 | namespace relay { |
32 | namespace collage { |
33 | |
34 | String GetSpecName(const Target& target) { |
35 | if (target.IsExternalCodegen()) { |
36 | return target->kind->name; |
37 | } else { |
38 | return std::string(kTVMSpecNamePrefix) + target->kind->name; |
39 | } |
40 | } |
41 | |
42 | String UnionLabels(String left, String right) { |
43 | if (left.empty()) { |
44 | return right; |
45 | } |
46 | if (right.empty()) { |
47 | return left; |
48 | } |
49 | return left + "+" + right; |
50 | } |
51 | |
52 | String NestLabels(String left, String right) { |
53 | if (left.empty()) { |
54 | return right; |
55 | } |
56 | if (right.empty()) { |
57 | return left; |
58 | } |
59 | if (right.size() > left.size()) { |
60 | std::string right_str = right; |
61 | if (right_str.substr(0, left.size()) == left) { |
62 | return right; |
63 | } |
64 | } |
65 | return left + "." + right; |
66 | } |
67 | |
68 | std::string KindToString(OpPatternKind kind) { |
69 | switch (kind) { |
70 | case kElemWise: |
71 | return "E" ; |
72 | case kBroadcast: |
73 | return "B" ; |
74 | case kInjective: |
75 | return "I" ; |
76 | case kCommReduce: |
77 | return "R" ; |
78 | case kOutEWiseFusable: |
79 | return "A" ; |
80 | case kTuple: |
81 | return "T" ; |
82 | case kOpaque: |
83 | return "O" ; |
84 | } |
85 | return "?" ; |
86 | } |
87 | |
88 | OpPatternKind CombineKinds(OpPatternKind left, OpPatternKind right) { |
89 | return std::max(left, right); |
90 | } |
91 | |
92 | bool CanInline(const Expr& expr) { |
93 | if (expr.as<OpNode>() || expr.as<ConstructorNode>() || expr.as<FunctionNode>()) { |
94 | return true; |
95 | } |
96 | if (const auto* constant_node = expr.as<ConstantNode>()) { |
97 | return support::IsSimpleScalar(constant_node); |
98 | } |
99 | return false; |
100 | } |
101 | |
102 | bool IsSpecialOp(const OpNode* op_node) { |
103 | auto op = GetRef<Op>(op_node); |
104 | static auto fnoncomputational = Op::GetAttrMap<TNonComputational>("TNonComputational" ); |
105 | if (fnoncomputational.count(op) && fnoncomputational[op]) { |
106 | // Operator has been marked as non-computational. |
107 | return true; |
108 | } |
109 | // TODO(mbs): This is incomplete. |
110 | static auto shape_of_op_ = Op::Get("shape_of" ); |
111 | static auto vm_shape_of_op_ = Op::Get("vm.shape_of" ); |
112 | if (op == DeviceCopyOp() || op == shape_of_op_ || op == vm_shape_of_op_) { |
113 | // Operator is compiled away by the VM compilation flow. |
114 | return true; |
115 | } |
116 | return false; |
117 | } |
118 | |
119 | bool MustBeLowered(const Expr& expr) { |
120 | if (const auto* call_node = expr.as<CallNode>()) { |
121 | if (const auto* function_node = call_node->op.as<FunctionNode>()) { |
122 | if (function_node->HasNonzeroAttr(attr::kPrimitive)) { |
123 | // We've already committed to this call being to one or more operators which must be |
124 | // lowered. |
125 | return true; |
126 | } |
127 | } else if (const auto* op_node = call_node->op.as<OpNode>()) { |
128 | if (!IsSpecialOp(op_node)) { |
129 | // The VM compilation path won't rewrite this call. |
130 | return true; |
131 | } |
132 | } |
133 | } |
134 | return false; |
135 | } |
136 | |
137 | std::vector<std::string> SplitString(std::string stmt, const char* del) { |
138 | std::vector<std::string> str_tokens; |
139 | int start = 0; |
140 | int end = stmt.find(del, 0); |
141 | str_tokens.emplace_back(stmt.substr(start, end)); |
142 | while (end != -1) { |
143 | stmt = stmt.substr(end + 1, stmt.size()); |
144 | end = stmt.find(del, 0); |
145 | str_tokens.emplace_back(stmt.substr(start, end)); |
146 | } |
147 | return str_tokens; |
148 | } |
149 | |
150 | } // namespace collage |
151 | } // namespace relay |
152 | } // namespace tvm |
153 | |