function varargout = gmm(X, K_or_centroids) threshold = 1e-15; [N, D] = size(X); if isscalar(K_or_centroids) K = K_or_centroids; rndp = randperm(N); centroids = X(rndp(1:K),:); else K = size(K_or_centroids, 1); centroids = K_or_centroids; end [pMiu pPi pSigma] = init_params(); Lprev = -inf; while true Px = calc_prob(); pGamma = Px .* repmat(pPi, N, 1); pGamma = pGamma ./ repmat(sum(pGamma, 2), 1, K); Nk = sum(pGamma, 1); pMiu = diag(1./Nk) * pGamma' * X; pPi = Nk/N; for kk = 1:K Xshift = X-repmat(pMiu(kk, : ), N, 1); pSigma(:, :, kk) = (Xshift' * ... (diag(pGamma(:, kk)) * Xshift)) / Nk(kk); end L = sum(log(Px*pPi')); if L-Lprev < threshold break; end Lprev = L; end if nargout == 1 varargout = {Px}; else model = []; model.Miu = pMiu; model.Sigma = pSigma; model.Pi = pPi; varargout = {pGamma, model}; end function [pMiu pPi pSigma] = init_params()%初始化参数 pMiu = centroids; pPi = zeros(1, K); pSigma = zeros(D, D, K); distmat = repmat(sum(X.*X, 2), 1, K) + ... repmat(sum(pMiu.*pMiu, 2)', N, 1) - ... 2*X*pMiu'; [~, labels] = min(distmat, [], 2); for k=1:K Xk = X(labels == k, : ); pPi(k) = size(Xk, 1)/N; pSigma(:, :, k) = cov(Xk); end end function Px = calc_prob() Px = zeros(N, K); for k = 1:K Xshift = X-repmat(pMiu(k, : ), N, 1); lemda=1e-5; conv=pSigma(:, :, k)+lemda*diag(diag(ones(D))); inv_pSigma = inv(conv); tmp = sum((Xshift*inv_pSigma) .* Xshift, 2); coef = (2*pi)^(-D/2) * sqrt(det(inv_pSigma)); Px(:, k) = coef * exp(-0.5*tmp); end end end
|