Skip to content

Commit

Permalink
Merge pull request #7 from biaslab/upgrade
Browse files Browse the repository at this point in the history
Upgrade simulations
  • Loading branch information
ThijsvdLaar authored Sep 19, 2024
2 parents b90799a + b04026d commit 4b858a8
Show file tree
Hide file tree
Showing 26 changed files with 1,132 additions and 734 deletions.
1,255 changes: 830 additions & 425 deletions Manifest.toml

Large diffs are not rendered by default.

77 changes: 37 additions & 40 deletions Part1/Policy_Inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
"metadata": {},
"outputs": [],
"source": [
"using Pkg;Pkg.activate(\"..\");Pkg.instantiate();\n",
"using Pkg;Pkg.activate(\"..\");# Pkg.instantiate();\n",
"using RxInfer, LinearAlgebra, Distributions, Random\n",
"Random.seed!(666)"
"Random.seed!(666);"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "5d731951",
"metadata": {},
"outputs": [],
Expand All @@ -26,12 +26,12 @@
"include(\"transition_mixture/out.jl\")\n",
"include(\"transition_mixture/switch.jl\")\n",
"include(\"../goal_observation.jl\")\n",
"include(\"helpers.jl\")"
"include(\"helpers.jl\");"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"id": "21614a0c",
"metadata": {},
"outputs": [],
Expand All @@ -51,83 +51,80 @@
" out[argmax(probvec(distribution))] = 1.\n",
"\n",
" PointMass(out)\n",
"end"
"end;"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"id": "349bef1c",
"metadata": {},
"outputs": [],
"source": [
"# Create the model\n",
"@model function t_maze(A,D,B1,B2,B3,B4,T)\n",
"\n",
"@model function t_maze(A,D,B1,B2,B3,B4,T,c)\n",
" z_0 ~ Categorical(D)\n",
"\n",
" z = randomvar(T)\n",
" switch = randomvar(T)\n",
"\n",
" c = datavar(Vector{Float64}, T)\n",
" z_prev = z_0\n",
"\n",
" for t in 1:T\n",
" switch[t] ~ Categorical(fill(1. /4. ,4))\n",
" z[t] ~ TransitionMixture(z_prev,switch[t], B1,B2,B3,B4)\n",
" c[t] ~ GoalObservation(z[t], A) where {pipeline = GeneralizedPipeline(vague(Categorical, 8)) }\n",
" c[t] ~ GoalObservation(z[t], A) where {dependencies = GeneralizedPipeline(vague(Categorical, 8)) }\n",
" z_prev = z[t]\n",
" end\n",
"end;"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"id": "b2dd120e",
"metadata": {},
"outputs": [],
"source": [
"#Pointmass constraints\n",
"@constraints function pointmass_q()\n",
" q(switch) :: PointMass\n",
" q(switch) :: PointMassFormConstraint()\n",
"end\n",
"\n",
"# Node constraints\n",
"@meta function t_maze_meta()\n",
" GoalObservation(c,z) -> GeneralizedMeta()\n",
"end"
"end\n",
"\n",
"@initialization function init_marginals()\n",
" q(z) = Categorical(fill(1. /8. ,8))\n",
"end;"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"id": "c8bcefdb",
"metadata": {},
"outputs": [],
"source": [
"# Configure experiment\n",
"T =2; # Planning horizon\n",
"T = 2; # Planning horizon\n",
"its = 10; #Number of inference iterations to run\n",
"initmarginals = ( z = [Categorical(fill(1. /8. ,8)) for t in 1:T], ) # Initial marginals\n",
"\n",
"A,B,C,D = constructABCD(0.9,[2.0 for t in 1:T],T); # Generate the matrices we need"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"id": "81d60bfc",
"metadata": {},
"outputs": [],
"source": [
"# Run inference\n",
"result = inference(model = t_maze(A,D,B[1],B[2],B[3],B[4],T),\n",
" data= (c = C,),\n",
" initmarginals = initmarginals,\n",
" meta= t_maze_meta(),\n",
" iterations=its,\n",
" )"
"result = infer(model = t_maze(A=A,D=D,B1=B[1],B2=B[2],B3=B[3],B4=B[4],T=T),\n",
" data = (c = C,),\n",
" initialization = init_marginals(),\n",
" meta = t_maze_meta(),\n",
" iterations = its,\n",
" );"
]
},
{
Expand All @@ -150,30 +147,30 @@
"outputs": [],
"source": [
"# Repeat experiments with pointmass constraints\n",
"result = inference(model = t_maze(A,D,B[1],B[2],B[3],B[4],T),\n",
" data= (c = C,),\n",
" initmarginals = initmarginals,\n",
" meta= t_maze_meta(),\n",
" constraints=pointmass_q(),\n",
" iterations=its,\n",
" )\n",
"result = infer(model = t_maze(A=A,D=D,B1=B[1],B2=B[2],B3=B[3],B4=B[4],T=T),\n",
" data = (c = C,),\n",
" initialization = init_marginals(),\n",
" meta = t_maze_meta(),\n",
" constraints = pointmass_q(),\n",
" iterations = its,\n",
" );\n",
"\n",
"println(\"Posterior controls as T=1, \",probvec(result.posteriors[:switch][end][1]), \"\\n\")\n",
"println(\"Posterior controls as T=2, \",probvec(result.posteriors[:switch][end][2]), \"\\n\")"
"println(\"Posterior controls as T=1, \", result.posteriors[:switch][end][1].point, \"\\n\")\n",
"println(\"Posterior controls as T=2, \", result.posteriors[:switch][end][2].point, \"\\n\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.8.2",
"display_name": "Julia 1.10.4",
"language": "julia",
"name": "julia-1.8"
"name": "julia-1.10"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.8.2"
"version": "1.10.4"
}
},
"nbformat": 4,
Expand Down
12 changes: 0 additions & 12 deletions Part1/transition_mixture/testing_ground.jl

This file was deleted.

47 changes: 25 additions & 22 deletions Part2/T-maze_Aggregate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
"source": [
"using Pkg\n",
"Pkg.activate(\"..\")\n",
"Pkg.instantiate()"
"# Pkg.instantiate()"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -38,27 +38,23 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"include(\"../goal_observation.jl\")\n",
"\n",
"# Define the generative model\n",
"@model function t_maze(A_s, D_s, x)\n",
" u = datavar(Matrix{Int64}, 2) # Policy for evaluations\n",
" z = randomvar(2) # Latent states\n",
" c = datavar(Vector{Float64}, 2) # Goal prior statistics\n",
"\n",
"@model function t_maze(A_s, D_s, x, c, u)\n",
" z_0 ~ Categorical(D_s) # State prior\n",
" A ~ MatrixDirichlet(A_s) # Observation matrix prior\n",
"\n",
" z_k_min = z_0\n",
" for k=1:2\n",
" z[k] ~ Transition(z_k_min, u[k])\n",
" c[k] ~ GoalObservation(z[k], A) where {\n",
" meta=GeneralizedMeta(x[k]), \n",
" pipeline=GeneralizedPipeline(vague(Categorical,8))} # With breaker message\n",
" meta = GeneralizedMeta(x[k]), \n",
" dependencies = GeneralizedPipeline(vague(Categorical,8))} # With breaker message\n",
"\n",
" z_k_min = z[k] # Reset for next slice\n",
" end\n",
Expand All @@ -68,9 +64,16 @@
"@constraints function structured(approximate::Bool)\n",
" q(z_0, z, A) = q(z_0, z)q(A)\n",
" if approximate # Sampling approximation on A required for t<3\n",
" q(A) :: SampleList(20)\n",
" q(A) :: SampleListFormConstraint(20, LeftProposal())\n",
" end\n",
"end"
"end\n",
"\n",
"@initialization function init_marginals(A_s)\n",
" q(A) = MatrixDirichlet(asym(A_s))\n",
" q(z_0) = Categorical(asym(8))\n",
" q(z) = [Categorical(asym(8)), Categorical(asym(8))]\n",
"end\n",
";"
]
},
{
Expand Down Expand Up @@ -107,7 +110,7 @@
"@showprogress for r=1:R\n",
" rs = generateGoalSequence(S) # Returns random goal sequence\n",
" (reset, execute, observe) = initializeWorld(A, B, C, D, rs) # Let there be a world\n",
" (infer, act) = initializeAgent(A_0, B, C, D_0) # Let there be a constrained agent\n",
" (inference, act) = initializeAgent(A_0, B, C, D_0) # Let there be a constrained agent\n",
"\n",
" # Step through the experimental protocol\n",
" As = Vector{Matrix}(undef, S) # Posterior statistics for A\n",
Expand All @@ -117,22 +120,22 @@
" for s = 1:S\n",
" reset(s) # Reset world\n",
" for t=1:2\n",
" (Gs[s][t], _) = infer(t, as[s], os[s])\n",
" as[s][t] = act(t, Gs[s][t])\n",
" (Gs[s][t], _) = inference(t, as[s], os[s])\n",
" as[s][t] = act(t, Gs[s][t])\n",
" execute(as[s][t])\n",
" os[s][t] = observe()\n",
" os[s][t] = observe()\n",
" end\n",
" (Gs[s][3], As[s]) = infer(3, as[s], os[s]) # Learn at t=3\n",
" (Gs[s][3], As[s]) = inference(3, as[s], os[s]) # Learn at t=3\n",
" end\n",
" wins[r] = extractWins(os)\n",
" wins[r] = extractWins(os)\n",
" params[r] = deepcopy(As[end])\n",
"end\n",
";"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -215,15 +218,15 @@
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.8.2",
"display_name": "Julia 1.10.4",
"language": "julia",
"name": "julia-1.8"
"name": "julia-1.10"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.8.2"
"version": "1.10.4"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 4b858a8

Please sign in to comment.