diff --git a/cmd/feed_pull.go b/cmd/feed_pull.go index 2a7652b..7ab7237 100644 --- a/cmd/feed_pull.go +++ b/cmd/feed_pull.go @@ -45,12 +45,11 @@ func newFeedPullCommand() *cobra.Command { return err } - entriesReadStatus := false var ( errs []error n int s = newPullSpinner(rawIDs) - ch = db.PullFeeds(cmd.Context(), ids, &entriesReadStatus) + ch = db.PullFeeds(cmd.Context(), ids, true) ) s.Start() diff --git a/internal/datastore/datastore.go b/internal/datastore/datastore.go index 97eeaa7..f0139db 100644 --- a/internal/datastore/datastore.go +++ b/internal/datastore/datastore.go @@ -52,7 +52,7 @@ type Datastore interface { PullFeeds( ctx context.Context, ids []entity.ID, - isRead *bool, + onlyUnread bool, ) ( results <-chan entity.PullResult, ) diff --git a/internal/datastore/sqlite_pull_feeds.go b/internal/datastore/sqlite_pull_feeds.go index ab69558..f4e27a9 100644 --- a/internal/datastore/sqlite_pull_feeds.go +++ b/internal/datastore/sqlite_pull_feeds.go @@ -16,7 +16,7 @@ import ( func (db *SQLite) PullFeeds( ctx context.Context, ids []entity.ID, - entryReadStatus *bool, + onlyUnread bool, ) <-chan entity.PullResult { var ( @@ -49,7 +49,7 @@ func (db *SQLite) PullFeeds( chs := make([]<-chan entity.PullResult, len(pks)) for i, pk := range pks { - chs[i] = pullFeedEntries(ctx, tx, pk, db.parser, entryReadStatus) + chs[i] = pullFeedEntries(ctx, tx, pk, db.parser, onlyUnread) } for pr := range merge(chs) { @@ -161,7 +161,7 @@ func pullFeedEntries( tx *sql.Tx, pk pullKey, parser Parser, - entryReadStatus *bool, + onlyUnread bool, ) chan entity.PullResult { pullTime := time.Now().UTC() @@ -188,6 +188,11 @@ func pullFeedEntries( return pk.err(err) } + var entryReadStatus *bool + if onlyUnread { + entryReadStatus = pointer(false) + } + entries, err := getEntries(ctx, tx, []ID{pk.feedID}, entryReadStatus, nil) if err != nil { return pk.err(err) diff --git a/internal/datastore/sqlite_pull_feeds_test.go b/internal/datastore/sqlite_pull_feeds_test.go index 1c9b3dd..e18c3a3 100644 --- a/internal/datastore/sqlite_pull_feeds_test.go +++ b/internal/datastore/sqlite_pull_feeds_test.go @@ -30,7 +30,7 @@ func TestPullFeedsAllOkEmptyDB(t *testing.T) { ParseURLWithContext(gomock.Any(), gomock.Any()). MaxTimes(0) - c := db.PullFeeds(context.Background(), nil, nil) + c := db.PullFeeds(context.Background(), nil, true) a.Empty(c) } @@ -69,7 +69,7 @@ func TestPullFeedsAllOkEmptyEntries(t *testing.T) { MaxTimes(1). Return(toGFeed(t, dbFeeds[1]), nil) - c := db.PullFeeds(context.Background(), nil, nil) + c := db.PullFeeds(context.Background(), nil, true) got := make([]entity.PullResult, 0) for res := range c { @@ -183,7 +183,7 @@ func TestPullFeedsAllOkNoNewEntries(t *testing.T) { MaxTimes(1). Return(toGFeed(t, pulledFeeds[1]), nil) - c := db.PullFeeds(context.Background(), nil, pointer(false)) + c := db.PullFeeds(context.Background(), nil, true) got := make([]entity.PullResult, 0) for res := range c { @@ -204,12 +204,12 @@ func TestPullFeedsAllOkNoNewEntries(t *testing.T) { a.ElementsMatch(want, got) } -func TestPullFeedsAllOkSomeNewEntries(t *testing.T) { +func TestPullFeedsAllOkSomeNewEntriesAll(t *testing.T) { t.Parallel() a := assert.New(t) db, dbFeeds, keys, pulledFeeds := setupComplexDBFixture(t) - c := db.PullFeeds(context.Background(), nil, nil) + c := db.PullFeeds(context.Background(), nil, false) got := make([]entity.PullResult, 0) for res := range c { @@ -310,12 +310,12 @@ func TestPullFeedsAllOkSomeNewEntries(t *testing.T) { a.ElementsMatch(want, got) } -func TestPullFeedsAllOkSomeNewEntriesUnread(t *testing.T) { +func TestPullFeedsAllOkSomeNewEntriesOnlyUnread(t *testing.T) { t.Parallel() a := assert.New(t) db, dbFeeds, keys, pulledFeeds := setupComplexDBFixture(t) - c := db.PullFeeds(context.Background(), nil, pointer(false)) + c := db.PullFeeds(context.Background(), nil, true) got := make([]entity.PullResult, 0) for res := range c { @@ -483,7 +483,7 @@ func TestPullFeedsSelectedOkSomeNewEntries(t *testing.T) { MaxTimes(1). Return(toGFeed(t, pulledFeed), nil) - c := db.PullFeeds(context.Background(), []ID{keys[pulledFeed.title].ID}, pointer(false)) + c := db.PullFeeds(context.Background(), []ID{keys[pulledFeed.title].ID}, true) got := make([]entity.PullResult, 0) for res := range c { diff --git a/internal/server/datastore_mock_test.go b/internal/server/datastore_mock_test.go index 52e351b..e762a09 100644 --- a/internal/server/datastore_mock_test.go +++ b/internal/server/datastore_mock_test.go @@ -193,17 +193,17 @@ func (mr *MockDatastoreMockRecorder) ListFeeds(ctx any) *gomock.Call { } // PullFeeds mocks base method. -func (m *MockDatastore) PullFeeds(ctx context.Context, ids []entity.ID, isRead *bool) <-chan entity.PullResult { +func (m *MockDatastore) PullFeeds(ctx context.Context, ids []entity.ID, onlyUnread bool) <-chan entity.PullResult { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PullFeeds", ctx, ids, isRead) + ret := m.ctrl.Call(m, "PullFeeds", ctx, ids, onlyUnread) ret0, _ := ret[0].(<-chan entity.PullResult) return ret0 } // PullFeeds indicates an expected call of PullFeeds. -func (mr *MockDatastoreMockRecorder) PullFeeds(ctx, ids, isRead any) *gomock.Call { +func (mr *MockDatastoreMockRecorder) PullFeeds(ctx, ids, onlyUnread any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PullFeeds", reflect.TypeOf((*MockDatastore)(nil).PullFeeds), ctx, ids, isRead) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PullFeeds", reflect.TypeOf((*MockDatastore)(nil).PullFeeds), ctx, ids, onlyUnread) } // MockeditableTable is a mock of editableTable interface. diff --git a/internal/server/service.go b/internal/server/service.go index 1e7b156..ab1f863 100644 --- a/internal/server/service.go +++ b/internal/server/service.go @@ -127,8 +127,7 @@ func (svc *service) PullFeeds( } // TODO: Expose isRead in proto. - isRead := false - ch := svc.ds.PullFeeds(stream.Context(), ids, &isRead) + ch := svc.ds.PullFeeds(stream.Context(), ids, true) for pr := range ch { payload, err := convert(pr)