1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/collage/utils.cc
22 * \brief Misc helpers.
23 */
24
25#include "./utils.h"
26
27#include "../../support/scalars.h"
28#include "../op/memory/device_copy.h"
29
30namespace tvm {
31namespace relay {
32namespace collage {
33
34String GetSpecName(const Target& target) {
35 if (target.IsExternalCodegen()) {
36 return target->kind->name;
37 } else {
38 return std::string(kTVMSpecNamePrefix) + target->kind->name;
39 }
40}
41
42String UnionLabels(String left, String right) {
43 if (left.empty()) {
44 return right;
45 }
46 if (right.empty()) {
47 return left;
48 }
49 return left + "+" + right;
50}
51
52String NestLabels(String left, String right) {
53 if (left.empty()) {
54 return right;
55 }
56 if (right.empty()) {
57 return left;
58 }
59 if (right.size() > left.size()) {
60 std::string right_str = right;
61 if (right_str.substr(0, left.size()) == left) {
62 return right;
63 }
64 }
65 return left + "." + right;
66}
67
68std::string KindToString(OpPatternKind kind) {
69 switch (kind) {
70 case kElemWise:
71 return "E";
72 case kBroadcast:
73 return "B";
74 case kInjective:
75 return "I";
76 case kCommReduce:
77 return "R";
78 case kOutEWiseFusable:
79 return "A";
80 case kTuple:
81 return "T";
82 case kOpaque:
83 return "O";
84 }
85 return "?";
86}
87
88OpPatternKind CombineKinds(OpPatternKind left, OpPatternKind right) {
89 return std::max(left, right);
90}
91
92bool CanInline(const Expr& expr) {
93 if (expr.as<OpNode>() || expr.as<ConstructorNode>() || expr.as<FunctionNode>()) {
94 return true;
95 }
96 if (const auto* constant_node = expr.as<ConstantNode>()) {
97 return support::IsSimpleScalar(constant_node);
98 }
99 return false;
100}
101
102bool IsSpecialOp(const OpNode* op_node) {
103 auto op = GetRef<Op>(op_node);
104 static auto fnoncomputational = Op::GetAttrMap<TNonComputational>("TNonComputational");
105 if (fnoncomputational.count(op) && fnoncomputational[op]) {
106 // Operator has been marked as non-computational.
107 return true;
108 }
109 // TODO(mbs): This is incomplete.
110 static auto shape_of_op_ = Op::Get("shape_of");
111 static auto vm_shape_of_op_ = Op::Get("vm.shape_of");
112 if (op == DeviceCopyOp() || op == shape_of_op_ || op == vm_shape_of_op_) {
113 // Operator is compiled away by the VM compilation flow.
114 return true;
115 }
116 return false;
117}
118
119bool MustBeLowered(const Expr& expr) {
120 if (const auto* call_node = expr.as<CallNode>()) {
121 if (const auto* function_node = call_node->op.as<FunctionNode>()) {
122 if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
123 // We've already committed to this call being to one or more operators which must be
124 // lowered.
125 return true;
126 }
127 } else if (const auto* op_node = call_node->op.as<OpNode>()) {
128 if (!IsSpecialOp(op_node)) {
129 // The VM compilation path won't rewrite this call.
130 return true;
131 }
132 }
133 }
134 return false;
135}
136
137std::vector<std::string> SplitString(std::string stmt, const char* del) {
138 std::vector<std::string> str_tokens;
139 int start = 0;
140 int end = stmt.find(del, 0);
141 str_tokens.emplace_back(stmt.substr(start, end));
142 while (end != -1) {
143 stmt = stmt.substr(end + 1, stmt.size());
144 end = stmt.find(del, 0);
145 str_tokens.emplace_back(stmt.substr(start, end));
146 }
147 return str_tokens;
148}
149
150} // namespace collage
151} // namespace relay
152} // namespace tvm
153