Comment utiliser un modèle de mélange gaussien (GMM) avec sklearn en python

Comment utiliser un modèle de mélange gaussien (GMM) avec sklearn en python

Daidalos 07 mai 2020


Exemples de comment utiliser un modèle de mélange gaussien (GMM) avec sklearn en python :

from sklearn import mixture

import numpy as np
import matplotlib.pyplot as plt

1 -- Exemple avec une seule gaussienne

Générons par exemple des données aléatoirement suivant une loi normale avec une moyenne $\mu_0 = 5$ et une standard déviation $\sigma_0 = 2$

mu_0 = 5.0
srd_0 = 2.0

data = np.random.randn(100000)
data = data * srd_0 + mu_0

data = data.reshape(-1, 1)

Visualisons les données:

hx, hy, _ = plt.hist(data, bins=50, density=1,color="lightblue")

plt.ylim(0.0,max(hx)+0.05)
plt.title('Gaussian mixture example 01')
plt.grid()

plt.xlim(mu_0-4*srd_0,mu_0+4*srd_0)

plt.savefig("example_gmm_01.png", bbox_inches='tight')
plt.show()

Comment utiliser un modèle de mélange gaussien (GMM) avec sklearn en python

gmm = mixture.GaussianMixture(n_components=1, covariance_type='full').fit(data)

print(gmm.means_)
print(np.sqrt(gmm.covariances_))

[[5.00715457]]
[[[1.99746652]]]

Comparaisons avec numpy:

print(np.mean(data))
print(np.std(data))

4.998997166872173
2.0008903305868855

2 -- Exemple avec un mélange de deux gaussiennes

mu_1 = 2.0
srd_1 = 4.0

mu_2 = 10.0
srd_2 = 1.0

new_data = np.random.randn(50000)
data_1 = new_data * srd_1 + mu_1

new_data = np.random.randn(50000)
data_2 = new_data * srd_2 + mu_2

new_data = np.concatenate((data_1, data_2), axis=0)

new_data = new_data.reshape(-1, 1)

hx, hy, _ = plt.hist(new_data, bins=100, density=1,color="lightblue")

plt.title('Gaussian mixture example 02')
plt.grid()

plt.savefig("example_gmm_02.png", bbox_inches='tight')
plt.show()

Comment utiliser un modèle de mélange gaussien (GMM) avec sklearn en python

gmm = mixture.GaussianMixture(n_components=2, max_iter=1000, covariance_type='full').fit(new_data)

print('means')
print(gmm.means_)
#print(gmm.covariances_)
print('std')
print(np.sqrt(gmm.covariances_))

means
[[1.82377272]
 [9.9837662 ]]
std
[[[3.89836502]]

 [[1.02825841]]]

3 -- Références