Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
SiberiaWolfP committed Apr 5, 2024
1 parent 8720721 commit 8e34267
Showing 1 changed file with 63 additions and 49 deletions.
112 changes: 63 additions & 49 deletions duckpgq/src/duckpgq/operators/physical_path_finding_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,28 @@ void PathFindingLocalState::CreateCSR(DataChunk &input,

class GlobalBFSState {
public:
GlobalBFSState() = default;

GlobalBFSState(int64_t v_size_, idx_t pairs_size_, int64_t *src_, int64_t *dst_,
UnifiedVectorFormat &vdata_src_, UnifiedVectorFormat &vdata_dst_)
: iter(1), v_size(v_size_), src(src_), dst(dst_), vdata_src(std::move(vdata_src_)),
vdata_dst(std::move(vdata_dst_)), started_searches(0),
seen(v_size_), visit1(v_size_), visit2(v_size_),
change(false), result(LogicalTypeId::BIGINT, true, true, pairs_size_) {

void init(int64_t v_size_, idx_t pairs_size_, int64_t *src_, int64_t *dst_,
UnifiedVectorFormat &vdata_src_, UnifiedVectorFormat &vdata_dst_) {
iter = 1;
v_size = v_size_;
src = src_;
dst = dst_;
vdata_src = std::move(vdata_src_);
vdata_dst = std::move(vdata_dst_);
started_searches = 0;
for (auto i = 0; i < LANE_LIMIT; i++) {
lane_to_num[i] = -1;
}
seen.resize(v_size_);
visit1.resize(v_size_);
visit2.resize(v_size_);
for (auto i = 0; i < v_size; i++) {
seen[i] = 0;
visit1[i] = 0;
}
change = false;
result = make_uniq<Vector>(LogicalType::BIGINT, true, true, pairs_size_);
}

void clear() {
Expand Down Expand Up @@ -156,7 +167,7 @@ class GlobalBFSState {
vector<std::bitset<LANE_LIMIT>> visit1;
vector<std::bitset<LANE_LIMIT>> visit2;
bool change;
Vector result;
unique_ptr<Vector> result;
};

class PathFindingGlobalState : public GlobalSinkState {
Expand All @@ -173,9 +184,9 @@ class PathFindingGlobalState : public GlobalSinkState {

PathFindingGlobalState(PathFindingGlobalState &prev)
: GlobalSinkState(prev), global_tasks(prev.global_tasks),
global_csr(std::move(prev.global_csr)), child(prev.child + 1) {

}
global_csr(std::move(prev.global_csr)),
global_bfs_state(std::move(prev.global_bfs_state)), child(prev.child + 1) {
}

void Sink(DataChunk &input, PathFindingLocalState &lstate) const {
lstate.Sink(input, *global_csr);
Expand All @@ -189,10 +200,10 @@ class PathFindingGlobalState : public GlobalSinkState {
ColumnDataAppendState append_state;

unique_ptr<GlobalCompressedSparseRow> global_csr;
size_t child;

// state for BFS
unique_ptr<GlobalBFSState> global_bfs_state;
GlobalBFSState global_bfs_state;

size_t child;
};

unique_ptr<GlobalSinkState>
Expand Down Expand Up @@ -368,11 +379,11 @@ class PhysicalBFSTask : public ExecutorTask {
}

TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override {
auto& change = state.global_bfs_state->change;
auto& v_size = state.global_bfs_state->v_size;
auto& seen = state.global_bfs_state->seen;
auto& visit = state.global_bfs_state->iter & 1 ? state.global_bfs_state->visit1 : state.global_bfs_state->visit2;
auto& next = state.global_bfs_state->iter & 1 ? state.global_bfs_state->visit2 : state.global_bfs_state->visit1;
auto& change = state.global_bfs_state.change;
auto& v_size = state.global_bfs_state.v_size;
auto& seen = state.global_bfs_state.seen;
auto& visit = state.global_bfs_state.iter & 1 ? state.global_bfs_state.visit1 : state.global_bfs_state.visit2;
auto& next = state.global_bfs_state.iter & 1 ? state.global_bfs_state.visit2 : state.global_bfs_state.visit1;
int64_t *v = (int64_t *)state.global_csr->v;
vector<int64_t> &e = state.global_csr->e;

Expand Down Expand Up @@ -431,19 +442,19 @@ class BFSIterativeEvent : public BasePipelineEvent {
void FinishEvent() override {
auto& bfs_state = gstate.global_bfs_state;

auto result_data = FlatVector::GetData<int64_t>(bfs_state->result);
ValidityMask &result_validity = FlatVector::Validity(bfs_state->result);
auto result_data = FlatVector::GetData<int64_t>(*bfs_state.result);
ValidityMask &result_validity = FlatVector::Validity(*bfs_state.result);

if (bfs_state->change) {
if (bfs_state.change) {
// detect lanes that finished
for (int64_t lane = 0; lane < LANE_LIMIT; lane++) {
int64_t search_num = bfs_state->lane_to_num[lane];
int64_t search_num = bfs_state.lane_to_num[lane];
if (search_num >= 0) { // active lane
int64_t dst_pos = bfs_state->vdata_dst.sel->get_index(search_num);
if (bfs_state->seen[bfs_state->dst[dst_pos]][lane]) {
int64_t dst_pos = bfs_state.vdata_dst.sel->get_index(search_num);
if (bfs_state.seen[bfs_state.dst[dst_pos]][lane]) {
result_data[search_num] =
bfs_state->iter; /* found at iter => iter = path length */
bfs_state->lane_to_num[lane] = -1; // mark inactive
bfs_state.iter; /* found at iter => iter = path length */
bfs_state.lane_to_num[lane] = -1; // mark inactive
}
}
}
Expand All @@ -453,16 +464,16 @@ class BFSIterativeEvent : public BasePipelineEvent {
} else {
// no changes anymore: any still active searches have no path
for (int64_t lane = 0; lane < LANE_LIMIT; lane++) {
int64_t search_num = bfs_state->lane_to_num[lane];
int64_t search_num = bfs_state.lane_to_num[lane];
if (search_num >= 0) { // active lane
result_validity.SetInvalid(search_num);
result_data[search_num] = (int64_t)-1; /* no path */
bfs_state->lane_to_num[lane] = -1; // mark inactive
bfs_state.lane_to_num[lane] = -1; // mark inactive
}
}

// if remaining pairs, schedule the BFS for the next batch
if (bfs_state->started_searches < gstate.global_tasks.Count()) {
if (bfs_state.started_searches < gstate.global_tasks.Count()) {
PhysicalPathFinding::ScheduleBFSTasks(*pipeline, *this, gstate);
}
}
Expand All @@ -480,6 +491,7 @@ PhysicalPathFinding::Finalize(Pipeline &pipeline, Event &event,
if (global_tasks.Count() != 0) {
DataChunk all_pairs;
DataChunk pairs;
global_tasks.InitializeScanChunk(all_pairs);
global_tasks.InitializeScanChunk(pairs);
ColumnDataScanState scan_state;
global_tasks.InitializeScan(scan_state);
Expand All @@ -498,52 +510,54 @@ PhysicalPathFinding::Finalize(Pipeline &pipeline, Event &event,
auto src_data = FlatVector::GetData<int64_t>(src);
auto dst_data = FlatVector::GetData<int64_t>(dst);

gstate.global_bfs_state = make_uniq<GlobalBFSState>(csr->v_size,
global_tasks.Count(), src_data, dst_data, vdata_src, vdata_dst);
// gstate.global_bfs_state = make_uniq<GlobalBFSState>(csr->v_size,
// global_tasks.Count(), src_data, dst_data, vdata_src, vdata_dst);
gstate.global_bfs_state.init(csr->v_size, global_tasks.Count(), src_data, dst_data, vdata_src, vdata_dst);

// Schedule the first round of BFS tasks
if (all_pairs.size() > 0) {
ScheduleBFSTasks(pipeline, event, gstate);
}

// debug print
gstate.global_bfs_state.result->Print(global_tasks.Count());
}

// debug print
gstate.global_bfs_state->result.Print(global_tasks.Count());

// Move to the next input child
++gstate.child;

return SinkFinalizeType::READY;
}

void ScheduleBFSTasks(Pipeline &pipeline, Event &event, GlobalSinkState &state) {
void PhysicalPathFinding::ScheduleBFSTasks(Pipeline &pipeline, Event &event, GlobalSinkState &state) {
auto &gstate = state.Cast<PathFindingGlobalState>();
auto &bfs_state = gstate.global_bfs_state;

// for every batch of pairs, schedule a BFS task
bfs_state->clear();
bfs_state.clear();

// remaining pairs
if (bfs_state->started_searches < gstate.global_tasks.Count()) {
if (bfs_state.started_searches < gstate.global_tasks.Count()) {

auto result_data = FlatVector::GetData<int64_t>(bfs_state->result);
auto& result_validity = FlatVector::Validity(bfs_state->result);
auto result_data = FlatVector::GetData<int64_t>(*bfs_state.result);
auto& result_validity = FlatVector::Validity(*bfs_state.result);

for (int64_t lane = 0; lane < LANE_LIMIT; lane++) {
bfs_state->lane_to_num[lane] = -1;
while (bfs_state->started_searches < gstate.global_tasks.Count()) {
int64_t search_num = bfs_state->started_searches++;
int64_t src_pos = bfs_state->vdata_src.sel->get_index(search_num);
int64_t dst_pos = bfs_state->vdata_src.sel->get_index(search_num);
if (!bfs_state->vdata_src.validity.RowIsValid(src_pos)) {
bfs_state.lane_to_num[lane] = -1;
while (bfs_state.started_searches < gstate.global_tasks.Count()) {
int64_t search_num = bfs_state.started_searches++;
int64_t src_pos = bfs_state.vdata_src.sel->get_index(search_num);
int64_t dst_pos = bfs_state.vdata_src.sel->get_index(search_num);
if (!bfs_state.vdata_src.validity.RowIsValid(src_pos)) {
result_validity.SetInvalid(search_num);
result_data[search_num] = (uint64_t)-1; /* no path */
} else if (bfs_state->src[src_pos] == bfs_state->dst[dst_pos]) {
} else if (bfs_state.src[src_pos] == bfs_state.dst[dst_pos]) {
result_data[search_num] =
(uint64_t)0; // path of length 0 does not require a search
} else {
bfs_state->visit1[bfs_state->src[src_pos]][lane] = true;
bfs_state->lane_to_num[lane] = search_num; // active lane
bfs_state.visit1[bfs_state.src[src_pos]][lane] = true;
bfs_state.lane_to_num[lane] = search_num; // active lane
break;
}
}
Expand Down

0 comments on commit 8e34267

Please sign in to comment.