diff --git a/jarbas/core/tests/__init__.py b/jarbas/core/tests/__init__.py index bb2d13a..9ca0f51 100644 --- a/jarbas/core/tests/__init__.py +++ b/jarbas/core/tests/__init__.py @@ -1,9 +1,76 @@ +from io import StringIO from datetime import date from random import randrange +from unittest.mock import Mock, call + from django.utils import timezone +from django.test import TestCase as DjangoTestCase + +from jarbas.core.models import Reimbursement, Tweet + + +class TestCase(DjangoTestCase): + + def serializer(self, command, input, expected): + serialized = command.serialize(input) + self.assertEqual(serialized, expected) + + def main(self, command, update, schedule_update, custom_method): + custom_method.return_value = (range(21), range(21, 43)) + command.main() + update.assert_has_calls([call()] * 2) + schedule_update.assert_has_calls(call(i) for i in range(42)) + + def schedule_update_non_existing_record(self, command, content, get): + get.side_effect = Reimbursement.DoesNotExist + command.queue = [] + command.schedule_update(content) + get.assert_called_once_with(document_id=42) + self.assertEqual([], command.queue) + + def update(self, command, fields, print_, bulk_update): + command.count = 40 + command.queue = list(range(2)) + command.update() + bulk_update.assert_called_with([0, 1], update_fields=fields) + print_.assert_called_with('42 reimbursements updated.', end='\r') + self.assertEqual(42, command.count) + + def handler_with_options(self, command, print_, exits, main, custom_command): + command.handle(dataset=self.file_name, batch_size=42) + main.assert_called_once_with() + print_.assert_called_once_with('0 reimbursements updated.') + self.assertEqual(command.path, self.file_name) + self.assertEqual(command.batch_size, 42) + + def handler_without_options(self, command, print_, exits, main, custom_command): + command.handle(dataset=self.file_name, batch_size=4096) + main.assert_called_once_with() + print_.assert_called_once_with('0 reimbursements updated.') + self.assertEqual(command.path, self.file_name) + self.assertEqual(command.batch_size, 4096) + + def handler_with_non_existing_file(self, command, exists, update, custom_command): + exists.return_value = False + with self.assertRaises(FileNotFoundError): + command.handle(dataset='suspicions.xz', batch_size=4096) + update.assert_not_called() + + def new_command(self, command, custom_command, serialize, rows, lzma, print_): + serialize.return_value = '.' + lzma.return_value = StringIO() + rows.return_value = range(42) + command.batch_size = 10 + command.path = self.file_name + expected = [['.'] * 10, ['.'] * 10, ['.'] * 10, ['.'] * 10, ['.'] * 2] + self.assertEqual(expected, list(custom_command)) + self.assertEqual(42, serialize.call_count) -from jarbas.core.models import Tweet + def add_arguments(self, command): + mock = Mock() + command.add_arguments(mock) + self.assertEqual(2, mock.add_argument.call_count) suspicions = { diff --git a/jarbas/core/tests/test_companies_command.py b/jarbas/core/tests/test_companies_command.py index f9995d6..1d2ebc1 100644 --- a/jarbas/core/tests/test_companies_command.py +++ b/jarbas/core/tests/test_companies_command.py @@ -2,17 +2,16 @@ from io import StringIO from unittest.mock import patch -from django.test import TestCase - from jarbas.core.management.commands.companies import Command from jarbas.core.models import Activity, Company -from jarbas.core.tests import sample_company_data +from jarbas.core.tests import TestCase, sample_company_data class TestCommand(TestCase): def setUp(self): self.command = Command() + self.file_name = 'companies.xz' class TestSerializer(TestCommand): @@ -23,7 +22,7 @@ def test_to_email(self): self.assertEqual(self.command.to_email('jane@example.com'), expected) def test_serializer(self): - company = { + input = { 'email': 'ahoy', 'opening': '31/12/1969', 'situation_date': '31/12/1969', @@ -39,7 +38,7 @@ def test_serializer(self): 'latitude': 3.1415, 'longitude': -42.0 } - self.assertEqual(self.command.serialize(company), expected) + self.serializer(self.command, input, expected) class TestCreate(TestCommand): @@ -83,25 +82,25 @@ def test_save_companies(self, create, print_count, serialize, save_activities, r class TestConventionMethods(TestCommand): @patch('jarbas.core.management.commands.companies.print') - @patch('jarbas.core.management.commands.companies.LoadCommand.drop_all') + @patch('jarbas.core.management.commands.companies.Command.drop_all') @patch('jarbas.core.management.commands.companies.Command.save_companies') @patch('jarbas.core.management.commands.companies.Command.print_count') - def test_handler_without_options(self, print_count, save_companies, drop_all, print_): + def test_handler_with_options(self, print_count, save_companies, drop_all, print_): print_count.return_value = 0 - self.command.handle(dataset='companies.xz') + self.command.handle(dataset=self.file_name, drop=True) print_.assert_called_with('Starting with 0 companies') + self.assertEqual(2, drop_all.call_count) self.assertEqual(1, save_companies.call_count) - self.assertEqual(1, print_count.call_count) - self.assertEqual('companies.xz', self.command.path) - drop_all.assert_not_called() @patch('jarbas.core.management.commands.companies.print') - @patch('jarbas.core.management.commands.companies.Command.drop_all') + @patch('jarbas.core.management.commands.companies.LoadCommand.drop_all') @patch('jarbas.core.management.commands.companies.Command.save_companies') @patch('jarbas.core.management.commands.companies.Command.print_count') - def test_handler_with_options(self, print_count, save_companies, drop_all, print_): + def test_handler_without_options(self, print_count, save_companies, drop_all, print_): print_count.return_value = 0 - self.command.handle(dataset='companies.xz', drop=True) + self.command.handle(dataset=self.file_name) print_.assert_called_with('Starting with 0 companies') - self.assertEqual(2, drop_all.call_count) self.assertEqual(1, save_companies.call_count) + self.assertEqual(1, print_count.call_count) + self.assertEqual(self.file_name, self.command.path) + drop_all.assert_not_called() diff --git a/jarbas/core/tests/test_company_model.py b/jarbas/core/tests/test_company_model.py index 022c06c..9fe56a6 100644 --- a/jarbas/core/tests/test_company_model.py +++ b/jarbas/core/tests/test_company_model.py @@ -1,4 +1,5 @@ from django.test import TestCase + from jarbas.core.models import Activity, Company from jarbas.core.tests import sample_activity_data, sample_company_data diff --git a/jarbas/core/tests/test_load_command.py b/jarbas/core/tests/test_load_command.py index 01445d8..567b7fa 100644 --- a/jarbas/core/tests/test_load_command.py +++ b/jarbas/core/tests/test_load_command.py @@ -1,47 +1,44 @@ from datetime import date from unittest.mock import Mock, patch -from django.test import TestCase - from jarbas.core.management.commands import LoadCommand from jarbas.core.models import Activity -from jarbas.core.tests import sample_activity_data +from jarbas.core.tests import TestCase, sample_activity_data -class TestStaticMethods(TestCase): +class TestCommand(TestCase): def setUp(self): - self.cmd = LoadCommand() + self.command = LoadCommand() + +class TestStaticMethods(TestCommand): def test_get_model_name(self): - self.assertEqual('Activity', self.cmd.get_model_name(Activity)) + self.assertEqual('Activity', self.command.get_model_name(Activity)) def test_to_date(self): expected = date(1991, 7, 22) - self.assertEqual(self.cmd.to_date('22/7/91'), expected) - self.assertEqual(self.cmd.to_date('1991-07-22 03:15:00+0300'), expected) - self.assertEqual(self.cmd.to_date('22/13/91'), None) - self.assertEqual(self.cmd.to_date('aa/7/91'), None) - self.assertEqual(self.cmd.to_date('22/07/16'), date(2016, 7, 22)) + self.assertEqual(self.command.to_date('22/7/91'), expected) + self.assertEqual(self.command.to_date('1991-07-22 03:15:00+0300'), expected) + self.assertEqual(self.command.to_date('22/13/91'), None) + self.assertEqual(self.command.to_date('aa/7/91'), None) + self.assertEqual(self.command.to_date('22/07/16'), date(2016, 7, 22)) def test_to_number(self): - self.assertIsNone(self.cmd.to_number('')) - self.assertIsNone(self.cmd.to_number('NaN')) - self.assertIsNone(self.cmd.to_number('nan')) - self.assertEqual(1.0, self.cmd.to_number('1')) - self.assertEqual(1.2, self.cmd.to_number('1.2')) - self.assertEqual(1, self.cmd.to_number('1', int)) - self.assertEqual(1, self.cmd.to_number('1.0', int)) + self.assertIsNone(self.command.to_number('')) + self.assertIsNone(self.command.to_number('NaN')) + self.assertIsNone(self.command.to_number('nan')) + self.assertEqual(1.0, self.command.to_number('1')) + self.assertEqual(1.2, self.command.to_number('1.2')) + self.assertEqual(1, self.command.to_number('1', int)) + self.assertEqual(1, self.command.to_number('1.0', int)) -class TestPrintCount(TestCase): - - def setUp(self): - self.cmd = LoadCommand() +class TestPrintCount(TestCommand): @patch('jarbas.core.management.commands.print') def test_print_no_records(self, mock_print): - self.cmd.print_count(Activity) + self.command.print_count(Activity) arg = 'Current count: 0 Activitys ' kwargs = {'end': '\r'} mock_print.assert_called_with(arg, **kwargs) @@ -49,20 +46,20 @@ def test_print_no_records(self, mock_print): @patch('jarbas.core.management.commands.print') def test_print_with_records(self, mock_print): Activity.objects.create(**sample_activity_data) - self.cmd.print_count(Activity) + self.command.print_count(Activity) arg = 'Current count: 1 Activitys ' kwargs = {'end': '\r'} mock_print.assert_called_with(arg, **kwargs) @patch('jarbas.core.management.commands.print') def test_print_with_permanent_keyword_arg(self, mock_print): - self.cmd.print_count(Activity, permanent=True) + self.command.print_count(Activity, permanent=True) arg = 'Current count: 0 Activitys ' kwargs = {'end': '\n'} mock_print.assert_called_with(arg, **kwargs) -class TestDropAll(TestCase): +class TestDropAll(TestCommand): @patch('jarbas.core.management.commands.print') def test_drop_all(self, mock_print): @@ -73,12 +70,10 @@ def test_drop_all(self, mock_print): self.assertEqual(0, Activity.objects.count()) -class TestAddArguments(TestCase): +class TestAddArguments(TestCommand): def test_add_arguments(self): - mock = Mock() - LoadCommand().add_arguments(mock) - self.assertEqual(2, mock.add_argument.call_count) + self.add_arguments(self.command) def test_add_arguments_without_drop_all(self): mock = Mock() diff --git a/jarbas/core/tests/test_receipts_command.py b/jarbas/core/tests/test_receipts_command.py index f1ed490..8dd9786 100644 --- a/jarbas/core/tests/test_receipts_command.py +++ b/jarbas/core/tests/test_receipts_command.py @@ -1,13 +1,19 @@ from unittest.mock import Mock, call, patch -from django.test import TestCase from django.db.models import QuerySet from requests.exceptions import ConnectionError from jarbas.core.management.commands.receipts import Command +from jarbas.core.tests import TestCase -class TestCommandHandler(TestCase): +class TestCommand(TestCase): + + def setUp(self): + self.command = Command() + + +class TestCommandHandler(TestCommand): @patch('jarbas.core.management.commands.receipts.Command.get_queryset') @patch('jarbas.core.management.commands.receipts.Command.fetch') @@ -17,105 +23,94 @@ class TestCommandHandler(TestCase): @patch('jarbas.core.management.commands.receipts.print') def test_handler_with_queryset(self, print_, sleep, print_pause, print_count, fetch, get_queryset): get_queryset.side_effect = (True, True, True, False) - command = Command() - command.handle(batch_size=3, pause=42) + self.command.handle(batch_size=3, pause=42) print_.assert_has_calls((call('Loading…'), call('Done!'))) print_pause.assert_has_calls((call(), call())) print_count.assert_called_once_with(permanent=True) sleep.assert_has_calls([call(42)] * 2) self.assertEqual(3, fetch.call_count) - self.assertEqual(3, command.batch) - self.assertEqual(42, command.pause) - self.assertEqual(0, command.count) + self.assertEqual(3, self.command.batch) + self.assertEqual(42, self.command.pause) + self.assertEqual(0, self.command.count) @patch('jarbas.core.management.commands.receipts.Command.get_queryset') @patch('jarbas.core.management.commands.receipts.Command.fetch') @patch('jarbas.core.management.commands.receipts.print') def test_handler_without_queryset(self, print_, fetch, get_queryset): get_queryset.return_value = False - command = Command() - command.handle(batch_size=42, pause=1) + self.command.handle(batch_size=42, pause=1) print_.assert_has_calls([ call('Loading…'), call('Nothing to fetch.') ]) get_queryset.assert_called_once_with() fetch.assert_not_called() - self.assertEqual(42, command.batch) - self.assertEqual(1, command.pause) - self.assertEqual(0, command.count) + self.assertEqual(42, self.command.batch) + self.assertEqual(1, self.command.pause) + self.assertEqual(0, self.command.count) def test_add_arguments(self): - parser = Mock() - command = Command() - command.add_arguments(parser) - self.assertEqual(2, parser.add_argument.call_count) + self.add_arguments(self.command) -class TestCommandMethods(TestCase): +class TestCommandMethods(TestCommand): @patch('jarbas.core.management.commands.receipts.Command.update') @patch('jarbas.core.management.commands.receipts.Command.bulk_update') @patch('jarbas.core.management.commands.receipts.Command.print_count') def test_fetch(self, print_count, bulk_update, update): - command = Command() - command.count = 0 - command.queryset = (1, 2, 3) - command.queue = [] - command.fetch() + self.command.count = 0 + self.command.queryset = (1, 2, 3) + self.command.queue = [] + self.command.fetch() print_count.assert_has_calls((call(), call(), call())) update.assert_has_calls(call(i) for i in range(1, 4)) - self.assertEqual(3, command.count) + self.assertEqual(3, self.command.count) bulk_update.assert_called_once_with() @patch.object(QuerySet, '__getitem__') @patch.object(QuerySet, 'filter', return_value=QuerySet()) def test_get_queryset(self, filter_, getitem): - command = Command() - command.batch = 42 - command.get_queryset() + self.command.batch = 42 + self.command.get_queryset() filter_.assert_called_once_with(receipt_fetched=False) getitem.assert_called_once_with(slice(None, 42)) def test_update(self): reimbursement = Mock() - command = Command() - command.queue = [] - command.update(reimbursement) + self.command.queue = [] + self.command.update(reimbursement) reimbursement.get_receipt_url.assert_called_once_with(bulk=True) - self.assertEqual(1, len(command.queue)) + self.assertEqual(1, len(self.command.queue)) def test_update_with_error(self): reimbursement = Mock() reimbursement.get_receipt_url.side_effect = ConnectionError() - command = Command() - command.queue = [] - command.update(reimbursement) + self.command.queue = [] + self.command.update(reimbursement) reimbursement.get_receipt_url.assert_called_once_with(bulk=True) - self.assertEqual(0, len(command.queue)) + self.assertEqual(0, len(self.command.queue)) @patch('jarbas.core.management.commands.receipts.bulk_update') @patch('jarbas.core.management.commands.receipts.Command.print_saving') def test_bulk_update(self, print_saving, bulk_update): - command = Command() - command.queue = [1, 2, 3] - command.bulk_update() + self.command.queue = [1, 2, 3] + self.command.bulk_update() fields = ['receipt_url', 'receipt_fetched'] bulk_update.assert_called_once_with([1, 2, 3], update_fields=fields) - self.assertEqual([], command.queue) + self.assertEqual([], self.command.queue) print_saving.assert_called_once_with() -class TestCommandPrintMethods(TestCase): +class TestCommandPrintMethods(TestCommand): def test_count_msg(self): - command = Command() - command.count = 42 - self.assertEqual('42 receipt URLs fetched', command.count_msg()) + self.command.count = 42 + self.assertEqual('42 receipt URLs fetched', self.command.count_msg()) @patch('jarbas.core.management.commands.receipts.print') def test_print_msg(self, print_): - Command.print_msg('42') + self.command.print_msg('42') print_.assert_has_calls(( call('\x1b[1A\x1b[2K\x1b[1A'), call('42') @@ -123,30 +118,27 @@ def test_print_msg(self, print_): @patch('jarbas.core.management.commands.receipts.print') def test_print_permanent_msg(self, print_): - Command.print_msg('42', permanent=True) + self.command.print_msg('42', permanent=True) print_.assert_called_once_with('42') @patch('jarbas.core.management.commands.receipts.Command.count_msg') @patch('jarbas.core.management.commands.receipts.Command.print_msg') def test_print_count(self, print_msg, count_msg): count_msg.return_value = '42' - command = Command() - command.print_count() - command.print_count(permanent=True) + self.command.print_count() + self.command.print_count(permanent=True) print_msg.assert_has_calls((call('42'), call('42', permanent=True))) @patch('jarbas.core.management.commands.receipts.Command.count_msg') @patch('jarbas.core.management.commands.receipts.Command.print_msg') def test_print_pause(self, print_msg, count_msg): count_msg.return_value = '42' - command = Command() - command.print_pause() + self.command.print_pause() print_msg.assert_called_once_with('42 (Taking a break to avoid being blocked…)') @patch('jarbas.core.management.commands.receipts.Command.count_msg') @patch('jarbas.core.management.commands.receipts.Command.print_msg') def test_print_saving(self, print_msg, count_msg): count_msg.return_value = '42' - command = Command() - command.print_saving() + self.command.print_saving() print_msg.assert_called_once_with('42 (Saving the URLs to the database…)') diff --git a/jarbas/core/tests/test_receipts_text_command.py b/jarbas/core/tests/test_receipts_text_command.py index 56fc4a2..a0db77a 100644 --- a/jarbas/core/tests/test_receipts_text_command.py +++ b/jarbas/core/tests/test_receipts_text_command.py @@ -1,16 +1,15 @@ -from io import StringIO -from unittest.mock import Mock, call, patch - -from django.test import TestCase +from unittest.mock import patch from jarbas.core.management.commands.receipts_text import Command from jarbas.core.models import Reimbursement +from jarbas.core.tests import TestCase class TestCommand(TestCase): def setUp(self): self.command = Command() + self.file_name = 'receipts-texts.xz' class TestSerializer(TestCommand): @@ -25,7 +24,7 @@ def test_serializer(self): 'document_id': '42', 'text': 'lorem ipsum' } - self.assertEqual(self.command.serialize(input), expected) + self.serializer(self.command, input, expected) def test_serializer_without_text(self): expected = { @@ -36,7 +35,7 @@ def test_serializer_without_text(self): input = { 'document_id': '42', } - self.assertEqual(self.command.serialize(input), expected) + self.serializer(self.command, input, expected) class TestCustomMethods(TestCommand): @@ -45,10 +44,7 @@ class TestCustomMethods(TestCommand): @patch('jarbas.core.management.commands.receipts_text.Command.schedule_update') @patch('jarbas.core.management.commands.receipts_text.Command.update') def test_main(self, update, schedule_update, receipts): - receipts.return_value = (range(21), range(21, 43)) - self.command.main() - update.assert_has_calls([call()] * 2) - schedule_update.assert_has_calls(call(i) for i in range(42)) + self.main(self.command, update, schedule_update, receipts) @patch.object(Reimbursement.objects, 'get') def test_schedule_update_existing_record(self, get): @@ -62,26 +58,18 @@ def test_schedule_update_existing_record(self, get): self.command.schedule_update(content) get.assert_called_once_with(document_id=42) self.assertEqual(content['receipt_text'], reimbursement.receipt_text) + self.assertEqual([reimbursement], self.command.queue) @patch.object(Reimbursement.objects, 'get') def test_schedule_update_non_existing_record(self, get): - get.side_effect = Reimbursement.DoesNotExist content = {'document_id': 42} - self.command.queue = [] - self.command.schedule_update(content) - get.assert_called_once_with(document_id=42) - self.assertEqual([], self.command.queue) + self.schedule_update_non_existing_record(self.command, content, get) @patch('jarbas.core.management.commands.receipts_text.bulk_update') @patch('jarbas.core.management.commands.receipts_text.print') def test_update(self, print_, bulk_update): - self.command.count = 40 - self.command.queue = list(range(2)) - self.command.update() fields = ['receipt_text',] - bulk_update.assert_called_with([0, 1], update_fields=fields) - print_.assert_called_with('42 reimbursements updated.', end='\r') - self.assertEqual(42, self.command.count) + self.update(self.command, fields, print_, bulk_update) class TestConventionMethods(TestCommand): @@ -91,31 +79,20 @@ class TestConventionMethods(TestCommand): @patch('jarbas.core.management.commands.receipts_text.os.path.exists') @patch('jarbas.core.management.commands.receipts_text.print') def test_handler_with_options(self, print_, exists, main, receipts): - self.command.handle(dataset='receipts-texts.xz', batch_size=42) - main.assert_called_once_with() - print_.assert_called_once_with('0 reimbursements updated.') - self.assertEqual(self.command.path, 'receipts-texts.xz') - self.assertEqual(self.command.batch_size, 42) + self.handler_with_options(self.command, print_, exists, main, receipts) @patch('jarbas.core.management.commands.receipts_text.Command.receipts') @patch('jarbas.core.management.commands.receipts_text.Command.main') @patch('jarbas.core.management.commands.receipts_text.os.path.exists') @patch('jarbas.core.management.commands.receipts_text.print') def test_handler_without_options(self, print_, exists, main, receipts): - self.command.handle(dataset='receipts-texts.xz', batch_size=4096) - main.assert_called_once_with() - print_.assert_called_once_with('0 reimbursements updated.') - self.assertEqual(self.command.path, 'receipts-texts.xz') - self.assertEqual(self.command.batch_size, 4096) + self.handler_without_options(self.command, print_, exists, main, receipts) @patch('jarbas.core.management.commands.receipts_text.Command.receipts') @patch('jarbas.core.management.commands.receipts_text.Command.main') @patch('jarbas.core.management.commands.receipts_text.os.path.exists') def test_handler_with_non_existing_file(self, exists, update, receipts): - exists.return_value = False - with self.assertRaises(FileNotFoundError): - self.command.handle(dataset='receipts-text.xz', batch_size=4096) - update.assert_not_called() + self.handler_with_non_existing_file(self.command, exists, update, receipts) class TestFileLoader(TestCommand): @@ -125,19 +102,11 @@ class TestFileLoader(TestCommand): @patch('jarbas.core.management.commands.receipts_text.csv.DictReader') @patch('jarbas.core.management.commands.receipts_text.Command.serialize') def test_receipts(self, serialize, rows, lzma, print_): - serialize.return_value = '.' - lzma.return_value = StringIO() - rows.return_value = range(42) - self.command.batch_size = 10 - self.command.path = 'receipts-text.xz' - expected = [['.'] * 10, ['.'] * 10, ['.'] * 10, ['.'] * 10, ['.'] * 2] - self.assertEqual(expected, list(self.command.receipts())) - self.assertEqual(42, serialize.call_count) + self.new_command(self.command, self.command.receipts(), + serialize, rows, lzma, print_) -class TestAddArguments(TestCase): +class TestAddArguments(TestCommand): def test_add_arguments(self): - mock = Mock() - Command().add_arguments(mock) - self.assertEqual(2, mock.add_argument.call_count) + self.add_arguments(self.command) diff --git a/jarbas/core/tests/test_reimbursements_command.py b/jarbas/core/tests/test_reimbursements_command.py index 998bf90..620f73b 100644 --- a/jarbas/core/tests/test_reimbursements_command.py +++ b/jarbas/core/tests/test_reimbursements_command.py @@ -2,10 +2,9 @@ from io import StringIO from unittest.mock import MagicMock, call, patch -from django.test import TestCase - from jarbas.core.management.commands.reimbursements import Command from jarbas.core.models import Reimbursement +from jarbas.core.tests import TestCase class TestCommand(TestCase): @@ -85,8 +84,7 @@ def test_serializer(self): 'reimbursement_value_total': 'NaN', 'year': '1970' } - self.maxDiff = 2 ** 10 - self.assertEqual(self.command.serialize(input), expected) + self.serializer(self.command, input, expected) class TestCreate(TestCommand): diff --git a/jarbas/core/tests/test_suspicions_command.py b/jarbas/core/tests/test_suspicions_command.py index 77a6191..0d0fca7 100644 --- a/jarbas/core/tests/test_suspicions_command.py +++ b/jarbas/core/tests/test_suspicions_command.py @@ -1,16 +1,15 @@ -from io import StringIO -from unittest.mock import Mock, call, patch - -from django.test import TestCase +from unittest.mock import patch from jarbas.core.management.commands.suspicions import Command from jarbas.core.models import Reimbursement +from jarbas.core.tests import TestCase class TestCommand(TestCase): def setUp(self): self.command = Command() + self.file_name = 'suspicions.xz' class TestSerializer(TestCommand): @@ -32,7 +31,8 @@ def test_serializer(self): 'hypothesis_3': 'True', 'probability': '0.38' } - self.assertEqual(self.command.serialize(input), expected) + self.serializer(self.command, input, expected) + def test_serializer_without_probability(self): expected = { @@ -50,7 +50,7 @@ def test_serializer_without_probability(self): 'hypothesis_2': 'False', 'hypothesis_3': 'True' } - self.assertEqual(self.command.serialize(input), expected) + self.serializer(self.command, input, expected) def test_serializer_without_suspicions(self): expected = { @@ -65,7 +65,7 @@ def test_serializer_without_suspicions(self): 'hypothesis_2': 'False', 'hypothesis_3': 'False' } - self.assertEqual(self.command.serialize(input), expected) + self.serializer(self.command, input, expected) class TestCustomMethods(TestCommand): @@ -74,10 +74,8 @@ class TestCustomMethods(TestCommand): @patch('jarbas.core.management.commands.suspicions.Command.schedule_update') @patch('jarbas.core.management.commands.suspicions.Command.update') def test_main(self, update, schedule_update, suspicions): - suspicions.return_value = (range(21), range(21, 43)) - self.command.main() - update.assert_has_calls([call()] * 2) - schedule_update.assert_has_calls(call(i) for i in range(42)) + self.main(self.command, update, schedule_update, suspicions) + @patch.object(Reimbursement.objects, 'get') def test_schedule_update_existing_record(self, get): @@ -90,30 +88,21 @@ def test_schedule_update_existing_record(self, get): } self.command.queue = [] self.command.schedule_update(content) - get.assert_called_once_with(document_id=42) - self.assertEqual(0.618, reimbursement.probability) - self.assertEqual({'answer': 42}, reimbursement.suspicions) + get.assert_called_once_with(document_id=content['document_id']) + self.assertEqual(content['probability'], reimbursement.probability) + self.assertEqual(content['suspicions'], reimbursement.suspicions) self.assertEqual([reimbursement], self.command.queue) @patch.object(Reimbursement.objects, 'get') def test_schedule_update_non_existing_record(self, get): - get.side_effect = Reimbursement.DoesNotExist content = {'document_id': 42} - self.command.queue = [] - self.command.schedule_update(content) - get.assert_called_once_with(document_id=42) - self.assertEqual([], self.command.queue) + self.schedule_update_non_existing_record(self.command, content, get) @patch('jarbas.core.management.commands.suspicions.bulk_update') @patch('jarbas.core.management.commands.suspicions.print') def test_update(self, print_, bulk_update): - self.command.count = 40 - self.command.queue = list(range(2)) - self.command.update() fields = ['probability', 'suspicions'] - bulk_update.assert_called_with([0, 1], update_fields=fields) - print_.assert_called_with('42 reimbursements updated.', end='\r') - self.assertEqual(42, self.command.count) + self.update(self.command, fields, print_, bulk_update) def test_bool(self): self.assertTrue(self.command.bool('True')) @@ -138,31 +127,20 @@ class TestConventionMethods(TestCommand): @patch('jarbas.core.management.commands.suspicions.os.path.exists') @patch('jarbas.core.management.commands.suspicions.print') def test_handler_with_options(self, print_, exists, main, suspicions): - self.command.handle(dataset='suspicions.xz', batch_size=42) - main.assert_called_once_with() - print_.assert_called_once_with('0 reimbursements updated.') - self.assertEqual(self.command.path, 'suspicions.xz') - self.assertEqual(self.command.batch_size, 42) + self.handler_with_options(self.command, print_, exists, main, suspicions) @patch('jarbas.core.management.commands.suspicions.Command.suspicions') @patch('jarbas.core.management.commands.suspicions.Command.main') @patch('jarbas.core.management.commands.suspicions.os.path.exists') @patch('jarbas.core.management.commands.suspicions.print') def test_handler_without_options(self, print_, exists, main, suspicions): - self.command.handle(dataset='suspicions.xz', batch_size=4096) - main.assert_called_once_with() - print_.assert_called_once_with('0 reimbursements updated.') - self.assertEqual(self.command.path, 'suspicions.xz') - self.assertEqual(self.command.batch_size, 4096) + self.handler_without_options(self.command, print_, exists, main, suspicions) @patch('jarbas.core.management.commands.suspicions.Command.suspicions') @patch('jarbas.core.management.commands.suspicions.Command.main') @patch('jarbas.core.management.commands.suspicions.os.path.exists') def test_handler_with_non_existing_file(self, exists, update, suspicions): - exists.return_value = False - with self.assertRaises(FileNotFoundError): - self.command.handle(dataset='suspicions.xz', batch_size=4096) - update.assert_not_called() + self.handler_with_non_existing_file(self.command, exists, update, suspicions) class TestFileLoader(TestCommand): @@ -172,19 +150,11 @@ class TestFileLoader(TestCommand): @patch('jarbas.core.management.commands.suspicions.csv.DictReader') @patch('jarbas.core.management.commands.suspicions.Command.serialize') def test_suspicions(self, serialize, rows, lzma, print_): - serialize.return_value = '.' - lzma.return_value = StringIO() - rows.return_value = range(42) - self.command.batch_size = 10 - self.command.path = 'suspicions.xz' - expected = [['.'] * 10, ['.'] * 10, ['.'] * 10, ['.'] * 10, ['.'] * 2] - self.assertEqual(expected, list(self.command.suspicions())) - self.assertEqual(42, serialize.call_count) + self.new_command(self.command, self.command.suspicions(), + serialize, rows, lzma, print_) -class TestAddArguments(TestCase): +class TestAddArguments(TestCommand): def test_add_arguments(self): - mock = Mock() - Command().add_arguments(mock) - self.assertEqual(2, mock.add_argument.call_count) + self.add_arguments(self.command) diff --git a/jarbas/core/tests/test_tweets_command.py b/jarbas/core/tests/test_tweets_command.py index 6c2130a..a9d3a76 100644 --- a/jarbas/core/tests/test_tweets_command.py +++ b/jarbas/core/tests/test_tweets_command.py @@ -2,12 +2,11 @@ from itertools import permutations from unittest.mock import MagicMock, PropertyMock, patch -from django.test import TestCase from mixer.backend.django import mixer from jarbas.core.models import Reimbursement, Tweet from jarbas.core.management.commands.tweets import Command -from jarbas.core.tests import random_tweet_status +from jarbas.core.tests import TestCase, random_tweet_status KEYS = (