Skip to content

Commit

Permalink
support weighted average aggregation function (#3052)
Browse files Browse the repository at this point in the history
* support wavg aggregate function

* fixing function handlers for wavg

* fixing pct java standard test

* updating data for wavg function

* reusing processing of aggregation function

* fixing compilation of empty list

* Improving wavg functionality with documentation and tests

* fixing pct test for java binding adapter
  • Loading branch information
gs-gunjan authored Sep 16, 2024
1 parent 8ddd43b commit d360a4f
Show file tree
Hide file tree
Showing 11 changed files with 287 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1828,6 +1828,11 @@ private void registerAggregations()
h("meta::pure::functions::math::sum_Integer_MANY__Integer_1_", false, ps -> res("Integer", "one"), ps -> typeMany(ps.get(0), "Integer")),
h("meta::pure::functions::math::sum_Number_MANY__Number_1_", false, ps -> res("Number", "one"), ps -> typeMany(ps.get(0), "Number")));

register(m(m(h("meta::pure::functions::math::wavg_Number_MANY__Number_MANY__Float_1_", false, ps -> res("Float", "one"), ps -> typeMany(ps.get(0), "Number"))),
m(h("meta::pure::functions::math::wavg_WavgRowMapper_MANY__Float_1_", false, ps -> res("Float", "one"), ps -> typeMany(ps.get(0), "meta::pure::functions::math::wavgUtility::WavgRowMapper")))));

register(h("meta::pure::functions::math::wavgUtility::wavgRowMapper_Number_$0_1$__Number_$0_1$__WavgRowMapper_1_", false, ps -> res("meta::pure::functions::math::wavgUtility::WavgRowMapper", "one"), ps -> typeZeroOne(ps.get(0), "Number")));

register(h("meta::pure::functions::math::variance_Number_MANY__Boolean_1__Number_1_", false, ps -> res("Number", "one")));

register(m(m(h("meta::pure::functions::math::percentile_Number_MANY__Float_1__Boolean_1__Boolean_1__Number_$0_1$_", false, ps -> res("Number", "zeroOne"), ps -> ps.size() == 4)),
Expand Down Expand Up @@ -2919,7 +2924,6 @@ private Map<String, Dispatch> buildDispatch()
map.put("meta::pure::functions::relation::pivot_Relation_1__ColSpecArray_1__AggColSpecArray_1__Relation_1_", (List<ValueSpecification> ps) -> ps.size() == 3 && isOne(ps.get(0)._multiplicity()) && Sets.immutable.with("Nil", "Relation", "RelationElementAccessor", "TDS", "RelationStoreAccessor").contains(ps.get(0)._genericType()._rawType()._name()) && isOne(ps.get(1)._multiplicity()) && ("Nil".equals(ps.get(1)._genericType()._rawType()._name()) || "ColSpecArray".equals(ps.get(1)._genericType()._rawType()._name())) && isOne(ps.get(2)._multiplicity()) && ("Nil".equals(ps.get(2)._genericType()._rawType()._name()) || "AggColSpecArray".equals(ps.get(2)._genericType()._rawType()._name())));
map.put("meta::pure::functions::relation::pivot_Relation_1__ColSpec_1__AggColSpecArray_1__Relation_1_", (List<ValueSpecification> ps) -> ps.size() == 3 && isOne(ps.get(0)._multiplicity()) && Sets.immutable.with("Nil", "Relation", "RelationElementAccessor", "TDS", "RelationStoreAccessor").contains(ps.get(0)._genericType()._rawType()._name()) && isOne(ps.get(1)._multiplicity()) && ("Nil".equals(ps.get(1)._genericType()._rawType()._name()) || "ColSpec".equals(ps.get(1)._genericType()._rawType()._name())) && isOne(ps.get(2)._multiplicity()) && ("Nil".equals(ps.get(2)._genericType()._rawType()._name()) || "AggColSpecArray".equals(ps.get(2)._genericType()._rawType()._name())));
map.put("meta::pure::functions::relation::pivot_Relation_1__ColSpec_1__AggColSpec_1__Relation_1_", (List<ValueSpecification> ps) -> ps.size() == 3 && isOne(ps.get(0)._multiplicity()) && Sets.immutable.with("Nil", "Relation", "RelationElementAccessor", "TDS", "RelationStoreAccessor").contains(ps.get(0)._genericType()._rawType()._name()) && isOne(ps.get(1)._multiplicity()) && ("Nil".equals(ps.get(1)._genericType()._rawType()._name()) || "ColSpec".equals(ps.get(1)._genericType()._rawType()._name())) && isOne(ps.get(2)._multiplicity()) && ("Nil".equals(ps.get(2)._genericType()._rawType()._name()) || "AggColSpec".equals(ps.get(2)._genericType()._rawType()._name())));

map.put("meta::pure::functions::relation::cumulativeDistribution_Relation_1___Window_1__T_1__Float_1_", (List<ValueSpecification> ps) -> ps.size() == 3 && isOne(ps.get(0)._multiplicity()) && Sets.immutable.with("Nil", "Relation", "RelationElementAccessor", "TDS", "RelationStoreAccessor").contains(ps.get(0)._genericType()._rawType()._name()) && isOne(ps.get(1)._multiplicity()) && ("Nil".equals(ps.get(1)._genericType()._rawType()._name()) || "_Window".equals(ps.get(1)._genericType()._rawType()._name())) && isOne(ps.get(2)._multiplicity()));
map.put("meta::pure::functions::relation::denseRank_Relation_1___Window_1__T_1__Integer_1_", (List<ValueSpecification> ps) -> ps.size() == 3 && isOne(ps.get(0)._multiplicity()) && Sets.immutable.with("Nil", "Relation", "RelationElementAccessor", "TDS", "RelationStoreAccessor").contains(ps.get(0)._genericType()._rawType()._name()) && isOne(ps.get(1)._multiplicity()) && ("Nil".equals(ps.get(1)._genericType()._rawType()._name()) || "_Window".equals(ps.get(1)._genericType()._rawType()._name())) && isOne(ps.get(2)._multiplicity()));
map.put("meta::pure::functions::relation::first_Relation_1___Window_1__T_1__T_$0_1$_", (List<ValueSpecification> ps) -> ps.size() == 3 && isOne(ps.get(0)._multiplicity()) && Sets.immutable.with("Nil", "Relation", "RelationElementAccessor", "TDS", "RelationStoreAccessor").contains(ps.get(0)._genericType()._rawType()._name()) && isOne(ps.get(1)._multiplicity()) && ("Nil".equals(ps.get(1)._genericType()._rawType()._name()) || "_Window".equals(ps.get(1)._genericType()._rawType()._name())) && isOne(ps.get(2)._multiplicity()));
Expand All @@ -2933,7 +2937,9 @@ private Map<String, Dispatch> buildDispatch()
map.put("meta::pure::functions::relation::percentRank_Relation_1___Window_1__T_1__Float_1_", (List<ValueSpecification> ps) -> ps.size() == 3 && isOne(ps.get(0)._multiplicity()) && Sets.immutable.with("Nil", "Relation", "RelationElementAccessor", "TDS", "RelationStoreAccessor").contains(ps.get(0)._genericType()._rawType()._name()) && isOne(ps.get(1)._multiplicity()) && ("Nil".equals(ps.get(1)._genericType()._rawType()._name()) || "_Window".equals(ps.get(1)._genericType()._rawType()._name())) && isOne(ps.get(2)._multiplicity()));
map.put("meta::pure::functions::relation::rank_Relation_1___Window_1__T_1__Integer_1_", (List<ValueSpecification> ps) -> ps.size() == 3 && isOne(ps.get(0)._multiplicity()) && Sets.immutable.with("Nil", "Relation", "RelationElementAccessor", "TDS", "RelationStoreAccessor").contains(ps.get(0)._genericType()._rawType()._name()) && isOne(ps.get(1)._multiplicity()) && ("Nil".equals(ps.get(1)._genericType()._rawType()._name()) || "_Window".equals(ps.get(1)._genericType()._rawType()._name())) && isOne(ps.get(2)._multiplicity()));
map.put("meta::pure::functions::relation::rowNumber_Relation_1__T_1__Integer_1_", (List<ValueSpecification> ps) -> ps.size() == 2 && isOne(ps.get(0)._multiplicity()) && Sets.immutable.with("Nil", "Relation", "RelationElementAccessor", "TDS", "RelationStoreAccessor").contains(ps.get(0)._genericType()._rawType()._name()) && isOne(ps.get(1)._multiplicity()));

map.put("meta::pure::functions::math::wavg_Number_MANY__Number_MANY__Float_1_", (List<ValueSpecification> ps) -> ps.size() == 2 && Sets.immutable.with("Nil","Number","Integer","Decimal","Float").contains(ps.get(0)._genericType()._rawType()._name()) && Sets.immutable.with("Nil","Number","Integer","Decimal","Float").contains(ps.get(1)._genericType()._rawType()._name()));
map.put("meta::pure::functions::math::wavg_WavgRowMapper_MANY__Float_1_", (List<ValueSpecification> ps) -> ps.size() == 1 && ("Nil".equals(ps.get(0)._genericType()._rawType()._name()) || "WavgRowMapper".equals(ps.get(0)._genericType()._rawType()._name())));
map.put("meta::pure::functions::math::wavgUtility::wavgRowMapper_Number_$0_1$__Number_$0_1$__WavgRowMapper_1_", (List<ValueSpecification> ps) -> ps.size() == 2 && matchZeroOne(ps.get(0)._multiplicity()) && Sets.immutable.with("Nil","Number","Integer","Decimal","Float").contains(ps.get(0)._genericType()._rawType()._name()) && matchZeroOne(ps.get(1)._multiplicity()) && Sets.immutable.with("Nil","Number","Integer","Decimal","Float").contains(ps.get(1)._genericType()._rawType()._name()));

// ------------------------------------------------------------------------------------------------
// Please do not update the following code manually! Please check with the team when introducing
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

// Copyright 2022 Goldman Sachs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

function <<test.Test>> meta::pure::functions::math::tests::wavg::testWavg():Boolean[1]
{
assertEq(292.5, [200,300,250,850,100]->wavg([0.25,0.35,0.15,0.1,0.15]));
}
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ function meta::pure::router::routing::processLambda(i:InstanceValue[1], routed:E
f:FunctionDefinition<Any>[1]| if($state.shouldBeRouted,
|$func->functionType().parameters->evaluateAndDeactivate()->map(
p | let class = $p.genericType.rawType;
if (!$class->toOne()->instanceOf(DataType) && $class != TDSRow && (!$class->toOne()->instanceOf(RelationType)),
if (!$class->toOne()->instanceOf(DataType) && $class != TDSRow && $class != meta::pure::functions::math::wavgUtility::WavgRowMapper && (!$class->toOne()->instanceOf(RelationType)),
| let map = $routed->filter(v|$v->evaluateAndDeactivate().value.genericType.rawType->toOne()->_subTypeOf($class->toOne()););
assert(!$map->isEmpty(),| 'Error mapping not found for class '+$class.name->toOne()+' cache:\''+$routed->cast(@meta::pure::router::store::metamodel::StoreMappingRoutedValueSpecification).sets.class.name->joinStrings(', ')+'\'');
pair($p.name, $map->at(0)->cast(@Any));,
Expand Down Expand Up @@ -785,6 +785,9 @@ function meta::pure::router::routing::shouldStopFunctions(extensions:meta::pure:
averageRank_Any_MANY__Map_1_,
denseRank_Any_MANY__Map_1_,
rank_Any_MANY__Map_1_,
wavg_Number_MANY__Number_MANY__Float_1_,
wavg_WavgRowMapper_MANY__Float_1_,
meta::pure::functions::math::wavgUtility::wavgRowMapper_Number_$0_1$__Number_$0_1$__WavgRowMapper_1_,
rowNumber_Any_MANY__Map_1_,
max_Float_MANY__Float_$0_1$_,
max_Integer_MANY__Integer_$0_1$_,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright 2024 Goldman Sachs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import meta::pure::test::pct::*;

function
<<PCT.function>>
{
doc.doc='Performs weighted average on a given data field with its respective weight field',
PCT.grammarDoc='wavg is used by wavg(Number[*], Number[*])'
}
meta::pure::functions::math::wavg(numbers:Number[*], weights:Number[*]):Float[1]
{
if($numbers->isEmpty() || $weights->isEmpty(),
|fail('Error: Mean of an empty set.');0.0;,
| if ($numbers->size() != $weights->size(),
|fail('Error: The data and weight column must have the same number of values');0.0;,
|let weightedQuantitySum = $numbers->zip($weights)->map(p | $p.first * $p.second)->sum();
let weightsSum = $weights->sum();
assert($weightsSum != 0.0,|'Weighted Average can\'t be performed as sum of weight column values equal 0');
$weightedQuantitySum/$weightsSum;));
}

function meta::pure::functions::math::wavg(wavgRows: meta::pure::functions::math::wavgUtility::WavgRowMapper[*]):Float[1]
{
meta::pure::functions::math::wavg($wavgRows.quantity, $wavgRows.weight);
}

Class meta::pure::functions::math::wavgUtility::WavgRowMapper
{
quantity: Number[0..1];
weight: Number[0..1];
}

function meta::pure::functions::math::wavgUtility::wavgRowMapper(quantity:Number[0..1], weight:Number[0..1]):meta::pure::functions::math::wavgUtility::WavgRowMapper[1]
{
^meta::pure::functions::math::wavgUtility::WavgRowMapper(
quantity= $quantity,
weight= $weight
)
}

function <<PCT.test>> meta::pure::functions::math::tests::wavg::testSimpleGroupByWavg<T|m>(f:Function<{Function<{->T[m]}>[1]->T[m]}>[1]):Boolean[1]
{
let expr = {
|#TDS
id, grp, name, quantity, weight
1, 2, A, 200, 0.5
2, 1, B, 100, 0.45
3, 3, C, 250, 0.25
4, 4, D, 700, 1
5, 2, E, 100, 0.5
6, 1, F, 500, 0.15
7, 3, G, 400, 0.75
8, 1, H, 150, 0.4
9, 5, I, 350, 1
#->groupBy(~grp, ~wavgCol : x | meta::pure::functions::math::wavgUtility::wavgRowMapper($x.quantity, $x.weight) : y | $y->wavg())
};

let res = $f->eval($expr);

assertEquals( '#TDS\n'+
' grp,wavgCol\n'+
' 1,180.0\n'+
' 2,150.0\n'+
' 3,362.5\n'+
' 4,700.0\n'+
' 5,350.0\n'+
'#', $res->sort(~grp->ascending())->toString());
}

function <<PCT.test>> meta::pure::functions::math::tests::wavg::testSimpleGroupByMultipleWavg<T|m>(f:Function<{Function<{->T[m]}>[1]->T[m]}>[1]):Boolean[1]
{
let expr = {
|#TDS
id, grp, name, quantity, weight, weight1
1, 2, A, 200, 0.5, 0.75
2, 1, B, 100, 0.45, 0.35
3, 3, C, 250, 0.25, 0.50
4, 4, D, 700, 1, 1
5, 2, E, 100, 0.5, 0.25
6, 1, F, 500, 0.15, 0.25
7, 3, G, 400, 0.75, 0.50
8, 1, H, 150, 0.4, 0.4
9, 5, I, 350, 1, 1
#->groupBy(~grp, ~[wavgCol1 : x | meta::pure::functions::math::wavgUtility::wavgRowMapper($x.quantity, $x.weight) : y | $y->wavg(),
wavgCol2 : x | meta::pure::functions::math::wavgUtility::wavgRowMapper($x.quantity, $x.weight1) : y | $y->wavg()])
};

let res = $f->eval($expr);

assertEquals( '#TDS\n'+
' grp,wavgCol1,wavgCol2\n'+
' 1,180.0,220.0\n'+
' 2,150.0,175.0\n'+
' 3,362.5,325.0\n'+
' 4,700.0,700.0\n'+
' 5,350.0,350.0\n'+
'#', $res->sort(~grp->ascending())->toString());
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ public class Test_JAVA_StandardFunction_PCT extends PCTReportConfiguration
one("meta::pure::functions::math::tests::stdDev::testSimpleGroupByStandardDeviationSample_Function_1__Boolean_1_", "\"meta::pure::functions::relation::groupBy_Relation_1__ColSpecArray_1__AggColSpecArray_1__Relation_1_ is not supported yet!\""),

one("meta::pure::functions::math::tests::variance::testSimpleGroupByVariancePopulation_Function_1__Boolean_1_", "\"meta::pure::functions::relation::groupBy_Relation_1__ColSpecArray_1__AggColSpecArray_1__Relation_1_ is not supported yet!\""),
one("meta::pure::functions::math::tests::variance::testSimpleGroupByVarianceSample_Function_1__Boolean_1_", "\"meta::pure::functions::relation::groupBy_Relation_1__ColSpecArray_1__AggColSpecArray_1__Relation_1_ is not supported yet!\"")
one("meta::pure::functions::math::tests::variance::testSimpleGroupByVarianceSample_Function_1__Boolean_1_", "\"meta::pure::functions::relation::groupBy_Relation_1__ColSpecArray_1__AggColSpecArray_1__Relation_1_ is not supported yet!\""),
one("meta::pure::functions::math::tests::wavg::testSimpleGroupByWavg_Function_1__Boolean_1_", "\"meta::pure::functions::relation::groupBy_Relation_1__ColSpec_1__AggColSpec_1__Relation_1_ is not supported yet!\""),
one("meta::pure::functions::math::tests::wavg::testSimpleGroupByMultipleWavg_Function_1__Boolean_1_", "\"meta::pure::functions::relation::groupBy_Relation_1__ColSpec_1__AggColSpecArray_1__Relation_1_ is not supported yet!\"")
);

public static Test suite()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1665,7 +1665,7 @@ function <<test.Test>> meta::pure::executionPlan::tests::testDatabaseConnectionS

let resultConnection = $result.rootExecutionNode->cast(@RelationalInstantiationExecutionNode).executionNodes->at(0)->cast(@SQLExecutionNode).connection->cast(@meta::external::store::relational::runtime::TestDatabaseConnection);

assertSize($resultConnection.testDataSetupSqls, 56);
assertSize($resultConnection.testDataSetupSqls, 58);
}

function <<test.Test>> meta::pure::executionPlan::tests::testDatabaseConnectionSQLPopulation():Boolean[1]
Expand All @@ -1685,7 +1685,7 @@ function <<test.Test>> meta::pure::executionPlan::tests::testDatabaseConnectionS
.connection->cast(@meta::external::store::relational::runtime::RelationalDatabaseConnection).datasourceSpecification
->cast(@meta::pure::alloy::connections::alloy::specification::LocalH2DatasourceSpecification);

assertSize($resultConnection.testDataSetupSqls, 56);
assertSize($resultConnection.testDataSetupSqls, 58);
}

function <<test.Test>> meta::pure::executionPlan::tests::tdsJoinOneDBOneExpression():Boolean[1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1269,3 +1269,25 @@ function <<test.Test>> meta::relational::tests::groupBy::testGroupByEmptyColsNon
assertEquals([], $result.values.rows->map(r|$r.values->makeString('|')));
assertSameSQL('select "Firm Count" as "Firm Count" from (select count("root".LEGALNAME) as "Firm Count" from firmTable as "root") as "subselect" where "Firm Count" > 10', $result);
}

function <<test.Test>> meta::relational::tests::groupBy::testGroupByWithWavgAggregation() : Boolean[1]
{
let mapping = meta::relational::tests::simpleRelationalMapping;
let runtime = meta::external::store::relational::tests::testRuntime();

let result = execute({|
WeightedTrade.all()
->groupBy(
[ x | $x.tradeDate->adjust(1, DurationUnit.DAYS)],
[
agg(x| $x.id, y|$y->count()),
agg(x | $x.quantity->meta::pure::functions::math::wavgUtility::wavgRowMapper($x.weight), y | $y->wavg())
],
['tradeDate', 'id count', 'Weighted Average (Trade)']
)
->sort(asc('tradeDate'))}, $mapping, $runtime, meta::relational::extension::relationalExtensions());

assertEquals(4, $result.values->at(0).rows->size());
assertSameSQL('select dateadd(DAY, 1, "root".tradeDate) as "tradeDate", count("root".ID) as "id count", ((1.0 * sum(("root".quantity * "root".weight))) / sum("root".weight)) as "Weighted Average (Trade)" from weightedTradeTable as "root" group by "tradeDate" order by "tradeDate"',$result);
assertSameElements(['2014-12-02,3,138.8', '2014-12-03,2,27.5', '2014-12-04,2,33.8', '2014-12-05,2,33.49999999999999'],$result.values->at(0).rows->map(r| $r.values->makeString(',')));
}
Loading

0 comments on commit d360a4f

Please sign in to comment.