1 | #include <stddef.h> |
2 | |
3 | #include <nnpack.h> |
4 | #include <nnpack/macros.h> |
5 | |
6 | #include <nnpack/hwinfo.h> |
7 | #include <nnpack/validation.h> |
8 | |
9 | |
10 | struct NNP_CACHE_ALIGN fully_connected_inference_context { |
11 | size_t input_channels; |
12 | const void* input; |
13 | const void* kernel; |
14 | void* output; |
15 | }; |
16 | |
17 | static void compute_fully_connected_inference_f32( |
18 | const struct fully_connected_inference_context context[restrict static 1], |
19 | size_t output_channels_subblock_start, size_t output_channels_subblock_size) |
20 | { |
21 | const size_t input_channels = context->input_channels; |
22 | const float* input = context->input; |
23 | const float* kernel = context->kernel; |
24 | float* output = context->output; |
25 | const nnp_sdotxf_function sdotxf = nnp_hwinfo.sdotxf.functions[output_channels_subblock_size - 1]; |
26 | |
27 | sdotxf(input, &kernel[output_channels_subblock_start * input_channels], |
28 | input_channels, &output[output_channels_subblock_start], input_channels); |
29 | } |
30 | |
31 | static void compute_fully_connected_inference_f16f32( |
32 | const struct fully_connected_inference_context context[restrict static 1], |
33 | size_t output_channels_subblock_start, size_t output_channels_subblock_size) |
34 | { |
35 | const size_t input_channels = context->input_channels; |
36 | const float* input = context->input; |
37 | const uint16_t* kernel = context->kernel; |
38 | float* output = context->output; |
39 | const nnp_shdotxf_function shdotxf = nnp_hwinfo.shdotxf.functions[output_channels_subblock_size - 1]; |
40 | |
41 | shdotxf(input, &kernel[output_channels_subblock_start * input_channels], |
42 | input_channels, &output[output_channels_subblock_start], input_channels); |
43 | } |
44 | |
45 | enum nnp_status nnp_fully_connected_inference( |
46 | size_t input_channels, |
47 | size_t output_channels, |
48 | const float* input, |
49 | const float* kernel, |
50 | float* output, |
51 | pthreadpool_t threadpool) |
52 | { |
53 | /* Basic validation of parameters. This check detects invalid, but not unsupported parameters. */ |
54 | enum nnp_status status = validate_fully_connected_arguments(1, input_channels, output_channels); |
55 | if (status != nnp_status_success) { |
56 | return status; |
57 | } |
58 | |
59 | /* Do the computation */ |
60 | const size_t output_channels_subblock_max = nnp_hwinfo.sdotxf.fusion; |
61 | struct fully_connected_inference_context fully_connected_inference_context = { |
62 | .input_channels = input_channels, |
63 | .input = input, |
64 | .kernel = kernel, |
65 | .output = output, |
66 | }; |
67 | pthreadpool_parallelize_1d_tile_1d(threadpool, |
68 | (pthreadpool_task_1d_tile_1d_t) compute_fully_connected_inference_f32, |
69 | &fully_connected_inference_context, |
70 | output_channels, output_channels_subblock_max, |
71 | PTHREADPOOL_FLAG_DISABLE_DENORMALS); |
72 | |
73 | return nnp_status_success; |
74 | } |
75 | |
76 | enum nnp_status nnp_fully_connected_inference_f16f32( |
77 | size_t input_channels, |
78 | size_t output_channels, |
79 | const float* input, |
80 | const void* kernel, |
81 | float* output, |
82 | pthreadpool_t threadpool) |
83 | { |
84 | /* Basic validation of parameters. This check detects invalid, but not unsupported parameters. */ |
85 | enum nnp_status status = validate_fully_connected_arguments(1, input_channels, output_channels); |
86 | if (status != nnp_status_success) { |
87 | return status; |
88 | } |
89 | |
90 | /* Do the computation */ |
91 | const size_t output_channels_subblock_max = nnp_hwinfo.sdotxf.fusion; |
92 | struct fully_connected_inference_context fully_connected_inference_context = { |
93 | .input_channels = input_channels, |
94 | .input = input, |
95 | .kernel = kernel, |
96 | .output = output, |
97 | }; |
98 | pthreadpool_parallelize_1d_tile_1d(threadpool, |
99 | (pthreadpool_task_1d_tile_1d_t) compute_fully_connected_inference_f16f32, |
100 | &fully_connected_inference_context, |
101 | output_channels, output_channels_subblock_max, |
102 | PTHREADPOOL_FLAG_DISABLE_DENORMALS); |
103 | |
104 | return nnp_status_success; |
105 | } |
106 | |