In this post, I will share infinite poisson mixture model with dirichlet process as a prior distribution on the number of clusters. When it comes to clustering, sometimes class number 'k' is hard to set. I think Dirichlet process can be one of the solution to rectify it. Regarding approximation of posterior distribution, sampling method will be presented, namely Collapsed Gibbs sampling. I'm glad if you find some fun out of this post.
Dirichlet Process Poisson Mixture Model¶
1. Collapsed Gibbs sampling¶
First of all, we're gonna start from Joint distribution of poisson mixture model which can be expressed as following.
$$p(X,Z,\lambda,\pi) = p(X | \lambda, Z)p(Z|\pi)p(\pi)p(\lambda)$$where X is observation, Z is latent variable, $\lambda$ is parameter of poisson distribution with prior distribution p($\lambda$), and $\pi$ is parameter of categorical distribution with prior distribution $p(\pi)$. If you are interested in poisson mixture model, you can check following link.
Poisson Mixture Model①
If your interest is only how the observation is classified, you can marginalize out $\lambda$, $\pi$. As long as conjugate prior is applied for $p(\lambda)$ and $p(\pi)$, marginalization is tractable.
$$p(X,Z) = \int p(X,Z,\lambda,\pi)d\lambda d \pi$$We can approximate posterior distribution $p(Z|X)$ by gibbs sampling.
$$p(z_n|X,S_{\backslash n}) \propto p(x_n|X_{\backslash n},z_n,Z_{\backslash n})p(z_n|Z_{\backslash n})$$For $p(x_n|X_{\backslash n},z_n,Z_{\backslash n})$, $$p(x_n|X_{\backslash n},z_{n,k}=1,Z_{\backslash n}) = NB(x_n | \sum_{n^{'}\neq n}z_{n^{'},k}x_{n^{'}}+a,\sum_{n^{'}\neq n}z_{n^{'},k}+b)$$
Regarding $p(z_n|Z_{\backslash n})$, we wanna make it have infinite categories.
2. Infinite mixture model¶
Now we're gonna think about dirichlet distribution with infinite categories. Let's say parameters of dirichlet distribution be $\alpha / K$. where $K$ is number of cluster. Then with marginalization over $\pi$, $p(z_i = k|z_{1:n}^{\backslash i},\alpha)$ can be computed as followings,
$$\begin{eqnarray} p(z_{i}=k|z_{1:n}^{\backslash i}, \alpha) = \begin{cases}\frac{n_k^{\backslash i}+\frac{\alpha}{K}}{n-1 + \alpha} \ \ if\ k \in \kappa^+(z^{\backslash i}_{1:n})\\ \frac{\frac{\alpha}{K}}{n-1+\alpha} \ \ if\ k \notin \kappa^+(z^{\backslash i}_{1:n})\end{cases}\end{eqnarray}$$where $\kappa^{+}(z^{\backslash i}_{1:n})$ means the category which is already drawned. If $K \to \infty$, then getting new category would be ZERO. This is not what I want. From this result, the probability of getting new category is same over all $k \notin \kappa^+(z^{\backslash i}_{1:n})$ cluster. Therefore probability of getting new category can be computed by multiplying $K - |\kappa^+(z^{\backslash i}_{1:n})|$.
$$p(z_i \notin \kappa^+(z^{\backslash i}_{1:n})|z^{\backslash i}_{1:n},\alpha)) = \left(1- \frac{|\kappa^{+}(z^{\backslash i}_{1:n})}{K}\right)\frac{\alpha}{n-1+\alpha}$$If $K \to \infty $,
$$\begin{eqnarray} p(z_{i}=k|z_{1:n}^{\backslash i}, \alpha) = \begin{cases}\frac{n_k^{\backslash i}}{n-1\alpha} \ \ if\ k \in \kappa^+(z^{\backslash i}_{1:n})\\
\frac{\alpha}{n-1+\alpha} \ \ if\ k \notin \kappa^+(z^{\backslash i}_{1:n})\end{cases}\end{eqnarray}$$
3. Sample data¶
Let's apply infinite poisson mixture model to toy data. The following is sample data which is created by two different poisson distribution. Hence we expect that result of gibbs sampling would show there are two classes.
from scipy.stats import poisson
from scipy.stats import gamma
from scipy.stats import dirichlet
from scipy.stats import multinomial
from scipy.stats import bernoulli
from scipy.stats import nbinom
from collections import Counter
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from matplotlib import cm
plt.style.use('dark_background')
# Sample from two different poisson distribution.
# whose parameter is 15, 30 respectively.
sample_15 = poisson.rvs(15,size=200,random_state=10)
sample_30 = poisson.rvs(30,size=300,random_state=10)
X = np.concatenate([sample_15,sample_30])
freq = Counter(X)
x = np.array([x for x in freq.keys()])
y = np.array([y for y in freq.values()])
# Plot data drawn from two different poisson distribution
plt.bar(x,y,width=1.0)
plt.title('Sample from two different poisson distribution.')
plt.show()
4. Apply poisson mixture model with dirichlet process¶
MAX_iter = 200
sample_z = []
# parameters of prior distribution
# alphat : Dirichlet distribution
# a, b : gamma distribution
alpha = 4
a = 1
b = 1
debug=False
# initialize latent variables
z = np.zeros(X.shape[0])
for i in range(MAX_iter):
latent = []
# sample n times
for n in range(X.shape[0]):
exist_cat = np.unique(np.delete(z,obj=n))
z_prob = {}
x_prob = {}
if debug:print(n,'exist_cat :',exist_cat)
# cluster number is assigned incrementaly from 0.
for k in exist_cat:
z_prob[k] = (np.where(np.delete(z,obj=n) == k)[0].sum()) / \
(X.shape[0] - 1 + alpha)
a_hat = np.delete(X,obj=n)[np.where(np.delete(z,obj=n)==k)[0]].sum()+a
b_hat = (np.delete(z,obj=n)==k).sum()+b
temp =1/(b_hat+1)
x_prob[k] = nbinom.pmf(X[n],a_hat,1-temp)
# probability of getting new cluster
# label of new cluster is smallest number except for exisiting labels
for i in range(exist_cat.shape[0]+1):
if i not in set(exist_cat):
new_c=i
break
z_prob[new_c] = alpha / (X.shape[0] -1 + alpha)
temp=1/(b+1)
x_prob[new_c] = nbinom.pmf(X[n],a,1-temp)
category = np.append(exist_cat, new_c)
if debug:print(n,'category :',category)
pi_list = np.array([ z*x for z, x in zip(z_prob.values(),x_prob.values())])
pi = pi_list / pi_list.sum()
if debug:print(n,'pi :',pi)
z[n] = category[np.where(multinomial.rvs(n=1,p=pi) == 1)[0]]
if debug:print(n,'sample',z[:n+1])
if debug:print(z)
sample_z.append(z.copy())
sample_z = np.array(sample_z)[10:]
num_of_cluster = np.array([np.unique(s).shape[0] for s in sample_z])
count = Counter(num_of_cluster)
temp_dict = {i:count[i+1] for i in range(4)}
x = temp_dict.keys()
Y = temp_dict.values()
plt.bar(x=x,height=Y,tick_label=[1,2,3,4])
plt.title('Number of cluster',fontsize=16)
plt.show()
Since early sampling might have the effect from initialization, first 10 sampling was eliminated. As we expect, cluster number "2" has dominant in 190 sampling. Interestingly there is few sampling where there are 3 clusters.