Saturday, July 21, 2018

Poisson Mixture Model ③

This article is the continuation from Poisson Mixture Model②. In this article, I'm gonna implement poisson mixture model in practice. I applied gibbs sampling to approximate posterior distribution of poisson mixture model.

0. Implementation of Gibbs sampling

Fist of all, we prepare sample data or observed data used in Poisson Mixture Model②.

In [48]:
# Required Import modulejj
from scipy.stats import poisson
from scipy.stats import gamma
from scipy.stats import dirichlet
from scipy.stats import multinomial
import numpy as np
import matplotlib.pyplot as plt
% matplotlib inline
import collections
import pandas as pd 
from matplotlib import cm
In [2]:
# Sample from two different poisson distribution.
# whose parameter is 15, 30 respectively.
sample_15 = poisson.rvs(15,size=200)
sample_30 = poisson.rvs(30,size=300)
sample_freq = np.concatenate([sample_15,sample_30])

We discussed how to avoid overflow when it comes to logsumexp in Poisson Mixture Model②. The following is the simple implementation.

In [3]:
def log_sum_exp(exponent):
    """
    Argument :
    assume exponent is matrix which is shapen 
    (n (number of x), k (number of cluster))
    
    Return :
    The value of "const" to be used normalization.
    It's gonna be n (number of x) dimentional vector.
    """
    max_val = np.max(exponent, axis=1).reshape(-1,1)
    return np.log(np.sum(np.exp(exponent - max_val),
                         axis=1).reshape((-1,1))) + max_val

We also implement function in order to sample from posterior distribution according to gibbs sampling. In this implementatio, tentatively hyper parameter were set 1 for a and b, $(\frac{1}{2},\frac{1}{2})$ for $\alpha$.

In [4]:
def gibbs_mix_poi(X, k_num, iter_num):
    """
    parameter:
    -------------------------
    Assume X be numpy
    """
    # Hyper parameter
    a = 1
    b = 1
    alpha = np.ones((k_num,1))
    
    # Initialize lambda
    lmd = np.ones((k_num,1))
    
    # Initialize pi 
    pi = np.ones((k_num,1)) / k_num
    
    # Something to store result of gibbs sampling
    sampled_s = np.empty((iter_num,X.shape[0],k_num))
    sampled_lmd = np.empty((iter_num,k_num))
    sampled_pi = np.empty((iter_num,k_num))
    
    for i in range(iter_num):
        # Compute eta
        exponent = np.dot(X, np.log(lmd).reshape(1,-1)) 
                                        - lmd.reshape(1,-1) + np.log(pi).reshape(1,-1)
        const = - log_sum_exp(exponent)
        eta = np.exp(exponent + const)
        #Sample lambda
        S = np.array([ multinomial.rvs(n=1,p=temp_pi) for temp_pi in eta])
        sampled_s[i]=S

        #Sample lambda
        hat_a = np.dot(X.T,S).reshape(-1,1) + a
        hat_b = np.sum(S,axis=0).reshape(-1,1) + b
        lmd =np.array([ gamma.rvs(a=tempa,scale=1/tempb) for
                                       tempa, tempb in zip(hat_a,hat_b) ])

        sampled_lmd[i] = lmd
        
        # Sample pi
        hat_alpha = np.sum(S,axis=0).reshape(-1,1) + alpha
        pi = dirichlet.rvs(alpha=hat_alpha.reshape(-1),size=1).reshape(-1,1)
        sampled_pi[i] = pi.reshape(-1)
    
    return sampled_s, sampled_lmd, sampled_pi

Now we are ready to sample with gibbs sampling. Here we're gonna sample with five hundred iterations.

In [6]:
# Sample with 500 iteration
sample_s,sample_lmd,sample_pi = gibbs_mix_poi(
                                                      sample_freq.reshape(-1,1),2,500)

Let's check the result of sample :) First of all, we'll see $\lambda$. As you may know, Expectation of poisson distribution is $\lambda$. As we prepared observed data from two different poisson distribution whose parameter is 15 and 30. Hence we expect something close to 15 and 30.

In [7]:
# Check sampled lambda 
df_lam = pd.DataFrame(sample_lmd, columns=['lambda1','lambda2'])

# Plot lambda
for lam in df_lam.columns:
    df_lam[lam].plot.hist(bins=20)
    plt.title('Sample of {}'.format(lam),fontsize=16)
    plt.text(19,10,'Mean of lambda = \n{}'.format(df_lam[lam].mean()))
    plt.xlim((10,50))
    plt.show()

As you can see expectation of $\lambda$ is close to what it should be :)
Next we're gonna check the sample of $\pi$.

In [8]:
df_pi = pd.DataFrame(sample_pi,columns=['pi_1','pi_2'])

# plot histogram of sample of pi
for pi in df_pi.columns:
    df_pi[pi].plot.hist(bins=20)
    plt.title('Sample of {}'.format(pi),fontsize=16)
    plt.text(0.22,10,'Mean of pi =\n {}'.format(df_pi[pi].mean()))
    plt.xlim((0.2,0.8))
    plt.show()

We observed data with ratio of 4 to 6. The outcome is 3.9 to 6.1. It catches the feature of samples well . :)

Checking the outcome of clustering is a little bit tricky. I implemented with folloing steps.

  1. Classify sample data into multiple bin. In this example, I set 41.
  2. Compute average of which cluster the sample belongs to.
  3. Plot histgram and color according to the number obtained in 2.
In [36]:
# Create binn
max_samp = np.max(sample_freq)
min_samp = np.min(sample_freq)
binn = np.linspace(min_samp,max_samp,41)

# Compute assignment of sample data
assign = np.digitize(sample_freq,binn)

sample_s_t = sample_s.transpose(1,0,2)

ratio_list = np.empty((binn.shape[0],2))

# Compute ratio of cluster for each bin
for i, assign_num in enumerate(np.unique(assign)):
    
    sum_per_iter = np.sum(sample_s_t[assign == assign_num],axis=1)
    sum_iter = np.sum(sum_per_iter,axis=1)
    ratio = sum_per_iter/sum_iter[:,np.newaxis]
    ratio_list[i] = np.mean(ratio,axis=0)

Following is the outcome of clustering of poisson mixture model. You can tell the sample around the 20 has light color since they are ambiguous about which poisson distribution yeilded. where as bigger and smaller number has darker color. I believe it go with your expectation :)

In [62]:
# Plot result of clustering
plt.figure(figsize=(8,4))
_,_, patches = plt.hist(sample_freq,bins=binn)
cm_map = cm.get_cmap('coolwarm')

# Create colors for each bin
colors = np.array([cm_map(ratio) for ratio in ratio_list[:,0]])

# Set color for each bin
for patch, color in zip(patches,colors):
    patch.set_fc(color)
    
plt.title('Result of clustering',fontsize=16)
plt.show()

No comments:

Post a Comment