diff --git a/copier/copier.go b/copier/copier.go index 206c72f6..329dfdca 100644 --- a/copier/copier.go +++ b/copier/copier.go @@ -279,7 +279,7 @@ func (copier *Copier) isMarked(ent tl.Entity) bool { // CopyEntity performs validation and saves errors and warnings. // An entity error means the entity was not not written because it had an error or was filtered out; not fatal. // A write error should be considered fatal and should stop any further write attempts. -// Any errors and warnings are added to the Result. +// Any errors and warnings are added to the copier result. func (copier *Copier) CopyEntity(ent tl.Entity) (error, error) { var expandedEntities []tl.Entity expanded := false @@ -295,86 +295,92 @@ func (copier *Copier) CopyEntity(ent tl.Entity) (error, error) { expandedEntities = append(expandedEntities, ent) } for _, ent := range expandedEntities { + efn := ent.Filename() + sid := ent.EntityID() if err := copier.checkEntity(ent); err != nil { return err, nil } - if _, err := copier.addEntity(ent); err != nil { + eid, err := copier.Writer.AddEntity(ent) + if err != nil { + copier.sublogger.Error().Err(err).Str("filename", efn).Str("source_id", sid).Msgf("critical error: failed to write -- entity dump %#v", ent) return nil, err } + copier.EntityMap.Set(efn, sid, eid) + copier.result.EntityCount[efn]++ + for _, v := range copier.afterWriters { + if err := v.AfterWrite(eid, ent, copier.EntityMap); err != nil { + return nil, err + } + } } return nil, nil } // CopyEntities validates a slice of entities and writes those that pass validation. func (copier *Copier) CopyEntities(ents []tl.Entity) error { - var okEnts []tl.Entity - var expandedEntities []tl.Entity + okEnts := make([]tl.Entity, 0, len(ents)) for _, ent := range ents { expanded := false for _, f := range copier.expandFilters { if exp, ok, err := f.Expand(ent, copier.EntityMap); err != nil { - // + // skip } else if ok { expanded = true - expandedEntities = append(expandedEntities, exp...) + if err := copier.checkEntity(ent); err == nil { + okEnts = append(okEnts, exp...) + } } } if !expanded { - expandedEntities = append(expandedEntities, ent) - } - } - for _, ent := range expandedEntities { - if err := copier.checkEntity(ent); err == nil { - okEnts = append(okEnts, ent) + if err := copier.checkEntity(ent); err == nil { + okEnts = append(okEnts, ent) + } } } - return copier.writeBatch(okEnts) -} - -// writeBatch handles writing a batch of entities, all of the same kind. -func (copier *Copier) writeBatch(ents []tl.Entity) error { - if len(ents) == 0 { + if len(okEnts) == 0 { return nil } - efn := ents[0].Filename() - sids := []string{} - for _, ent := range ents { - sids = append(sids, ent.EntityID()) + efn := okEnts[0].Filename() + sids := make([]string, len(okEnts)) + for i, ent := range okEnts { + sids[i] = ent.EntityID() } - // OK, Save - eids, err := copier.Writer.AddEntities(ents) + eids, err := copier.Writer.AddEntities(okEnts) if err != nil { - copier.sublogger.Error().Err(err).Str("filename", efn).Msgf("critical error: failed to write %d entities", len(ents)) + copier.sublogger.Error().Err(err).Str("filename", efn).Msgf("critical error: failed to write %d entities", len(okEnts)) return err } for i, eid := range eids { - sid := sids[i] // copier.sublogger.Trace().Str("filename", efn).Str("source_id", sid).Str("output_id", eid).Msg("saved") + sid := sids[i] copier.EntityMap.Set(efn, sid, eid) } - copier.result.EntityCount[efn] += len(ents) + copier.result.EntityCount[efn] += len(okEnts) + // AfterWriters for i, eid := range eids { for _, v := range copier.afterWriters { - if err := v.AfterWrite(eid, ents[i], copier.EntityMap); err != nil { + if err := v.AfterWrite(eid, okEnts[i], copier.EntityMap); err != nil { return err } } } - // Return an emtpy slice and no error return nil } // checkBatch adds an entity to the current batch and calls writeBatch if above batch size. -func (copier *Copier) checkBatch(ents []tl.Entity, ent tl.Entity) ([]tl.Entity, error) { - if err := copier.checkEntity(ent); err != nil { - return ents, nil +func (copier *Copier) checkBatch(ents []tl.Entity, ent tl.Entity, flush bool) ([]tl.Entity, error) { + if ent != nil { + ents = append(ents, ent) } - ents = append(ents, ent) - if len(ents) < copier.BatchSize { - return ents, nil + if len(ents) >= copier.BatchSize { + flush = true } - return nil, copier.writeBatch(ents) + if flush { + err := copier.CopyEntities(ents) + return nil, err + } + return ents, nil } // checkEntity is the main filter and validation check. @@ -456,26 +462,6 @@ func (copier *Copier) checkEntity(ent tl.Entity) error { return nil } -func (copier *Copier) addEntity(ent tl.Entity) (string, error) { - // OK, Save - efn := ent.Filename() - sid := ent.EntityID() - eid, err := copier.Writer.AddEntity(ent) - if err != nil { - copier.sublogger.Error().Err(err).Str("filename", efn).Str("source_id", sid).Msgf("critical error: failed to write -- entity dump %#v", ent) - return "", err - } - copier.EntityMap.Set(efn, sid, eid) - copier.result.EntityCount[efn]++ - // AfterWriters - for _, v := range copier.afterWriters { - if err := v.AfterWrite(eid, ent, copier.EntityMap); err != nil { - return "", err - } - } - return eid, nil -} - ////////////////////////////////// ////////// Copy Methods ////////// ////////////////////////////////// @@ -911,7 +897,7 @@ func (copier *Copier) copyCalendars() error { cds := svc.CalendarDates() for i := range cds { cds[i].ServiceID = cid - if bt, btErr = copier.checkBatch(bt, &cds[i]); btErr != nil { + if bt, btErr = copier.checkBatch(bt, &cds[i], false); btErr != nil { return btErr } } @@ -919,7 +905,7 @@ func (copier *Copier) copyCalendars() error { copier.result.GeneratedCount["calendar.txt"]++ } } - if btErr = copier.writeBatch(bt); btErr != nil { + if _, btErr = copier.checkBatch(bt, nil, true); btErr != nil { return btErr } // Attempt to copy duplicate services @@ -1056,7 +1042,7 @@ func (copier *Copier) copyTripsAndStopTimes() error { } else { for i := range trip.StopTimes { var err error - stbt, err = copier.checkBatch(stbt, &trip.StopTimes[i]) + stbt, err = copier.checkBatch(stbt, &trip.StopTimes[i], false) if err != nil { return err } @@ -1064,7 +1050,7 @@ func (copier *Copier) copyTripsAndStopTimes() error { } } } - if err := copier.writeBatch(stbt); err != nil { + if _, err := copier.checkBatch(stbt, nil, true); err != nil { return err } diff --git a/copier/copier_test.go b/copier/copier_test.go index 10a85dce..9a8e3433 100644 --- a/copier/copier_test.go +++ b/copier/copier_test.go @@ -59,3 +59,79 @@ func TestCopier_Expand(t *testing.T) { assert.Equal(t, 1, agencyIds["test:2"]) assert.Equal(t, 1, agencyIds["test:3"]) } + +//////// + +// TODO: figure out why the fast benchmark is fast and the slow benchmark is slow +// This relates to copier.checkBatch: why is it faster when +// checkEntity is BEFORE appending to the batch slice, +// vs. appending always and then calling checkEntity during +// other filtering (as in CopyEntity) +var wtfBatchSize = 1_000_000 + +func BenchmarkWtfSlow(b *testing.B) { + testWtfCopyEntities := func(ents []tl.Entity) { + okEnts := make([]tl.Entity, 0, len(ents)) + for _, ent := range ents { + if err := testWtfCheck(ent); err != nil { + okEnts = append(okEnts, ent) + } + } + testWtfWriteEntities(okEnts) + } + testWtfCheckBatch := func(ents []tl.Entity, ent tl.Entity) []tl.Entity { + if len(ents) >= wtfBatchSize || ent == nil { + testWtfCopyEntities(ents) + return nil + } + ents = append(ents, ent) + return ents + } + b.ResetTimer() + for n := 0; n < b.N; n++ { + var ents []tl.Entity + for i := 0; i < wtfBatchSize; i++ { + ents = testWtfCheckBatch(ents, &tl.StopTime{}) + } + testWtfCheckBatch(ents, nil) + } +} + +func BenchmarkWtfFast(b *testing.B) { + testWtfCopyEntities := func(ents []tl.Entity) { + testWtfWriteEntities(ents) + } + testWtfCheckBatch := func(ents []tl.Entity, ent tl.Entity) []tl.Entity { + if len(ents) >= wtfBatchSize || ent == nil { + testWtfCopyEntities(ents) + return nil + } + if err := testWtfCheck(ent); err == nil { + ents = append(ents, ent) + } + return ents + } + b.ResetTimer() + for n := 0; n < b.N; n++ { + var ents []tl.Entity + for i := 0; i < wtfBatchSize; i++ { + ents = testWtfCheckBatch(ents, &tl.StopTime{}) + } + testWtfCheckBatch(ents, nil) + } +} + +func testWtfCheck(ent tl.Entity) error { + a := ent.Filename() + b := ent.EntityID() + _ = a + _ = b + return nil +} + +func testWtfWriteEntities(ents []tl.Entity) { + // fmt.Println("writing:", len(ents)) + b := len(ents) + _ = b + _ = ents +}