diff --git a/.gitignore b/.gitignore index e67f88e8..dcf450f2 100644 --- a/.gitignore +++ b/.gitignore @@ -32,4 +32,6 @@ __pycache__ repro java_test/test.parquet java_test/test.arrow -java_test/simulation* \ No newline at end of file +java_test/simulation* +.idea +.calva diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..8f2b7113 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "java.compile.nullAnalysis.mode": "disabled" +} \ No newline at end of file diff --git a/java_test/java/jtest/ConfigTest.java b/java_test/java/jtest/ConfigTest.java new file mode 100644 index 00000000..7b197458 --- /dev/null +++ b/java_test/java/jtest/ConfigTest.java @@ -0,0 +1,51 @@ +package jtest; + +import com.oracle.labs.mlrg.olcut.config.Configurable; +import com.oracle.labs.mlrg.olcut.config.ConfigurationData; +import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; +import com.oracle.labs.mlrg.olcut.config.DescribeConfigurable; +import com.oracle.labs.mlrg.olcut.config.property.Property; +import com.oracle.labs.mlrg.olcut.config.property.SimpleProperty; +import org.tribuo.classification.dtree.CARTClassificationTrainer; +import org.tribuo.classification.dtree.impurity.Entropy; +import org.tribuo.classification.dtree.impurity.GiniIndex; +import org.tribuo.classification.dtree.impurity.LabelImpurity; + +import java.util.HashMap; +import java.util.Map; +import java.util.SortedMap; + +public class ConfigTest { + + public static final void describe () { + + } + public static void main(String[] args) { + + Class clazz = CARTClassificationTrainer.class; + + SortedMap generateFieldInfos = DescribeConfigurable.generateFieldInfo(clazz); + + + //System.out.println("DescribeConfigurable.generateFieldInfo(CARTClassificationTrainer.class) = " + generateFieldInfo); + ConfigurationManager cm = new ConfigurationManager(); + DescribeConfigurable dc = new DescribeConfigurable(); + + + + Map properties = Map.of( + "maxDepth", new SimpleProperty("20"), + "minImpurityDecrease",new SimpleProperty("0.5"), + "impurity",new SimpleProperty("entropy") + ); + ConfigurationData configData0 = new ConfigurationData("entropy", Entropy.class.getName()); + cm.addConfiguration(configData0); + ConfigurationData configData1 = new ConfigurationData(CARTClassificationTrainer.class.getName(), clazz.getName(),properties); + + cm.addConfiguration(configData1); + Configurable lookup = cm.lookup(clazz.getName()); + System.out.println(lookup); + + + } +} diff --git a/test/tech/v3/dataset/categorical_test.clj b/test/tech/v3/dataset/categorical_test.clj index 851ea030..87ce9fef 100644 --- a/test/tech/v3/dataset/categorical_test.clj +++ b/test/tech/v3/dataset/categorical_test.clj @@ -87,3 +87,62 @@ (get :y) distinct set)))) + + +(defn- =-invert-cat [target-1 target-2 + lookup-one lookup-two + result-datatype + expected-result + ] + (let [ds (ds/->dataset {:target [target-1 target-2]}) + inverted + (ds-cat/invert-categorical-map ds + {:lookup-table {:one lookup-one + :two lookup-two}, + :src-column :target, + :result-datatype result-datatype}) + inverted-target (-> inverted :target)] + (= expected-result inverted-target))) + ;(format "expected %s, found: %s" expected-result) (seq inverted-target))) + +(deftest invert-cat--works + (is + (=-invert-cat 1 2 + 1 2 + :int + [:one :two])) + ; TODO - should pass ? + (is (=-invert-cat 1.0 2.0 + 1 2 + :int + [:one :two])) + + ; TODO - should pass ? + (is (=-invert-cat 1.99999 2.99999 + 1 2 + :int + [:one :two])) + + ; TODO - should pass ? + (is (=-invert-cat 1.2 1.3 + 1 2 + :int + [:one :one]))) + +(deftest invert-cat--throws + (is (thrown? Exception + (=-invert-cat 1 2 + 4 5 + :int + [:one :two]))) +;; => Execution error at tech.v3.dataset.categorical/invert-categorical-map$fn (categorical.clj:177). +;; Unable to find src value for numeric value 1 + + (is (thrown? Exception + (=-invert-cat 1 2 + 1.0 2.0 + :int + [:one :two])))) +;; => Execution error at tech.v3.dataset.categorical/invert-categorical-map$fn (categorical.clj:177). +;; Unable to find src value for numeric value 1 +