diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 40e8a9f22c..5f08435ac3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -89,7 +89,7 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 - - uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4 + - uses: actions/setup-node@1d0ff469b7ec7b3cb9d8673fde0c81c44821de2a # v4 with: node-version: "^20" cache: yarn @@ -194,7 +194,7 @@ jobs: - name: Checkout uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 - - uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4 + - uses: actions/setup-node@1d0ff469b7ec7b3cb9d8673fde0c81c44821de2a # v4 with: node-version: "^22" cache: yarn @@ -217,7 +217,7 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 - - uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4 + - uses: actions/setup-node@1d0ff469b7ec7b3cb9d8673fde0c81c44821de2a # v4 with: node-version: "^20" cache: yarn @@ -256,7 +256,7 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 - - uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4 + - uses: actions/setup-node@1d0ff469b7ec7b3cb9d8673fde0c81c44821de2a # v4 with: node-version: "^20" cache: yarn diff --git a/.github/workflows/production.yml b/.github/workflows/production.yml index 2dd823256b..4ca9d32e83 100644 --- a/.github/workflows/production.yml +++ b/.github/workflows/production.yml @@ -28,7 +28,7 @@ jobs: run: heroku container:login - name: Release Backend on Heroku - uses: akhileshns/heroku-deploy@e86b991436e126ff9d78399b801a6610a64881c9 + uses: akhileshns/heroku-deploy@c3187cbbeceea824a6f5d9e0e14e2995a611059c with: heroku_api_key: ${{ secrets.HEROKU_API_KEY }} heroku_app_name: mitopen-production diff --git a/.github/workflows/publish-pages.yml b/.github/workflows/publish-pages.yml index b558565bd7..a4f5f6e4a6 100644 --- a/.github/workflows/publish-pages.yml +++ b/.github/workflows/publish-pages.yml @@ -15,7 +15,7 @@ jobs: - name: Checkout uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 - - uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4 + - uses: actions/setup-node@1d0ff469b7ec7b3cb9d8673fde0c81c44821de2a # v4 with: node-version: "^20" cache: yarn diff --git a/.github/workflows/release-candidate.yml b/.github/workflows/release-candidate.yml index e4d39aa9d5..8983c24c0a 100644 --- a/.github/workflows/release-candidate.yml +++ b/.github/workflows/release-candidate.yml @@ -28,7 +28,7 @@ jobs: run: heroku container:login - name: Release Backend on Heroku - uses: akhileshns/heroku-deploy@e86b991436e126ff9d78399b801a6610a64881c9 + uses: akhileshns/heroku-deploy@c3187cbbeceea824a6f5d9e0e14e2995a611059c with: heroku_api_key: ${{ secrets.HEROKU_API_KEY }} heroku_app_name: mitopen-rc diff --git a/.gitignore b/.gitignore index b145bc50cd..a194b08a4d 100644 --- a/.gitignore +++ b/.gitignore @@ -135,3 +135,6 @@ storybook-static/ # ignore local ssl certs certs/ + +# ignore db backups +backups/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 017e65ff62..41e14e3e39 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -74,7 +74,7 @@ repos: - ".*/generated/" additional_dependencies: ["gibberish-detector"] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.9.1" + rev: "v0.9.4" hooks: - id: ruff-format - id: ruff diff --git a/Dockerfile b/Dockerfile index 85076f0560..8e0ac5aa7a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -18,6 +18,10 @@ RUN mkdir /src RUN adduser --disabled-password --gecos "" mitodl RUN mkdir /var/media && chown -R mitodl:mitodl /var/media +# copy in trusted certs +COPY --chmod=644 certs/*.crt /usr/local/share/ca-certificates/ +RUN update-ca-certificates + ## Set some poetry config ENV \ POETRY_VERSION=1.7.1 \ diff --git a/RELEASE.rst b/RELEASE.rst index b156ef95e7..fa9007cefe 100644 --- a/RELEASE.rst +++ b/RELEASE.rst @@ -1,6 +1,22 @@ Release Notes ============= +Version 0.30.2 +-------------- + +- Update dependency litellm to v1.60.8 (#2028) +- Update dependency drf-spectacular to ^0.28.0 (#2027) +- Update nginx Docker tag to v1.27.4 (#2026) +- Skip existing embeddings (#2017) +- Update dependency ruff to v0.9.5 (#2025) +- Update dependency Django to v4.2.19 (#2024) +- Update akhileshns/heroku-deploy digest to c3187cb (#1829) +- [pre-commit.ci] pre-commit autoupdate (#1976) +- Update dependency eslint-config-prettier to v10 (#1995) +- Update actions/setup-node digest to 1d0ff46 (#2009) +- Added SCIM /Bulk API endpoint (#1985) +- Add initial migration for users app (#2013) + Version 0.30.1 (Released February 10, 2025) -------------- diff --git a/profiles/scim/__init__.py b/backups/.keep similarity index 100% rename from profiles/scim/__init__.py rename to backups/.keep diff --git a/certs/.keep b/certs/.keep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docker-compose.services.yml b/docker-compose.services.yml index 83b1c9c658..57d355265b 100644 --- a/docker-compose.services.yml +++ b/docker-compose.services.yml @@ -19,6 +19,7 @@ services: - POSTGRES_PASSWORD=postgres volumes: - pgdata:/var/lib/postgresql + - ./backups:/mnt/backups redis: profiles: diff --git a/frontends/package.json b/frontends/package.json index ddf04c32f9..0aedd7b9a9 100644 --- a/frontends/package.json +++ b/frontends/package.json @@ -50,7 +50,7 @@ "cross-fetch": "^4.0.0", "eslint": "8.57.1", "eslint-config-mitodl": "^2.1.0", - "eslint-config-prettier": "^9.0.0", + "eslint-config-prettier": "^10.0.0", "eslint-import-resolver-typescript": "^3.6.1", "eslint-plugin-import": "^2.29.1", "eslint-plugin-jest": "^28.6.0", diff --git a/main/factories.py b/main/factories.py index 499c599cdb..2ea8eccb1d 100644 --- a/main/factories.py +++ b/main/factories.py @@ -4,7 +4,7 @@ import ulid from django.conf import settings -from factory import LazyFunction, RelatedFactory, SubFactory, Trait +from factory import Faker, LazyFunction, RelatedFactory, SubFactory, Trait from factory.django import DjangoModelFactory from factory.fuzzy import FuzzyText from social_django.models import UserSocialAuth @@ -15,8 +15,8 @@ class UserFactory(DjangoModelFactory): username = LazyFunction(lambda: ulid.new().str) email = FuzzyText(suffix="@example.com") - first_name = FuzzyText() - last_name = FuzzyText() + first_name = Faker("first_name") + last_name = Faker("last_name") profile = RelatedFactory("profiles.factories.ProfileFactory", "user") diff --git a/main/settings.py b/main/settings.py index 6e19e044dd..80ebc048fe 100644 --- a/main/settings.py +++ b/main/settings.py @@ -33,7 +33,7 @@ from main.settings_pluggy import * # noqa: F403 from openapi.settings_spectacular import open_spectacular_settings -VERSION = "0.30.1" +VERSION = "0.30.2" log = logging.getLogger() @@ -107,6 +107,7 @@ "drf_spectacular", # Put our apps after this point "main", + "users", "authentication", "channels", "profiles", @@ -121,6 +122,7 @@ "data_fixtures", "vector_search", "ai_chat", + "scim", ) if not get_bool("RUN_DATA_MIGRATIONS", default=False): @@ -140,9 +142,11 @@ "documentationUri": "", }, ], - "USER_ADAPTER": "profiles.scim.adapters.LearnSCIMUser", - "USER_MODEL_GETTER": "profiles.scim.adapters.get_user_model_for_scim", - "USER_FILTER_PARSER": "profiles.scim.filters.LearnUserFilterQuery", + "SERVICE_PROVIDER_CONFIG_MODEL": "scim.config.LearnSCIMServiceProviderConfig", + "USER_ADAPTER": "scim.adapters.LearnSCIMUser", + "USER_MODEL_GETTER": "scim.adapters.get_user_model_for_scim", + "USER_FILTER_PARSER": "scim.filters.LearnUserFilterQuery", + "GET_IS_AUTHENTICATED_PREDICATE": "scim.utils.is_authenticated_predicate", } diff --git a/main/settings_celery.py b/main/settings_celery.py index b043acc0f3..227c5e032d 100644 --- a/main/settings_celery.py +++ b/main/settings_celery.py @@ -131,6 +131,12 @@ "schedule": crontab(minute=30, hour=18), # 2:30pm EST "kwargs": {"period": "daily", "subscription_type": "channel_subscription_type"}, }, + "daily_embed_new_learning_resources": { + "task": "vector_search.tasks.embed_new_learning_resources", + "schedule": get_int( + "EMBED_NEW_RESOURCES_SCHEDULE_SECONDS", 60 * 30 + ), # default is every 30 minutes + }, "send-search-subscription-emails-every-1-days": { "task": "learning_resources_search.tasks.send_subscription_emails", "schedule": crontab(minute=0, hour=19), # 3:00pm EST diff --git a/main/urls.py b/main/urls.py index ef45361d81..b34382d7ba 100644 --- a/main/urls.py +++ b/main/urls.py @@ -17,7 +17,7 @@ from django.conf import settings from django.conf.urls.static import static from django.contrib import admin -from django.urls import include, path, re_path +from django.urls import include, re_path from django.views.generic.base import RedirectView from rest_framework.routers import DefaultRouter @@ -41,7 +41,6 @@ urlpatterns = ( [ # noqa: RUF005 - path("scim/v2/", include("django_scim.urls")), re_path(r"^o/", include("oauth2_provider.urls", namespace="oauth2_provider")), re_path(r"^admin/", admin.site.urls), re_path(r"", include("authentication.urls")), @@ -58,6 +57,7 @@ re_path(r"", include("articles.urls")), re_path(r"", include("testimonials.urls")), re_path(r"", include("news_events.urls")), + re_path(r"", include("scim.urls")), re_path(r"", include(features_router.urls)), re_path(r"^app", RedirectView.as_view(url=settings.APP_BASE_URL)), # Hijack diff --git a/nginx/Dockerfile b/nginx/Dockerfile index 53f7c7f776..87998cfbe2 100644 --- a/nginx/Dockerfile +++ b/nginx/Dockerfile @@ -2,7 +2,7 @@ # it's primary purpose is to emulate heroku-buildpack-nginx's # functionality that compiles config/nginx.conf.erb # See https://github.com/heroku/heroku-buildpack-nginx/blob/fefac6c569f28182b3459cb8e34b8ccafc403fde/bin/start-nginx -FROM nginx:1.27.3 +FROM nginx:1.27.4 # Logs are configured to a relatic path under /etc/nginx # but the container expects /var/log diff --git a/poetry.lock b/poetry.lock index e6542321b4..99c06f832d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1209,6 +1209,20 @@ files = [ {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, ] +[[package]] +name = "deepmerge" +version = "2.0" +description = "A toolset for deeply merging Python dictionaries." +optional = false +python-versions = ">=3.8" +files = [ + {file = "deepmerge-2.0-py3-none-any.whl", hash = "sha256:6de9ce507115cff0bed95ff0ce9ecc31088ef50cbdf09bc90a09349a318b3d00"}, + {file = "deepmerge-2.0.tar.gz", hash = "sha256:5c3d86081fbebd04dd5de03626a0607b809a98fb6ccba5770b62466fe940ff20"}, +] + +[package.extras] +dev = ["black", "build", "mypy", "pytest", "pyupgrade", "twine", "validate-pyproject[all]"] + [[package]] name = "defusedxml" version = "0.7.1" @@ -1289,13 +1303,13 @@ static3 = "*" [[package]] name = "django" -version = "4.2.18" +version = "4.2.19" description = "A high-level Python web framework that encourages rapid development and clean, pragmatic design." optional = false python-versions = ">=3.8" files = [ - {file = "Django-4.2.18-py3-none-any.whl", hash = "sha256:ba52eff7e228f1c775d5b0db2ba53d8c49d2f8bfe6ca0234df6b7dd12fb25b19"}, - {file = "Django-4.2.18.tar.gz", hash = "sha256:52ae8eacf635617c0f13b44f749e5ea13dc34262819b2cc8c8636abb08d82c4b"}, + {file = "Django-4.2.19-py3-none-any.whl", hash = "sha256:a104e13f219fc55996a4e416ef7d18ab4eeb44e0aa95174c192f16cda9f94e75"}, + {file = "Django-4.2.19.tar.gz", hash = "sha256:6c833be4b0ca614f0a919472a1028a3bbdeb6f056fa04023aeb923346ba2c306"}, ] [package.dependencies] @@ -1665,13 +1679,13 @@ djangorestframework = ">=3.14.0" [[package]] name = "drf-spectacular" -version = "0.27.2" +version = "0.28.0" description = "Sane and flexible OpenAPI 3 schema generation for Django REST framework" optional = false python-versions = ">=3.7" files = [ - {file = "drf-spectacular-0.27.2.tar.gz", hash = "sha256:a199492f2163c4101055075ebdbb037d59c6e0030692fc83a1a8c0fc65929981"}, - {file = "drf_spectacular-0.27.2-py3-none-any.whl", hash = "sha256:b1c04bf8b2fbbeaf6f59414b4ea448c8787aba4d32f76055c3b13335cf7ec37b"}, + {file = "drf_spectacular-0.28.0-py3-none-any.whl", hash = "sha256:856e7edf1056e49a4245e87a61e8da4baff46c83dbc25be1da2df77f354c7cb4"}, + {file = "drf_spectacular-0.28.0.tar.gz", hash = "sha256:2c778a47a40ab2f5078a7c42e82baba07397bb35b074ae4680721b2805943061"}, ] [package.dependencies] @@ -3243,23 +3257,23 @@ langsmith-pyo3 = ["langsmith-pyo3 (>=0.1.0rc2,<0.2.0)"] [[package]] name = "litellm" -version = "1.59.7" +version = "1.60.8" description = "Library to easily interface with LLM API providers" optional = false python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" files = [ - {file = "litellm-1.59.7-py3-none-any.whl", hash = "sha256:6d934d42560b88b4bbb7374ff9d609379712b258f548c809b7678a9cc6e83661"}, - {file = "litellm-1.59.7.tar.gz", hash = "sha256:e718d725f89c31c8404f793775a81e29a1bc030417771a39b392d474a1dbb47a"}, + {file = "litellm-1.60.8-py3-none-any.whl", hash = "sha256:260bdcc9749c769f1a84dc927abe7c91f6294a97da05abc6b513c5dd2dcf17a1"}, + {file = "litellm-1.60.8.tar.gz", hash = "sha256:4a0aca9bd226d727ca4a41aaf8722f825fc10cf33f37a177a3cceb4ee2c442d8"}, ] [package.dependencies] aiohttp = "*" click = "*" -httpx = ">=0.23.0,<0.28.0" +httpx = ">=0.23.0" importlib-metadata = ">=6.8.0" jinja2 = ">=3.1.2,<4.0.0" jsonschema = ">=4.22.0,<5.0.0" -openai = ">=1.55.3" +openai = ">=1.61.0" pydantic = ">=2.0.0,<3.0.0" python-dotenv = ">=0.2.0" tiktoken = ">=0.7.0" @@ -4510,13 +4524,13 @@ sympy = "*" [[package]] name = "openai" -version = "1.58.1" +version = "1.61.1" description = "The official Python library for the openai API" optional = false python-versions = ">=3.8" files = [ - {file = "openai-1.58.1-py3-none-any.whl", hash = "sha256:e2910b1170a6b7f88ef491ac3a42c387f08bd3db533411f7ee391d166571d63c"}, - {file = "openai-1.58.1.tar.gz", hash = "sha256:f5a035fd01e141fc743f4b0e02c41ca49be8fab0866d3b67f5f29b4f4d3c0973"}, + {file = "openai-1.61.1-py3-none-any.whl", hash = "sha256:72b0826240ce26026ac2cd17951691f046e5be82ad122d20a8e1b30ca18bd11e"}, + {file = "openai-1.61.1.tar.gz", hash = "sha256:ce1851507218209961f89f3520e06726c0aa7d0512386f0f977e3ac3e4f2472e"}, ] [package.dependencies] @@ -6612,29 +6626,29 @@ files = [ [[package]] name = "ruff" -version = "0.9.3" +version = "0.9.5" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.9.3-py3-none-linux_armv6l.whl", hash = "sha256:7f39b879064c7d9670197d91124a75d118d00b0990586549949aae80cdc16624"}, - {file = "ruff-0.9.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:a187171e7c09efa4b4cc30ee5d0d55a8d6c5311b3e1b74ac5cb96cc89bafc43c"}, - {file = "ruff-0.9.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:c59ab92f8e92d6725b7ded9d4a31be3ef42688a115c6d3da9457a5bda140e2b4"}, - {file = "ruff-0.9.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2dc153c25e715be41bb228bc651c1e9b1a88d5c6e5ed0194fa0dfea02b026439"}, - {file = "ruff-0.9.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:646909a1e25e0dc28fbc529eab8eb7bb583079628e8cbe738192853dbbe43af5"}, - {file = "ruff-0.9.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5a5a46e09355695fbdbb30ed9889d6cf1c61b77b700a9fafc21b41f097bfbba4"}, - {file = "ruff-0.9.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:c4bb09d2bbb394e3730d0918c00276e79b2de70ec2a5231cd4ebb51a57df9ba1"}, - {file = "ruff-0.9.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:96a87ec31dc1044d8c2da2ebbed1c456d9b561e7d087734336518181b26b3aa5"}, - {file = "ruff-0.9.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9bb7554aca6f842645022fe2d301c264e6925baa708b392867b7a62645304df4"}, - {file = "ruff-0.9.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cabc332b7075a914ecea912cd1f3d4370489c8018f2c945a30bcc934e3bc06a6"}, - {file = "ruff-0.9.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:33866c3cc2a575cbd546f2cd02bdd466fed65118e4365ee538a3deffd6fcb730"}, - {file = "ruff-0.9.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:006e5de2621304c8810bcd2ee101587712fa93b4f955ed0985907a36c427e0c2"}, - {file = "ruff-0.9.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:ba6eea4459dbd6b1be4e6bfc766079fb9b8dd2e5a35aff6baee4d9b1514ea519"}, - {file = "ruff-0.9.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:90230a6b8055ad47d3325e9ee8f8a9ae7e273078a66401ac66df68943ced029b"}, - {file = "ruff-0.9.3-py3-none-win32.whl", hash = "sha256:eabe5eb2c19a42f4808c03b82bd313fc84d4e395133fb3fc1b1516170a31213c"}, - {file = "ruff-0.9.3-py3-none-win_amd64.whl", hash = "sha256:040ceb7f20791dfa0e78b4230ee9dce23da3b64dd5848e40e3bf3ab76468dcf4"}, - {file = "ruff-0.9.3-py3-none-win_arm64.whl", hash = "sha256:800d773f6d4d33b0a3c60e2c6ae8f4c202ea2de056365acfa519aa48acf28e0b"}, - {file = "ruff-0.9.3.tar.gz", hash = "sha256:8293f89985a090ebc3ed1064df31f3b4b56320cdfcec8b60d3295bddb955c22a"}, + {file = "ruff-0.9.5-py3-none-linux_armv6l.whl", hash = "sha256:d466d2abc05f39018d53f681fa1c0ffe9570e6d73cde1b65d23bb557c846f442"}, + {file = "ruff-0.9.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:38840dbcef63948657fa7605ca363194d2fe8c26ce8f9ae12eee7f098c85ac8a"}, + {file = "ruff-0.9.5-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d56ba06da53536b575fbd2b56517f6f95774ff7be0f62c80b9e67430391eeb36"}, + {file = "ruff-0.9.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f7cb2a01da08244c50b20ccfaeb5972e4228c3c3a1989d3ece2bc4b1f996001"}, + {file = "ruff-0.9.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:96d5c76358419bc63a671caac70c18732d4fd0341646ecd01641ddda5c39ca0b"}, + {file = "ruff-0.9.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:deb8304636ed394211f3a6d46c0e7d9535b016f53adaa8340139859b2359a070"}, + {file = "ruff-0.9.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:df455000bf59e62b3e8c7ba5ed88a4a2bc64896f900f311dc23ff2dc38156440"}, + {file = "ruff-0.9.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de92170dfa50c32a2b8206a647949590e752aca8100a0f6b8cefa02ae29dce80"}, + {file = "ruff-0.9.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d28532d73b1f3f627ba88e1456f50748b37f3a345d2be76e4c653bec6c3e393"}, + {file = "ruff-0.9.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c746d7d1df64f31d90503ece5cc34d7007c06751a7a3bbeee10e5f2463d52d2"}, + {file = "ruff-0.9.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:11417521d6f2d121fda376f0d2169fb529976c544d653d1d6044f4c5562516ee"}, + {file = "ruff-0.9.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:5b9d71c3879eb32de700f2f6fac3d46566f644a91d3130119a6378f9312a38e1"}, + {file = "ruff-0.9.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:2e36c61145e70febcb78483903c43444c6b9d40f6d2f800b5552fec6e4a7bb9a"}, + {file = "ruff-0.9.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:2f71d09aeba026c922aa7aa19a08d7bd27c867aedb2f74285a2639644c1c12f5"}, + {file = "ruff-0.9.5-py3-none-win32.whl", hash = "sha256:134f958d52aa6fdec3b294b8ebe2320a950d10c041473c4316d2e7d7c2544723"}, + {file = "ruff-0.9.5-py3-none-win_amd64.whl", hash = "sha256:78cc6067f6d80b6745b67498fb84e87d32c6fc34992b52bffefbdae3442967d6"}, + {file = "ruff-0.9.5-py3-none-win_arm64.whl", hash = "sha256:18a29f1a005bddb229e580795627d297dfa99f16b30c7039e73278cf6b5f9fa9"}, + {file = "ruff-0.9.5.tar.gz", hash = "sha256:11aecd7a633932875ab3cb05a484c99970b9d52606ce9ea912b690b02653d56c"}, ] [[package]] @@ -7831,4 +7845,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"] [metadata] lock-version = "2.0" python-versions = "3.12.6" -content-hash = "21a25778ed407405a95e83d4acf4bad0046b0098d8ac525ac533dcf9b3cfb35d" +content-hash = "f733ab133945d808cb1294cb1bb99d9a34b118d57c324fcee30e0680978b0353" diff --git a/profiles/factories.py b/profiles/factories.py index 502cd599f5..7803c232f3 100644 --- a/profiles/factories.py +++ b/profiles/factories.py @@ -1,6 +1,8 @@ """Factories for making test data""" -from factory import Faker, Sequence, SubFactory +import uuid + +from factory import Faker, LazyFunction, SelfAttribute, Sequence, SubFactory from factory.django import DjangoModelFactory from factory.fuzzy import FuzzyChoice from faker.providers import BaseProvider @@ -49,6 +51,9 @@ class ProfileFactory(DjangoModelFactory): [Profile.CertificateDesired.YES.value, Profile.CertificateDesired.NO.value] ) + scim_external_id = LazyFunction(uuid.uuid4) + scim_username = SelfAttribute("user.email") + class Meta: model = Profile diff --git a/profiles/scim/views_test.py b/profiles/scim/views_test.py deleted file mode 100644 index 5a340dac79..0000000000 --- a/profiles/scim/views_test.py +++ /dev/null @@ -1,117 +0,0 @@ -import json - -from django.contrib.auth import get_user_model -from django.urls import reverse -from django_scim import constants - -User = get_user_model() - - -def test_scim_post_user(staff_client): - """Test that we can create a user via SCIM API""" - user_q = User.objects.filter(profile__scim_external_id="1") - assert not user_q.exists() - - resp = staff_client.post( - reverse("scim:users"), - content_type="application/scim+json", - data=json.dumps( - { - "schemas": [constants.SchemaURI.USER], - "emails": [{"value": "jdoe@example.com", "primary": True}], - "active": True, - "userName": "jdoe", - "externalId": "1", - "name": { - "familyName": "Doe", - "givenName": "John", - }, - "fullName": "John Smith Doe", - "emailOptIn": 1, - } - ), - ) - - assert resp.status_code == 201, f"Error response: {resp.content}" - - user = user_q.first() - - assert user is not None - assert user.email == "jdoe@example.com" - assert user.username == "jdoe" - assert user.first_name == "John" - assert user.last_name == "Doe" - assert user.profile.name == "John Smith Doe" - assert user.profile.email_optin is True - - # test an update - resp = staff_client.put( - f"{reverse('scim:users')}/{user.profile.scim_id}", - content_type="application/scim+json", - data=json.dumps( - { - "schemas": [constants.SchemaURI.USER], - "emails": [{"value": "jsmith@example.com", "primary": True}], - "active": True, - "userName": "jsmith", - "externalId": "1", - "name": { - "familyName": "Smith", - "givenName": "Jimmy", - }, - "fullName": "Jimmy Smith", - "emailOptIn": 0, - } - ), - ) - - assert resp.status_code == 200, f"Error response: {resp.content}" - - user = user_q.first() - - assert user is not None - assert user.email == "jsmith@example.com" - assert user.username == "jsmith" - assert user.first_name == "Jimmy" - assert user.last_name == "Smith" - assert user.profile.name == "Jimmy Smith" - assert user.profile.email_optin is False - - resp = staff_client.patch( - f"{reverse('scim:users')}/{user.profile.scim_id}", - content_type="application/scim+json", - data=json.dumps( - { - "schemas": [constants.SchemaURI.PATCH_OP], - "Operations": [ - { - "op": "replace", - # yes, the value we get from scim-for-keycloak is a JSON encoded string...inside JSON... - "value": json.dumps( - { - "schemas": [constants.SchemaURI.USER], - "emailOptIn": 1, - "fullName": "Billy Bob", - "name": { - "givenName": "Billy", - "familyName": "Bob", - }, - } - ), - } - ], - } - ), - ) - - assert resp.status_code == 200, f"Error response: {resp.content}" - - user = user_q.first() - - assert user is not None - assert user.email == "jsmith@example.com" - assert user.username == "jsmith" - assert user.first_name == "Billy" - assert user.last_name == "Bob" - assert user.profile.name == "Billy Bob" - assert user.profile.email_optin is True diff --git a/pyproject.toml b/pyproject.toml index d16a073a9a..6e32e402d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ cffi = "^1.15.1" cryptography = "^44.0.0" dj-database-url = "^2.0.0" dj-static = "^0.0.6" -Django = "4.2.18" +Django = "4.2.19" django-anymail = {extras = ["mailgun"], version = "^12.0"} django-bitfield = "^2.2.0" django-cache-memoize = "^0.2.0" @@ -36,7 +36,7 @@ django-server-status = "^0.7.0" django-storages = "^1.13.2" djangorestframework = "^3.14.0" drf-jwt = "^1.19.2" -drf-spectacular = "^0.27.0" +drf-spectacular = "^0.28.0" feedparser = "^6.0.10" google-api-python-client = "^2.89.0" html5lib = "^1.1" @@ -75,7 +75,7 @@ django-scim2 = "^0.19.1" django-oauth-toolkit = "^2.3.0" youtube-transcript-api = "^0.6.2" posthog = "^3.5.0" -ruff = "0.9.3" +ruff = "0.9.5" dateparser = "^1.2.0" uwsgitop = "^0.12" pytest-lazy-fixtures = "^1.1.1" @@ -83,7 +83,7 @@ pycountry = "^24.6.1" qdrant-client = {extras = ["fastembed"], version = "^1.12.0"} onnxruntime = "1.20.1" openai = "^1.55.3" -litellm = "1.59.7" +litellm = "1.60.8" langchain = "^0.3.11" tiktoken = "^0.8.0" llama-index = "^0.12.6" @@ -91,6 +91,7 @@ llama-index-llms-openai = "^0.3.12" llama-index-agent-openai = "^0.4.1" langchain-experimental = "^0.3.4" langchain-openai = "^0.3.2" +deepmerge = "^2.0" [tool.poetry.group.dev.dependencies] diff --git a/scim/README.md b/scim/README.md new file mode 100644 index 0000000000..017948ca5a --- /dev/null +++ b/scim/README.md @@ -0,0 +1,36 @@ +## SCIM + +## Prerequisites + +- You need the following a local [Keycloak](https://www.keycloak.org/) instance running. Note which major version you are running (should be at least 26.x). + - You should have custom user profile fields setup on your `olapps` realm: + - `fullName`: required, otherwise defaults + - `emailOptIn`: defaults + +## Install the scim-for-keycloak plugin + +Sign up for an account on https://scim-for-keycloak.de and follow the instructions here: https://scim-for-keycloak.de/documentation/installation/install + +## Configure SCIM + +In the SCIM admin console, do the following: + +### Configure Remote SCIM Provider + +- In django-admin, go to OAuth Toolkit and create a new access token +- Go to Remote SCIM Provider +- Click the `+` button +- Specify a base URL for your learn API backend: `http://:8063/scim/v2/` +- At the bottom of the page, click "Use default configuration" +- Add a new authentication method: + - Type: Long Life Bearer Token + - Bearer Token: the access token you created above +- On the Schemas tab, edit the User schema and add these custom attributes: + - Add a `fullName` attribute and set the Custom Attribute Name to `fullName` + - Add an attribute named `emailOptIn` with the following settings: + - Type: integer + - Custom Attribute Name: `emailOptIn` +- On the Realm Assignments tab, assign to the `olapps` realm +- Go to the Synchronization tab and perform one: + - Identifier attribute: email + - Synchronization strategy: Search and Bulk diff --git a/scim/__init__.py b/scim/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/profiles/scim/adapters.py b/scim/adapters.py similarity index 84% rename from profiles/scim/adapters.py rename to scim/adapters.py index 94c4b012f9..6983d4c480 100644 --- a/profiles/scim/adapters.py +++ b/scim/adapters.py @@ -44,6 +44,7 @@ class LearnSCIMUser(SCIMUser): ("active", None, None): "is_active", ("name", "givenName", None): "first_name", ("name", "familyName", None): "last_name", + ("userName", None, None): "username", } IGNORED_PATHS = { @@ -158,7 +159,7 @@ def delete(self): """ self.obj.is_active = False self.obj.save() - logger.info("Deactivated user id %i", self.obj.user.id) + logger.info("Deactivated user id %i", self.obj.id) def handle_add( self, @@ -193,7 +194,7 @@ def parse_scim_for_keycloak_payload(self, payload: str) -> dict: if isinstance(value, dict): for nested_key, nested_value in value.items(): - result[f"{key}.{nested_key}"] = nested_value + result[self.split_path(f"{key}.{nested_key}")] = nested_value else: result[key] = value @@ -202,11 +203,32 @@ def parse_scim_for_keycloak_payload(self, payload: str) -> dict: def parse_path_and_values( self, path: Optional[str], value: Union[str, list, dict] ) -> list: - if not path and isinstance(value, str): + """Parse the incoming value(s)""" + if isinstance(value, str): # scim-for-keycloak sends this as a noncompliant JSON-encoded string - value = self.parse_scim_for_keycloak_payload(value) + if path is None: + val = json.loads(value) + else: + msg = "Called with a non-null path and a str value" + raise ValueError(msg) + else: + val = value + + results = [] + + for attr_path, attr_value in val.items(): + if isinstance(attr_value, dict): + # nested object, we want to recursively flatten it to `first.second` + results.extend(self.parse_path_and_values(attr_path, attr_value)) + else: + flattened_path = ( + f"{path}.{attr_path}" if path is not None else attr_path + ) + new_path = self.split_path(flattened_path) + new_value = attr_value + results.append((new_path, new_value)) - return super().parse_path_and_values(path, value) + return results def handle_replace( self, @@ -219,22 +241,20 @@ def handle_replace( All operations happen within an atomic transaction. """ + if not isinstance(value, dict): # Restructure for use in loop below. value = {path: value} for nested_path, nested_value in (value or {}).items(): if nested_path.first_path in self.ATTR_MAP: - setattr( - self.obj, self.ATTR_MAP.get(nested_path.first_path), nested_value - ) - + setattr(self.obj, self.ATTR_MAP[nested_path.first_path], nested_value) elif nested_path.first_path == ("fullName", None, None): self.obj.profile.name = nested_value elif nested_path.first_path == ("emailOptIn", None, None): self.obj.profile.email_optin = nested_value == 1 elif nested_path.first_path == ("emails", None, None): - self.parse_emails(value) + self.parse_emails(nested_value) elif nested_path.first_path not in self.IGNORED_PATHS: logger.debug( "Ignoring SCIM update for path: %s", nested_path.first_path diff --git a/scim/apps.py b/scim/apps.py new file mode 100644 index 0000000000..7cfdae6bfa --- /dev/null +++ b/scim/apps.py @@ -0,0 +1,5 @@ +from django.apps import AppConfig + + +class ScimConfig(AppConfig): + name = "scim" diff --git a/scim/config.py b/scim/config.py new file mode 100644 index 0000000000..49da726497 --- /dev/null +++ b/scim/config.py @@ -0,0 +1,13 @@ +from django_scim.models import SCIMServiceProviderConfig + + +class LearnSCIMServiceProviderConfig(SCIMServiceProviderConfig): + """Custom provider config""" + + def to_dict(self): + result = super().to_dict() + + result["bulk"]["supported"] = True + result["filter"]["supported"] = True + + return result diff --git a/scim/constants.py b/scim/constants.py new file mode 100644 index 0000000000..c51546dabe --- /dev/null +++ b/scim/constants.py @@ -0,0 +1,7 @@ +"""SCIM constants""" + + +class SchemaURI: + BULK_REQUEST = "urn:ietf:params:scim:api:messages:2.0:BulkRequest" + + BULK_RESPONSE = "urn:ietf:params:scim:api:messages:2.0:BulkResponse" diff --git a/profiles/scim/filters.py b/scim/filters.py similarity index 51% rename from profiles/scim/filters.py rename to scim/filters.py index 699752be01..a971e7fe9f 100644 --- a/profiles/scim/filters.py +++ b/scim/filters.py @@ -2,16 +2,27 @@ from django_scim.filters import UserFilterQuery +from scim.parser.queries.sql import PatchedSQLQuery + class LearnUserFilterQuery(UserFilterQuery): """Filters for users""" + query_class = PatchedSQLQuery + attr_map: dict[tuple[Optional[str], Optional[str], Optional[str]], str] = { ("userName", None, None): "auth_user.username", + ("emails", "value", None): "auth_user.email", ("active", None, None): "auth_user.is_active", - ("name", "formatted", None): "profiles_profile.name", + ("fullName", None, None): "profiles_profile.name", + ("name", "givenName", None): "auth_user.first_name", + ("name", "familyName", None): "auth_user.last_name", } joins: tuple[str, ...] = ( "INNER JOIN profiles_profile ON profiles_profile.user_id = auth_user.id", ) + + @classmethod + def search(cls, filter_query, request=None): + return super().search(filter_query, request=request) diff --git a/scim/forms.py b/scim/forms.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/scim/parser/__init__.py b/scim/parser/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/scim/parser/queries/__init__.py b/scim/parser/queries/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/scim/parser/queries/sql.py b/scim/parser/queries/sql.py new file mode 100644 index 0000000000..d6d8fcde52 --- /dev/null +++ b/scim/parser/queries/sql.py @@ -0,0 +1,15 @@ +from scim2_filter_parser.lexer import SCIMLexer +from scim2_filter_parser.parser import SCIMParser +from scim2_filter_parser.queries.sql import SQLQuery + +from scim.parser.transpilers.sql import PatchedTranspiler + + +class PatchedSQLQuery(SQLQuery): + """Patched SQLQuery to use the patch transpiler""" + + def build_where_sql(self): + self.token_stream = SCIMLexer().tokenize(self.filter) + self.ast = SCIMParser().parse(self.token_stream) + self.transpiler = PatchedTranspiler(self.attr_map) + self.where_sql, self.params_dict = self.transpiler.transpile(self.ast) diff --git a/scim/parser/transpilers/__init__.py b/scim/parser/transpilers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/scim/parser/transpilers/sql.py b/scim/parser/transpilers/sql.py new file mode 100644 index 0000000000..e095cce2b3 --- /dev/null +++ b/scim/parser/transpilers/sql.py @@ -0,0 +1,20 @@ +import string + +from scim2_filter_parser.transpilers.sql import Transpiler + + +class PatchedTranspiler(Transpiler): + """ + This is a fixed version of the upstream sql transpiler that converts SCIM queries + to SQL queries. + + Specifically it fixes the upper limit of 26 conditions for the search endpoint due + to the upstream library using the ascii alphabet for query parameters. + """ + + def get_next_id(self): + """Convert the current index to a base26 string""" + chars = string.ascii_lowercase + index = len(self.params) + + return (chars[-1] * int(index / len(chars))) + chars[index % len(chars)] diff --git a/scim/urls.py b/scim/urls.py new file mode 100644 index 0000000000..88221d982e --- /dev/null +++ b/scim/urls.py @@ -0,0 +1,17 @@ +"""URL configurations for profiles""" + +from django.urls import include, re_path + +from scim import views + +ol_scim_urls = ( + [ + re_path("^Bulk$", views.BulkView.as_view(), name="bulk"), + ], + "ol-scim", +) + +urlpatterns = [ + re_path("^scim/v2/", include(ol_scim_urls)), + re_path("^scim/v2/", include("django_scim.urls", namespace="scim")), +] diff --git a/scim/utils.py b/scim/utils.py new file mode 100644 index 0000000000..d6d0a4aee1 --- /dev/null +++ b/scim/utils.py @@ -0,0 +1,6 @@ +"""Utils""" + + +def is_authenticated_predicate(user): + """Verify that the user is active and staff""" + return user.is_authenticated and user.is_active and user.is_staff diff --git a/scim/views.py b/scim/views.py new file mode 100644 index 0000000000..72dcc09a65 --- /dev/null +++ b/scim/views.py @@ -0,0 +1,160 @@ +"""SCIM view customizations""" + +import copy +import json +import logging +from http import HTTPStatus +from urllib.parse import urlparse + +from django.http import HttpRequest, HttpResponse +from django.urls import Resolver404, resolve +from django_scim import constants as djs_constants +from django_scim import exceptions +from django_scim import views as djs_views + +from scim import constants + +log = logging.getLogger() + + +class InMemoryHttpRequest(HttpRequest): + """ + A spoofed HttpRequest that only exists in-memory. + It does not implement all features of HttpRequest and is only used + for the bulk SCIM operations here so we can reuse view implementations. + """ + + def __init__(self, request, path, method, body): + super().__init__() + + self.META = copy.deepcopy( + { + key: value + for key, value in request.META.items() + if not key.startswith(("wsgi", "uwsgi")) + } + ) + self.path = path + self.method = method + self.content_type = djs_constants.SCIM_CONTENT_TYPE + + # normally HttpRequest would read this in, but we already have the value + self._body = body + + +class BulkView(djs_views.SCIMView): + http_method_names = ["post"] + + def post(self, request, *args, **kwargs): # noqa: ARG002 + body = self.load_body(request.body) + + if body.get("schemas") != [constants.SchemaURI.BULK_REQUEST]: + msg = "Invalid schema uri. Must be SearchRequest." + raise exceptions.BadRequestError(msg) + + fail_on_errors = body.get("failOnErrors", None) + + if fail_on_errors is not None and isinstance(int, fail_on_errors): + msg = "Invalid failOnErrors. Must be an integer." + raise exceptions.BaseRequestError(msg) + + operations = body.get("Operations") + + results = self._attempt_operations(request, operations, fail_on_errors) + + response = { + "schemas": [constants.SchemaURI.BULK_RESPONSE], + "Operations": results, + } + + content = json.dumps(response) + + return HttpResponse( + content=content, + content_type=djs_constants.SCIM_CONTENT_TYPE, + status=HTTPStatus.OK, + ) + + def _attempt_operations(self, request, operations, fail_on_errors): + """Attempt to run the operations that were passed""" + responses = [] + num_errors = 0 + + for operation in operations: + # per-spec,if we've hit the error threshold stop processing and return + if fail_on_errors is not None and num_errors >= fail_on_errors: + break + + op_response = self._attempt_operation(request, operation) + + # if the operation returned a non-2xx status code, record it as a failure + if int(op_response.get("status")) >= HTTPStatus.MULTIPLE_CHOICES: + num_errors += 1 + + responses.append(op_response) + + return responses + + def _attempt_operation(self, bulk_request, operation): + """Attempt an operation as part of a bulk request""" + + method = operation.get("method") + bulk_id = operation.get("bulkId") + path = operation.get("path") + data = operation.get("data") + + try: + url_match = resolve(path, urlconf="django_scim.urls") + except Resolver404: + return self._operation_error( + bulk_id, + HTTPStatus.NOT_IMPLEMENTED, + "Endpoint is not supported for /Bulk", + ) + + # this is an ephemeral request not tied to the real request directly + op_request = InMemoryHttpRequest( + bulk_request, path, method, json.dumps(data).encode(djs_constants.ENCODING) + ) + + op_response = url_match.func(op_request, *url_match.args, **url_match.kwargs) + result = { + "method": method, + "bulkId": bulk_id, + "status": str(op_response.status_code), + } + + location = None + + if op_response.status_code >= HTTPStatus.BAD_REQUEST and op_response.content: + result["response"] = json.loads(op_response.content.decode("utf-8")) + + location = op_response.headers.get("Location", None) + + if location is not None: + result["location"] = location + # this is a custom field that the scim-for-keycloak plugin requires + try: + path = urlparse(location).path + location_match = resolve(path) + # this URL will be something like /scim/v2/Users/12345 + # resolving it gives the uuid + result["id"] = location_match.kwargs["uuid"] + except Resolver404: + log.exception("Unable to resolve resource url: %s", location) + + return result + + def _operation_error(self, method, bulk_id, status_code, detail): + """Return a failure response""" + status_code = str(status_code) + return { + "method": method, + "status": status_code, + "bulkId": bulk_id, + "response": { + "schemas": [djs_constants.SchemaURI.ERROR], + "status": status_code, + "detail": detail, + }, + } diff --git a/scim/views_test.py b/scim/views_test.py new file mode 100644 index 0000000000..c1ee44f9db --- /dev/null +++ b/scim/views_test.py @@ -0,0 +1,415 @@ +import itertools +import json +import operator +import random +from collections.abc import Callable +from functools import reduce +from types import SimpleNamespace + +import pytest +from anys import ANY_STR +from deepmerge import always_merger +from django.contrib.auth import get_user_model +from django.test import Client +from django.urls import reverse +from django_scim import constants as djs_constants + +from main.factories import UserFactory +from scim import constants + +User = get_user_model() + + +@pytest.fixture +def scim_client(staff_user): + """Test client for scim""" + client = Client() + client.force_login(staff_user) + return client + + +def test_scim_user_post(scim_client): + """Test that we can create a user via SCIM API""" + user_q = User.objects.filter(profile__scim_external_id="1") + assert not user_q.exists() + + resp = scim_client.post( + reverse("scim:users"), + content_type="application/scim+json", + data=json.dumps( + { + "schemas": [djs_constants.SchemaURI.USER], + "emails": [{"value": "jdoe@example.com", "primary": True}], + "active": True, + "userName": "jdoe", + "externalId": "1", + "name": { + "familyName": "Doe", + "givenName": "John", + }, + "fullName": "John Smith Doe", + "emailOptIn": 1, + } + ), + ) + + assert resp.status_code == 201, f"Error response: {resp.content}" + + user = user_q.first() + + assert user is not None + assert user.email == "jdoe@example.com" + assert user.username == "jdoe" + assert user.first_name == "John" + assert user.last_name == "Doe" + assert user.profile.name == "John Smith Doe" + assert user.profile.email_optin is True + + +def test_scim_user_put(scim_client): + """Test that a user can be updated via PUT""" + user = UserFactory.create() + + resp = scim_client.put( + f"{reverse('scim:users')}/{user.profile.scim_id}", + content_type="application/scim+json", + data=json.dumps( + { + "schemas": [djs_constants.SchemaURI.USER], + "emails": [{"value": "jsmith@example.com", "primary": True}], + "active": True, + "userName": "jsmith", + "externalId": "1", + "name": { + "familyName": "Smith", + "givenName": "Jimmy", + }, + "fullName": "Jimmy Smith", + "emailOptIn": 0, + } + ), + ) + + assert resp.status_code == 200, f"Error response: {resp.content}" + + user.refresh_from_db() + + assert user.email == "jsmith@example.com" + assert user.username == "jsmith" + assert user.first_name == "Jimmy" + assert user.last_name == "Smith" + assert user.profile.name == "Jimmy Smith" + assert user.profile.email_optin is False + + +def test_scim_user_patch(scim_client): + """Test that a user can be updated via PATCH""" + user = UserFactory.create() + + resp = scim_client.patch( + f"{reverse('scim:users')}/{user.profile.scim_id}", + content_type="application/scim+json", + data=json.dumps( + { + "schemas": [djs_constants.SchemaURI.PATCH_OP], + "Operations": [ + { + "op": "replace", + # yes, the value we get from scim-for-keycloak is a JSON encoded string...inside JSON... + "value": json.dumps( + { + "schemas": [djs_constants.SchemaURI.USER], + "emailOptIn": 1, + "fullName": "Billy Bob", + "name": { + "givenName": "Billy", + "familyName": "Bob", + }, + } + ), + } + ], + } + ), + ) + + assert resp.status_code == 200, f"Error response: {resp.content}" + + user_updated = User.objects.get(pk=user.id) + + assert user_updated.email == user.email + assert user_updated.username == user.username + assert user_updated.first_name == "Billy" + assert user_updated.last_name == "Bob" + assert user_updated.profile.name == "Billy Bob" + assert user_updated.profile.email_optin is True + + +def _user_to_scim_payload(user): + """Test util to serialize a user to a SCIM representation""" + return { + "schemas": [djs_constants.SchemaURI.USER], + "emails": [{"value": user.email, "primary": True}], + "userName": user.username, + "emailOptIn": 1 if user.profile.email_optin else 0, + "fullName": user.profile.name, + "name": { + "givenName": user.first_name, + "familyName": user.last_name, + }, + } + + +USER_FIELD_TYPES: dict[str, type] = { + "username": str, + "email": str, + "first_name": str, + "last_name": str, + "profile.email_optin": bool, + "profile.name": str, +} + +USER_FIELDS_TO_SCIM: dict[str, Callable] = { + "username": lambda value: {"userName": value}, + "email": lambda value: {"emails": [{"value": value, "primary": True}]}, + "first_name": lambda value: {"name": {"givenName": value}}, + "last_name": lambda value: {"name": {"familyName": value}}, + "profile.email_optin": lambda value: {"emailOptIn": 1 if value else 0}, + "profile.name": lambda value: {"fullName": value}, +} + + +def _post_operation(data, bulk_id_gen): + """Operation for a bulk POST""" + bulk_id = str(next(bulk_id_gen)) + return SimpleNamespace( + payload={ + "method": "post", + "bulkId": bulk_id, + "path": "/Users", + "data": _user_to_scim_payload(data), + }, + user=None, + expected_user_state=data, + expected_response={ + "method": "post", + "location": ANY_STR, + "bulkId": bulk_id, + "status": "201", + "id": ANY_STR, + }, + ) + + +def _put_operation(user, data, bulk_id_gen): + """Operation for a bulk PUT""" + bulk_id = str(next(bulk_id_gen)) + return SimpleNamespace( + payload={ + "method": "put", + "bulkId": bulk_id, + "path": f"/Users/{user.profile.scim_id}", + "data": _user_to_scim_payload(data), + }, + user=user, + expected_user_state=data, + expected_response={ + "method": "put", + "location": ANY_STR, + "bulkId": bulk_id, + "status": "200", + "id": str(user.profile.scim_id), + }, + ) + + +def _patch_operation(user, data, fields_to_patch, bulk_id_gen): + """Operation for a bulk PUT""" + + def _expected_patch_value(field): + field_getter = operator.attrgetter(field) + return field_getter(data if field in fields_to_patch else user) + + bulk_id = str(next(bulk_id_gen)) + field_updates = [ + mk_scim_value(operator.attrgetter(user_path)(data)) + for user_path, mk_scim_value in USER_FIELDS_TO_SCIM.items() + if user_path in fields_to_patch + ] + + return SimpleNamespace( + payload={ + "method": "patch", + "bulkId": bulk_id, + "path": f"/Users/{user.profile.scim_id}", + "data": { + "schemas": [djs_constants.SchemaURI.PATCH_OP], + "Operations": [ + { + "op": "replace", + "value": reduce(always_merger.merge, field_updates, {}), + } + ], + }, + }, + user=user, + expected_user_state=SimpleNamespace( + email=_expected_patch_value("email"), + username=_expected_patch_value("username"), + first_name=_expected_patch_value("first_name"), + last_name=_expected_patch_value("last_name"), + profile=SimpleNamespace( + name=_expected_patch_value("profile.name"), + email_optin=_expected_patch_value("profile.email_optin"), + ), + ), + expected_response={ + "method": "patch", + "location": ANY_STR, + "bulkId": bulk_id, + "status": "200", + "id": str(user.profile.scim_id), + }, + ) + + +def _delete_operation(user, bulk_id_gen): + """Operation for a bulk DELETE""" + bulk_id = str(next(bulk_id_gen)) + return SimpleNamespace( + payload={ + "method": "delete", + "bulkId": bulk_id, + "path": f"/Users/{user.profile.scim_id}", + }, + user=user, + expected_user_state=None, + expected_response={ + "method": "delete", + "bulkId": bulk_id, + "status": "204", + }, + ) + + +@pytest.fixture +def bulk_test_data(): + """Test data for the /Bulk API tests""" + existing_users = UserFactory.create_batch(500) + remaining_users = set(existing_users) + + users_to_put = random.sample(sorted(remaining_users, key=lambda user: user.id), 100) + remaining_users = remaining_users - set(users_to_put) + + users_to_patch = random.sample( + sorted(remaining_users, key=lambda user: user.id), 100 + ) + remaining_users = remaining_users - set(users_to_patch) + + users_to_delete = random.sample( + sorted(remaining_users, key=lambda user: user.id), 100 + ) + remaining_users = remaining_users - set(users_to_delete) + + user_post_data = UserFactory.build_batch(100) + user_put_data = UserFactory.build_batch(len(users_to_put)) + user_patch_data = UserFactory.build_batch(len(users_to_patch)) + + bulk_id_gen = itertools.count() + + post_operations = [_post_operation(data, bulk_id_gen) for data in user_post_data] + put_operations = [ + _put_operation(user, data, bulk_id_gen) + for user, data in zip(users_to_put, user_put_data) + ] + patch_operations = [ + _patch_operation(user, patch_data, fields_to_patch, bulk_id_gen) + for user, patch_data, fields_to_patch in [ + ( + user, + patch_data, + # random number of field updates + list( + random.sample( + list(USER_FIELDS_TO_SCIM.keys()), + random.randint(1, len(USER_FIELDS_TO_SCIM.keys())), # noqa: S311 + ) + ), + ) + for user, patch_data in zip(users_to_patch, user_patch_data) + ] + ] + delete_operations = [ + _delete_operation(user, bulk_id_gen) for user in users_to_delete + ] + + operations = [ + *post_operations, + *patch_operations, + *put_operations, + *delete_operations, + ] + random.shuffle(operations) + + return SimpleNamespace( + existing_users=existing_users, + remaining_users=remaining_users, + post_operations=post_operations, + patch_operations=patch_operations, + put_operations=put_operations, + delete_operations=delete_operations, + operations=operations, + ) + + +def test_bulk_post(scim_client, bulk_test_data): + """Verify that bulk operations work as expected""" + user_count = User.objects.count() + + resp = scim_client.post( + reverse("ol-scim:bulk"), + content_type="application/scim+json", + data=json.dumps( + { + "schemas": [constants.SchemaURI.BULK_REQUEST], + "Operations": [ + operation.payload for operation in bulk_test_data.operations + ], + } + ), + ) + + assert resp.status_code == 200 + + # singular user is the staff user + assert User.objects.count() == user_count + len(bulk_test_data.post_operations) + + results_by_bulk_id = { + op_result["bulkId"]: op_result for op_result in resp.json()["Operations"] + } + + for operation in bulk_test_data.operations: + assert ( + results_by_bulk_id[operation.payload["bulkId"]] + == operation.expected_response + ) + + if operation in bulk_test_data.delete_operations: + user = User.objects.get(id=operation.user.id) + assert not user.is_active + else: + if operation in bulk_test_data.post_operations: + user = User.objects.get(username=operation.expected_user_state.username) + else: + user = User.objects.get(id=operation.user.id) + + for key, key_type in USER_FIELD_TYPES.items(): + attr_getter = operator.attrgetter(key) + + actual_value = attr_getter(user) + expected_value = attr_getter(operation.expected_user_state) + + if key_type is bool or key_type is None: + assert actual_value is expected_value + else: + assert actual_value == expected_value diff --git a/users/__init__.py b/users/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/users/apps.py b/users/apps.py new file mode 100644 index 0000000000..88f7b1798e --- /dev/null +++ b/users/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class UsersConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "users" diff --git a/users/migrations/0001_initial.py b/users/migrations/0001_initial.py new file mode 100644 index 0000000000..b90bc3412a --- /dev/null +++ b/users/migrations/0001_initial.py @@ -0,0 +1,9 @@ +# Generated by Django 4.2.18 on 2025-02-05 19:41 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [] + + operations = [] diff --git a/users/migrations/__init__.py b/users/migrations/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vector_search/management/commands/create_qdrant_collections.py b/vector_search/management/commands/create_qdrant_collections.py index 762abc7d26..17b3a419ed 100644 --- a/vector_search/management/commands/create_qdrant_collections.py +++ b/vector_search/management/commands/create_qdrant_collections.py @@ -3,7 +3,7 @@ from django.core.management.base import BaseCommand from vector_search.utils import ( - create_qdrand_collections, + create_qdrant_collections, ) @@ -26,8 +26,8 @@ def handle(self, *args, **options): # noqa: ARG002 """Create Qdrant collections""" if options["force"]: - create_qdrand_collections(force_recreate=True) + create_qdrant_collections(force_recreate=True) else: - create_qdrand_collections(force_recreate=False) + create_qdrant_collections(force_recreate=False) self.stdout.write("Created Qdrant collections") diff --git a/vector_search/management/commands/generate_embeddings.py b/vector_search/management/commands/generate_embeddings.py index 324654da71..3e75f6fe4f 100644 --- a/vector_search/management/commands/generate_embeddings.py +++ b/vector_search/management/commands/generate_embeddings.py @@ -6,7 +6,7 @@ from main.utils import clear_search_cache, now_in_utc from vector_search.tasks import embed_learning_resources_by_id, start_embed_resources from vector_search.utils import ( - create_qdrand_collections, + create_qdrant_collections, ) @@ -42,6 +42,12 @@ def add_arguments(self, parser): action="store_true", help="Skip embedding content files", ) + parser.add_argument( + "--overwrite", + dest="overwrite", + action="store_true", + help="Force overwrite existing embeddings", + ) for object_type in sorted(LEARNING_RESOURCE_TYPES): parser.add_argument( @@ -71,7 +77,7 @@ def handle(self, *args, **options): # noqa: ARG002 self.stdout.write(f" --{object_type}s") return if options["recreate_collections"]: - create_qdrand_collections(force_recreate=True) + create_qdrant_collections(force_recreate=True) if options["resource-ids"]: task = embed_learning_resources_by_id.delay( [ @@ -79,10 +85,13 @@ def handle(self, *args, **options): # noqa: ARG002 for resource_id in options["resource-ids"].split(",") ], skip_content_files=options["skip_content_files"], + overwrite=options["overwrite"], ) else: task = start_embed_resources.delay( - indexes_to_update, skip_content_files=options["skip_content_files"] + indexes_to_update, + skip_content_files=options["skip_content_files"], + overwrite=options["overwrite"], ) self.stdout.write( f"Started celery task {task} to index content for the following" diff --git a/vector_search/tasks.py b/vector_search/tasks.py index dcba089cb6..1a5c5948e3 100644 --- a/vector_search/tasks.py +++ b/vector_search/tasks.py @@ -31,10 +31,7 @@ chunks, now_in_utc, ) -from vector_search.constants import ( - RESOURCES_COLLECTION_NAME, -) -from vector_search.utils import embed_learning_resources, filter_existing_qdrant_points +from vector_search.utils import embed_learning_resources log = logging.getLogger(__name__) @@ -46,7 +43,7 @@ retry_backoff=True, rate_limit="600/m", ) -def generate_embeddings(ids, resource_type): +def generate_embeddings(ids, resource_type, overwrite): """ Generate learning resource embeddings and index in Qdrant @@ -57,7 +54,7 @@ def generate_embeddings(ids, resource_type): """ try: with wrap_retry_exception(*SEARCH_CONN_EXCEPTIONS): - embed_learning_resources(ids, resource_type) + embed_learning_resources(ids, resource_type, overwrite) except (RetryError, Ignore): raise except SystemExit as err: @@ -69,7 +66,7 @@ def generate_embeddings(ids, resource_type): @app.task(bind=True) -def start_embed_resources(self, indexes, skip_content_files): +def start_embed_resources(self, indexes, skip_content_files, overwrite): """ Celery task to embed all learning resources for given indexes @@ -89,7 +86,7 @@ def start_embed_resources(self, indexes, skip_content_files): blocklisted_ids = load_course_blocklist() index_tasks = [ - generate_embeddings.si(ids, COURSE_TYPE) + generate_embeddings.si(ids, COURSE_TYPE, overwrite) for ids in chunks( Course.objects.filter(learning_resource__published=True) .exclude(learning_resource__readable_id=blocklisted_ids) @@ -123,10 +120,7 @@ def start_embed_resources(self, indexes, skip_content_files): ) index_tasks = index_tasks + [ - generate_embeddings.si( - ids, - CONTENT_FILE_TYPE, - ) + generate_embeddings.si(ids, CONTENT_FILE_TYPE, overwrite) for ids in chunks( run_contentfiles, chunk_size=settings.QDRANT_CHUNK_SIZE, @@ -150,10 +144,7 @@ def start_embed_resources(self, indexes, skip_content_files): chunk_size=settings.QDRANT_CHUNK_SIZE, ): index_tasks.append( - generate_embeddings.si( - ids, - resource_type, - ) + generate_embeddings.si(ids, resource_type, overwrite) ) except: # noqa: E722 error = "start_embed_resources threw an error" @@ -166,7 +157,7 @@ def start_embed_resources(self, indexes, skip_content_files): @app.task(bind=True) -def embed_learning_resources_by_id(self, ids, skip_content_files): +def embed_learning_resources_by_id(self, ids, skip_content_files, overwrite): """ Celery task to embed specific resources @@ -190,10 +181,7 @@ def embed_learning_resources_by_id(self, ids, skip_content_files): embed_resources = resources.filter(resource_type=resource_type) [ index_tasks.append( - generate_embeddings.si( - chunk_ids, - resource_type, - ) + generate_embeddings.si(chunk_ids, resource_type, overwrite) ) for chunk_ids in chunks( embed_resources.order_by("id").values_list("id", flat=True), @@ -216,10 +204,7 @@ def embed_learning_resources_by_id(self, ids, skip_content_files): ).order_by("id") content_ids = run_contentfiles.values_list("id", flat=True) index_tasks = index_tasks + [ - generate_embeddings.si( - ids, - CONTENT_FILE_TYPE, - ) + generate_embeddings.si(ids, CONTENT_FILE_TYPE, overwrite) for ids in chunks( content_ids, chunk_size=settings.QDRANT_CHUNK_SIZE, @@ -249,27 +234,19 @@ def embed_new_learning_resources(self): published=True, created_on__gt=since, ).exclude(resource_type=CONTENT_FILE_TYPE) - existing_readable_ids = [ - learning_resource.readable_id for learning_resource in new_learning_resources - ] - filtered_readable_ids = filter_existing_qdrant_points( - values=existing_readable_ids, - lookup_field="readable_id", - collection_name=RESOURCES_COLLECTION_NAME, - ) - filtered_resources = LearningResource.objects.filter( - readable_id__in=filtered_readable_ids + + resource_types = list( + new_learning_resources.values_list("resource_type", flat=True) ) - resource_types = list(filtered_resources.values_list("resource_type", flat=True)) tasks = [] for resource_type in resource_types: tasks.extend( [ - generate_embeddings.si(ids, resource_type) + generate_embeddings.si(ids, resource_type, overwrite=False) for ids in chunks( - filtered_resources.filter(resource_type=resource_type).values_list( - "id", flat=True - ), + new_learning_resources.filter( + resource_type=resource_type + ).values_list("id", flat=True), chunk_size=settings.QDRANT_CHUNK_SIZE, ) ] diff --git a/vector_search/tasks_test.py b/vector_search/tasks_test.py index 9c09acc615..7e87d18368 100644 --- a/vector_search/tasks_test.py +++ b/vector_search/tasks_test.py @@ -66,9 +66,13 @@ def test_start_embed_resources(mocker, mocked_celery, index): ) with pytest.raises(mocked_celery.replace_exception_class): - start_embed_resources.delay([index], skip_content_files=True) + start_embed_resources.delay([index], skip_content_files=True, overwrite=True) - generate_embeddings_mock.si.assert_called_once_with(resource_ids, index) + generate_embeddings_mock.si.assert_called_once_with( + resource_ids, + index, + True, # noqa: FBT003 + ) assert mocked_celery.replace.call_count == 1 assert mocked_celery.replace.call_args[0][1] == mocked_celery.chain.return_value @@ -101,7 +105,7 @@ def test_start_embed_resources_without_settings(mocker, mocked_celery, index): generate_embeddings_mock = mocker.patch( "vector_search.tasks.generate_embeddings", autospec=True ) - start_embed_resources.delay([index], skip_content_files=True) + start_embed_resources.delay([index], skip_content_files=True, overwrite=True) generate_embeddings_mock.si.assert_not_called() @@ -172,7 +176,9 @@ def test_embed_learning_resources_by_id(mocker, mocked_celery): content_ids.append(cf.id) with pytest.raises(mocked_celery.replace_exception_class): - embed_learning_resources_by_id.delay(resource_ids, skip_content_files=False) + embed_learning_resources_by_id.delay( + resource_ids, skip_content_files=False, overwrite=True + ) for mock_call in generate_embeddings_mock.si.mock_calls[1:]: assert mock_call.args[0][0] in content_ids assert mock_call.args[1] == "content_file" diff --git a/vector_search/utils.py b/vector_search/utils.py index 054ae4e1d4..43f43eccca 100644 --- a/vector_search/utils.py +++ b/vector_search/utils.py @@ -63,7 +63,7 @@ def points_generator( yield models.PointStruct(id=idx, payload=payload, vector=point_vector) -def create_qdrand_collections(force_recreate): +def create_qdrant_collections(force_recreate): """ Create or recreate QDrant collections @@ -174,8 +174,10 @@ def _process_resource_embeddings(serialized_resources): docs.append( f"{doc.get('title')} {doc.get('description')} {doc.get('full_description')}" ) - embeddings = encoder.embed_documents(docs) - return points_generator(ids, metadata, embeddings, vector_name) + if len(docs) > 0: + embeddings = encoder.embed_documents(docs) + return points_generator(ids, metadata, embeddings, vector_name) + return None def _chunk_documents(encoder, texts, metadatas): @@ -282,10 +284,12 @@ def _process_content_embeddings(serialized_content): except Exception as e: # noqa: BLE001 msg = f"Exceeded multi-vector max size: {e}" logger.warning(msg) - return points_generator(ids, metadata, embeddings, vector_name) + if ids: + return points_generator(ids, metadata, embeddings, vector_name) + return None -def embed_learning_resources(ids, resource_type): +def embed_learning_resources(ids, resource_type, overwrite): """ Embed learning resources @@ -296,20 +300,47 @@ def embed_learning_resources(ids, resource_type): client = qdrant_client() - resources_collection_name = RESOURCES_COLLECTION_NAME - content_files_collection_name = CONTENT_FILES_COLLECTION_NAME - - create_qdrand_collections(force_recreate=False) + create_qdrant_collections(force_recreate=False) if resource_type != CONTENT_FILE_TYPE: - serialized_resources = serialize_bulk_learning_resources(ids) - collection_name = resources_collection_name + serialized_resources = list(serialize_bulk_learning_resources(ids)) + points = [ + (vector_point_id(serialized["readable_id"]), serialized) + for serialized in serialized_resources + ] + if not overwrite: + filtered_point_ids = filter_existing_qdrant_points_by_ids( + [point[0] for point in points], + collection_name=RESOURCES_COLLECTION_NAME, + ) + serialized_resources = [ + point[1] for point in points if point[0] in filtered_point_ids + ] + + collection_name = RESOURCES_COLLECTION_NAME points = _process_resource_embeddings(serialized_resources) else: - serialized_resources = serialize_bulk_content_files(ids) - collection_name = content_files_collection_name + serialized_resources = list(serialize_bulk_content_files(ids)) + collection_name = CONTENT_FILES_COLLECTION_NAME + points = [ + ( + vector_point_id( + f"{doc['resource_readable_id']}.{doc['run_readable_id']}.{doc['key']}.0" + ), + doc, + ) + for doc in serialized_resources + ] + if not overwrite: + filtered_point_ids = filter_existing_qdrant_points_by_ids( + [point[0] for point in points], + collection_name=CONTENT_FILES_COLLECTION_NAME, + ) + serialized_resources = [ + point[1] for point in points if point[0] in filtered_point_ids + ] points = _process_content_embeddings(serialized_resources) - - client.upload_points(collection_name, points=points, wait=False) + if points: + client.upload_points(collection_name, points=points, wait=False) def _resource_vector_hits(search_result): @@ -395,6 +426,17 @@ def vector_search( } +def document_exists(document, collection_name=RESOURCES_COLLECTION_NAME): + client = qdrant_client() + count_result = client.count( + collection_name=collection_name, + count_filter=models.Filter( + must=qdrant_query_conditions(document, collection_name=collection_name) + ), + ) + return count_result.count > 0 + + def qdrant_query_conditions(params, collection_name=RESOURCES_COLLECTION_NAME): """ Generate Qdrant query conditions from query params @@ -432,6 +474,21 @@ def qdrant_query_conditions(params, collection_name=RESOURCES_COLLECTION_NAME): return conditions +def filter_existing_qdrant_points_by_ids( + point_ids, collection_name=RESOURCES_COLLECTION_NAME +): + """ + Return only points that dont exist in qdrant + """ + client = qdrant_client() + response = client.retrieve( + collection_name=collection_name, + ids=point_ids, + ) + existing = [record.id for record in response] + return [point_id for point_id in point_ids if point_id not in existing] + + def filter_existing_qdrant_points( values, lookup_field="readable_id", diff --git a/vector_search/utils_test.py b/vector_search/utils_test.py index fb7e83a3ec..f2d9d6273f 100644 --- a/vector_search/utils_test.py +++ b/vector_search/utils_test.py @@ -13,7 +13,7 @@ from vector_search.encoders.utils import dense_encoder from vector_search.utils import ( _chunk_documents, - create_qdrand_collections, + create_qdrant_collections, embed_learning_resources, filter_existing_qdrant_points, qdrant_query_conditions, @@ -35,8 +35,14 @@ def test_vector_point_id_used_for_embed(mocker, content_type): "vector_search.utils.qdrant_client", return_value=mock_qdrant, ) - - embed_learning_resources([resource.id for resource in resources], content_type) + if content_type == "learning_resource": + mocker.patch( + "vector_search.utils.filter_existing_qdrant_points", + return_value=[r.readable_id for r in resources], + ) + embed_learning_resources( + [resource.id for resource in resources], content_type, overwrite=True + ) if content_type == "learning_resource": point_ids = [vector_point_id(resource.readable_id) for resource in resources] @@ -47,12 +53,49 @@ def test_vector_point_id_used_for_embed(mocker, content_type): ) for resource in serialize_bulk_content_files([r.id for r in resources]) ] - assert sorted( [p.id for p in mock_qdrant.upload_points.mock_calls[0].kwargs["points"]] ) == sorted(point_ids) +@pytest.mark.parametrize("content_type", ["learning_resource", "content_file"]) +def test_embed_learning_resources_no_overwrite(mocker, content_type): + # test when overwrite flag is false we dont re-embed existing resources + if content_type == "learning_resource": + resources = LearningResourceFactory.create_batch(5) + else: + resources = ContentFileFactory.create_batch(5, content="test content") + mock_qdrant = mocker.patch("qdrant_client.QdrantClient") + mocker.patch( + "vector_search.utils.qdrant_client", + return_value=mock_qdrant, + ) + if content_type == "learning_resource": + # filter out 3 resources that are already embedded + mocker.patch( + "vector_search.utils.filter_existing_qdrant_points_by_ids", + return_value=[vector_point_id(r.readable_id) for r in resources[0:2]], + ) + else: + # all contentfiles exist in qdrant + mocker.patch( + "vector_search.utils.filter_existing_qdrant_points_by_ids", + return_value=[ + vector_point_id( + f"{doc['resource_readable_id']}.{doc['run_readable_id']}.{doc['key']}.0" + ) + for doc in serialize_bulk_content_files([r.id for r in resources[0:3]]) + ], + ) + embed_learning_resources( + [resource.id for resource in resources], content_type, overwrite=False + ) + if content_type == "learning_resource": + assert len(list(mock_qdrant.upload_points.mock_calls[0].kwargs["points"])) == 2 + else: + assert len(list(mock_qdrant.upload_points.mock_calls[0].kwargs["points"])) == 3 + + def test_filter_existing_qdrant_points(mocker): """ Test that filter_existing_qdrant_points filters out @@ -96,7 +139,7 @@ def test_filter_existing_qdrant_points(mocker): assert filtered_resources.count() == 7 -def test_force_create_qdrand_collections(mocker): +def test_force_create_qdrant_collections(mocker): """ Test that the force flag will recreate collections even if they exist @@ -107,7 +150,7 @@ def test_force_create_qdrand_collections(mocker): return_value=mock_qdrant, ) mock_qdrant.collection_exists.return_value = True - create_qdrand_collections(force_recreate=True) + create_qdrant_collections(force_recreate=True) assert ( mock_qdrant.recreate_collection.mock_calls[0].kwargs["collection_name"] == RESOURCES_COLLECTION_NAME @@ -126,7 +169,7 @@ def test_force_create_qdrand_collections(mocker): ) -def test_auto_create_qdrand_collections(mocker): +def test_auto_create_qdrant_collections(mocker): """ Test that collections will get autocreated if they don't exist @@ -137,7 +180,7 @@ def test_auto_create_qdrand_collections(mocker): return_value=mock_qdrant, ) mock_qdrant.collection_exists.return_value = False - create_qdrand_collections(force_recreate=False) + create_qdrant_collections(force_recreate=False) assert ( mock_qdrant.recreate_collection.mock_calls[0].kwargs["collection_name"] == RESOURCES_COLLECTION_NAME @@ -167,7 +210,7 @@ def test_skip_creating_qdrand_collections(mocker): return_value=mock_qdrant, ) mock_qdrant.collection_exists.return_value = False - create_qdrand_collections(force_recreate=False) + create_qdrant_collections(force_recreate=False) assert ( mock_qdrant.recreate_collection.mock_calls[0].kwargs["collection_name"] == RESOURCES_COLLECTION_NAME diff --git a/yarn.lock b/yarn.lock index 9f5e96a086..cb5715cb65 100644 --- a/yarn.lock +++ b/yarn.lock @@ -9126,14 +9126,14 @@ __metadata: languageName: node linkType: hard -"eslint-config-prettier@npm:^9.0.0": - version: 9.1.0 - resolution: "eslint-config-prettier@npm:9.1.0" +"eslint-config-prettier@npm:^10.0.0": + version: 10.0.1 + resolution: "eslint-config-prettier@npm:10.0.1" peerDependencies: eslint: ">=7.0.0" bin: - eslint-config-prettier: bin/cli.js - checksum: 10/411e3b3b1c7aa04e3e0f20d561271b3b909014956c4dba51c878bf1a23dbb8c800a3be235c46c4732c70827276e540b6eed4636d9b09b444fd0a8e07f0fcd830 + eslint-config-prettier: build/bin/cli.js + checksum: 10/ba6875df0fc4fd3c7c6e2ec9c2e6a224462f7afc662f4cf849775c598a3571c1be136a9b683b12971653b3dcf3f31472aaede3076524b46ec9a77582630158e5 languageName: node linkType: hard @@ -10070,7 +10070,7 @@ __metadata: cross-fetch: "npm:^4.0.0" eslint: "npm:8.57.1" eslint-config-mitodl: "npm:^2.1.0" - eslint-config-prettier: "npm:^9.0.0" + eslint-config-prettier: "npm:^10.0.0" eslint-import-resolver-typescript: "npm:^3.6.1" eslint-plugin-import: "npm:^2.29.1" eslint-plugin-jest: "npm:^28.6.0"