Skip to content

Commit

Permalink
Temporary Pigeons instrumentations to investigate bug
Browse files Browse the repository at this point in the history
  • Loading branch information
alexandrebouchard committed Jul 17, 2023
1 parent 4a7578f commit 2063388
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/main/java/blang/distributions/Generators.java
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ public static double logLogistic(Random rand, double scale, double shape)
/** */
public static double halfstudentt(Random random, double nu, double sigma) {
double t = studentt(random, nu, 0, 1);
System.out.println("t = " + t);
return Math.abs(t) * sigma;
}

Expand Down
57 changes: 52 additions & 5 deletions src/main/java/blang/engines/internals/factories/Pigeons.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package blang.engines.internals.factories;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.InputStreamReader;

import bayonet.distributions.Random;
import blang.engines.internals.PosteriorInferenceEngine;
import blang.inits.Arg;
import blang.inits.DefaultValue;
import blang.inits.GlobalArg;
import blang.inits.experiments.ExperimentResults;
import blang.runtime.SampledModel;
import blang.runtime.internals.objectgraph.GraphAnalysis;

Expand All @@ -19,34 +23,72 @@
*/
public class Pigeons implements PosteriorInferenceEngine
{
@GlobalArg public ExperimentResults results = new ExperimentResults();

@Arg
public Random random;
public Long random;

public Random rng = null;

@Arg @DefaultValue("3")
public double nPassesPerScan = 3;

public boolean log = true; // TODO: default should be false!!
private BufferedWriter logger = null;

SampledModel model;

public static String LOG_POTENTIAL_CODE = "log_potential(";
public static String CALL_SAMPLER_CODE = "call_sampler!(";


public static <T> T static_log(T object) {
return instance.log(object);
}

public <T> T log(T object)
{
if (log)
{
String msg = "[time=" + System.currentTimeMillis() + ",seed=" + random + "] " + object.toString();
if (logger == null)
logger = results.getAutoClosedBufferedWriter("log.txt");
try
{
logger.append(msg + "\n");
logger.flush();
} catch (IOException e)
{
e.printStackTrace();
}
}
return object;
}

static Pigeons instance = null;

@Override
public void performInference()
{
instance = this;
if (rng == null)
rng = new Random(random);
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
try {
while (true)
{
String line = br.readLine();
log("input=" + line);
if (line == null)
return;
if (line.startsWith(LOG_POTENTIAL_CODE))
log_potential(parseAnnealingParameter(line,
LOG_POTENTIAL_CODE));
else
if (line.startsWith(CALL_SAMPLER_CODE))
else if (line.startsWith(CALL_SAMPLER_CODE))
call_sampler(parseAnnealingParameter(line,
CALL_SAMPLER_CODE));
else
throw new RuntimeException();
}
} catch (Exception ioe) {
throw new RuntimeException(ioe);
Expand All @@ -56,17 +98,22 @@ public void performInference()
private void call_sampler(double annealingParam)
{
model.setExponent(annealingParam);
log("call_sampler logd_before=" + model.logDensity());
if (annealingParam == 0.0)
model.forwardSample(random, false);
model.forwardSample(rng, false);
else
model.posteriorSamplingScan(random, nPassesPerScan);
model.posteriorSamplingScan(rng, nPassesPerScan);
System.out.println("response()");
log("call_sampler logd_after=" + model.logDensity());
}



private void log_potential(double annealingParam)
{
double result = model.logDensity(annealingParam);
System.out.println("response(" + result + ")");
log("log_potential logd=" + result);
}

public double parseAnnealingParameter(String line, String code)
Expand Down
6 changes: 5 additions & 1 deletion src/main/java/blang/mcmc/RealSliceSampler.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import blang.core.LogScaleFactor;
import blang.core.WritableRealVar;
import blang.distributions.Generators;
import blang.engines.internals.factories.Pigeons;


public class RealSliceSampler implements Sampler
Expand Down Expand Up @@ -52,6 +53,7 @@ public static RealSliceSampler build(WritableRealVar variable, List<LogScaleFact

public void execute(Random random)
{
Pigeons.static_log("before_slice: " + variable.doubleValue());
// sample slice
final double logSliceHeight = nextLogSliceHeight(random, logDensity()); // log(Y) in Neal's paper
final double oldState = variable.doubleValue(); // x0 in Neal's paper
Expand Down Expand Up @@ -100,6 +102,7 @@ public void execute(Random random)
if (logSliceHeight <= logDensityAt(newState) && accept(oldState, newState, logSliceHeight, leftProposalEndPoint, rightProposalEndPoint)) // *
{
variable.set(newState);
Pigeons.static_log("after_slice(1): " + variable.doubleValue());
return;
}
if (newState < oldState)
Expand All @@ -113,6 +116,7 @@ public void execute(Random random)
// This was observed in a case caused by the uniform generator excluding
// the right end point, creating an infinite loop.
variable.set(oldState);
Pigeons.static_log("after_slice(2): " + variable.doubleValue());
return;
}
}
Expand Down Expand Up @@ -144,7 +148,7 @@ public static double nextLogSliceHeight(Random random, double logDensity)
{
if (logDensity == Double.NEGATIVE_INFINITY)
return -1e100; // work around: if initialized at zero probability, e.g Beta at zero, we want this to be greater INF so than line (*) above is rejected for invalid configs
return logDensity - Generators.unitRateExponential(random);
return logDensity - Pigeons.static_log(Generators.unitRateExponential(random));
}

private double logDensityAt(double x)
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/blang/runtime/SampledModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import blang.core.LogScaleFactor;
import blang.core.Model;
import blang.core.Param;
import blang.engines.internals.factories.Pigeons;
import blang.inits.experiments.tabwriters.TidySerializer;
import blang.mcmc.Sampler;
import blang.mcmc.internals.BuiltSamplers;
Expand Down Expand Up @@ -267,6 +268,7 @@ public void posteriorSamplingStep(Random random)
if (currentPosition == -1)
{
Collections.shuffle(currentSamplingOrder, random);
Pigeons.static_log(currentSamplingOrder);
currentPosition = nPosteriorSamplers() - 1;
}
int samplerIndex = currentSamplingOrder.get(currentPosition--);
Expand Down

0 comments on commit 2063388

Please sign in to comment.