Skip to content

A more general group map sum #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,26 @@ Aggregating operation on map<string,int> than performs the unions of keys of the
exists in multiples maps


CREATE TABLE docs {
CREATE TABLE docs (
docid int;
word_count map<string, int>
}
)

SELECT map_group_sum(word_count) FROM docs; ## Get the global word frequency

## map_group_sum2

Aggregating operation on map<char,double> than performs the unions of keys of the map, and sum the value when a key
exists in multiples maps


CREATE TABLE docs (
docid int;
commission_amounts map<char(3), double>
)

SELECT map_group_sum2(commission_amounts) FROM docs; ## Get the total commissions

### Maths

### UDFExponentialSmoothingMovingAverage.
Expand Down
12 changes: 7 additions & 5 deletions ivy.xml
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,22 @@
<exclude org="asm" />
<exclude org="org.eclipse.jdt" />
</dependency>-->
<dependency org="com.google.code.gson" name="gson" rev="2.2.2"/>
<dependency org="com.google.guava" name="guava" rev="16.0.1" />
<dependency org="log4j" name="log4j" rev="1.2.15" force="true">
<exclude org="com.sun.jdmk"/>
<exclude org="com.sun.jmx"/>
<exclude org="javax.jms"/>
</dependency>
<dependency org="org.apache.hadoop" name="hadoop-core" rev="0.20.2"/>
<dependency org="org.apache.hive" name="hive-serde" rev="0.10.0" >
<dependency org="org.apache.hive" name="hive-serde" rev="1.2.1" >
<exclude org="commons-daemon" />
</dependency>
<dependency org="org.apache.hive" name="hive-exec" rev="0.10.0" >
<exclude org="commons-daemon" />
</dependency>
<!-- <dependency org="org.apache.hive" name="hive-exec" rev="0.10.0" > -->
<!-- <exclude org="commons-daemon" /> -->
<!-- </dependency> -->
<dependency org="javax.jdo" name="jdo2-api" rev="2.3-eb" force="true"/>
<dependency org="org.jboss.netty" name="netty" rev="3.2.9" force="true" />
<!-- <dependency org="org.jboss.netty" name="netty" rev="3.2.9" force="true" /> -->
<!--<dependency org="commons-daemon" name="commons-daemon" rev="1.0.15" force="true"/>-->

<!--<dependency org="commons-lang" name="commons-lang" rev="2.5"/>
Expand Down
Binary file added ivy/ivy.jar
Binary file not shown.
132 changes: 132 additions & 0 deletions src/com/dataiku/hive/udf/maps/UDAFMapGroupSum2.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package com.dataiku.hive.udf.maps;

import com.google.common.collect.Maps;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.lazy.LazyFactory;
import org.apache.hadoop.hive.serde2.lazy.LazyMap;
import org.apache.hadoop.hive.serde2.lazy.objectinspector.LazyMapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.*;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveCharObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveDecimalObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.common.type.HiveDecimal;
import org.apache.hadoop.hive.common.type.HiveChar;
import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveCharObjectInspector;

import java.rmi.MarshalledObject;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

/**
* Group a set of map (char, double) and sum double values for identical double keys and output result as a map (char, double)
*/
public class UDAFMapGroupSum2 extends AbstractGenericUDAFResolver {
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] tis) throws SemanticException {
if (tis.length != 1) {
throw new UDFArgumentTypeException(tis.length - 1, "Exactly one argument is expected.");
}
return new MapGroupSumEvaluator();
}

public static class MapGroupSumEvaluator extends GenericUDAFEvaluator {
private MapObjectInspector originalDataOI;
private HiveDecimalObjectInspector valueOI;
private HiveCharObjectInspector keyOI;

@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
super.init(m, parameters);

originalDataOI = (MapObjectInspector) parameters[0];
keyOI = (HiveCharObjectInspector) originalDataOI.getMapKeyObjectInspector();
valueOI = (HiveDecimalObjectInspector) originalDataOI.getMapValueObjectInspector();

int length = ((CharTypeInfo)keyOI.getTypeInfo()).getLength();
return ObjectInspectorFactory.getStandardMapObjectInspector(
new JavaHiveCharObjectInspector((CharTypeInfo) new CharTypeInfo(length)),
PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector);
}

static class MapBuffer implements AggregationBuffer {
Map<HiveChar, HiveDecimal> map = new HashMap<HiveChar, HiveDecimal>();
}

@Override
public void reset(AggregationBuffer ab) throws HiveException {
((MapBuffer) ab).map.clear();
}

@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
return new MapBuffer();
}

protected void mapAppend(Map<HiveChar, HiveDecimal> m, Map<Object, Object> from) {
if (from == null) {
return;
}
for(Map.Entry<Object, Object> entry : from.entrySet()) {
Object okey = entry.getKey();
Object ovalue = entry.getValue();
if (okey == null || ovalue == null) continue;
HiveChar key = keyOI.getPrimitiveJavaObject(entry.getKey());
HiveDecimal value = valueOI.getPrimitiveJavaObject(entry.getValue());
if (m.containsKey(key)) {
m.put(key, value.add(m.get(key)));
} else {
m.put(key, value);
}
}
}

@Override
public void iterate(AggregationBuffer ab, Object[] parameters) throws HiveException {
assert (parameters.length == 1);
Object p = parameters[0];
if (p != null) {
MapBuffer agg = (MapBuffer) ab;
Map<Object, Object> o = (Map<Object, Object>) this.originalDataOI.getMap(p);
mapAppend(agg.map, o);
}
}

@Override
public Object terminatePartial(AggregationBuffer ab) throws HiveException {
return terminate(ab);
}

@Override
public void merge(AggregationBuffer ab, Object p) throws HiveException {
MapBuffer agg = (MapBuffer) ab;
@SuppressWarnings("unchecked")
Map<Object, Object> obj = (Map<Object, Object>) this.originalDataOI.getMap(p);
mapAppend(agg.map, obj);
}

@Override
public Object terminate(AggregationBuffer ab) throws HiveException {
MapBuffer agg = (MapBuffer) ab;
Map<HiveChar, HiveDecimal> result = new HashMap<HiveChar, HiveDecimal>();

for(Map.Entry<HiveChar, HiveDecimal> entry : agg.map.entrySet()) {
HiveChar oKey = entry.getKey();
HiveDecimal oValue = entry.getValue();

oKey = new HiveChar(oKey.toString().trim(), 3);

result.put(oKey, oValue);
}

return result;
}
}
}