IT++ Logo
gmm.cpp
Go to the documentation of this file.
1 
29 #include <itpp/srccode/gmm.h>
30 #include <itpp/srccode/vqtrain.h>
32 #include <itpp/base/matfunc.h>
33 #include <itpp/base/specmat.h>
34 #include <itpp/base/random.h>
35 #include <itpp/base/timing.h>
36 #include <iostream>
37 #include <fstream>
38 
40 
41 namespace itpp
42 {
43 
44 GMM::GMM()
45 {
46  d = 0;
47  M = 0;
48 }
49 
50 GMM::GMM(std::string filename)
51 {
52  load(filename);
53 }
54 
55 GMM::GMM(int M_in, int d_in)
56 {
57  M = M_in;
58  d = d_in;
59  m = zeros(M * d);
60  sigma = zeros(M * d);
61  w = 1. / M * ones(M);
62 
63  for (int i = 0;i < M;i++) {
64  w(i) = 1.0 / M;
65  }
66  compute_internals();
67 }
68 
69 void GMM::init_from_vq(const vec &codebook, int dim)
70 {
71 
72  mat C(dim, dim);
73  int i;
74  vec v;
75 
76  d = dim;
77  M = codebook.length() / dim;
78 
79  m = codebook;
80  w = ones(M) / double(M);
81 
82  C.clear();
83  for (i = 0;i < M;i++) {
84  v = codebook.mid(i * d, d);
85  C = C + outer_product(v, v);
86  }
87  C = 1. / M * C;
88  sigma.set_length(M*d);
89  for (i = 0;i < M;i++) {
90  sigma.replace_mid(i*d, diag(C));
91  }
92 
93  compute_internals();
94 }
95 
96 void GMM::init(const vec &w_in, const mat &m_in, const mat &sigma_in)
97 {
98  int i, j;
99  d = m_in.rows();
100  M = m_in.cols();
101 
102  m.set_length(M*d);
103  sigma.set_length(M*d);
104  for (i = 0;i < M;i++) {
105  for (j = 0;j < d;j++) {
106  m(i*d + j) = m_in(j, i);
107  sigma(i*d + j) = sigma_in(j, i);
108  }
109  }
110  w = w_in;
111 
112  compute_internals();
113 }
114 
115 void GMM::set_mean(const mat &m_in)
116 {
117  int i, j;
118 
119  d = m_in.rows();
120  M = m_in.cols();
121 
122  m.set_length(M*d);
123  for (i = 0;i < M;i++) {
124  for (j = 0;j < d;j++) {
125  m(i*d + j) = m_in(j, i);
126  }
127  }
128  compute_internals();
129 }
130 
131 void GMM::set_mean(int i, const vec &means, bool compflag)
132 {
133  m.replace_mid(i*length(means), means);
134  if (compflag) compute_internals();
135 }
136 
137 void GMM::set_covariance(const mat &sigma_in)
138 {
139  int i, j;
140 
141  d = sigma_in.rows();
142  M = sigma_in.cols();
143 
144  sigma.set_length(M*d);
145  for (i = 0;i < M;i++) {
146  for (j = 0;j < d;j++) {
147  sigma(i*d + j) = sigma_in(j, i);
148  }
149  }
150  compute_internals();
151 }
152 
153 void GMM::set_covariance(int i, const vec &covariances, bool compflag)
154 {
155  sigma.replace_mid(i*length(covariances), covariances);
156  if (compflag) compute_internals();
157 }
158 
159 void GMM::marginalize(int d_new)
160 {
161  it_error_if(d_new > d, "GMM.marginalize: cannot change to a larger dimension");
162 
163  vec mnew(d_new*M), sigmanew(d_new*M);
164  int i, j;
165 
166  for (i = 0;i < M;i++) {
167  for (j = 0;j < d_new;j++) {
168  mnew(i*d_new + j) = m(i * d + j);
169  sigmanew(i*d_new + j) = sigma(i * d + j);
170  }
171  }
172  m = mnew;
173  sigma = sigmanew;
174  d = d_new;
175 
176  compute_internals();
177 }
178 
179 void GMM::join(const GMM &newgmm)
180 {
181  if (d == 0) {
182  w = newgmm.w;
183  m = newgmm.m;
184  sigma = newgmm.sigma;
185  d = newgmm.d;
186  M = newgmm.M;
187  }
188  else {
189  it_error_if(d != newgmm.d, "GMM.join: cannot join GMMs of different dimension");
190 
191  w = concat(double(M) / (M + newgmm.M) * w, double(newgmm.M) / (M + newgmm.M) * newgmm.w);
192  w = w / sum(w);
193  m = concat(m, newgmm.m);
194  sigma = concat(sigma, newgmm.sigma);
195 
196  M = M + newgmm.M;
197  }
198  compute_internals();
199 }
200 
201 void GMM::clear()
202 {
203  w.set_length(0);
204  m.set_length(0);
205  sigma.set_length(0);
206  d = 0;
207  M = 0;
208 }
209 
210 void GMM::save(std::string filename)
211 {
212  std::ofstream f(filename.c_str());
213  int i, j;
214 
215  f << M << " " << d << std::endl ;
216  for (i = 0;i < w.length();i++) {
217  f << w(i) << std::endl ;
218  }
219  for (i = 0;i < M;i++) {
220  f << m(i*d) ;
221  for (j = 1;j < d;j++) {
222  f << " " << m(i*d + j) ;
223  }
224  f << std::endl ;
225  }
226  for (i = 0;i < M;i++) {
227  f << sigma(i*d) ;
228  for (j = 1;j < d;j++) {
229  f << " " << sigma(i*d + j) ;
230  }
231  f << std::endl ;
232  }
233 }
234 
235 void GMM::load(std::string filename)
236 {
237  std::ifstream GMMFile(filename.c_str());
238  int i, j;
239 
240  it_error_if(!GMMFile, std::string("GMM::load : cannot open file ") + filename);
241 
242  GMMFile >> M >> d ;
243 
244 
245  w.set_length(M);
246  for (i = 0;i < M;i++) {
247  GMMFile >> w(i) ;
248  }
249  m.set_length(M*d);
250  for (i = 0;i < M;i++) {
251  for (j = 0;j < d;j++) {
252  GMMFile >> m(i*d + j) ;
253  }
254  }
255  sigma.set_length(M*d);
256  for (i = 0;i < M;i++) {
257  for (j = 0;j < d;j++) {
258  GMMFile >> sigma(i*d + j) ;
259  }
260  }
261  compute_internals();
262  std::cout << " mixtures:" << M << " dim:" << d << std::endl ;
263 }
264 
265 double GMM::likelihood(const vec &x)
266 {
267  double fx = 0;
268  int i;
269 
270  for (i = 0;i < M;i++) {
271  fx += w(i) * likelihood_aposteriori(x, i);
272  }
273  return fx;
274 }
275 
276 vec GMM::likelihood_aposteriori(const vec &x)
277 {
278  vec v(M);
279  int i;
280 
281  for (i = 0;i < M;i++) {
282  v(i) = w(i) * likelihood_aposteriori(x, i);
283  }
284  return v;
285 }
286 
287 double GMM::likelihood_aposteriori(const vec &x, int mixture)
288 {
289  int j;
290  double s;
291 
292  it_error_if(d != x.length(), "GMM::likelihood_aposteriori : dimensions does not match");
293  s = 0;
294  for (j = 0;j < d;j++) {
295  s += normexp(mixture * d + j) * sqr(x(j) - m(mixture * d + j));
296  }
297  return normweight(mixture)*std::exp(s);;
298 }
299 
300 void GMM::compute_internals()
301 {
302  int i, j;
303  double s;
304  double constant = 1.0 / std::pow(2 * pi, d / 2.0);
305 
306  normweight.set_length(M);
307  normexp.set_length(M*d);
308 
309  for (i = 0;i < M;i++) {
310  s = 1;
311  for (j = 0;j < d;j++) {
312  normexp(i*d + j) = -0.5 / sigma(i * d + j); // check time
313  s *= sigma(i * d + j);
314  }
315  normweight(i) = constant / std::sqrt(s);
316  }
317 
318 }
319 
320 vec GMM::draw_sample()
321 {
322  static bool first = true;
323  static vec cumweight;
324  double u = randu();
325  int k;
326 
327  if (first) {
328  first = false;
329  cumweight = cumsum(w);
330  it_error_if(std::abs(cumweight(length(cumweight) - 1) - 1) > 1e-6, "weight does not sum to 0");
331  cumweight(length(cumweight) - 1) = 1;
332  }
333  k = 0;
334  while (u > cumweight(k)) k++;
335 
336  return elem_mult(sqrt(sigma.mid(k*d, d)), randn(d)) + m.mid(k*d, d);
337 }
338 
339 GMM gmmtrain(Array<vec> &TrainingData, int M, int NOITER, bool VERBOSE)
340 {
341  mat mean;
342  int i, j, d = TrainingData(0).length();
343  vec sig;
344  GMM gmm(M, d);
345  vec m(d*M);
346  vec sigma(d*M);
347  vec w(M);
348  vec normweight(M);
349  vec normexp(d*M);
350  double LL = 0, LLold, fx;
351  double constant = 1.0 / std::pow(2 * pi, d / 2.0);
352  int T = TrainingData.length();
353  vec x1;
354  int t, n;
355  vec msum(d*M);
356  vec sigmasum(d*M);
357  vec wsum(M);
358  vec p_aposteriori(M);
359  vec x2;
360  double s;
361  vec temp1, temp2;
362  //double MINIMUM_VARIANCE=0.03;
363 
364  //-----------initialization-----------------------------------
365 
366  mean = vqtrain(TrainingData, M, 200000, 0.5, VERBOSE);
367  for (i = 0;i < M;i++) gmm.set_mean(i, mean.get_col(i), false);
368  // for (i=0;i<M;i++) gmm.set_mean(i,TrainingData(randi(0,TrainingData.length()-1)),false);
369  sig = zeros(d);
370  for (i = 0;i < TrainingData.length();i++) sig += sqr(TrainingData(i));
371  sig /= TrainingData.length();
372  for (i = 0;i < M;i++) gmm.set_covariance(i, 0.5*sig, false);
373 
374  gmm.set_weight(1.0 / M*ones(M));
375 
376  //-----------optimization-----------------------------------
377 
378  tic();
379  for (i = 0;i < M;i++) {
380  temp1 = gmm.get_mean(i);
381  temp2 = gmm.get_covariance(i);
382  for (j = 0;j < d;j++) {
383  m(i*d + j) = temp1(j);
384  sigma(i*d + j) = temp2(j);
385  }
386  w(i) = gmm.get_weight(i);
387  }
388  for (n = 0;n < NOITER;n++) {
389  for (i = 0;i < M;i++) {
390  s = 1;
391  for (j = 0;j < d;j++) {
392  normexp(i*d + j) = -0.5 / sigma(i * d + j); // check time
393  s *= sigma(i * d + j);
394  }
395  normweight(i) = constant * w(i) / std::sqrt(s);
396  }
397  LLold = LL;
398  wsum.clear();
399  msum.clear();
400  sigmasum.clear();
401  LL = 0;
402  for (t = 0;t < T;t++) {
403  x1 = TrainingData(t);
404  x2 = sqr(x1);
405  fx = 0;
406  for (i = 0;i < M;i++) {
407  s = 0;
408  for (j = 0;j < d;j++) {
409  s += normexp(i * d + j) * sqr(x1(j) - m(i * d + j));
410  }
411  p_aposteriori(i) = normweight(i) * std::exp(s);
412  fx += p_aposteriori(i);
413  }
414  p_aposteriori /= fx;
415  LL = LL + std::log(fx);
416 
417  for (i = 0;i < M;i++) {
418  wsum(i) += p_aposteriori(i);
419  for (j = 0;j < d;j++) {
420  msum(i*d + j) += p_aposteriori(i) * x1(j);
421  sigmasum(i*d + j) += p_aposteriori(i) * x2(j);
422  }
423  }
424  }
425  for (i = 0;i < M;i++) {
426  for (j = 0;j < d;j++) {
427  m(i*d + j) = msum(i * d + j) / wsum(i);
428  sigma(i*d + j) = sigmasum(i * d + j) / wsum(i) - sqr(m(i * d + j));
429  }
430  w(i) = wsum(i) / T;
431  }
432  LL = LL / T;
433 
434  if (std::abs((LL - LLold) / LL) < 1e-6) break;
435  if (VERBOSE) {
436  std::cout << n << ": " << LL << " " << std::abs((LL - LLold) / LL) << " " << toc() << std::endl ;
437  std::cout << "---------------------------------------" << std::endl ;
438  tic();
439  }
440  else {
441  std::cout << n << ": LL = " << LL << " " << std::abs((LL - LLold) / LL) << "\r" ;
442  std::cout.flush();
443  }
444  }
445  for (i = 0;i < M;i++) {
446  gmm.set_mean(i, m.mid(i*d, d), false);
447  gmm.set_covariance(i, sigma.mid(i*d, d), false);
448  }
449  gmm.set_weight(w);
450  return gmm;
451 }
452 
453 } // namespace itpp
454 
SourceForge Logo

Generated on Sat May 25 2013 16:32:23 for IT++ by Doxygen 1.8.2