PandA-2024.02
helm.c
Go to the documentation of this file.
1 #include "common.h"
2 
3 #pragma GCC diagnostic ignored "-Wincompatible-pointer-types"
4 
5 const size_t P = 11;
6 
7 void helm_naive(
8  real_t w[P],
9  real_t L[P][P],
10  real_t d[4],
11  real_t u[P][P][P],
12  real_t r[P][P][P]
13 )
14 {
15  for (size_t x = 0; x < P; ++x)
16  for (size_t y = 0; y < P; ++y)
17  for (size_t z = 0; z < P; ++z) {
18  r[x][y][z] = d[0] * w[x] * w[y] * w[z] * u[x][y][z];
19  }
20 
21  for (size_t x = 0; x < P; ++x)
22  for (size_t y = 0; y < P; ++y)
23  for (size_t z = 0; z < P; ++z) {
24  real_t accu = 0;
25  for (size_t k = 0; k < P; ++k) {
26  accu += L[x][k] * w[y] * w[z] * u[k][y][z];
27  }
28  r[x][y][z] += d[1] * accu;
29  }
30 
31  for (size_t x = 0; x < P; ++x)
32  for (size_t y = 0; y < P; ++y)
33  for (size_t z = 0; z < P; ++z) {
34  real_t accu = 0;
35  for (size_t k = 0; k < P; ++k) {
36  accu += w[x] * L[y][k] * w[z] * u[x][k][z];
37  }
38  r[x][y][z] += d[2] * accu;
39  }
40 
41  for (size_t x = 0; x < P; ++x)
42  for (size_t y = 0; y < P; ++y)
43  for (size_t z = 0; z < P; ++z) {
44  real_t accu = 0;
45  for (size_t k = 0; k < P; ++k) {
46  accu += w[x] * w[y] * L[z][k] * u[x][y][k];
47  }
48  r[x][y][z] += d[3] * accu;
49  }
50 }
51 
53  real_t w[P],
54  real_t L[P][P],
55  real_t d[4],
56  real_t u[P][P][P],
57  real_t L_hat[P][P],
58  real_t M_u[P][P][P],
59  real_t r[P][P][P]
60 )
61 {
62  for (size_t x = 0; x < P; ++x)
63  for (size_t y = 0; y < P; ++y)
64  for (size_t z = 0; z < P; ++z) {
65  real_t M_u_xyz = w[x] * w[y] * w[z] * u[x][y][z];
66  M_u[x][y][z] = M_u_xyz;
67  r[x][y][z] = M_u_xyz * d[0];
68  }
69 
70  for (size_t i = 0; i < P; ++i)
71  for (size_t j = 0; j < P; ++j) {
72  L_hat[i][j] = L[i][j] / w[j];
73  }
74 
75  for (size_t x = 0; x < P; ++x)
76  for (size_t y = 0; y < P; ++y)
77  for (size_t z = 0; z < P; ++z) {
78  real_t accu = 0;
79  for (size_t k = 0; k < P; ++k) {
80  accu += L_hat[x][k] * M_u[k][y][z];
81  }
82  r[x][y][z] += d[1] * accu;
83  }
84 
85  for (size_t x = 0; x < P; ++x)
86  for (size_t y = 0; y < P; ++y)
87  for (size_t z = 0; z < P; ++z) {
88  real_t accu = 0;
89  for (size_t k = 0; k < P; ++k) {
90  accu += L_hat[y][k] * M_u[x][k][z];
91  }
92  r[x][y][z] += d[2] * accu;
93  }
94 
95  for (size_t x = 0; x < P; ++x)
96  for (size_t y = 0; y < P; ++y)
97  for (size_t z = 0; z < P; ++z) {
98  real_t accu = 0;
99  for (size_t k = 0; k < P; ++k) {
100  accu += L_hat[z][k] * M_u[x][y][k];
101  }
102  r[x][y][z] += d[3] * accu;
103  }
104 }
105 
107  real_t w[P],
108  real_t L[P][P],
109  real_t d[4],
110  real_t u[P][P][P],
111  real_t r[P][P][P]
112 )
113 {
114  real_t* L_hat = make_empty(P*P);
115  real_t* M_u = make_empty(P*P*P);
116 
118  w,
119  L,
120  d,
121  u,
122  L_hat,
123  M_u,
124  r
125  );
126 }
127 
128 int main(int argc, const char* argv[])
129 {
130  srandom(0xDEADBEEF);
131 
132  real_t* w = make_random(P);
133  real_t* L = make_random(P*P);
134  real_t* d = make_random(4);
135  real_t* u = make_random(P*P*P);
136 
137  real_t* r1 = make_empty(P*P*P);
138  helm_naive(w, L, d, u, r1);
139 
140  real_t* r2 = make_empty(P*P*P);
141  helm_factor(w, L, d, u, r2);
142  real_t mse2 = mse(r1, r2, P*P*P);
143  printf("mse2 = %G\n", mse2);
144 
145  return EXIT_SUCCESS;
146 }
real_t * make_random(size_t size)
Definition: common.h:15
void helm_factor_impl(real_t w[P], real_t L[P][P], real_t d[4], real_t u[P][P][P], real_t L_hat[P][P], real_t M_u[P][P][P], real_t r[P][P][P])
Definition: helm.c:52
real_t * make_empty(size_t size)
Definition: common.h:10
void helm_factor(real_t w[P], real_t L[P][P], real_t d[4], real_t u[P][P][P], real_t r[P][P][P])
Definition: helm.c:106
float real_t
Definition: common.h:8
void helm_naive(real_t w[P], real_t L[P][P], real_t d[4], real_t u[P][P][P], real_t r[P][P][P])
Definition: helm.c:7
int main(int argc, const char *argv[])
Definition: helm.c:128
const size_t P
Definition: helm.c:5
static const uint32_t k[]
Definition: sha-256.c:22
real_t mse(const real_t *a, const real_t *b, size_t size)
Definition: common.h:37
#define L
Definition: spmv.h:13
x
Return the smallest n such that 2^n >= _x.

Generated on Mon Feb 12 2024 13:02:48 for PandA-2024.02 by doxygen 1.8.13