diff --git a/course_discovery/apps/course_metadata/management/commands/populate_product_catalog.py b/course_discovery/apps/course_metadata/management/commands/populate_product_catalog.py index 3ceeef149d..a020452a55 100644 --- a/course_discovery/apps/course_metadata/management/commands/populate_product_catalog.py +++ b/course_discovery/apps/course_metadata/management/commands/populate_product_catalog.py @@ -46,7 +46,7 @@ def add_arguments(self, parser): dest='product_source', type=str, required=False, - help='The product source to filter the products' + help='The comma-separated product source str to filter the products' ) parser.add_argument( '--use_gspread_client', @@ -86,7 +86,7 @@ def get_products(self, product_type, product_source): queryset = queryset.filter(type__slug=CourseType.BOOTCAMP_2U) if product_source: - queryset = queryset.filter(product_source__slug=product_source) + queryset = queryset.filter(product_source__slug__in=product_source.split(',')) queryset = queryset.annotate( num_orgs=Count('authoring_organizations') @@ -109,7 +109,7 @@ def get_products(self, product_type, product_source): .select_related('partner', 'type', 'primary_subject_override', 'language_override') if product_source: - queryset = queryset.filter(product_source__slug=product_source) + queryset = queryset.filter(product_source__slug__in=product_source.split(',')) queryset = queryset.annotate( num_orgs=Count('authoring_organizations') diff --git a/course_discovery/apps/course_metadata/management/commands/tests/test_populate_product_catalog.py b/course_discovery/apps/course_metadata/management/commands/tests/test_populate_product_catalog.py index b2eec9cf54..d160c57271 100644 --- a/course_discovery/apps/course_metadata/management/commands/tests/test_populate_product_catalog.py +++ b/course_discovery/apps/course_metadata/management/commands/tests/test_populate_product_catalog.py @@ -27,6 +27,7 @@ def setUp(self): self.organization = OrganizationFactory(partner=self.partner) self.course_type = CourseTypeFactory(slug=CourseType.AUDIT) self.source = SourceFactory.create(slug="edx") + self.source_2 = SourceFactory.create(slug="test-source") self.courses = CourseFactory.create_batch( 2, product_source=self.source, @@ -346,6 +347,91 @@ def test_populate_product_catalog_with_degrees_having_overrides(self): self.assertIn(degree.primary_subject_override.name, row["Subjects"]) self.assertEqual(row["Languages"], degree.language_override.code) + def test_populate_product_catalog_supports_multiple_product_sources(self): + """ + Test that the populate_product_catalog command supports multiple product sources. + """ + marketable_degree = DegreeFactory.create( + partner=self.partner, + additional_metadata=None, + type=self.program_type, + status=ProgramStatus.Active, + marketing_slug="valid-marketing-slug", + title="Marketable Degree", + authoring_organizations=[self.organization], + card_image=factory.django.ImageField(), + product_source=self.source, + ) + marketable_degree_2 = DegreeFactory.create( + partner=self.partner, + additional_metadata=None, + type=self.program_type, + status=ProgramStatus.Active, + marketing_slug="valid-marketing-slug", + title="Marketable Degree - with different product sources", + authoring_organizations=[self.organization], + card_image=factory.django.ImageField(), + language_override=None, + product_source=self.source_2, + ) + + with NamedTemporaryFile() as output_csv: + call_command( + "populate_product_catalog", + product_type="degree", + output_csv=output_csv.name, + product_source="edx", + gspread_client_flag=False, + ) + + with open(output_csv.name, "r") as output_csv_file: + csv_reader = csv.DictReader(output_csv_file) + rows = list(csv_reader) + + # Check that the marketable degree is in the CSV for the specified product source + matching_rows = [ + row for row in rows if row["UUID"] == str(marketable_degree.uuid.hex) + ] + self.assertEqual( + len(matching_rows), 1, f"Marketable degree '{marketable_degree.title}' should be in the CSV", + ) + + # Check that the marketable degree with different product sources is not in the CSV + matching_rows = [ + row for row in rows if row["UUID"] == str(marketable_degree_2.uuid.hex) + ] + self.assertEqual( + len(matching_rows), 0, + f"'{marketable_degree_2.title}' with different product sources should not be in the CSV", + ) + + with NamedTemporaryFile() as output_csv: + call_command( + "populate_product_catalog", + product_type="degree", + output_csv=output_csv.name, + product_source="edx,test-source", + gspread_client_flag=False, + ) + + with open(output_csv.name, "r") as output_csv_file: + csv_reader = csv.DictReader(output_csv_file) + rows = list(csv_reader) + + # Check that the marketable degree is in the CSV for the specified product sources + matching_rows = [ + row for row in rows if row["UUID"] == str(marketable_degree.uuid.hex) + ] + self.assertEqual( + len(matching_rows), 1, f"Marketable degree '{marketable_degree.title}' should be in the CSV", + ) + matching_rows = [ + row for row in rows if row["UUID"] == str(marketable_degree_2.uuid.hex) + ] + self.assertEqual( + len(matching_rows), 1, f"'{marketable_degree_2.title}' should be in the CSV", + ) + @mock.patch( "course_discovery.apps.course_metadata.management.commands.populate_product_catalog.Command.get_products" )