Skip to content

Commit

Permalink
Improving the multiple response documentation in CSVLoader, adding a …
Browse files Browse the repository at this point in the history
…test for the documented behaviour and fixing a small toString bug in MockMultiOutput and MultiLabel. (#306)
  • Loading branch information
Craigacp committed Dec 16, 2022
1 parent d0d6415 commit 5d6b60e
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 11 deletions.
10 changes: 6 additions & 4 deletions Core/src/test/java/org/tribuo/test/MockMultiOutput.java
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,13 @@ public String toString() {
StringBuilder builder = new StringBuilder();

builder.append("(LabelSet={");
for (MockOutput l : labels) {
builder.append(l.toString());
builder.append(',');
if (labels.size() > 0) {
for (MockOutput l : labels) {
builder.append(l.toString());
builder.append(',');
}
builder.deleteCharAt(builder.length() - 1);
}
builder.deleteCharAt(builder.length()-1);
builder.append('}');
if (!Double.isNaN(score)) {
builder.append(",OverallScore=");
Expand Down
32 changes: 31 additions & 1 deletion Data/src/main/java/org/tribuo/data/csv/CSVLoader.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -68,6 +68,12 @@
* {@link org.tribuo.data.columnar.RowProcessor} to cope with your specific input format.
* <p>
* CSVLoader is thread safe and immutable.
* <p>
* Multi-output responses such as {@code MultiLabel} or {@code Regressor} can be processed in
* two different ways either as a single column of separated values, or multiple columns. If
* there is a single column the value is passed directly to the {@link OutputFactory}. If
* there are multiple response columns then the name of the column is concatenated with the
* value, then a list of the concatenated values is passed to the {@link OutputFactory}.
* @param <T> The type of the output generated.
*/
public class CSVLoader<T extends Output<T>> {
Expand Down Expand Up @@ -139,6 +145,10 @@ public MutableDataset<T> load(Path csvPath, String responseName, String[] header
* <p>
* The {@code responseNames} set is traversed in iteration order to emit outputs,
* and should be an ordered set to ensure reproducibility.
* <p>
* If there are multiple elements in {@code responseNames} then the responses are
* processed into the form 'column-name=column-value' before being passed to the
* {@link OutputFactory} for conversion into an {@link Output}.
*
* @param csvPath The path to load.
* @param responseNames The names of the response variables.
Expand All @@ -154,6 +164,10 @@ public MutableDataset<T> load(Path csvPath, Set<String> responseNames) throws IO
* <p>
* The {@code responseNames} set is traversed in iteration order to emit outputs,
* and should be an ordered set to ensure reproducibility.
* <p>
* If there are multiple elements in {@code responseNames} then the responses are
* processed into the form 'column-name=column-value' before being passed to the
* {@link OutputFactory} for conversion into an {@link Output}.
*
* @param csvPath The path to load.
* @param responseNames The names of the response variables.
Expand Down Expand Up @@ -220,6 +234,10 @@ public DataSource<T> loadDataSource(URL csvPath, String responseName, String[] h
* <p>
* The {@code responseNames} set is traversed in iteration order to emit outputs,
* and should be an ordered set to ensure reproducibility.
* <p>
* If there are multiple elements in {@code responseNames} then the responses are
* processed into the form 'column-name=column-value' before being passed to the
* {@link OutputFactory} for conversion into an {@link Output}.
*
* @param csvPath The csv to load from.
* @param responseNames The names of the response variables.
Expand All @@ -235,6 +253,10 @@ public DataSource<T> loadDataSource(Path csvPath, Set<String> responseNames) thr
* <p>
* The {@code responseNames} set is traversed in iteration order to emit outputs,
* and should be an ordered set to ensure reproducibility.
* <p>
* If there are multiple elements in {@code responseNames} then the responses are
* processed into the form 'column-name=column-value' before being passed to the
* {@link OutputFactory} for conversion into an {@link Output}.
*
* @param csvPath The csv to load from.
* @param responseNames The names of the response variables.
Expand All @@ -250,6 +272,10 @@ public DataSource<T> loadDataSource(URL csvPath, Set<String> responseNames) thro
* <p>
* The {@code responseNames} set is traversed in iteration order to emit outputs,
* and should be an ordered set to ensure reproducibility.
* <p>
* If there are multiple elements in {@code responseNames} then the responses are
* processed into the form 'column-name=column-value' before being passed to the
* {@link OutputFactory} for conversion into an {@link Output}.
*
* @param csvPath The csv to load from.
* @param responseNames The names of the response variables.
Expand All @@ -266,6 +292,10 @@ public DataSource<T> loadDataSource(Path csvPath, Set<String> responseNames, Str
* <p>
* The {@code responseNames} set is traversed in iteration order to emit outputs,
* and should be an ordered set to ensure reproducibility.
* <p>
* If there are multiple elements in {@code responseNames} then the responses are
* processed into the form 'column-name=column-value' before being passed to the
* {@link OutputFactory} for conversion into an {@link Output}.
*
* @param csvPath The csv to load from.
* @param responseNames The names of the response variables.
Expand Down
12 changes: 10 additions & 2 deletions Data/src/test/java/org/tribuo/data/csv/CSVLoaderTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -82,13 +82,21 @@ public void testLoadMultiOutput() throws IOException {
assertTrue(data.getExample(1).getOutput().contains("R1"));
assertTrue(data.getExample(1).getOutput().contains("R2"));


//
// Row #2: R1=False and R2=False.
// In this case, the labelSet is empty and the labelString is the empty string.
assertEquals(0, data.getExample(2).getOutput().getLabelSet().size());
assertEquals("", data.getExample(2).getOutput().getLabelString());
assertTrue(data.getExample(2).validateExample());

URL singlePath = CSVLoaderTest.class.getResource("/org/tribuo/data/csv/test-multioutput-singlecolumn.csv");
DataSource<MockMultiOutput> singleSource = loader.loadDataSource(singlePath, "Label");
MutableDataset<MockMultiOutput> singleData = new MutableDataset<>(singleSource);
assertEquals(6, singleData.size());

for (int i = 0; i < 6; i++) {
assertEquals(data.getExample(i).getOutput().getLabelString(), singleData.getExample(i).getOutput().getLabelString());
}
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
A,B,C,D,Label
1,2,3,4,"R1"
6,7,8,9,"R1,R2"
6,7,8,9,
2,5,3,4,"R1"
1,2,5,9,"R2"
0,2,5,9,"R2"
Original file line number Diff line number Diff line change
Expand Up @@ -314,11 +314,13 @@ public String toString() {
StringBuilder builder = new StringBuilder();

builder.append("(LabelSet={");
for (Label l : labels) {
builder.append(l.toString());
builder.append(',');
if (labels.size() > 0) {
for (Label l : labels) {
builder.append(l.toString());
builder.append(',');
}
builder.deleteCharAt(builder.length() - 1);
}
builder.deleteCharAt(builder.length()-1);
builder.append('}');
if (!Double.isNaN(score)) {
builder.append(",OverallScore=");
Expand Down

0 comments on commit 5d6b60e

Please sign in to comment.