## Sparse matrix-variate normal model ## x is a p by q by n dimensional array #thr: Threshold for convergence. Default value is 1e-4. Iterations # stop when average absolute parameter change is less than thr # * ave(abs((s))), note in glasso it is ave(abs(offdiag(s))) require(glasso) ## Sparse matrix variate graphical model smgm = function(x, rholist=NULL, scale=F, max.iter=100, thr=1e-4, method='scad', start='warm', Sig.init=NULL, Omega.init=NULL, Gam.init=NULL, Psi.init=NULL) { dx = dim(x) p = dx[1]; q = dx[2]; n = dx[3] M = apply(x, c(1,2), mean) # center the data MA = outer(M, rep(1,n)) x = x-MA if(scale) { tx = t(matrix(x, ncol=n)) x = scale(tx) # the p*q scaling matrix, denoted as W W = attr(x,"scaled:scale") x = array(t(x),c(p,q,n)) } else W = diag(p*q) if (is.null(rholist)) { rholist = 2^(seq(-10,2,length=100))*log(p)*log(q)# } rholist = sort(rholist,decreasing=T) nr = length(rholist) ## To store the result Sigma = array(0, c(p,p,nr)) Omega = array(0, c(p,p,nr)) Psi = array(0, c(q,q,nr)) Gam = array(0, c(q,q,nr)) Sig.x = matrix(0,p,p) # \sum X_i \Gamma X_i^T Psi.x = matrix(0,q,q) # \sum X_i^T \Omega X_i if(!is.null(Sig.init)){ Sigma[,,] = Sig.init Omega[,,] = Omega.init Psi[,,] = Psi.init Gam[,,] = Gam.init } else { Sigma[,,] = diag(p) Omega[,,] = diag(p) Psi[,,] = diag(q) Gam[,,] = diag(q) } skip.rho = 0 for(jj in 1:nr) { rho = rholist[jj] iter = 1 JJ = max(jj-1,1) Gam0 = Gam[,,JJ] Psi0 = Psi[,,JJ] Omega0 = Omega[,,JJ] Sigma0 = Sigma[,,JJ] while(iter < max.iter) { Sig.x[,] = 0 for(i in 1:n) Sig.x = Sig.x+x[,,i]%*%Gam0%*%t(x[,,i]) Sig.x = Sig.x/n/q#+1e-6*diag(p) if(jj==1) { fit1=glasso(Sig.x,rho,penalize.diagonal=F) } else { fit1 = glasso(Sig.x, rho, penalize.diagonal=F,start=start, w.init=Sigma0, wi.init=Omega0) } if(fit1$niter==10000 | fit1$ni==100000) { skip.rho=jj break } aa = fit1$w[1,1] Sigma.new = fit1$w/aa Omega.new = fit1$wi*aa Psi.x[,] = 0 for(i in 1:n) Psi.x = Psi.x+t(x[,,i])%*%Omega.new%*%x[,,i] Psi.x = Psi.x/n/p#+1e-6*diag(q) if(jj==1) { fit2 = glasso(Psi.x, rho, penalize.diagonal=F) } else { fit2 = glasso(Psi.x, rho, penalize.diagonal=F, start=start, w.init=Psi0,wi.init=Gam0) } if(fit2$niter==10000 | fit2$ni==100000) { skip.rho = jj break } Psi.new = fit2$w Gam.new = fit2$wi tmp1 = Omega0%x%Gam0 diff = tmp1 - Omega.new%x%Gam.new if(mean(abs(diff))0) { break } Sigma[,,jj] = Sigma.new Omega[,,jj] = Omega.new Psi[,,jj] = Psi.new Gam[,,jj] = Gam.new } ######### SCAD ########### if(method=='scad') { Sigma.scad = Sigma Omega.scad = Omega Psi.scad = Psi Gam.scad = Gam skip.rho2 = 0 for(jj in 1:nr) { rho = rholist[jj] iter = 1 JJ = max(1,jj-1) Gam0 = Gam.scad[,,JJ] Psi0 = Psi.scad[,,JJ] Omega0 = Omega.scad[,,JJ] Sigma0 = Sigma.scad[,,JJ] Omega1 = Omega[,,jj] Gam1 = Gam[,,jj] while(iter < 2) { Sig.x[,] = 0 for(i in 1:n) Sig.x = Sig.x+x[,,i]%*%Gam0%*%t(x[,,i]) Sig.x = Sig.x/n/q#+1e-6*diag(p) fit1 = glasso(Sig.x, scad(Omega1,rho), penalize.diagonal=F, start=start,w.init=Sigma0, wi.init=Omega0) if(fit1$niter==10000 | fit1$ni==100000) { skip.rho2 = jj break } aa = fit1$w[1,1] Sigma.new = fit1$w/aa Omega.new = fit1$wi*aa Psi.x[,] = 0 for(i in 1:n) Psi.x = Psi.x+t(x[,,i])%*%Omega.new%*%x[,,i] Psi.x = Psi.x/n/p#+1e-6*diag(q) fit2 = glasso(Psi.x, scad(Gam1,rho), penalize.diagonal=F, start=start,w.init=Psi0,wi.init=Gam0) if(fit2$niter==10000 | fit1$ni==100000) { skip.rho2 = jj break } Psi.new = fit2$w Gam.new = fit2$wi tmp1 = Omega0%x%Gam0 diff = tmp1 - Omega.new%x%Gam.new if(mean(abs(diff))0) break Sigma.scad[,,jj] = Sigma.new Omega.scad[,,jj] = Omega.new Psi.scad[,,jj] = Psi.new Gam.scad[,,jj] = Gam.new } # for(jj in ... } if(method=='scad') { list(M=M, W=W, Sigma=Sigma, Omega=Omega, Psi=Psi, Gam=Gam, Sigma.scad=Sigma.scad, Omega.scad=Omega.scad,Psi.scad=Psi.scad, Gam.scad=Gam.scad, n=n, p=p, q=q, rholist=rholist, nr=nr) } else { list(M=M, W=W, Sigma=Sigma, Omega=Omega, Psi=Psi, Gam=Gam, n=n, p=p, q=q, rholist=rholist, nr=nr) } } ## SCAD penalty: here is the first derivative ## rho, a=3.7 ## return a penalty matrix scad = function(A, rho, a=3.7) { A = abs(A) z = a*rho-A z[z<0] = 0 B = (A<=rho)+z/(a-1)/rho*(A>rho) rho*B }