Skip to content

Commit

Permalink
impl agg
Browse files Browse the repository at this point in the history
  • Loading branch information
chagelo committed Jan 16, 2024
1 parent 63a5fd5 commit 4a721ee
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 8 deletions.
70 changes: 67 additions & 3 deletions src/execution/aggregation_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,84 @@
// Copyright (c) 2015-2021, Carnegie Mellon University Database Group
//
//===----------------------------------------------------------------------===//
#include <cstdint>
#include <memory>
#include <vector>

#include "common/rid.h"
#include "execution/executors/aggregation_executor.h"
#include "execution/plans/aggregation_plan.h"
#include "storage/table/tuple.h"
#include "type/type_id.h"
#include "type/value.h"

namespace bustub {

AggregationExecutor::AggregationExecutor(ExecutorContext *exec_ctx, const AggregationPlanNode *plan,
std::unique_ptr<AbstractExecutor> &&child)
: AbstractExecutor(exec_ctx) {}
: AbstractExecutor(exec_ctx),
plan_(plan),
child_(child.release()),
aht_(plan_->GetAggregates(), plan_->GetAggregateTypes()),
aht_iterator_(aht_.Begin()) {}

void AggregationExecutor::Init() {}
void AggregationExecutor::Init() {
// Add this line
child_->Init();
Tuple child_tuple{};
RID child_rid{};

auto AggregationExecutor::Next(Tuple *tuple, RID *rid) -> bool { return false; }
while (child_->Next(&child_tuple, &child_rid)) {
std::vector<Value> values{};
for (auto &expr : plan_->GetGroupBys()) {
values.push_back(expr->Evaluate(&child_tuple, child_->GetOutputSchema()));
}

auto group_by = AggregateKey{values};

values.clear();
for (auto &expr : plan_->GetAggregates()) {
values.push_back(expr->Evaluate(&child_tuple, child_->GetOutputSchema()));
}
auto aggregate = AggregateValue{values};
aht_.InsertCombine(group_by, aggregate, false);
}
aht_iterator_ = aht_.Begin();
if (aht_iterator_ == aht_.End()) {
aht_.InsertCombine(AggregateKey{}, AggregateValue{}, true);
aht_iterator_ = aht_.Begin();
is_empty_ = true;
}
}

auto AggregationExecutor::Next(Tuple *tuple, RID *rid) -> bool {
if (!plan_->GetGroupBys().empty() && is_empty_) {
return false;
}

if (aht_iterator_ == aht_.End()) {
return false;
}

auto group = aht_iterator_.Key();
auto aggre = aht_iterator_.Val();

std::vector<Value> values{};
values.reserve(plan_->GetGroupBys().size() + aggre.aggregates_.size());
if (!plan_->GetGroupBys().empty()) {
for (auto &gp : group.group_bys_) {
values.emplace_back(std::move(gp));
}
}

for (auto &agg : aggre.aggregates_) {
values.emplace_back(std::move(agg));
}

*tuple = Tuple{values, &GetOutputSchema()};
++aht_iterator_;
return true;
}

auto AggregationExecutor::GetChildExecutor() const -> const AbstractExecutor * { return child_.get(); }

Expand Down
78 changes: 73 additions & 5 deletions src/include/execution/executors/aggregation_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#pragma once

#include <cstdint>
#include <memory>
#include <unordered_map>
#include <utility>
Expand All @@ -24,6 +25,10 @@
#include "execution/expressions/abstract_expression.h"
#include "execution/plans/aggregation_plan.h"
#include "storage/table/tuple.h"
#include "type/limits.h"
#include "type/type.h"
#include "type/type_id.h"
#include "type/value.h"
#include "type/value_factory.h"

namespace bustub {
Expand All @@ -48,7 +53,27 @@ class SimpleAggregationHashTable {
for (const auto &agg_type : agg_types_) {
switch (agg_type) {
case AggregationType::CountStarAggregate:
// Count start starts at zero.
case AggregationType::CountAggregate:
case AggregationType::SumAggregate:
values.emplace_back(ValueFactory::GetIntegerValue(0));
break;
case AggregationType::MinAggregate:
values.emplace_back(ValueFactory::GetIntegerValue(BUSTUB_INT32_MAX));
break;
case AggregationType::MaxAggregate:
// Others starts at null.
values.emplace_back(ValueFactory::GetIntegerValue(BUSTUB_INT32_MIN));
break;
}
}
return {values};
}

auto GenerateInvalidAggregateValue() -> AggregateValue {
std::vector<Value> values{};
for (const auto &agg_type : agg_types_) {
switch (agg_type) {
case AggregationType::CountStarAggregate:
values.emplace_back(ValueFactory::GetIntegerValue(0));
break;
case AggregationType::CountAggregate:
Expand All @@ -62,7 +87,6 @@ class SimpleAggregationHashTable {
}
return {values};
}

/**
* TODO(Student)
*
Expand All @@ -74,22 +98,65 @@ class SimpleAggregationHashTable {
for (uint32_t i = 0; i < agg_exprs_.size(); i++) {
switch (agg_types_[i]) {
case AggregationType::CountStarAggregate:
CountStar(result->aggregates_[i]);
break;
case AggregationType::CountAggregate:
Count(result->aggregates_[i], input.aggregates_[i]);
break;
case AggregationType::SumAggregate:
Sum(result->aggregates_[i], input.aggregates_[i]);
break;
case AggregationType::MinAggregate:
Min(result->aggregates_[i], input.aggregates_[i]);
break;
case AggregationType::MaxAggregate:
Max(result->aggregates_[i], input.aggregates_[i]);
break;
}
}
}

void CountStar(Value &value) { value = value.Add(Value{TypeId::INTEGER, 1}); }

void Count(Value &value, const Value &v) {
if (v.IsNull()) {
return;
}
value = value.Add(Value{TypeId::INTEGER, 1});
}

void Sum(Value &value, const Value &v) {
if (v.IsNull()) {
return;
}
value = value.Add(v);
}

void Min(Value &value, const Value &v) {
if (v.IsNull()) {
return;
}
value = value.Min(v);
}

void Max(Value &value, const Value &v) {
if (v.IsNull()) {
return;
}
value = value.Max(v);
}

/**
* Inserts a value into the hash table and then combines it with the current aggregation.
* @param agg_key the key to be inserted
* @param agg_val the value to be inserted
*/
void InsertCombine(const AggregateKey &agg_key, const AggregateValue &agg_val) {
void InsertCombine(const AggregateKey &agg_key, const AggregateValue &agg_val, const bool &invalid) {
if (ht_.count(agg_key) == 0) {
if (invalid) {
ht_.insert({agg_key, GenerateInvalidAggregateValue()});
return;
}
ht_.insert({agg_key, GenerateInitialAggregateValue()});
}
CombineAggregateValues(&ht_[agg_key], agg_val);
Expand Down Expand Up @@ -201,8 +268,9 @@ class AggregationExecutor : public AbstractExecutor {
/** The child executor that produces tuples over which the aggregation is computed */
std::unique_ptr<AbstractExecutor> child_;
/** Simple aggregation hash table */
// TODO(Student): Uncomment SimpleAggregationHashTable aht_;
SimpleAggregationHashTable aht_;
/** Simple aggregation hash table iterator */
// TODO(Student): Uncomment SimpleAggregationHashTable::Iterator aht_iterator_;
SimpleAggregationHashTable::Iterator aht_iterator_;
bool is_empty_{false};
};
} // namespace bustub
5 changes: 5 additions & 0 deletions src/include/execution/plans/aggregation_plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,16 @@ class AggregationPlanNode : public AbstractPlanNode {

BUSTUB_PLAN_NODE_CLONE_WITH_CHILDREN(AggregationPlanNode);


// SELECT min(t.z), max(t.z), sum(t.z) FROM t GROUP BY t.x, t.y;
/** The GROUP BY expressions */
// {t.x, t.y}
std::vector<AbstractExpressionRef> group_bys_;
/** The aggregation expressions */
// {t.z, t.z, t.z}
std::vector<AbstractExpressionRef> aggregates_;
/** The aggregation types */
// {min, max, sum}
std::vector<AggregationType> agg_types_;

protected:
Expand Down

0 comments on commit 4a721ee

Please sign in to comment.