1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/ExecutionEngine/ExecutionEngine.h" |
18 | #include "glow/Graph/Graph.h" |
19 | #include "glow/IR/IR.h" |
20 | #include "glow/IR/IRBuilder.h" |
21 | #include "glow/IR/Instrs.h" |
22 | #include "glow/Support/Random.h" |
23 | |
24 | #include "gtest/gtest.h" |
25 | |
26 | #include <cassert> |
27 | #include <string> |
28 | |
29 | using namespace glow; |
30 | using llvm::cast; |
31 | |
32 | extern "C" { |
33 | // Forward declare functions from libjit. |
34 | extern void libjit_matmul_f(float *c, const float *a, const float *b, |
35 | const dim_t *cDims, const dim_t *aDims, |
36 | const dim_t *bDims); |
37 | } |
38 | |
39 | void infer(Tensor *out, Tensor *lhs, Tensor *rhs) { |
40 | ExecutionEngine EE; |
41 | PlaceholderBindings bindings; |
42 | |
43 | auto &mod = EE.getModule(); |
44 | Function *F = mod.createFunction("main" ); |
45 | auto *lhsVar = |
46 | mod.createPlaceholder(lhs->getElementType(), lhs->dims(), "lhs" , false); |
47 | bindings.allocate(lhsVar); |
48 | auto *rhsVar = |
49 | mod.createPlaceholder(rhs->getElementType(), rhs->dims(), "rhs" , false); |
50 | bindings.allocate(rhsVar); |
51 | auto OT = F->getParent()->uniqueType(out->getElementType(), out->dims()); |
52 | auto *matmul = F->createMatMul("matmul" , OT, lhsVar, rhsVar); |
53 | auto *save = F->createSave("ret" , matmul); |
54 | auto *res = bindings.allocate(save->getPlaceholder()); |
55 | |
56 | EE.compile(CompilationMode::Infer); |
57 | |
58 | updateInputPlaceholders(bindings, {lhsVar, rhsVar}, {lhs, rhs}); |
59 | EE.run(bindings); |
60 | |
61 | out->assign(res); |
62 | } |
63 | |
64 | static void testGemm(dim_t m, dim_t n, dim_t k) { |
65 | PseudoRNG PRNG; |
66 | |
67 | Tensor lhs(ElemKind::FloatTy, {m, k}); |
68 | Tensor rhs(ElemKind::FloatTy, {k, n}); |
69 | lhs.getHandle().randomize(-7.2, 8.3, PRNG); |
70 | rhs.getHandle().randomize(-6.3, 10.1, PRNG); |
71 | Tensor out1(ElemKind::FloatTy, {m, n}); |
72 | Tensor out2(ElemKind::FloatTy, {m, n}); |
73 | |
74 | libjit_matmul_f((float *)out1.getUnsafePtr(), (float *)lhs.getUnsafePtr(), |
75 | (float *)rhs.getUnsafePtr(), out1.dims().data(), |
76 | lhs.dims().data(), rhs.dims().data()); |
77 | |
78 | infer(&out2, &lhs, &rhs); |
79 | |
80 | EXPECT_TRUE(out1.isEqual(out2, 2e-2)); |
81 | } |
82 | |
83 | TEST(Gemm, Sweep) { |
84 | for (size_t m : {1, 4, 5, 8}) { |
85 | for (size_t n : {1, 16, 17, 1024}) { |
86 | for (size_t k : {1, 3}) { |
87 | testGemm(m, n, k); |
88 | } |
89 | } |
90 | } |
91 | } |
92 | |
93 | TEST(Gemm, Big) { testGemm(1, 1028, 32); } |
94 | |