Processing math: 100%

Thursday, January 31, 2019

Gaussian Mixture model with EM (Expectation-Maximization) algorithm

In this article, I will share about gaussian mixture model with EM (Expectation-Maximization) algorithm. When it comes to mixture model, I introduced some with bayesian inference in following link Poisson Mixture Model①. However in this article, since I will utilize EM (Expectation-Maximization) algorithm, we will basically apply *Maximum likelihood estimation*. For instance we don't put prior distribution on parameter π which determine the ratio of cluster. Anyway I will be glad if you enjoy this article :)

01. The model

Let's say the number of cluster is K here. At first, let me introduce K-dimentional binary random variable z which is 1-of-K representation. Therefore value of z satisfy z{0,1} and kzk=1. This z is drawn from categorical distribution Cat(z|π). Therefore, probability of z can be written as following,

p(z)=Kk=1πzkk


From which cluster the observation comes from is determined by this zk, hence probability of p(x|z) can be writtein as below,

p(x|z)=KkN(x|μk,Σk)zk

Joint distribution of x and z is ,

p(x,z)=p(x|z)p(z)=Kk{πkN(x|μk,Σk)}zk

Thus, marginal distirbution of x is ,

p(x)=zp(x,z)=KkπkN(x|μk,Σk)

You can also derive posterior distribution of z,

p(z|x)=p(x,z)p(x)=Kk=1{πkN(x|μk,Σk)}zkKkπkN(x|μk,Σk)

With graphical model, this can be expressed like,

02. EM (Expectation-maximization mehod) algorithm

In order to get optimized parameter, we will apply EM(Expectation maximization) algorithm in this article. The steps of EM(expectation maximization) is as following,

  1. Initialize the means μ, and Σ and pi, and compute log likelihood of p(X).
  2. Estep. Evaluate following(which is called "responsibilities"),

    γ(znk)=πkN(xn|μk,Σk)KkπkN(xn|μk,Σk

  3. Mstep. Update μ, and Σ and pi with resposibilities we got previous step.

  4. Evaluate the log likelihood of p(X). And check the convergence of the log likelihood.

03. Implementation

Now the time to implement the gausian mixture model with EM (Expectation-maximization) model. At first, we have 300 observation from 2 different cluster. One has μ=(1,1)T and Σ=(0.40.30.30.4), the other has μ=(2,3)T and Σ=(0.80.20.20.8).

In [3]:
from scipy.stats import multivariate_normal
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn
In [4]:
obs_mean = [[1,1],
            [-2,3]]
obs_cov = [
                    [[0.4,-0.3],
                   [-0.3,0.4]],
                    [[0.8,0.2],
                    [0.2,0.8]]
]
obs_num = [100,200]
obs_data = []

for mean, cov,num in zip(obs_mean,obs_cov,obs_num):
    obs_data.append(multivariate_normal.rvs(mean=mean,
                                                    cov=cov,
                                                    size=num,
                                                    random_state=42))
    
obs_data = np.concatenate([obs_data[0],obs_data[1]])
obs_data.shape
Out[4]:
(300, 2)

Observations look like this :)

In [7]:
plt.scatter(x=obs_data[:,0],y=obs_data[:,1])
plt.title('Observations',fontsize=14)
plt.show()

The following is the implementation of EM algorithm.

In [8]:
# number of cluster 
k_num = 2
# Initial value 
mean = np.array([[1,0],
                             [-1,0]])
cov = np.array([[[0.2,0],
                          [0,0.2]],
                        [[0.2,0],
                        [0,0.2]]])
pi = [0.5,0.5]

# Threshold of optimaization 
eps = 1e-8
# max number of iteration
max_iter = 100
# log likelihood
ln_like = 0

# EM algorithm
for i in range(max_iter):
    # compute log likelihood [ok]
    ln_p_X = np.array([
        np.log(
            np.array(
                [pi[k] * (multivariate_normal.pdf(x=obs_data[i],
                                                  mean=mean[k],
                                                  cov=cov[k])) 
                 for k in range(k_num)]
            ).sum()
        )
        for i in range(obs_data.shape[0])
    ]).sum()
    
    # E Step compute posterior dist of z
    z_pos = np.empty((len(obs_data),k_num))
    for i in range(len(obs_data)):
        # calculator denominator
        denom = np.array(
            [pi[k] * multivariate_normal.pdf(x=obs_data[i],
                                             mean=mean[k],
                                             cov=cov[k]) 
             for k in range(k_num)]
        ).sum()
        # posterior probability
        z_pos[i] = np.array(
            [ pi[k] * multivariate_normal.pdf(x=obs_data[i],
                                              mean=mean[k],
                                              cov=cov[k])/denom 
             for k in range(k_num)]
        )
        
    #M Step update mean,  covariance, pi
    pi = z_pos.mean(axis=0)

    cov = np.array(
        [np.array(
            [z_pos[i,k] * (((obs_data[i] - mean[k])[:,np.newaxis] ) @ ((obs_data[i] - mean[k])[np.newaxis,:]))
            for i in range(len(obs_data))]
            ).sum(axis=0) / z_pos.sum(axis=0)[k]
         for k in range(k_num)]
    )

    mean = np.array(
            [((obs_data *  z_pos[:,k][:,np.newaxis]).sum(axis=0)) / z_pos.sum(axis=0)[k]
            for k in range(k_num)]
    )
    
    # Check whether it was converged
    ln_p_X_after = np.array([
        np.log(
            np.array(
                [pi[k] * (multivariate_normal.pdf(x=obs_data[i],
                                                  mean=mean[k],
                                                  cov=cov[k])) 
                 for k in range(k_num)]
            ).sum()
        )
        for i in range(obs_data.shape[0])
    ]).sum()
    # If it's getting being coverged, wrap up em algorithm
    if abs(ln_p_X - ln_p_X_after) < eps:
        break

The followings are the result of EM algorithm. We can say it's somehow close to parameter of population we set :)

In [13]:
print('mean :\n',mean)
print('covariance matrix :\n',cov)
print('pi :\n',pi)
mean :
 [[ 1.08056432  0.93457732]
 [-2.02086015  3.01428165]]
covariance matrix :
 [[[ 0.29324429 -0.20124271]
  [-0.20124271  0.30748327]]

 [[ 0.72269474  0.17299484]
  [ 0.17299484  0.75529692]]]
pi :
 [0.33179738 0.66820262]