diff --git a/src/main/java/nl/peterbloem/motive/exec/ClassExperiment.java b/src/main/java/nl/peterbloem/motive/exec/ClassExperiment.java deleted file mode 100644 index 1c7dd62..0000000 --- a/src/main/java/nl/peterbloem/motive/exec/ClassExperiment.java +++ /dev/null @@ -1,596 +0,0 @@ -package nl.peterbloem.motive.exec; - -import static java.util.concurrent.TimeUnit.MINUTES; -import static nl.peterbloem.kit.Functions.dot; -import static nl.peterbloem.kit.Functions.subset; -import static nl.peterbloem.kit.Functions.tic; -import static nl.peterbloem.kit.Functions.toc; -import static nl.peterbloem.kit.Series.series; -import static org.nodes.LightUGraph.copy; -import static org.nodes.models.USequenceEstimator.perturbation; - -import static nl.peterbloem.kit.Functions.tic; - -import java.io.BufferedReader; -import java.io.BufferedWriter; -import java.io.File; -import java.io.FileReader; -import java.io.FileWriter; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.LinkedHashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.ThreadPoolExecutor; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; - -import org.apache.xerces.util.SynchronizedSymbolTable; -import org.nodes.DGraph; -import org.nodes.Graph; -import org.nodes.Graphs; -import org.nodes.LightUGraph; -import org.nodes.MapDTGraph; -import org.nodes.MapUTGraph; -import org.nodes.Subgraph; -import org.nodes.UGraph; -import org.nodes.ULink; -import org.nodes.UNode; -import org.nodes.UTGraph; -import org.nodes.algorithms.Nauty; -import org.nodes.data.Examples; -import org.nodes.data.RDF; -import org.nodes.models.EdgeListModel; -import org.nodes.models.USequenceEstimator; -import org.nodes.models.DegreeSequenceModel.Prior; -import org.nodes.motifs.AllSubgraphs; -import org.openrdf.rio.RDFFormat; - -import nl.peterbloem.kit.FileIO; -import nl.peterbloem.kit.FrequencyModel; -import nl.peterbloem.kit.Functions; -import nl.peterbloem.kit.Generator; -import nl.peterbloem.kit.Global; -import nl.peterbloem.kit.MaxObserver; -import nl.peterbloem.kit.Series; -import nl.peterbloem.kit.data.Point; -import nl.peterbloem.kit.data.classification.Classification; -import nl.peterbloem.kit.data.classification.Classified; -import nl.peterbloem.motive.MotifSearchModel; -import nl.peterbloem.motive.UPlainMotifExtractor; - -public class ClassExperiment -{ - /** - * Which motif sizes to use in constructing the features - */ - public List sizes= Arrays.asList(3, 4, 5); - - /** - * How many samples to take from the FANMOD null model - */ - public int samples = 1000; - - /** - * How often to iterate the null model Markov chain for each sample - */ - public int mixingTime = 10000; - - /** - * The significance level (for both fanmod and motive) - */ - public double alpha = 0.05; - - /** - * How many samples to take in the motive version - */ - public int motiveSamples = 100000; - - /** - * How many hubs to remove from the data before counting instances - * (this reduces the size of the instances) - */ - public int hubsToRemove = 500; - - /** - * The number of instances to use (a sample form the total number available) - */ - public int numInstances = 100; - - /** - * How many steps to use for the instances - */ - public int instanceDepth = 2; - - /** - * The probability of expanding a search tree node in the fanmod enumeration - */ - public double prob = 0.001; - - /** - * The seed for the random number generator - */ - public int seed = 0; - - public Comparator natural = Functions.natural(); - - public UGraph graph = null; - public Map map = null; - - - public void main() - throws IOException - { - tic(); - -// UGraph graph = RDF.readSimple(new File("/Users/Peter/Documents/Datasets/graphs/aifb/aifb_fixed_complete.n3")); -// Map map = tsv(new File("/Users/Peter/Documents/Datasets/graphs/aifb/completeDataset.tsv")); - - // * Remove hubs - LinkedList> nodes = new LinkedList>(graph.nodes()); - Comparator> comp = new Comparator>() - { - @Override - public int compare(UNode n1, UNode n2) - { - int d1 = n1.degree(), - d2 = n2.degree(); - - return - Integer.compare(d1, d2); - } - }; - Collections.sort(nodes, comp); - - List toRemove = new ArrayList(hubsToRemove); - while(toRemove.size() < hubsToRemove) - { - UNode node = nodes.pop(); - - if(! map.containsKey(node.label())) - { - toRemove.add(node.index()); - } - } - - Comparator intComp = Collections.reverseOrder(); - Collections.sort(toRemove, intComp); - - for(int index : toRemove) - graph.get(index).remove(); - - System.out.println(MaxObserver.quickSelect(10, Graphs.degrees(graph), intComp, false)); - - System.out.println("graph read: " + Functions.toc() + " s, size " + graph.size()); - - List classInts = new ArrayList(new LinkedHashSet(map.values())); - System.out.println("table read: " + map.size()); - - int c = 0; - Classified> data = Classification.empty(); - - - tic(); - // - Sample random instances - for(String name : subset(map.keySet(), numInstances)) - { - UNode instanceNode = graph.node(name); - UGraph instance = instance(instanceNode, instanceDepth); - - data.add(instance, classInts.indexOf( map.get(name))); - } - - double size = 0; - double numLinks = 0; - for(int i : series(data.size())) - { - size += data.get(i).size(); - numLinks += data.get(i).numLinks(); - - dot(i, data.size()); - } - - size /= data.size(); - numLinks /= data.size(); - - System.out.println("Instances loaded, n: " + size + ", m: " + numLinks + ", time: " + toc() + "s."); - - // * Collect all connected motifs (up to isomorphism) for the given sizes - List> motifs = new ArrayList>(); - for(int msize : sizes) - motifs.addAll(Graphs.allIsoConnected(msize, "x")); - - System.out.println(motifs.size() + " features"); - - BufferedWriter motiveWriter = new BufferedWriter(new FileWriter( - new File(String.format("motive.%05d.csv", seed)))); - BufferedWriter fanmodWriter = new BufferedWriter(new FileWriter( - new File(String.format("fanmod.%05d.csv", seed)))); - - for(int i : Series.series(data.size())) - { - Global.log().info("Starting instance " + i + " of " + data.size()); - - UGraph bg = Graphs.blank(data.get(i), "x"); - int cls = data.cls(i); - - motiveWriter.write(cls + ""); - for(double v : featuresMotive(bg, motifs)) - motiveWriter.write(", " + v); - motiveWriter.newLine(); - motiveWriter.flush(); - - fanmodWriter.write(cls + ""); - for(double v : featuresFANMOD(bg, motifs)) - fanmodWriter.write(", " + v); - fanmodWriter.newLine(); - fanmodWriter.flush(); - } - - motiveWriter.close(); - fanmodWriter.close(); - - try { - FileIO.python(new File("."), "scripts/plot.classification.py"); - } catch (Exception e) - { - Global.log().warning("Failed to run plot script. The script has been copied to the output directory. (trace:" + e + ")"); - } - } - - private UGraph instance(UNode instanceNode, int depth) - { - Map map = new HashMap(); - - UGraph graph = instanceNode.graph(); - UGraph instance = new LightUGraph(); - - Set added = new LinkedHashSet(); - - UNode iiNode = instance.add(instanceNode.label()); - map.put(instanceNode.index(), iiNode.index()); - added.add(instanceNode.index()); - - expand(graph, instance, map, added, depth); - - return instance; - } - - /** - * Expands 'instance', a subgraph of 'graph', by adding all neighbours of - * nodes indicated in 'addded'. Repeats recursively to the given depth. - * - * @param graph - * @param instance - * @param added Indices in the supergraph of nodes that were added in the last iteration. - */ - private void expand(UGraph graph, UGraph instance, Map map, Set added, int depth) - { - if(depth == 0) - return; - - Set inInstance = new HashSet(map.keySet()); - - Set newAdded = new LinkedHashSet(); - for (int index : added) - for (UNode neighbor : graph.get(index).neighbors()) - { - if (inInstance.contains(neighbor.index())) - { - // If this happens, all relevant links should already be included - } - else if (map.containsKey(neighbor.index())) // node already added - { - if(! instance.get(map.get(index)).connected(instance.get(map.get(neighbor.index())))) - instance.get(map.get(index)).connect(instance.get(map.get(neighbor.index()))); - } else // new node - { - UNode nwNode = instance.add(neighbor.label()); - map.put(neighbor.index(), nwNode.index()); - - instance.get(map.get(index)).connect(nwNode); - - newAdded.add(neighbor.index()); - } - } - - // * Add all links between nodes in nwAdded - for(int oldIndex : newAdded) - for(UNode neighbor : graph.get(oldIndex).neighbors()) - { - int oldNIndex = neighbor.index(); - if(newAdded.contains(oldNIndex) && oldNIndex > oldIndex) - { - int index = map.get(oldIndex); - int nIndex = map.get(oldNIndex); - if(! instance.get(index).connected(instance.get(nIndex))) - instance.get(index).connect(instance.get(nIndex)); - } - } - - expand(graph, instance, map, newAdded, depth -1); - } - -// private UGraph instanceOld(UNode instanceNode, int depth) -// { -// UGraph graph = instanceNode.graph(); -// -// Set neighbors = new LinkedHashSet(); -// neighbors.add(instanceNode.index()); -// -// expandOld(neighbors, graph, depth); -// -// return Subgraph.uSubgraphIndices(graph, neighbors); -// } -// -// private void expandOld(Set neighbors, UGraph graph, int depth) -// { -// if(depth == 0) -// return; -// -// Set nw = new LinkedHashSet(); -// for(int index : neighbors) -// for(UNode neighbor : graph.get(index).neighbors()) -// nw.add(neighbor.index()); -// -// neighbors.addAll(nw); -// -// expandOld(neighbors, graph, depth - 1); -// } - - private List featuresFANMOD(final UGraph graph, final List> motifs) - { - FrequencyModel> dataCounts = count(graph, sizes, prob); - - Global.log().info("Data count completed"); - - // * Synched object to collect the counts from the different threads - class Collector { - private Map, List> nullCounts = new LinkedHashMap, List>(); - - public Collector() - { - for(UGraph motif : motifs) - nullCounts.put(motif, new ArrayList(samples)); - } - - public synchronized void register(FrequencyModel> sampleCounts) - { - for(UGraph token : motifs) - nullCounts.get(token).add((int)sampleCounts.frequency(token)); - } - - public Map, List> nullCounts() - { - return nullCounts; - } - } - - final Collector collector = new Collector(); - - final USequenceEstimator model = new USequenceEstimator(graph, "x"); - - ExecutorService executor = Executors.newFixedThreadPool(Global.numThreads()); - final AtomicInteger finished = new AtomicInteger(0); - - for(final int i : series(samples)) - { - executor.execute(new Thread() - { - public void run() - { - // * Generate a random graph - final Generator> gen = model.uniform(mixingTime); - UGraph sample = gen.generate(); - - // * Count its subgraphs - FrequencyModel> sampleCounts = count(sample, sizes, prob); - - // * Add the counts to the nullCounts - collector.register(sampleCounts); - - finished.incrementAndGet(); - } - }); - } - - // * Execute all threads and wait until finished - executor.shutdown(); - try - { - while(! executor.awaitTermination(1, MINUTES)) - Global.log().info(finished + " of " + samples + " completed."); - - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - - Global.log().info("Sampling finished"); - - Map, List> nullCounts = collector.nullCounts(); - - for(List counts : nullCounts.values()) - Collections.sort(counts); - - List features = new ArrayList(motifs.size()); - for(UGraph motif : motifs) - { - int dc = (int) dataCounts.frequency(motif); - List nc = collector.nullCounts().get(motif); - - features.add(prop(nc, dc) <= alpha ? 1.0 : 0.0); - } - - return features; - } - - private List featuresMotive(UGraph graph, List> motifs) - { - - System.out.println("n: " + graph.size() + " m: " + graph.numLinks() + ""); - UPlainMotifExtractor ex = new UPlainMotifExtractor(graph, motiveSamples, sizes.get(0), sizes.get(sizes.size()-1)); - - double threshold = - Functions.log2(alpha); - System.out.println("Threshold: " + threshold); - - double baseline = new EdgeListModel(Prior.COMPLETE).codelength(graph); - System.out.println("Baseline: " + baseline); - - List degrees = Graphs.degrees(graph); - - List features = new ArrayList(motifs.size()); - for(UGraph motif : motifs) - { - List> occ = ex.occurrences(motif); - if(occ == null) - occ = Collections.emptyList(); - - double length = MotifSearchModel.sizeELInst(graph, degrees, motif, occ, true, -1); - - features.add( (baseline - length) > threshold ? 1.0 : 0.0 ); - } - - return features; - } - - /** - * Returns the proportion of 'values' that is larger than or equal to - * the given threshold. - * - * @param values - * @param threshold - * @return - */ - private static double prop(List values, int threshold) - { - int total = 0; - for(int value : values) - { - if(value >= threshold) - break; - - total ++; - } - - return (values.size() - total) / (double) values.size(); - } - - private static FrequencyModel> count(UGraph graph, List sizes, double prob) - { - Comparator natural = Functions.natural(); - FrequencyModel> fm = new FrequencyModel>(); - - for(int size : sizes) - { - AllSubgraphs as = new AllSubgraphs(graph, size, prob); - - for(Set indices : as) - { - UGraph sub = Subgraph.uSubgraphIndices(graph, indices); - sub = Graphs.reorder(sub, Nauty.order(sub, natural)); - - fm.add(sub); - } - } - - return fm; - } - - /** - * Returns a blanked graph with string labels - * @param graph - * @param label - * @return - */ - private static UGraph str(UGraph graph, String label) - { - UGraph result = new MapUTGraph(); - - for(UNode node : graph.nodes()) - result.add(label); - - for(ULink link : graph.links()) - result.get(link.first().index()).connect( - result.get(link.second().index())); - - return result; - } - - /** - * Does a run of the curveball algorithm for each graph and dumps the - * perturbation scores to a CSV file. - * - * @param graphs - */ - private static void estimateMixingTime(Classified> graphs) - throws IOException - { - int n = 100000; int m = 3; - List> scores = new ArrayList>(); - - for(int c : Series.series(m)) - { - UGraph graph = Functions.choose(graphs); - - List s = new ArrayList(n); - - List> - start = USequenceEstimator.adjacencies(graph), - current = USequenceEstimator.adjacencies(graph); - - for(int i : series(n)) - { - USequenceEstimator.step(current); - s.add(perturbation(start, current)); - } - - scores.add(s); - - dot(c, graphs.size()); - } - - // * Write to CSV - - BufferedWriter writer = new BufferedWriter(new FileWriter(new File("./am.csv"))); - for(int row : series(n)) - { - for(int col : series(m)) - writer.write((col != 0? ", " : "") + scores.get(col).get(row)); - - writer.write("\n"); - } - writer.close(); - - } - - public static Map tsv(File file) - throws IOException - { - Map map = new LinkedHashMap(); - - BufferedReader reader = new BufferedReader(new FileReader(file)); - - reader.readLine(); // skip titles - String line = reader.readLine(); - - while(line != null) - { - String[] split = line.split("\\s"); - map.put(split[1], split[2]); - - line = reader.readLine(); - } - - return map; - } -} diff --git a/src/main/java/nl/peterbloem/motive/exec/Run.java b/src/main/java/nl/peterbloem/motive/exec/Run.java index 9fd76fc..11e7b0e 100644 --- a/src/main/java/nl/peterbloem/motive/exec/Run.java +++ b/src/main/java/nl/peterbloem/motive/exec/Run.java @@ -112,47 +112,7 @@ public class Run name="--synth.maxdegree", usage="Maximum degree for an instance node.") private static int synthMaxDegree = 5; - - @Option( - name="--class.prob", - usage="Fanmod probability of expanding a search tree node (1.0 enumerates all subgraphs, lower means a smaller sample, and lower runtime)") - private static double classProb = 0.5; - - @Option( - name="--class.hubs", - usage="Number of hubs to remove from the data (the more hubs removed, the smaller the instances become.") - private static int classHubs = 0; - - @Option( - name="--class.fanmodSamples", - usage="Number of samples from the null model in the FANMOD experiment.") - private static int classFMSamples = 1000; - - @Option( - name="--class.motiveSamples", - usage="Number of subgraphs to sample in the motive experiment.") - private static int classMotiveSamples = 1000000; - - @Option( - name="--class.depth", - usage="Depth to which to extract the instances.") - private static int classDepth = 2; - - @Option( - name="--class.mixingTime", - usage="Mixing time for the curveball sampling algorithm (ie. the number of steps taken in the markov chain for each sample).") - private static int classMixingTime = 10000; - - @Option( - name="--class.numInstances", - usage="The number of instances to use (samples from the total available)") - private static int classNumInstances = 100; - - @Option( - name="--class.sizes", - usage="The motif sizes to use as features") - private static String classSizes = "3,4"; - + /** * Main executable function * @param args @@ -185,50 +145,7 @@ public static void main(String[] args) Global.setNumThreads(threads); Global.log().info("Using " + Global.numThreads() + " concurrent threads"); - if ("class".equals(type.toLowerCase())) - { - - ClassExperiment exp = new ClassExperiment(); - - try { - exp.graph = RDF.readSimple(file); - } catch (IOException e) { - throw new RuntimeException("Could not read RDF input file.", e); - } - - try { - exp.map = ClassExperiment.tsv(classTSV); - } catch (IOException e) { - throw new RuntimeException("Could not read TSV classification file.", e); - } - - exp.prob = classProb; - exp.hubsToRemove = classHubs; - exp.samples = classFMSamples; - exp.motiveSamples = classMotiveSamples; - exp.instanceDepth = classDepth; - exp.mixingTime = classMixingTime; - exp.numInstances = classNumInstances; - - exp.sizes = new ArrayList(); - try{ - for(String elem : classSizes.split(",")) - { - exp.sizes.add(Integer.parseInt(elem)); - } - } catch(RuntimeException e) - { - throw new RuntimeException("Failed to parse sizes argument: " + classSizes + " (does it contain spaces, or non-integers?)." , e); - } - - try { - exp.main(); - } catch (IOException e) { - throw new RuntimeException(e); - } - - - } else if ("preload".equals(type.toLowerCase())) + if ("preload".equals(type.toLowerCase())) { try {