Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Debug graph for multi-tokenization #114

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ public List<? extends TokenBase> tokenize(String text) {
}

public <T extends TokenBase> List<List<T>> multiTokenize(String text, int maxCount, int costSlack) {

return createMultiTokenList(text, maxCount, costSlack);
}

Expand Down Expand Up @@ -258,6 +257,26 @@ public void debugTokenize(OutputStream outputStream, String text) throws IOExcep
outputStream.flush();
}

/**
* Multi tokenizes the provided text and outputs the corresponding Viterbi lattice and the Viterbi path to the provided output stream
* <p>
* The output is written in <a href="https://en.wikipedia.org/wiki/DOT_(graph_description_language)">DOT</a> format.
* <p>
* This method is not thread safe
*
* @param outputStream output stream to write to
* @param text text to tokenize
* @throws java.io.IOException if an error occurs when writing the lattice and path
*/
public void debugMultiTokenize(OutputStream outputStream, String text, int maxCount, int costSlack) throws IOException {
ViterbiLattice lattice = viterbiBuilder.build(text);
List<List<ViterbiNode>> bestPaths = viterbiSearcher.searchMultiple(lattice, maxCount, costSlack).getTokenizedResultsList();
outputStream.write(
viterbiFormatter.multiFormat(lattice, bestPaths).getBytes(StandardCharsets.UTF_8)
);
outputStream.flush();
};

/**
* Writes the Viterbi lattice for the provided text to an output stream
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import com.atilika.kuromoji.dict.ConnectionCosts;

import java.awt.Color;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -30,49 +32,114 @@ public class ViterbiFormatter {

private ConnectionCosts costs;
private Map<String, ViterbiNode> nodeMap;
private Map<String, String> bestPathMap;
private Map<PathEdge, List<Integer>> bestPathsMap;
private List<Color> edgeColors;

private boolean foundBOS;

private class PathEdge {
private final ViterbiNode from;
private final ViterbiNode to;

public PathEdge(ViterbiNode from, ViterbiNode to) {
this.from = from;
this.to = to;
}

public ViterbiNode getFrom() {
return from;
}

public ViterbiNode getTo() {
return to;
}

@Override
public int hashCode() {
return this.from.hashCode()*101 + this.to.hashCode();
}

@Override
public boolean equals(Object obj) {
if (obj == null) {
return false;
}
if (!PathEdge.class.isAssignableFrom(obj.getClass())) {
return false;
}
final PathEdge other = (PathEdge) obj;
return this.from.equals(other.getFrom()) && this.to.equals(other.getTo());
}
}

public ViterbiFormatter(ConnectionCosts costs) {
this.costs = costs;
this.nodeMap = new HashMap<>();
this.bestPathMap = new HashMap<>();
this.bestPathsMap = new HashMap<>();
this.edgeColors = new ArrayList<>();
}

public String format(ViterbiLattice lattice) {
return format(lattice, null);
return multiFormat(lattice, null);
}

public String format(ViterbiLattice lattice, List<ViterbiNode> bestPath) {

initBestPathMap(bestPath);
List<List<ViterbiNode>> bestPaths = new ArrayList<>();
bestPaths.add(bestPath);
return multiFormat(lattice, bestPaths);

}

public String multiFormat(ViterbiLattice lattice, List<List<ViterbiNode>> bestPaths) {

generateColors(bestPaths.size());

initBestPathMap(bestPaths);

StringBuilder builder = new StringBuilder();
builder.append(formatHeader());
builder.append(formatLegend());
builder.append(formatNodes(lattice));
builder.append(formatTrailer());
return builder.toString();

}

private void initBestPathMap(List<ViterbiNode> bestPath) {
this.bestPathMap.clear();
void generateColors(int count) {
float hue = 0.33f;
float saturation = 0.71f;
float brightness = 0.88f;
for (int i = 0; i < count; i++) {
float angle = i * (1f / (float)count);
Color color = Color.getHSBColor(hue + angle, saturation, brightness);
this.edgeColors.add(color);
}
}

private void initBestPathMap(List<List<ViterbiNode>> bestPaths) {
this.bestPathsMap.clear();

if (bestPath == null) {
if (bestPaths == null) {
return;
}
for (int i = 0; i < bestPath.size() - 1; i++) {
ViterbiNode from = bestPath.get(i);
ViterbiNode to = bestPath.get(i + 1);

String fromId = getNodeId(from);
String toId = getNodeId(to);
for (int i = 0; i < bestPaths.size(); i++) {
List<ViterbiNode> path = bestPaths.get(i);
for (int j = 0; j < path.size() - 1; j++) {
ViterbiNode from = path.get(j);
ViterbiNode to = path.get(j + 1);

PathEdge pathEdge = new PathEdge(from, to);
addPathEdge(pathEdge, i);
}
}
}

assert this.bestPathMap.containsKey(fromId) == false;
assert this.bestPathMap.containsValue(toId) == false;
this.bestPathMap.put(fromId, toId);
private void addPathEdge(PathEdge pathEdge, Integer pathId) {
if (!this.bestPathsMap.containsKey(pathEdge)) {
this.bestPathsMap.put(pathEdge, new ArrayList<Integer>());
}
this.bestPathsMap.get(pathEdge).add(pathId);
}

private String formatNodes(ViterbiLattice lattice) {
Expand Down Expand Up @@ -135,15 +202,51 @@ private String formatTrailer() {


private String formatEdge(ViterbiNode from, ViterbiNode to) {
if (this.bestPathMap.containsKey(getNodeId(from)) &&
this.bestPathMap.get(getNodeId(from)).equals(getNodeId(to))) {
return formatEdge(from, to, "color=\"#40e050\" fontcolor=\"#40a050\" penwidth=3 fontsize=20 ");
PathEdge pathEdge = new PathEdge(from, to);
if (this.bestPathsMap.containsKey(pathEdge)) {
List<Integer> colorIndices = this.bestPathsMap.get(pathEdge);

String attributes = formatEdgeAttributes(colorIndices);

return formatEdge(from, to, attributes);

} else {
return formatEdge(from, to, "");
}
}

private String formatEdgeAttributes(List<Integer> colorIndices) {
StringBuilder builder = new StringBuilder();
builder.append("color=\"");

String fontColor = "";
for (int i = 0; i < colorIndices.size(); i++) {
if (i != 0) {
builder.append(":");
}

String hex = this.getColorHex(colorIndices.get(i));
builder.append(hex);

if (fontColor.isEmpty()) {
fontColor = hex;
}
}

builder.append("\" fontcolor=\"");
builder.append(fontColor);
builder.append("\" penwidth=3 fontsize=20");
return builder.toString();
}

private String getColorHex(int colorIndex) {
Color color = this.edgeColors.get(colorIndex);
int r = color.getRed();
int g = color.getGreen();
int b = color.getBlue();
String hex = String.format("#%02x%02x%02x", r, g, b);
return hex;
}

private String formatEdge(ViterbiNode from, ViterbiNode to, String attributes) {
StringBuilder builder = new StringBuilder();
Expand Down Expand Up @@ -196,6 +299,36 @@ private String formatNodeLabel(ViterbiNode node) {
return builder.toString();
}

private String formatLegend() {
StringBuilder builder = new StringBuilder();
builder.append("subgraph cluster_legend {\n");
builder.append("label = \"Legend\";\n");

for (int i = 0; i < edgeColors.size(); i++) {
builder.append("path" + i + " [ ");
builder.append("style=\"plaintext\" shape=\"none\" ");
builder.append("label=\"Path " + i + "\"");
builder.append(" ]\n");
}

for (int i = 0; i < edgeColors.size(); i++) {
builder.append("right" + i + " [ ");
builder.append("style=\"plaintext\" shape=\"none\" ");
builder.append("label=\"\"");
builder.append(" ]\n");
}

for (int i = 0; i < edgeColors.size(); i++) {
builder.append("path" + i + ":e");
builder.append(" -> ");
builder.append("right" + i + ":w");
builder.append(" [ color=\"" + this.getColorHex(i) + "\" ]\n");
}

builder.append("}\n");
return builder.toString();
}

private String getNodeId(ViterbiNode node) {
return String.valueOf(node.hashCode());
}
Expand Down