1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/target/parsers/aprofile.cc
22 * \brief Target Parser for Arm(R) Cortex(R) A-Profile CPUs
23 */
24
25#include "aprofile.h"
26
27#include <string>
28
29#include "../../support/utils.h"
30
31namespace tvm {
32namespace target {
33namespace parsers {
34namespace aprofile {
35
36double GetArchVersion(Array<String> mattr) {
37 for (const String& attr : mattr) {
38 std::string attr_string = attr;
39 size_t attr_len = attr_string.size();
40 if (attr_len >= 4 && attr_string.substr(0, 2) == "+v" && attr_string.back() == 'a') {
41 std::string version_string = attr_string.substr(2, attr_string.size() - 2);
42 return atof(version_string.data());
43 }
44 }
45 return 0.0;
46}
47
48double GetArchVersion(Optional<Array<String>> attr) {
49 if (!attr) {
50 return false;
51 }
52 return GetArchVersion(attr.value());
53}
54
55static inline bool HasFlag(String attr, std::string flag) {
56 std::string attr_str = attr;
57 return attr_str.find(flag) != std::string::npos;
58}
59
60static inline bool HasFlag(Optional<String> attr, std::string flag) {
61 if (!attr) {
62 return false;
63 }
64 return HasFlag(attr.value(), flag);
65}
66
67static inline bool HasFlag(Optional<Array<String>> attr, std::string flag) {
68 if (!attr) {
69 return false;
70 }
71 Array<String> attr_array = attr.value();
72
73 auto matching_attr = std::find_if(attr_array.begin(), attr_array.end(),
74 [flag](String attr_str) { return HasFlag(attr_str, flag); });
75 return matching_attr != attr_array.end();
76}
77
78static bool HasFlag(Optional<String> mcpu, Optional<Array<String>> mattr, std::string flag) {
79 return HasFlag(mcpu, flag) || HasFlag(mattr, flag);
80}
81
82bool IsAArch32(Optional<String> mtriple, Optional<String> mcpu) {
83 if (mtriple) {
84 bool is_mprofile = mcpu && support::StartsWith(mcpu.value(), "cortex-m");
85 return support::StartsWith(mtriple.value(), "arm") && !is_mprofile;
86 }
87 return false;
88}
89
90bool IsAArch64(Optional<String> mtriple) {
91 if (mtriple) {
92 return support::StartsWith(mtriple.value(), "aarch64");
93 }
94 return false;
95}
96
97bool IsArch(TargetJSON attrs) {
98 Optional<String> mtriple = Downcast<Optional<String>>(attrs.Get("mtriple"));
99 Optional<String> mcpu = Downcast<Optional<String>>(attrs.Get("mcpu"));
100
101 return IsAArch32(mtriple, mcpu) || IsAArch64(mtriple);
102}
103
104static TargetFeatures GetFeatures(TargetJSON target) {
105 Optional<String> mcpu = Downcast<Optional<String>>(target.Get("mcpu"));
106 Optional<String> mtriple = Downcast<Optional<String>>(target.Get("mtriple"));
107 Optional<Array<String>> mattr = Downcast<Optional<Array<String>>>(target.Get("mattr"));
108
109 double arch_version = GetArchVersion(mattr);
110
111 bool is_aarch64 = IsAArch64(mtriple);
112
113 bool simd_flag = HasFlag(mcpu, mattr, "+neon") || HasFlag(mcpu, mattr, "+simd");
114 bool has_asimd = is_aarch64 || simd_flag;
115
116 bool i8mm_flag = HasFlag(mcpu, mattr, "+i8mm");
117 bool i8mm_disable = HasFlag(mcpu, mattr, "+noi8mm");
118 bool i8mm_default = arch_version >= 8.6;
119 bool i8mm_support = arch_version >= 8.2 && arch_version <= 8.5;
120 bool has_i8mm = (i8mm_default && !i8mm_disable) || (i8mm_support && i8mm_flag);
121
122 bool dotprod_flag = HasFlag(mcpu, mattr, "+dotprod");
123 bool dotprod_disable = HasFlag(mcpu, mattr, "+nodotprod");
124 bool dotprod_default = arch_version >= 8.4;
125 bool dotprod_support = arch_version >= 8.2 && arch_version <= 8.3;
126 bool has_dotprod = (dotprod_default && !dotprod_disable) || (dotprod_support && dotprod_flag);
127
128 return {
129 {"is_aarch64", Bool(is_aarch64)},
130 {"has_asimd", Bool(has_asimd)},
131 {"has_dotprod", Bool(has_dotprod)},
132 {"has_matmul_i8", Bool(has_i8mm)},
133 };
134}
135
136static Array<String> MergeKeys(Optional<Array<String>> existing_keys) {
137 const Array<String> kExtraKeys = {"arm_cpu", "cpu"};
138
139 if (!existing_keys) {
140 return kExtraKeys;
141 }
142
143 Array<String> keys = existing_keys.value();
144 for (String key : kExtraKeys) {
145 if (std::find(keys.begin(), keys.end(), key) == keys.end()) {
146 keys.push_back(key);
147 }
148 }
149 return keys;
150}
151
152TargetJSON ParseTarget(TargetJSON target) {
153 target.Set("features", GetFeatures(target));
154 target.Set("keys", MergeKeys(Downcast<Optional<Array<String>>>(target.Get("keys"))));
155
156 return target;
157}
158
159} // namespace aprofile
160} // namespace parsers
161} // namespace target
162} // namespace tvm
163