1/**
2 * Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *
16 * \author guonix
17 * \date Nov 2020
18 * \brief
19 */
20
21#include "query/knn_task.h"
22#include <gtest/gtest.h>
23#include "index/mock_segment.h" // for MockSegment
24#include "mock_query_context.h" // for Mock*Context
25
26TEST(KNNTaskTest, TestTaskRun) {
27 { // Test Invalid arguments
28 MockKNNQueryContext context;
29 KNNTask task(nullptr, &context);
30 ASSERT_TRUE(task.run() != 0);
31 KNNTask task1(nullptr, nullptr);
32 ASSERT_TRUE(task1.run() != 0);
33 KNNTask task2(std::make_shared<MockSegment>(), nullptr);
34 ASSERT_TRUE(task2.run() != 0);
35 }
36
37 {
38 std::string column{"column"};
39 std::string features{"features"};
40 QueryParams param;
41
42 MockKNNQueryContext context;
43 EXPECT_CALL(context, column()).WillRepeatedly(ReturnRef(column));
44 EXPECT_CALL(context, features()).WillRepeatedly(ReturnRef(features));
45 EXPECT_CALL(context, query_params()).WillRepeatedly(ReturnRef(param));
46 EXPECT_CALL(context, batch_count()).WillRepeatedly(Return(1));
47
48 auto segment = std::make_shared<MockSegment>();
49 // Specific batch knn_search
50 EXPECT_CALL(*segment, knn_search(_, _, _, _, _))
51 .Times(1)
52 .WillOnce(Return(1))
53 .RetiresOnSaturation();
54
55 { // Failed
56 KNNTask task(segment, &context);
57 task.status(Task::Status::SCHEDULED);
58 EXPECT_EQ(task.run(), 1);
59 EXPECT_EQ(task.exit_code(), 1);
60 EXPECT_TRUE(task.result().empty());
61 }
62
63 EXPECT_CALL(*segment, knn_search(_, _, _, _, _))
64 .Times(1)
65 .WillOnce(Return(0))
66 .RetiresOnSaturation();
67
68 {
69 KNNTask task(segment, &context);
70 task.status(Task::Status::SCHEDULED);
71 EXPECT_EQ(task.run(), 0);
72 EXPECT_EQ(task.exit_code(), 0);
73 EXPECT_TRUE(task.wait_finish());
74 EXPECT_TRUE(task.finished());
75 EXPECT_TRUE(task.result().empty());
76 }
77 }
78}
79