diff --git a/src/main/java/blang/distributions/Generators.java b/src/main/java/blang/distributions/Generators.java index c6e294a..3bfe0c8 100644 --- a/src/main/java/blang/distributions/Generators.java +++ b/src/main/java/blang/distributions/Generators.java @@ -99,6 +99,7 @@ public static double logLogistic(Random rand, double scale, double shape) /** */ public static double halfstudentt(Random random, double nu, double sigma) { double t = studentt(random, nu, 0, 1); + System.out.println("t = " + t); return Math.abs(t) * sigma; } diff --git a/src/main/java/blang/engines/internals/factories/Pigeons.java b/src/main/java/blang/engines/internals/factories/Pigeons.java index 1d503eb..cbc6ad8 100644 --- a/src/main/java/blang/engines/internals/factories/Pigeons.java +++ b/src/main/java/blang/engines/internals/factories/Pigeons.java @@ -1,12 +1,16 @@ package blang.engines.internals.factories; import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.IOException; import java.io.InputStreamReader; import bayonet.distributions.Random; import blang.engines.internals.PosteriorInferenceEngine; import blang.inits.Arg; import blang.inits.DefaultValue; +import blang.inits.GlobalArg; +import blang.inits.experiments.ExperimentResults; import blang.runtime.SampledModel; import blang.runtime.internals.objectgraph.GraphAnalysis; @@ -19,34 +23,72 @@ */ public class Pigeons implements PosteriorInferenceEngine { + @GlobalArg public ExperimentResults results = new ExperimentResults(); + @Arg - public Random random; + public Long random; + + public Random rng = null; @Arg @DefaultValue("3") public double nPassesPerScan = 3; + public boolean log = true; // TODO: default should be false!! + private BufferedWriter logger = null; + SampledModel model; public static String LOG_POTENTIAL_CODE = "log_potential("; public static String CALL_SAMPLER_CODE = "call_sampler!("; + + + public static T static_log(T object) { + return instance.log(object); + } + + public T log(T object) + { + if (log) + { + String msg = "[time=" + System.currentTimeMillis() + ",seed=" + random + "] " + object.toString(); + if (logger == null) + logger = results.getAutoClosedBufferedWriter("log.txt"); + try + { + logger.append(msg + "\n"); + logger.flush(); + } catch (IOException e) + { + e.printStackTrace(); + } + } + return object; + } + + static Pigeons instance = null; @Override public void performInference() { + instance = this; + if (rng == null) + rng = new Random(random); BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); try { while (true) { String line = br.readLine(); + log("input=" + line); if (line == null) return; if (line.startsWith(LOG_POTENTIAL_CODE)) log_potential(parseAnnealingParameter(line, LOG_POTENTIAL_CODE)); - else - if (line.startsWith(CALL_SAMPLER_CODE)) + else if (line.startsWith(CALL_SAMPLER_CODE)) call_sampler(parseAnnealingParameter(line, CALL_SAMPLER_CODE)); + else + throw new RuntimeException(); } } catch (Exception ioe) { throw new RuntimeException(ioe); @@ -56,17 +98,22 @@ public void performInference() private void call_sampler(double annealingParam) { model.setExponent(annealingParam); + log("call_sampler logd_before=" + model.logDensity()); if (annealingParam == 0.0) - model.forwardSample(random, false); + model.forwardSample(rng, false); else - model.posteriorSamplingScan(random, nPassesPerScan); + model.posteriorSamplingScan(rng, nPassesPerScan); System.out.println("response()"); + log("call_sampler logd_after=" + model.logDensity()); } + + private void log_potential(double annealingParam) { double result = model.logDensity(annealingParam); System.out.println("response(" + result + ")"); + log("log_potential logd=" + result); } public double parseAnnealingParameter(String line, String code) diff --git a/src/main/java/blang/mcmc/RealSliceSampler.java b/src/main/java/blang/mcmc/RealSliceSampler.java index 3de64bf..e817ffa 100644 --- a/src/main/java/blang/mcmc/RealSliceSampler.java +++ b/src/main/java/blang/mcmc/RealSliceSampler.java @@ -6,6 +6,7 @@ import blang.core.LogScaleFactor; import blang.core.WritableRealVar; import blang.distributions.Generators; +import blang.engines.internals.factories.Pigeons; public class RealSliceSampler implements Sampler @@ -52,6 +53,7 @@ public static RealSliceSampler build(WritableRealVar variable, List