-
Notifications
You must be signed in to change notification settings - Fork 0
/
experiment3.jl
60 lines (48 loc) · 1.82 KB
/
experiment3.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#!/usr/bin/env julia
#SBATCH --time=02:55:00
#SBATCH --mem-per-cpu=1000M
#SBATCH --array=1-1920
#SBATCH [email protected]
#SBATCH --mail-type=ALL
using JLD2
using OrderedCollections
using Printf
using PrivateBandits.DifferentialPrivacy
using PrivateBandits.LinearBandits
using PrivateBandits.Experiments
horizon = 5*10^7;
dp = (ε=1.0, δ=0.1)
env = EnvParams(dim=5, maxrewardmean=0.75, maxreward=1.0);
gaps = reverse!([0.1 .* 0.5.^(0:6); 0.0])
mechanisms = OrderedDict(
"Gaussian" => GaussianMechanism,
"Gaussian(Opt)" => OptShifted{GaussianMechanism}(env, horizon),
"Wishart" => ShiftedWishart,
"Wishart(Unshifted)" => WishartMechanism,
"Wishart(Opt)" => OptShifted{WishartMechanism}(env, horizon)
)
(ρmin_lo, ρmin_hi) = (extrema ∘ map)(values(mechanisms)) do Mechanism
strategy = make_strategy(env, horizon, Mechanism; dp...)
regparams(strategy; α=1/2horizon).ρmin
end
num_ρmins = 8
ρmin_interp(c) = exp((1-c)*log(ρmin_lo) + c*log(ρmin_hi))
ρmins = ρmin_interp.(((1:num_ρmins).-2) .// (num_ρmins-3))
function task_params(A...; taskid=parse(Int, ENV["SLURM_ARRAY_TASK_ID"]))
Tuple(CartesianIndices(((size.(A)...)..., taskid))[taskid])
end
(ρmin_ix, gap_ix, run_ix) = task_params(ρmins, gaps)
ρmin = ρmins[ρmin_ix]
gap = gaps[gap_ix]
algs = OrderedDict{Symbol, ContextLinBandit}(
:NonPrivate => make_alg(env, horizon; ρ=ρmin),
:Wishart => make_alg(env, horizon, shifted(WishartMechanism; ρmin=ρmin); dp...),
:Gaussian => make_alg(env, horizon, shifted(GaussianMechanism; ρmin=ρmin); dp...)
)
arms = GapArms(env; gap=gap)
for (alg_name, alg) in algs
result_name = "$alg_name,ρmin=$ρmin,Δ=$gap"
mkpath(result_name)
@time result = run_episode(env, alg, arms, horizon; subsample=10^4)
@save joinpath(result_name, @sprintf("%03d.jld", run_ix)) result
end