Skip to content

Commit

Permalink
Made some improvements to the smallsim_hard example.
Browse files Browse the repository at this point in the history
  • Loading branch information
pcarbo committed Jun 23, 2024
1 parent 0b42379 commit 0bb4102
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 32 deletions.
25 changes: 0 additions & 25 deletions analysis/smallsim_easy.Rmd

This file was deleted.

41 changes: 34 additions & 7 deletions analysis/smallsim_hard.Rmd
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
title: Toy example comparing EM vs. SCD (hard case)
title: Toy example comparing EM vs. SCD
author: Peter Carbonetto
output: workflowr::wflow_html
---
Expand Down Expand Up @@ -40,15 +40,13 @@ simulate_correlated_loadings <- function (n, k) {
}
return(normalize.rows(L))
}
set.seed(1)
set.seed(2)
n <- 100
m <- 400
k <- 6
F <- simulate_factors(m,k)
L <- simulate_loadings(n,k,S)
s <- simulate_sizes(n)
set.seed(1)
L <- simulate_correlated_loadings(n,k)
s <- simulate_sizes(n)
X <- simulate_multinom_counts(L,F,s)
X <- X[,colSums(X > 0) > 0]
```
Expand All @@ -67,9 +65,38 @@ ADD TEXT HERE.
control <- list(extrapolate = FALSE,numiter = 4)
fit0 <- fit_poisson_nmf(X,k,numiter = 20,method = "em",control = control)
fit1 <- fit_poisson_nmf(X,fit0=fit0,numiter=180,method="em",control=control)
fit2 <- fit_poisson_nmf(X,fit0=fit0,numiter=180,method="scd",control=control)
fit2 <- fit_poisson_nmf(X,fit0=fit0,numiter=80,method="scd",control=control)
control$extrapolate <- TRUE
fit2 <- fit_poisson_nmf(X,fit0=fit2,numiter=100,method="scd",control=control)
fit0 <- poisson2multinom(fit0)
fit1 <- poisson2multinom(fit1)
fit2 <- poisson2multinom(fit2)
loadings_scatterplot(fit1$L,fit2$L,topic_colors,"em","scd")
print(loadings_scatterplot(fit1$L,fit2$L,topic_colors,"em","scd"))
```

```{r}
pdat <- rbind(data.frame(iter = 1:200,
ll = fit1$progress$loglik.multinom,
res = fit1$progress$res,
method = "em"),
data.frame(iter = 1:200,
ll = fit2$progress$loglik.multinom,
res = fit2$progress$res,
method = "scd"))
pdat <- transform(pdat,
ll = max(ll) - ll + 0.1)
p <- ggplot(pdat,aes(x = iter,y = ll,color = method)) +
geom_line(size = 0.75) +
scale_y_continuous(trans = "log10") +
scale_color_manual(values = c("dodgerblue","darkorange")) +
labs(x = "iteration",y = "loglik difference") +
theme_cowplot(font_size = 10)
print(p)
```

```{r}
p1 <- simdata_structure_plot(L,topic_colors)
p2 <- simdata_structure_plot(fit1$L,topic_colors)
p3 <- simdata_structure_plot(fit2$L,topic_colors)
plot_grid(p1,p2,p3,nrow = 3,ncol = 1)
```

0 comments on commit 0bb4102

Please sign in to comment.