Skip to content

Commit

Permalink
Adding a test for sampling from a CDF.
Browse files Browse the repository at this point in the history
  • Loading branch information
Craigacp committed Jul 27, 2022
1 parent 4c12343 commit d4cfb11
Showing 1 changed file with 47 additions and 1 deletion.
48 changes: 47 additions & 1 deletion Core/src/test/java/org/tribuo/util/UtilTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,15 +17,22 @@
package org.tribuo.util;


import com.oracle.labs.mlrg.olcut.util.MutableLong;
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.SplittableRandom;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.fail;

Expand Down Expand Up @@ -80,4 +87,43 @@ public void testAUC() {
assertEquals(0.5,output,DELTA);
}

@Test
public void testSampleFromCDF() {
double[] pmf = new double[]{0.1,0.2,0.0,0.3,0.0,0.0,0.4,0.0};
double[] cdf = Util.generateCDF(pmf);

double[] expectedCDF = new double[]{0.1,0.3,0.3,0.6,0.6,0.6,1.0,1.0};

assertArrayEquals(expectedCDF,cdf,1e-10);

SplittableRandom rng = new SplittableRandom(1235L);

Map<Integer, MutableLong> counter = new HashMap<>();

final int numSamples = 10000;
for (int i = 0; i < numSamples; i++) {
int curSample = Util.sampleFromCDF(cdf,rng);
MutableLong l = counter.computeIfAbsent(curSample, k -> new MutableLong());
l.increment();
}

assertNotNull(counter.get(0));
assertNotNull(counter.get(1));
assertNull(counter.get(2));
assertNotNull(counter.get(3));
assertNull(counter.get(4));
assertNull(counter.get(5));
assertNotNull(counter.get(6));
assertNull(counter.get(7));

double total = 0;
for (Map.Entry<Integer, MutableLong> e : counter.entrySet()) {
total += e.getValue().longValue();
}
assertEquals(numSamples,total);
assertEquals(counter.get(0).longValue()/total,0.1,1e-1);
assertEquals(counter.get(1).longValue()/total,0.2,1e-1);
assertEquals(counter.get(3).longValue()/total,0.3,1e-1);
assertEquals(counter.get(6).longValue()/total,0.4,1e-1);
}
}

0 comments on commit d4cfb11

Please sign in to comment.