1#include <stddef.h>
2
3#include <nnpack.h>
4#include <nnpack/macros.h>
5
6#include <nnpack/hwinfo.h>
7#include <nnpack/validation.h>
8
9
10struct NNP_CACHE_ALIGN fully_connected_inference_context {
11 size_t input_channels;
12 const void* input;
13 const void* kernel;
14 void* output;
15};
16
17static void compute_fully_connected_inference_f32(
18 const struct fully_connected_inference_context context[restrict static 1],
19 size_t output_channels_subblock_start, size_t output_channels_subblock_size)
20{
21 const size_t input_channels = context->input_channels;
22 const float* input = context->input;
23 const float* kernel = context->kernel;
24 float* output = context->output;
25 const nnp_sdotxf_function sdotxf = nnp_hwinfo.sdotxf.functions[output_channels_subblock_size - 1];
26
27 sdotxf(input, &kernel[output_channels_subblock_start * input_channels],
28 input_channels, &output[output_channels_subblock_start], input_channels);
29}
30
31static void compute_fully_connected_inference_f16f32(
32 const struct fully_connected_inference_context context[restrict static 1],
33 size_t output_channels_subblock_start, size_t output_channels_subblock_size)
34{
35 const size_t input_channels = context->input_channels;
36 const float* input = context->input;
37 const uint16_t* kernel = context->kernel;
38 float* output = context->output;
39 const nnp_shdotxf_function shdotxf = nnp_hwinfo.shdotxf.functions[output_channels_subblock_size - 1];
40
41 shdotxf(input, &kernel[output_channels_subblock_start * input_channels],
42 input_channels, &output[output_channels_subblock_start], input_channels);
43}
44
45enum nnp_status nnp_fully_connected_inference(
46 size_t input_channels,
47 size_t output_channels,
48 const float* input,
49 const float* kernel,
50 float* output,
51 pthreadpool_t threadpool)
52{
53 /* Basic validation of parameters. This check detects invalid, but not unsupported parameters. */
54 enum nnp_status status = validate_fully_connected_arguments(1, input_channels, output_channels);
55 if (status != nnp_status_success) {
56 return status;
57 }
58
59 /* Do the computation */
60 const size_t output_channels_subblock_max = nnp_hwinfo.sdotxf.fusion;
61 struct fully_connected_inference_context fully_connected_inference_context = {
62 .input_channels = input_channels,
63 .input = input,
64 .kernel = kernel,
65 .output = output,
66 };
67 pthreadpool_parallelize_1d_tile_1d(threadpool,
68 (pthreadpool_task_1d_tile_1d_t) compute_fully_connected_inference_f32,
69 &fully_connected_inference_context,
70 output_channels, output_channels_subblock_max,
71 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
72
73 return nnp_status_success;
74}
75
76enum nnp_status nnp_fully_connected_inference_f16f32(
77 size_t input_channels,
78 size_t output_channels,
79 const float* input,
80 const void* kernel,
81 float* output,
82 pthreadpool_t threadpool)
83{
84 /* Basic validation of parameters. This check detects invalid, but not unsupported parameters. */
85 enum nnp_status status = validate_fully_connected_arguments(1, input_channels, output_channels);
86 if (status != nnp_status_success) {
87 return status;
88 }
89
90 /* Do the computation */
91 const size_t output_channels_subblock_max = nnp_hwinfo.sdotxf.fusion;
92 struct fully_connected_inference_context fully_connected_inference_context = {
93 .input_channels = input_channels,
94 .input = input,
95 .kernel = kernel,
96 .output = output,
97 };
98 pthreadpool_parallelize_1d_tile_1d(threadpool,
99 (pthreadpool_task_1d_tile_1d_t) compute_fully_connected_inference_f16f32,
100 &fully_connected_inference_context,
101 output_channels, output_channels_subblock_max,
102 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
103
104 return nnp_status_success;
105}
106