1 | #include <nnpack.h> |
2 | #include <nnpack/reference.h> |
3 | #include <nnpack/activations.h> |
4 | |
5 | |
6 | struct relu_input_gradient_context { |
7 | size_t channels; |
8 | const float* grad_output; |
9 | const float* input; |
10 | float* grad_input; |
11 | float negative_slope; |
12 | }; |
13 | |
14 | static void compute_relu_input_gradient( |
15 | const struct relu_input_gradient_context context[restrict static 1], |
16 | size_t sample) |
17 | { |
18 | const size_t channels = context->channels; |
19 | const float* grad_output = context->grad_output + sample * channels; |
20 | const float* input = context->input + sample * channels; |
21 | float* grad_input = context->grad_input + sample * channels; |
22 | float negative_slope = context->negative_slope; |
23 | |
24 | for (size_t channel = 0; channel < channels; channel++) { |
25 | grad_input[channel] = grad_relu(grad_output[channel], input[channel], negative_slope); |
26 | } |
27 | } |
28 | |
29 | void nnp_relu_input_gradient__reference( |
30 | size_t batch_size, |
31 | size_t channels, |
32 | const float grad_output[], |
33 | const float input[], |
34 | float grad_input[], |
35 | float negative_slope, |
36 | pthreadpool_t threadpool) |
37 | { |
38 | struct relu_input_gradient_context relu_input_gradient_context = { |
39 | .channels = channels, |
40 | .grad_output = grad_output, |
41 | .input = input, |
42 | .grad_input = grad_input, |
43 | .negative_slope = negative_slope, |
44 | }; |
45 | |
46 | pthreadpool_parallelize_1d(threadpool, |
47 | (pthreadpool_function_1d_t) compute_relu_input_gradient, |
48 | &relu_input_gradient_context, |
49 | batch_size, |
50 | PTHREADPOOL_FLAG_DISABLE_DENORMALS); |
51 | } |
52 | |