Skip to content

Commit

Permalink
Add jvmfile name and string as field heuristics
Browse files Browse the repository at this point in the history
  • Loading branch information
wu-hui committed May 29, 2024
1 parent ca9a22b commit 93135a8
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 7 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,4 @@ mvn com.coveo:fmt-maven-plugin:format

[1]: https://cloud.google.com/docs/authentication/getting-started#creating_a_service_account
[2]: https://maven.apache.org/settings.html#Active_Profiles
[3]: https://github.com/GoogleCloudPlatform/java-docs-samples/blob/main/SAMPLE_FORMAT.md
[3]: https://github.com/GoogleCloudPlatform/java-docs-samples/blob/main/SAMPLE_FORMAT.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public Pipeline toPipeline() {
.toPipeline()
.aggregate(
this.aggregateFieldList.stream()
.map(PipelineUtilsKt::toPipelineAggregatorTarget)
.map(PipelineUtils::toPipelineAggregatorTarget)
.toArray(AggregatorTarget[]::new));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import com.google.firestore.v1.Value
import java.util.Date
import javax.annotation.Nonnull

/**
* Result from a {@code Pipeline} execution.
*/
class PipelineResult
internal constructor(
private val rpcContext: FirestoreRpcContext<*>?,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
@file:JvmName("PipelineUtils")
package com.google.cloud.firestore

import com.google.cloud.firestore.Query.ComparisonFilterInternal
Expand All @@ -9,7 +10,6 @@ import com.google.cloud.firestore.pipeline.AggregatorTarget
import com.google.cloud.firestore.pipeline.Constant
import com.google.cloud.firestore.pipeline.Field
import com.google.cloud.firestore.pipeline.Function
import com.google.cloud.firestore.pipeline.Function.Companion.count
import com.google.cloud.firestore.pipeline.Function.Companion.countAll
import com.google.cloud.firestore.pipeline.Function.Companion.not
import com.google.firestore.v1.Cursor
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
@file:JvmName("Pipelines")
package com.google.cloud.firestore

import com.google.api.core.ApiFuture
Expand Down Expand Up @@ -185,6 +186,14 @@ class Pipeline private constructor(private val stages: List<Stage>, private val
return projMap
}

private fun fieldNamesToMap(vararg fields: String): Map<String, Expr> {
val projMap = mutableMapOf<String, Expr>()
for (field in fields) {
projMap[field] = Field.of(field)
}
return projMap
}

fun addFields(vararg fields: Selectable): Pipeline {
return Pipeline(stages.plus(AddFields(projectablesToMap(*fields))), name)
}
Expand All @@ -193,6 +202,10 @@ class Pipeline private constructor(private val stages: List<Stage>, private val
return Pipeline(stages.plus(Select(projectablesToMap(*projections))), name)
}

fun select(vararg fields: String): Pipeline {
return Pipeline(stages.plus(Select(fieldNamesToMap(*fields))), name)
}

fun <T> filter(condition: T): Pipeline where T : Expr, T : Function.FilterCondition {
return Pipeline(stages.plus(Filter(condition)), name)
}
Expand Down Expand Up @@ -234,7 +247,7 @@ class Pipeline private constructor(private val stages: List<Stage>, private val
return PaginatingPipeline(this, pageSize, orders.toList())
}

fun genericOperation(name: String, params: Map<String, Any>? = null): Pipeline {
fun genericStage(name: String, params: Map<String, Any>? = null): Pipeline {
return this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

package com.google.cloud.firestore;

import static com.google.cloud.firestore.PipelineUtilsKt.toPaginatedPipeline;
import static com.google.cloud.firestore.PipelineUtilsKt.toPipelineFilterCondition;
import static com.google.cloud.firestore.PipelineUtils.toPaginatedPipeline;
import static com.google.cloud.firestore.PipelineUtils.toPipelineFilterCondition;
import static com.google.common.collect.Lists.reverse;
import static com.google.firestore.v1.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS;
import static com.google.firestore.v1.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS_ANY;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.google.cloud.firestore;

import static com.google.cloud.firestore.pipeline.ExpressionsKt.exprToValue;
import static com.google.cloud.firestore.pipeline.Expressions.exprToValue;

import com.google.cloud.Timestamp;
import com.google.cloud.firestore.pipeline.Expr;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
@file:JvmName("Expressions")
package com.google.cloud.firestore.pipeline

import com.google.cloud.Timestamp
Expand Down Expand Up @@ -414,28 +415,56 @@ open class Function(val name: String, val params: List<Expr>) : Expr {

@JvmStatic fun equal(left: Expr, right: Any) = Equal(left, Constant.of(right))

@JvmStatic
fun equal(left: String, right: Expr) = Equal(Field.of(left), right)

@JvmStatic
fun equal(left: String, right: Any) = Equal(Field.of(left), Constant.of(right))

@JvmStatic fun notEqual(left: Expr, right: Expr) = NotEqual(left, right)

@JvmStatic fun notEqual(left: Expr, right: Any) = NotEqual(left, Constant.of(right))

@JvmStatic fun notEqual(left: String, right: Expr) = NotEqual(Field.of(left), right)

@JvmStatic fun notEqual(left: String, right: Any) = NotEqual(Field.of(left), Constant.of(right))

@JvmStatic fun greaterThan(left: Expr, right: Expr) = GreaterThan(left, right)

@JvmStatic fun greaterThan(left: Expr, right: Any) = GreaterThan(left, Constant.of(right))

@JvmStatic fun greaterThan(left: String, right: Expr) = GreaterThan(Field.of(left), right)

@JvmStatic fun greaterThan(left: String, right: Any) = GreaterThan(Field.of(left), Constant.of(right))

@JvmStatic fun greaterThanOrEqual(left: Expr, right: Expr) = GreaterThanOrEqual(left, right)

@JvmStatic
fun greaterThanOrEqual(left: Expr, right: Any) = GreaterThanOrEqual(left, Constant.of(right))

@JvmStatic fun greaterThanOrEqual(left: String, right: Expr) = GreaterThanOrEqual(Field.of(left), right)

@JvmStatic
fun greaterThanOrEqual(left: String, right: Any) = GreaterThanOrEqual(Field.of(left), Constant.of(right))

@JvmStatic fun lessThan(left: Expr, right: Expr) = LessThan(left, right)

@JvmStatic fun lessThan(left: Expr, right: Any) = LessThan(left, Constant.of(right))

@JvmStatic fun lessThan(left: String, right: Expr) = LessThan(Field.of(left), right)

@JvmStatic fun lessThan(left: String, right: Any) = LessThan(Field.of(left), Constant.of(right))

@JvmStatic fun lessThanOrEqual(left: Expr, right: Expr) = LessThanOrEqual(left, right)

@JvmStatic
fun lessThanOrEqual(left: Expr, right: Any) = LessThanOrEqual(left, Constant.of(right))

@JvmStatic fun lessThanOrEqual(left: String, right: Expr) = LessThanOrEqual(Field.of(left), right)

@JvmStatic
fun lessThanOrEqual(left: String, right: Any) = LessThanOrEqual(Field.of(left), Constant.of(right))

@JvmStatic
fun inAny(left: Expr, values: List<Any>) =
In(
Expand All @@ -448,6 +477,18 @@ open class Function(val name: String, val params: List<Expr>) : Expr {
},
)

@JvmStatic
fun inAny(left: String, values: List<Any>) =
In(
Field.of(left),
values.map {
when (it) {
is Expr -> it
else -> Constant.of(it)
}
},
)

@JvmStatic
fun notInAny(left: Expr, values: List<Any>) =
Not(
Expand All @@ -462,6 +503,21 @@ open class Function(val name: String, val params: List<Expr>) : Expr {
)
)

@JvmStatic
fun notInAny(left: String, values: List<Any>) =
Not(
In(
Field.of(left),
values.map {
when (it) {
is Expr -> it
else -> Constant.of(it)
}
},
)
)


@JvmStatic
fun <T> and(left: T, right: T) where T : FilterCondition, T : Expr = And(listOf(left, right))

Expand All @@ -478,9 +534,16 @@ open class Function(val name: String, val params: List<Expr>) : Expr {

@JvmStatic fun arrayContains(expr: Expr, element: Expr) = ArrayContains(expr, element)

@JvmStatic
fun arrayContains(field: String, element: Expr) = ArrayContains(Field.of(field), element)

@JvmStatic
fun arrayContains(expr: Expr, element: Any) = ArrayContains(expr, Constant.of(element))

@JvmStatic
fun arrayContains(field: String, element: Any) =
ArrayContains(Field.of(field), Constant.of(element))

@JvmStatic
fun arrayContainsAny(expr: Expr, vararg elements: Expr) =
ArrayContainsAny(expr, elements.toList())
Expand All @@ -489,22 +552,51 @@ open class Function(val name: String, val params: List<Expr>) : Expr {
fun arrayContainsAny(expr: Expr, vararg elements: Any) =
ArrayContainsAny(expr, elements.toList().map { Constant.of(it) })

@JvmStatic
fun arrayContainsAny(field: String, vararg elements: Expr) =
ArrayContainsAny(Field.of(field), elements.toList())

@JvmStatic
fun arrayContainsAny(field: String, vararg elements: Any) =
ArrayContainsAny(Field.of(field), elements.toList().map { Constant.of(it) })

@JvmStatic fun isNaN(expr: Expr) = IsNaN(expr)

@JvmStatic
fun isNaN(field: String) = IsNaN(Field.of(field))

@JvmStatic fun isNull(expr: Expr) = IsNull(expr)

@JvmStatic
fun isNull(field: String) = IsNull(Field.of(field))

@JvmStatic fun not(expr: Expr) = Not(expr)

@JvmStatic fun sum(expr: Expr) = Sum(expr, false)

@JvmStatic
fun sum(field: String) = Sum(Field.of(field), false)

@JvmStatic fun avg(expr: Expr) = Avg(expr, false)

@JvmStatic
fun avg(field: String) = Avg(Field.of(field), false)

@JvmStatic fun min(expr: Expr) = Sum(expr, false)

@JvmStatic
fun min(field: String) = Sum(Field.of(field), false)

@JvmStatic fun max(expr: Expr) = Avg(expr, false)

@JvmStatic
fun max(field: String) = Avg(Field.of(field), false)

@JvmStatic fun count(expr: Expr) = Count(expr, false)

@JvmStatic
fun count(field: String) = Count(Field.of(field), false)

@JvmStatic fun countAll() = Count(null, false)

@JvmStatic fun cosineDistance(expr: Expr, other: Expr) = CosineDistance(expr, other)
Expand All @@ -513,18 +605,39 @@ open class Function(val name: String, val params: List<Expr>) : Expr {
fun cosineDistance(expr: Expr, other: DoubleArray) =
CosineDistance(expr, Constant.ofVector(other))

@JvmStatic
fun cosineDistance(field: String, other: Expr) = CosineDistance(Field.of(field), other)

@JvmStatic
fun cosineDistance(field: String, other: DoubleArray) =
CosineDistance(Field.of(field), Constant.ofVector(other))

@JvmStatic fun dotProductDistance(expr: Expr, other: Expr) = CosineDistance(expr, other)

@JvmStatic
fun dotProductDistance(expr: Expr, other: DoubleArray) =
CosineDistance(expr, Constant.ofVector(other))

@JvmStatic
fun dotProductDistance(field: String, other: Expr) = CosineDistance(Field.of(field), other)

@JvmStatic
fun dotProductDistance(field: String, other: DoubleArray) =
CosineDistance(Field.of(field), Constant.ofVector(other))

@JvmStatic fun euclideanDistance(expr: Expr, other: Expr) = EuclideanDistance(expr, other)

@JvmStatic
fun euclideanDistance(expr: Expr, other: DoubleArray) =
EuclideanDistance(expr, Constant.ofVector(other))

@JvmStatic
fun euclideanDistance(field: String, other: Expr) = EuclideanDistance(Field.of(field), other)

@JvmStatic
fun euclideanDistance(field: String, other: DoubleArray) =
EuclideanDistance(Field.of(field), Constant.ofVector(other))

@JvmStatic fun function(name: String, params: List<Expr>) = Generic(name, params)
}
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
@file:JvmName("Stages")
package com.google.cloud.firestore.pipeline

import com.google.cloud.firestore.DocumentReference
Expand Down

0 comments on commit 93135a8

Please sign in to comment.