1#include "llvm_codegen_utils.h"
2
3namespace taichi::lang {
4
5std::string type_name(llvm::Type *type) {
6 std::string type_name_str;
7 llvm::raw_string_ostream rso(type_name_str);
8 type->print(rso, /*IsForDebug=*/false, /*NoDetails=*/true);
9 return type_name_str;
10}
11
12/*
13 * Determine whether two types are the same
14 * (a type is a renamed version of the other one) based on the
15 * type name. Check recursively if the types are function types.
16 *
17 * Types like `PhysicalCoordinates` occur in every struct module.
18 * When a struct module is copied into a LLVM context,
19 * types in the module which already exist in the context are renamed
20 * by adding a suffix starting with a "." following by a number.
21 * For example, "PhysicalCoordinates" may be renamed to
22 * names like "PhysicalCoordinates.0" and "PhysicalCoordinates.8".
23 */
24bool is_same_type(llvm::Type *a, llvm::Type *b) {
25 if (a == b) {
26 return true;
27 }
28 if (a->isPointerTy() != b->isPointerTy()) {
29 return false;
30 }
31 if (a->isPointerTy()) {
32 return is_same_type(a->getPointerElementType(), b->getPointerElementType());
33 }
34 if (a->isFunctionTy() != b->isFunctionTy()) {
35 return false;
36 }
37 if (a->isFunctionTy()) {
38 auto req_func = llvm::cast<llvm::FunctionType>(a);
39 auto prov_func = llvm::cast<llvm::FunctionType>(b);
40 if (!is_same_type(req_func->getReturnType(), prov_func->getReturnType())) {
41 return false;
42 }
43 if (req_func->getNumParams() != prov_func->getNumParams()) {
44 return false;
45 }
46 for (int j = 0; j < req_func->getNumParams(); j++) {
47 if (!is_same_type(req_func->getParamType(j),
48 prov_func->getParamType(j))) {
49 return false;
50 }
51 }
52 return true;
53 }
54
55 auto a_name = type_name(a);
56 auto b_name = type_name(b);
57 if (a_name.size() > b_name.size()) {
58 std::swap(a_name, b_name);
59 }
60 int len_same = 0;
61 while (len_same < a_name.size()) {
62 if (a_name[len_same] != b_name[len_same]) {
63 break;
64 }
65 len_same++;
66 }
67 if (len_same == a_name.size()) {
68 TI_ASSERT(len_same != b_name.size());
69 if (b_name[len_same] == '.') {
70 // a is xxx, and b is xxx.yyy, yyy is a number
71 for (int i = len_same + 1; i < b_name.size(); i++) {
72 if (!std::isdigit(b_name[i])) {
73 return false;
74 }
75 }
76 return true;
77 }
78 }
79 // a is xxx.yyy, and b is xxx.zzz, yyy and zzz are numbers
80 if (len_same == 0) {
81 return false;
82 }
83 int dot_pos = len_same - 1;
84 while (dot_pos && a_name[dot_pos] != '.') {
85 dot_pos--;
86 }
87 if (!dot_pos) {
88 return false;
89 }
90 for (int i = dot_pos + 1; i < a_name.size(); i++) {
91 if (!std::isdigit(a_name[i])) {
92 return false;
93 }
94 }
95 for (int i = dot_pos + 1; i < b_name.size(); i++) {
96 if (!std::isdigit(b_name[i])) {
97 return false;
98 }
99 }
100 return true;
101}
102
103void check_func_call_signature(llvm::FunctionType *func_type,
104 llvm::StringRef func_name,
105 std::vector<llvm::Value *> &arglist,
106 llvm::IRBuilder<> *builder) {
107 int num_params = func_type->getFunctionNumParams();
108 if (func_type->isFunctionVarArg()) {
109 TI_ASSERT(num_params <= arglist.size());
110 } else {
111 TI_ERROR_IF(num_params != arglist.size(),
112 "Function \"{}\" requires {} arguments but {} provided",
113 std::string(func_name), num_params, arglist.size());
114 }
115
116 for (int i = 0; i < num_params; i++) {
117 auto required = func_type->getFunctionParamType(i);
118 auto provided = arglist[i]->getType();
119 /*
120 * When importing a module from file, the imported `llvm::Type`s can get
121 * conflict with the same type in the llvm::Context. In such scenario,
122 * the imported types will be renamed from "original_type" to
123 * "original_type.xxx", making them separate types in essence.
124 * To make the types of the argument and parameter the same,
125 * a pointer cast must be performed.
126 */
127 if (required != provided) {
128 if (is_same_type(required, provided)) {
129 arglist[i] = builder->CreatePointerCast(arglist[i], required);
130 continue;
131 }
132 TI_INFO("Function : {}", std::string(func_name));
133 TI_INFO(" Type : {}", type_name(func_type));
134 if (&required->getContext() != &provided->getContext()) {
135 TI_INFO(" parameter {} types are from different contexts", i);
136 TI_INFO(" required from context {}",
137 (void *)&required->getContext());
138 TI_INFO(" provided from context {}",
139 (void *)&provided->getContext());
140 }
141 TI_ERROR(" parameter {} mismatch: required={}, provided={}", i,
142 type_name(required), type_name(provided));
143 }
144 }
145}
146
147} // namespace taichi::lang
148