-
Notifications
You must be signed in to change notification settings - Fork 0
/
stanGgplot.R
270 lines (250 loc) · 9.23 KB
/
stanGgplot.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
# some ggplot2 graphics for RStan
# note that if the chains are long we also require Hadley's "bigvis" package
# to avoid spending too much time hanging around!
# do a path plot of N points, i.e. follow the sequence of the chain
# the default is for N=50 points and the first two variables
# useful for checking that:
# * I have converged
# * I am mixing properly
#
# explore multiple sub-samples in different colours
# perhaps label the starting point numbers of each subsample in the legend
#
# General outline... for a given chain - possibly chosen at random from
# a fit object... take a pair of parameters and extract a number of different
# subsamples (say 4 by default) and plot them in different colours
# label each subsample with the index of the first of the points in the subsample
#
# consider overlaying this on a two-dimensional density plot or contour plot
#
# validation needed:
# * that there is more than one parameter!
# * the chain specified exists
# *
#
# Parameter explanations:
# fit - an RStan "fit" object
# whichpair - which pair of parameters do you wish to plot?
# chaini - which chain shall we subsample
# N - the lenght of each subsample
# ns - the number of subsamples
#
# nb - if the length of the chain is < N then the entire chain will be plotted
# - if the length of the chain is < N*ns, then the number of subsamples
# will be reduced accordingly - no point in plotting the same point more than once
# - subsamples will be chosen to be non-overlapping
#
# NB - we can end up with duplicate rows, so ideally we highlight these with a bigger point
# > head(e2)
# parameters
# iterations alpha beta
# [1,] 1450.828 0.0002808135
# [2,] 1631.769 0.0002418388
# [3,] 1631.769 0.0002418388
# [4,] 1718.608 0.0002664065
# [5,] 1718.608 0.0002664065
# [6,] 1482.773 0.0002456264
#
# geomlist=c('path','point') is also possible
#
stanPathPlot <- function(fit, whichpair=c(1,2), chaini=1, N=50, ns=4,
ignoreDups=TRUE, sDEBUG=FALSE, geomlist=c('path')) {
stopifnot(require(ggplot2))
stopifnot(require(rstan))
# ok... lets be sensible... these MUST be TRUE
stopifnot(!is.null(fit))
stopifnot(!is.null(whichpair))
stopifnot(chaini >=1)
stopifnot(N > 1)
# need to set permuted as otherwise there will be a random order to the points in the chain!
e1 <- extract(fit, permuted=FALSE)
e2 <- e1[,chaini,whichpair] # now extract the chain and pair we're interested in
e2 <- data.frame(e2)
if (ignoreDups) {
e2 <- e2[!duplicated(e2),]
}
if (sDEBUG) {
cat('Effective Number of rows is', nrow(e2), '\n')
}
# how many will we actually sample?
if (nrow(e2) > N*ns) { # we can manage to do all the subsamples as required
# starting indices
svec <- chooseSubsamples(nrow(e2), N, ns)
# ending indices
evec <- svec + N
} else if (nrow(e2) > N) { # we can't do all of them but we will be able to do at least one
# we can manage to do nrow(e2) div N
svec <- chooseSubsamples(nrow(e2), N, nrow(e2) %/% ns)
evec <- svec + N
} else { # we can't do subsampling, so lets just plot the lot!
svec <- 1
evec <- nrow(e2)
}
# there should be very few (say 4) so it's ok to do a for loop... honest!
for (i in 1:length(svec)) {
# let's create a data.frame with the required info
if (i==1) {
adf <- data.frame(e2[svec[i]:evec[i],])
adf$label <- svec[i]
} else {
tmp <- data.frame(e2[svec[i]:evec[i],])
tmp$label <- svec[i]
# not very memory efficient, but we'll optimize later
adf <- rbind(adf, tmp)
}
# tmp <- adf[,1]
}
adf$label <- as.factor(adf$label)
q <- qplot(adf[,1], adf[,2], xlab=names(adf)[1], ylab=names(adf)[2], colour=adf$label,
xlim=range(e2[,1]), ylim=range(e2[,2]), geom=geomlist) +
labs(colour='Starting Index')
# print(q)
q
}
# having a fit object in memory already I can just say
# saveRDS(fit, file='sample_banana_fit_object.rds')
# and later read it back in using...
# fit <- readRDS('sample_banana_fit_object.rds')
# stanPathPlot(fit)
# nb. if put inside a script, then this will need to have a print() wrapper!!!
# ggsave('sample_banana_fit_object.pdf') # if suitable for saving
if (FALSE) {
pdf('testing_Path_Plots.pdf', width=(297-20)/25.4, height=(210-20)/25.4)
for (i in 1:25) {
stanPathPlot(fit)
}
dev.off()
}
# here's an older dumber version
stanPathPlot.1 <- function(fit, whichpair=c(1,2), chaini=1, N=75, ns=4) {
stopifnot(require(ggplot2))
stopifnot(require(rstan))
# ok... lets be sensible... these MUST be TRUE
stopifnot(!is.null(fit))
stopifnot(!is.null(whichpair))
stopifnot(chaini >=1)
stopifnot(N > 1)
# need to set permuted as otherwise there will be a random order to the points in the chain!
e1 <- extract(fit, permuted=FALSE)
eN <- min(N, nrow(e2)) # how many will we actually sample?
e2 <- e1[,chaini,whichpair]
eRangeMin <- sample(nrow(e2) - eN, 1)
eRangeMax <- eRangeMin + eN
e3 <- e2[c(eRangeMin:eRangeMax),]
qplot(e3[,1], e3[,2], xlab=dimnames(e3)$parameters[1], ylab=dimnames(e3)$parameters[2],
xlim=range(e2[,1]), ylim=range(e2[,2]), geom='path')
}
# stanPathPlot.1(fit)
# a small helper function to chose a set of sane subsamples for stanPathPlot()
# len - the length of the given chain - typically something like 2000
# N - the length of each subsample - defaults to 75 so we can see the paths
# ns - the number of subsamples - not too many otherwise it makes things hard to
#
# error proofing and parameter sanity checking needs to be improved
#
# algorithm - not very clever for now... lets just brute force things for now
# and be clever later - this one creates a full set of possible numbers that
# could be resampled and then after picking each N sized range - removes
# those that have already been sampled - and those that would given
# rise to overlaps
# NB. will temporarily allocate a set of ns integer vectors N-len in size
#
# it's important that we take *random* subsamples
#
# returns a
chooseSubsamples <- function(len, N=75, ns=4, sDEBUG=FALSE) {
# some basic sanity checks
stopifnot(len >2) # lets not be silly!
stopifnot(N > 1)
stopifnot(ns >= 1)
# for the purposes of development lets create some test code
if (sDEBUG) {
len <- 20
N <- 4
ns <- 3
}
# Rn is the full list of possible starting points that are allowed
Rn <- 1:(len-N)
starts <- vector(mode='integer', length=ns)
for (i in 1:ns) { # i <- 1
starts[i] <- sample(Rn, size=1)
# now we need to remove these possibilities from Rn
# but to avoid overlap we also need to remove the N before this too!
# remember that we are working with indices and not the numbers themselves
# so we need to do setdiffs NOT Rn[-c(1:3)] etc...
minRmve <- max(1, starts[i]-N+1)
maxRmve <- min(len, starts[i]+N-1)
Rn <- setdiff(Rn, c(minRmve:maxRmve))
if (sDEBUG) {
cat(i, starts[i])
print(Rn)
}
}
sort(starts)
}
# chooseSubsamples(20, 4, 3) # for test purposes
# system.time( { chooseSubsamples(1000) } ) # not measurable at N=1000
# at N=1 million it takes about 0.3secs on this laptop
# do a marginal density plot of all of the info extracted from a fit object
# for a given parameter
marginalDensity <- function(fit, col=1) {
stopifnot(require(ggplot2))
stopifnot(require(rstan))
# ok... lets be sensible... these MUST be TRUE
stopifnot(!is.null(fit))
stopifnot(!is.null(col))
stopifnot(col >= 1)
adf <- data.frame(extract(fit))
qplot(adf[,col], data=adf, geom='density',
xlim=range(adf[,col]), ylab='Marginal Density', xlab=names(adf)[col])
}
# marginalDensity(fit)
# marginalDensity(fit, col=2)
# lets just see what things look like in a LaTeX document
if (FALSE) {
stopifnot(require(tikzDevice))
# for marginfigures in a Tufte doc
kTIKZ.width <- 2.5
kTIKZ.height <- 2.5
# for plain figures in a Tufte doc
kTIKZ.big.width <- 4.0
kTIKZ.big.height <- 2.8
# word of warning tikz() can't handle lp__ as a parameter name
# so has to be preprocessed to something without the underscores
# or so that the underscores are escaped (for TeX).
tikz('samplePathPlot.tex', width=kTIKZ.big.width, height=kTIKZ.big.height)
print(stanPathPlot(fit))
dev.off()
tikz('sampleMarginalDensityPlot.tex', width=kTIKZ.big.width, height=kTIKZ.big.height)
print(marginalDensity(fit))
dev.off()
}
# XXXX TO DO document this!
#
jointdensityplotfn <- function(fit, adf=NULL, cols=c(1,2),
colNames=c('Alpha', 'Beta'),
tikzOutFnam=NULL) {
if (is.null(fit)) {
fit.s <- extract(fit)
}
# need to
adf <- data.frame(alpha=as.numeric(fit.s$alpha),
beta=as.numeric(fit.s$beta),
lp=as.numeric(fit.s$lp__))
stopifnot(require(ggplot2))
if (!is.null(tikzOutFnam)) {
stopifnot(require(tikzDevice))
tikz(tikzOutFnam, width=kTIKZ.width, height=kTIKZ.height)
}
# subsample down from 20k+ points to 1k if lots of points...
if (nrow(adf) > 1000)
adf <- adf[sample(nrow(adf), size=1000),]
adf <- adf[,cols]
names(adf) <- c('alpha', 'beta')
p <- qplot(alpha, beta, xlim=c(0, max(adf$alpha)),
xlab=colNames[[1]], ylab=colNames[[2]],
geom='point', data=adf)
print(p + stat_density2d())
if (!is.null(tikzOutFnam)) dev.off()
print(p + stat_density2d())
}