Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Applying function to priors doesn't work with "new" model fitting approach #2477

Closed
DominiqueMakowski opened this issue Jan 22, 2025 · 2 comments
Labels

Comments

@DominiqueMakowski
Copy link
Contributor

Minimal working example

*see below*

Description

I've been using @torfjelde's trick to set the initial parameters to the theoretical mean of the priors. Unfortunately, I've noticed that it doesn't seem to work with the "new" syntax for fitting models

Works:

using Turing

@model function mymodel(y)
	x ~ Normal(0, 1)
	for i in 1:length(y)
        y[i] ~ Normal(x, 1)
    end
end

fitted = mymodel([1.0, 2.0, 3.0])
initial_params = mapreduce(DynamicPPL.tovec  mean, vcat, values(Turing.extract_priors(fitted)))

Errors:

@model function mymodel(y)
	x ~ Normal(0, 1)
	for i in 1:length(y)
        y[i] ~ Normal(x, 1)
    end
end

fitted = mymodel() | (y=[1.0, 2.0, 3.0],)
initial_params = mapreduce(DynamicPPL.tovec  mean, vcat, values(Turing.extract_priors(fitted)))
ERROR: UndefVarError: `y` not defined in `Main`
Suggestion: check for spelling errors or missing imports.
Stacktrace:
 [1] macro expansion
   @ C:\Users\domma\.julia\packages\DynamicPPL\cvlfK\src\compiler.jl:584 [inlined]
 [2] mymodel(__model__::DynamicPPL.Model{…}, __varinfo__::DynamicPPL.ThreadSafeVarInfo{…}, __context__::DynamicPPL.PriorExtractorContext{…})
   @ Main .\Untitled-1:5
 [3] _evaluate!!
   @ C:\Users\domma\.julia\packages\DynamicPPL\cvlfK\src\model.jl:914 [inlined]
 [4] evaluate_threadsafe!!(model::DynamicPPL.Model{…}, varinfo::DynamicPPL.UntypedVarInfo{…}, context::DynamicPPL.PriorExtractorContext{…})
   @ DynamicPPL C:\Users\domma\.julia\packages\DynamicPPL\cvlfK\src\model.jl:903
 [5] evaluate!!(model::DynamicPPL.Model{…}, varinfo::DynamicPPL.UntypedVarInfo{…}, context::DynamicPPL.PriorExtractorContext{…})
   @ DynamicPPL C:\Users\domma\.julia\packages\DynamicPPL\cvlfK\src\model.jl:833
 [6] extract_priors
   @ C:\Users\domma\.julia\packages\DynamicPPL\cvlfK\src\extract_priors.jl:117 [inlined]
 [7] extract_priors(args::DynamicPPL.Model{typeof(mymodel), (), (), (), Tuple{}, Tuple{}, DynamicPPL.ConditionContext{…}})
   @ DynamicPPL C:\Users\domma\.julia\packages\DynamicPPL\cvlfK\src\extract_priors.jl:113
 [8] top-level scope
   @ Untitled-1:23
Some type information was truncated. Use `show(err)` to see complete types.

Julia version info

versioninfo()
Julia Version 1.11.2
Commit 5e9a32e7af (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: 16 × 11th Gen Intel(R) Core(TM) i9-11950H @ 2.60GHz
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, tigerlake)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_EDITOR = code

Manifest

]st --manifest
# This file is machine-generated - editing it directly is not advised

julia_version = "1.11.2"
manifest_format = "2.0"
project_hash = "8cb754c24264dd6e5f40106269377e003b9cda9f"

[[deps.ADTypes]]
git-tree-sha1 = "72af59f5b8f09faee36b4ec48e014a79210f2f4f"
uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
version = "1.11.0"

    [deps.ADTypes.extensions]
    ADTypesChainRulesCoreExt = "ChainRulesCore"
    ADTypesConstructionBaseExt = "ConstructionBase"
    ADTypesEnzymeCoreExt = "EnzymeCore"

    [deps.ADTypes.weakdeps]
    ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
    ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
    EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"

[[deps.AbstractFFTs]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "1.5.0"
weakdeps = ["ChainRulesCore", "Test"]

    [deps.AbstractFFTs.extensions]
    AbstractFFTsChainRulesCoreExt = "ChainRulesCore"
    AbstractFFTsTestExt = "Test"

[[deps.AbstractMCMC]]
deps = ["BangBang", "ConsoleProgressMonitor", "Distributed", "FillArrays", "LogDensityProblems", "Logging", "LoggingExtras", "ProgressLogging", "Random", "StatsBase", "TerminalLoggers", "Transducers"]
git-tree-sha1 = "aa469a7830413bd4c855963e3f648bd9d145c2c3"
uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
version = "5.6.0"

[[deps.AbstractPPL]]
deps = ["AbstractMCMC", "Accessors", "DensityInterface", "JSON", "Random"]
git-tree-sha1 = "bdb19638644450ee1b0fd63740381835069d34b9"
uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
version = "0.9.0"

[[deps.AbstractTrees]]
git-tree-sha1 = "2d9c9a55f9c93e8887ad391fbae72f8ef55e1177"
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
version = "0.4.5"
@penelopeysm
Copy link
Member

penelopeysm commented Jan 23, 2025

This should work:

@model function mymodel() # `y` is no longer a function argument
    x ~ Normal(0, 1)
    y = Vector{Float64}(undef, 3) # this line needs to be added
    for i in 1:length(y)
        y[i] ~ Normal(x, 1)
    end
end

fitted = mymodel() | (y=[1.0, 2.0, 3.0],)

The main issue is that lhs ~ rhs doesn't define lhs until after the tilde statement, so 1:length(y) won't work. (In the original model, it does work because the function argument y is in scope at that point).

I guess in principle this could be reworked, but it would require some fairly complex surgery in DynamicPPL and imo there are other underlying problems with conditioning on vectors that are more pressing. So, although I'll open a corresponding issue in DynamicPPL to track this, I think it's unlikely to be 'fixed' in the near future - hopefully this is a good enough fix :)

If the above doesn't work for any reason, let me know, and we can reopen.

@penelopeysm
Copy link
Member

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants