1#include <nnpack.h>
2#include <nnpack/reference.h>
3#include <nnpack/activations.h>
4
5
6struct relu_input_gradient_context {
7 size_t channels;
8 const float* grad_output;
9 const float* input;
10 float* grad_input;
11 float negative_slope;
12};
13
14static void compute_relu_input_gradient(
15 const struct relu_input_gradient_context context[restrict static 1],
16 size_t sample)
17{
18 const size_t channels = context->channels;
19 const float* grad_output = context->grad_output + sample * channels;
20 const float* input = context->input + sample * channels;
21 float* grad_input = context->grad_input + sample * channels;
22 float negative_slope = context->negative_slope;
23
24 for (size_t channel = 0; channel < channels; channel++) {
25 grad_input[channel] = grad_relu(grad_output[channel], input[channel], negative_slope);
26 }
27}
28
29void nnp_relu_input_gradient__reference(
30 size_t batch_size,
31 size_t channels,
32 const float grad_output[],
33 const float input[],
34 float grad_input[],
35 float negative_slope,
36 pthreadpool_t threadpool)
37{
38 struct relu_input_gradient_context relu_input_gradient_context = {
39 .channels = channels,
40 .grad_output = grad_output,
41 .input = input,
42 .grad_input = grad_input,
43 .negative_slope = negative_slope,
44 };
45
46 pthreadpool_parallelize_1d(threadpool,
47 (pthreadpool_function_1d_t) compute_relu_input_gradient,
48 &relu_input_gradient_context,
49 batch_size,
50 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
51}
52