Skip to content

Commit

Permalink
Fixing issue while using Feed iterator (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
chexca authored Mar 31, 2020
1 parent 09f6530 commit 7f7a3cb
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 34 deletions.
13 changes: 11 additions & 2 deletions examples/file_feed_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ async def _get_from_feed_and_enqueue(self):
break
self._cursor = feed.cursor

self._enqueue_files_task.done()

async def _process_files_from_queue(self):
"""Process files put in the queue by _get_from_feed_and_enqueue.
Expand Down Expand Up @@ -83,6 +85,9 @@ async def _process_files_from_queue(self):
self._queue.task_done()
print(file_obj.id)

task = self._worker_tasks.pop(0)
task.done()

def abort(self):
self._aborted = True

Expand All @@ -92,9 +97,12 @@ def cursor(self):
def run(self):

loop = asyncio.get_event_loop()
loop_tasks = []
# Create a task that read file object's from the feed and put them in a
# queue.
loop.create_task(self._get_from_feed_and_enqueue())
self._enqueue_files_task = loop.create_task(
self._get_from_feed_and_enqueue())
loop_tasks.append(self._enqueue_files_task)

# Create multiple tasks that read file object's from the queue, download
# the file's content, and create the output files.
Expand All @@ -109,7 +117,8 @@ def run(self):
loop.add_signal_handler(s, self.abort)

# Wait until all worker tasks has completed.
loop.run_until_complete(asyncio.gather(*self._worker_tasks))
loop_tasks.extend(self._worker_tasks)
loop.run_until_complete(asyncio.gather(*loop_tasks))
loop.close()


Expand Down
5 changes: 3 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,10 @@ def test_feed(httpserver):

with new_client(httpserver) as client:
feed = client.feed(FeedType.FILES, cursor='200102030405')
obj = feed.__next__()
feed_iterator = feed.__iter__()
obj = next(feed_iterator)
assert obj.type == 'file'
assert obj.id == 'dummy_file_id_1'
obj = feed.__next__()
obj = next(feed_iterator)
assert obj.type == 'file'
assert obj.id == 'dummy_file_id_2'
58 changes: 29 additions & 29 deletions vt/feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,38 +140,38 @@ def _skip(self, n):
self._batch_cursor += 1

def __iter__(self):
return self
while True:
if self._batch:
next_item = self._batch.readline()
else:
self._get_next_batch()
self._skip(self._batch_skip)
self._batch_skip = 0
next_item = self._batch.readline()
self._batch_cursor += 1
self._count += 1

if next_item:
yield Object.from_dict(json.loads(next_item.decode('utf-8')))
else:
self._batch = None

async def __aiter__(self):
return self
while True:
if self._batch:
next_item = self._batch.readline()
else:
await self._get_next_batch_async()
self._skip(self._batch_skip)
self._batch_skip = 0
next_item = self._batch.readline()
self._batch_cursor += 1
self._count += 1

def __next__(self):
if self._batch:
next_item = self._batch.readline()
else:
next_item = None
if not next_item:
self._get_next_batch()
self._skip(self._batch_skip)
self._batch_skip = 0
next_item = self._batch.readline()
self._batch_cursor += 1
self._count += 1
return Object.from_dict(json.loads(next_item.decode('utf-8')))

async def __anext__(self):
if self._batch:
next_item = self._batch.readline()
else:
next_item = None
if not next_item:
await self._get_next_batch_async()
self._skip(self._batch_skip)
self._batch_skip = 0
next_item = self._batch.readline()
self._batch_cursor += 1
self._count += 1
return Object.from_dict(json.loads(next_item.decode('utf-8')))
if next_item:
yield Object.from_dict(json.loads(next_item.decode('utf-8')))
else:
self._batch = None

@property
def cursor(self):
Expand Down
2 changes: 1 addition & 1 deletion vt/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.5.2'
__version__ = '0.5.3'

0 comments on commit 7f7a3cb

Please sign in to comment.