Skip to content

Commit

Permalink
Added gini score implementation to checkpoint1
Browse files Browse the repository at this point in the history
  • Loading branch information
paujla committed Sep 25, 2017
1 parent 07e61f4 commit 993a5c6
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 0 deletions.
1 change: 1 addition & 0 deletions checkpoint1/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Please goto the following link https://docs.google.com/a/springer.com/presentation/d/1cC3opKIVAHYFXjWB1sTE5FO7aii9xPBeOw5qC18puuU/edit?usp=sharing
Empty file added checkpoint1/build.gradle
Empty file.
Binary file not shown.
Binary file not shown.
25 changes: 25 additions & 0 deletions checkpoint1/src/main/kotlin/com/springernature/ctree4k/Feature.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.springernature.ctree4k

data class FeatureId(val value: String)
data class Feature(val value: Double, val id: FeatureId)

data class InstanceClass(val value: Int, val name: String)
data class InstanceClasses(val value: Set<InstanceClass>)

interface Instance {

val features: List<Feature>

val instanceClass: InstanceClass

fun get(featureId: FeatureId): Feature {
val feature = features.find { it.id == featureId }

if(feature != null) {
return feature
} else {

throw IllegalArgumentException("Could not find feature in Instance[$this] with id[$featureId]")
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package com.springernature.ctree4k

data class SplitHalf(val instances: List<Instance>)

class GiniImpurityScore : (Pair<SplitHalf, SplitHalf>, InstanceClasses) -> GiniScore {

override fun invoke(bothSplits: Pair<SplitHalf, SplitHalf>, allClasses: InstanceClasses): GiniScore {

val eachScore = bothSplits.toList().map { half: List<Instance> ->
val currentSplitSize = half.size

val eachClassScoreBySplit = allClasses.value.map { featureSetClass: InstanceClass ->
val ratio = half.ratioThatAre(featureSetClass)
ratio * ratio
}

val score = eachClassScoreBySplit.sum()

val scoreBySplit = (1.0 - score) * (currentSplitSize.toDouble().div(bothSplits.count()))
scoreBySplit
}
val score = eachScore.sum()

return GiniScore(score)
}
}

data class GiniScore(val value: Double)

private fun List<Instance>.ratioThatAre(instanceClass: InstanceClass): Double {
val count = count {
it.instanceClass == instanceClass
}

return count.toDouble().div(size.toDouble())
}

private fun Pair<SplitHalf, SplitHalf>.toList(): List<List<Instance>> {
return listOf(first.instances, second.instances)
}

private fun Pair<SplitHalf, SplitHalf>.count(): Int = first.instances.count().plus(second.instances.count())
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package com.springernature.ctree4k

import com.natpryce.hamkrest.assertion.assertThat
import com.natpryce.hamkrest.equalTo
import org.junit.Test

class GiniImpurityScoreTest {

val triangleClass = InstanceClass(0, "triangle")
val circleClass = InstanceClass(1, "circle")
val allClasses = InstanceClasses(setOf(triangleClass, circleClass))

val triangle = createInstance(triangleClass)
val circle = createInstance(circleClass)

val giniScore = GiniImpurityScore()

@Test fun `returns 0,5`() {

val leftSplit = SplitHalf(listOf(triangle, circle))
val rightSplit = SplitHalf(listOf(triangle, circle))

assertThat(giniScore(Pair(leftSplit, rightSplit), allClasses), equalTo(GiniScore(0.5)))
}

@Test fun `returns 0`() {

val leftSplit = SplitHalf(listOf(triangle, triangle))
val rightSplit = SplitHalf(listOf(circle, circle))

assertThat(giniScore(Pair(leftSplit, rightSplit), allClasses), equalTo(GiniScore(0.0)))
}

@Test fun `returns 0,4`() {

val leftSplit = SplitHalf(listOf(triangle, triangle))
val rightSplit = SplitHalf(listOf(circle, circle, circle, circle, triangle, triangle, triangle, triangle))

assertThat(giniScore(Pair(leftSplit, rightSplit), allClasses), equalTo(GiniScore(0.4)))
}
}

private fun createInstance(instanceClass: InstanceClass): Instance {
return object : Instance {

override val features: List<Feature>
get() = emptyList()

override val instanceClass: InstanceClass
get() = instanceClass
}
}
1 change: 1 addition & 0 deletions settings.gradle
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
rootProject.name = 'ctree4k'
include 'checkpoint0'
include 'checkpoint1'

0 comments on commit 993a5c6

Please sign in to comment.