diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 00000000..58bfecda --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,32 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: '' +assignees: '' + +--- + +**Describe the bug** +A clear and concise description of what the bug is. + +**To Reproduce** +Steps to reproduce the behavior: +1. Go to our [off-the-shelf samples](https://github.com/AzureAD/microsoft-authentication-library-for-python/tree/dev/sample) and pick one that is closest to your usage scenario. You should not need to modify the sample. +2. Follow the description of the sample, typically at the beginning of it, to prepare a `config.json` containing your test configurations +3. Run such sample, typically by `python sample.py config.json` +4. See the error +5. In this bug report, tell us the sample you choose, paste the content of the config.json with your test setup (which you can choose to skip your credentials, and/or mail it to our developer's email). + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**What you see instead** +Paste the sample output, or add screenshots to help explain your problem. + +**The MSAL Python version you are using** +Paste the output of this +`python -c "import msal; print(msal.__version__)"` + +**Additional context** +Add any other context about the problem here. diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml new file mode 100644 index 00000000..10afc207 --- /dev/null +++ b/.github/workflows/python-package.yml @@ -0,0 +1,92 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: CI/CD + +on: + push: + pull_request: + branches: [ dev ] + + # This guards against unknown PR until a community member vet it and label it. + types: [ labeled ] + +jobs: + ci: + env: + # Fake a TRAVIS env so that the pre-existing test cases would behave like before + TRAVIS: true + LAB_APP_CLIENT_ID: ${{ secrets.LAB_APP_CLIENT_ID }} + LAB_APP_CLIENT_SECRET: ${{ secrets.LAB_APP_CLIENT_SECRET }} + LAB_OBO_CLIENT_SECRET: ${{ secrets.LAB_OBO_CLIENT_SECRET }} + LAB_OBO_CONFIDENTIAL_CLIENT_ID: ${{ secrets.LAB_OBO_CONFIDENTIAL_CLIENT_ID }} + LAB_OBO_PUBLIC_CLIENT_ID: ${{ secrets.LAB_OBO_PUBLIC_CLIENT_ID }} + + # Derived from https://docs.github.com/en/actions/guides/building-and-testing-python#starting-with-the-python-workflow-template + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [2.7, 3.5, 3.6, 3.7, 3.8, 3.9, "3.10", "3.11.0-alpha.5"] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + # Derived from https://github.com/actions/cache/blob/main/examples.md#using-pip-to-get-cache-location + # However, a before-and-after test shows no improvement in this repo, + # possibly because the bottlenect was not in downloading those small python deps. + - name: Get pip cache dir from pip 20.1+ + id: pip-cache + run: | + echo "::set-output name=dir::$(pip cache dir)" + - name: pip cache + uses: actions/cache@v2 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt') }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install flake8 pytest + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + #flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + #flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + pytest + + cd: + needs: ci + if: github.event_name == 'push' && (startsWith(github.ref, 'refs/tags') || github.ref == 'refs/heads/main') + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.9 + uses: actions/setup-python@v2 + with: + python-version: 3.9 + - name: Build a package for release + run: | + python -m pip install build --user + python -m build --sdist --wheel --outdir dist/ . + - name: Publish to TestPyPI + uses: pypa/gh-action-pypi-publish@v1.4.2 + if: github.ref == 'refs/heads/main' + with: + user: __token__ + password: ${{ secrets.TEST_PYPI_API_TOKEN }} + repository_url: https://test.pypi.org/legacy/ + - name: Publish to PyPI + if: startsWith(github.ref, 'refs/tags') + uses: pypa/gh-action-pypi-publish@v1.4.2 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.gitignore b/.gitignore index eb93430d..18dae08c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,48 +1,61 @@ -.DS_Store -*.py[co] +# Python cache +__pycache__/ +*.pyc -# Packages -*.egg -*.egg-info -dist -build -eggs -parts -var -sdist -develop-eggs -.installed.cfg +# PTVS analysis +.ptvs/ +*.pyproj -# Installer logs -pip-log.txt +# Build results +/bin/ +/obj/ +/dist/ +/MANIFEST -# Unit test / coverage reports -.coverage -.tox +# Result of running python setup.py install/pip install -e +/build/ +/msal.egg-info/ -#Translations -*.mo +# Test results +/TestResults/ -#Mr Developer -.mr.developer.cfg +# User-specific files +*.suo +*.user +*.sln.docstates +/tests/config.py -# Emacs backup files -*~ +# Windows image file caches +Thumbs.db +ehthumbs.db -# IDEA / PyCharm IDE -.idea/ +# Folder config file +Desktop.ini -# vim -*.vim -*.swp +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Mac desktop service store files +.DS_Store -# Virtualenvs -env* +.idea +src/build +*.iml +/doc/_build -docs/source/reference/services -tests/coverage.xml -tests/nosetests.xml +# Virtual Environments +/env* +.venv/ +docs/_build/ +# Visual Studio Files +/.vs/* +/tests/.vs/* + +# vim files +*.swp # The test configuration file(s) could potentially contain credentials tests/config.json + +.env \ No newline at end of file diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 00000000..85917242 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,46 @@ +sudo: false +language: python +python: + - "2.7" + - "3.5" + - "3.6" +# Borrowed from https://github.com/travis-ci/travis-ci/issues/9815 +# Enable 3.7 without globally enabling sudo and dist: xenial for other build jobs +matrix: + include: + - python: 3.7 + dist: xenial + sudo: true + - python: 3.8 + dist: xenial + sudo: true + +install: + - pip install -r requirements.txt +script: + - python -m unittest discover -s tests + +deploy: + - # test pypi + provider: pypi + distributions: "sdist bdist_wheel" + server: https://test.pypi.org/legacy/ + user: "nugetaad" + password: + secure: KkjKySJujYxx31B15mlAZr2Jo4P99LcrMj3uON/X/WMXAqYVcVsYJ6JSzUvpNnCAgk+1hc24Qp6nibQHV824yiK+eG4qV+lpzkEEedkRx6NOW/h09OkT+pOSVMs0kcIhz7FzqChpl+jf6ZZpb13yJpQg2LoZIA4g8UdYHHFidWt4m5u1FZ9LPCqQ0OT3gnKK4qb0HIDaECfz5GYzrelLLces0PPwj1+X5eb38xUVtbkA1UJKLGKI882D8Rq5eBdbnDGsfDnF6oU+EBnGZ7o6HVQLdBgagDoVdx7yoXyntULeNxTENMTOZJEJbncQwxRgeEqJWXTTEW57O6Jo5uiHEpJA9lAePlRbS+z6BPDlnQogqOdTsYS0XMfOpYE0/r3cbtPUjETOmGYQxjQzfrFBfM7jaWnUquymZRYqCQ66VDo3I/ykNOCoM9qTmWt5L/MFfOZyoxLHnDThZBdJ3GXHfbivg+v+vOfY1gG8e2H2lQY+/LIMIJibF+MS4lJgrB81dcNdBzyxMNByuWQjSL1TY7un0QzcRcZz2NLrFGg8+9d67LQq4mK5ySimc6zdgnanuROU02vGr1EApT6D/qUItiulFgWqInNKrFXE9q74UP/WSooZPoLa3Du8y5s4eKerYYHQy5eSfIC8xKKDU8MSgoZhwQhCUP46G9Nsty0PYQc= + on: + branch: master + tags: false + condition: $TRAVIS_PYTHON_VERSION = "2.7" + + - # production pypi + provider: pypi + distributions: "sdist bdist_wheel" + user: "nugetaad" + password: + secure: KkjKySJujYxx31B15mlAZr2Jo4P99LcrMj3uON/X/WMXAqYVcVsYJ6JSzUvpNnCAgk+1hc24Qp6nibQHV824yiK+eG4qV+lpzkEEedkRx6NOW/h09OkT+pOSVMs0kcIhz7FzqChpl+jf6ZZpb13yJpQg2LoZIA4g8UdYHHFidWt4m5u1FZ9LPCqQ0OT3gnKK4qb0HIDaECfz5GYzrelLLces0PPwj1+X5eb38xUVtbkA1UJKLGKI882D8Rq5eBdbnDGsfDnF6oU+EBnGZ7o6HVQLdBgagDoVdx7yoXyntULeNxTENMTOZJEJbncQwxRgeEqJWXTTEW57O6Jo5uiHEpJA9lAePlRbS+z6BPDlnQogqOdTsYS0XMfOpYE0/r3cbtPUjETOmGYQxjQzfrFBfM7jaWnUquymZRYqCQ66VDo3I/ykNOCoM9qTmWt5L/MFfOZyoxLHnDThZBdJ3GXHfbivg+v+vOfY1gG8e2H2lQY+/LIMIJibF+MS4lJgrB81dcNdBzyxMNByuWQjSL1TY7un0QzcRcZz2NLrFGg8+9d67LQq4mK5ySimc6zdgnanuROU02vGr1EApT6D/qUItiulFgWqInNKrFXE9q74UP/WSooZPoLa3Du8y5s4eKerYYHQy5eSfIC8xKKDU8MSgoZhwQhCUP46G9Nsty0PYQc= + on: + branch: master + tags: true + condition: $TRAVIS_PYTHON_VERSION = "2.7" + diff --git a/LICENSE b/LICENSE index 844a00f9..e7a9ff04 100644 --- a/LICENSE +++ b/LICENSE @@ -1,7 +1,6 @@ The MIT License (MIT) -Copyright (c) 2021 Ray Luo -Copyright (c) 2018-2021 Microsoft Corporation. +Copyright (c) Microsoft Corporation. All rights reserved. This code is licensed under the MIT License. @@ -22,4 +21,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. +THE SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 00000000..9088b60a --- /dev/null +++ b/README.md @@ -0,0 +1,147 @@ +# Microsoft Authentication Library (MSAL) for Python + +| `dev` branch | Reference Docs | # of Downloads per different platforms | # of Downloads per recent MSAL versions | +|---------------|---------------|----------------------------------------|-----------------------------------------| + [![Build status](https://github.com/AzureAD/microsoft-authentication-library-for-python/actions/workflows/python-package.yml/badge.svg?branch=dev)](https://github.com/AzureAD/microsoft-authentication-library-for-python/actions) | [![Documentation Status](https://readthedocs.org/projects/msal-python/badge/?version=latest)](https://msal-python.readthedocs.io/en/latest/?badge=latest) | [![Downloads](https://pepy.tech/badge/msal)](https://pypistats.org/packages/msal) | [![Download monthly](https://pepy.tech/badge/msal/month)](https://pepy.tech/project/msal) + +The Microsoft Authentication Library for Python enables applications to integrate with the [Microsoft identity platform](https://aka.ms/aaddevv2). It allows you to sign in users or apps with Microsoft identities ([Azure AD](https://azure.microsoft.com/services/active-directory/), [Microsoft Accounts](https://account.microsoft.com) and [Azure AD B2C](https://azure.microsoft.com/services/active-directory-b2c/) accounts) and obtain tokens to call Microsoft APIs such as [Microsoft Graph](https://graph.microsoft.io/) or your own APIs registered with the Microsoft identity platform. It is built using industry standard OAuth2 and OpenID Connect protocols + +Not sure whether this is the SDK you are looking for your app? There are other Microsoft Identity SDKs +[here](https://github.com/AzureAD/microsoft-authentication-library-for-python/wiki/Microsoft-Authentication-Client-Libraries). + +Quick links: + +| [Getting Started](https://docs.microsoft.com/azure/active-directory/develop/quickstart-v2-python-webapp) | [Docs](https://github.com/AzureAD/microsoft-authentication-library-for-python/wiki) | [Samples](https://aka.ms/aaddevsamplesv2) | [Support](README.md#community-help-and-support) | [Feedback](https://forms.office.com/r/TMjZkDbzjY) | +| --- | --- | --- | --- | --- | + +## Scenarios supported + +Click on the following thumbnail to visit a large map with clickable links to proper samples. + +[![Map effect won't work inside github's markdown file, so we have to use a thumbnail here to lure audience to a real static website](https://raw.githubusercontent.com/AzureAD/microsoft-authentication-library-for-python/dev/docs/thumbnail.png)](https://msal-python.readthedocs.io/en/latest/) + +## Installation + +You can find MSAL Python on [Pypi](https://pypi.org/project/msal/). +1. If you haven't already, [install and/or upgrade the pip](https://pip.pypa.io/en/stable/installing/) + of your Python environment to a recent version. We tested with pip 18.1. +2. As usual, just run `pip install msal`. + +## Versions + +This library follows [Semantic Versioning](http://semver.org/). + +You can find the changes for each version under +[Releases](https://github.com/AzureAD/microsoft-authentication-library-for-python/releases). + +## Usage + +Before using MSAL Python (or any MSAL SDKs, for that matter), you will have to +[register your application with the Microsoft identity platform](https://docs.microsoft.com/azure/active-directory/develop/quickstart-v2-register-an-app). + +Acquiring tokens with MSAL Python follows this 3-step pattern. +(Note: That is the high level conceptual pattern. +There will be some variations for different flows. They are demonstrated in +[runnable samples hosted right in this repo](https://github.com/AzureAD/microsoft-authentication-library-for-python/tree/dev/sample). +) + + +1. MSAL proposes a clean separation between + [public client applications, and confidential client applications](https://tools.ietf.org/html/rfc6749#section-2.1). + So you will first create either a `PublicClientApplication` or a `ConfidentialClientApplication` instance, + and ideally reuse it during the lifecycle of your app. The following example shows a `PublicClientApplication`: + + ```python + from msal import PublicClientApplication + app = PublicClientApplication( + "your_client_id", + authority="https://login.microsoftonline.com/Enter_the_Tenant_Name_Here") + ``` + + Later, each time you would want an access token, you start by: + ```python + result = None # It is just an initial value. Please follow instructions below. + ``` + +2. The API model in MSAL provides you explicit control on how to utilize token cache. + This cache part is technically optional, but we highly recommend you to harness the power of MSAL cache. + It will automatically handle the token refresh for you. + + ```python + # We now check the cache to see + # whether we already have some accounts that the end user already used to sign in before. + accounts = app.get_accounts() + if accounts: + # If so, you could then somehow display these accounts and let end user choose + print("Pick the account you want to use to proceed:") + for a in accounts: + print(a["username"]) + # Assuming the end user chose this one + chosen = accounts[0] + # Now let's try to find a token in cache for this account + result = app.acquire_token_silent(["your_scope"], account=chosen) + ``` + +3. Either there is no suitable token in the cache, or you chose to skip the previous step, + now it is time to actually send a request to AAD to obtain a token. + There are different methods based on your client type and scenario. Here we demonstrate a placeholder flow. + + ```python + if not result: + # So no suitable token exists in cache. Let's get a new one from AAD. + result = app.acquire_token_by_one_of_the_actual_method(..., scopes=["User.Read"]) + if "access_token" in result: + print(result["access_token"]) # Yay! + else: + print(result.get("error")) + print(result.get("error_description")) + print(result.get("correlation_id")) # You may need this when reporting a bug + ``` + +Refer the [Wiki](https://github.com/AzureAD/microsoft-authentication-library-for-python/wiki) pages for more details on the MSAL Python functionality and usage. + +## Migrating from ADAL + +If your application is using ADAL Python, we recommend you to update to use MSAL Python. No new feature work will be done in ADAL Python. + +See the [ADAL to MSAL migration](https://github.com/AzureAD/microsoft-authentication-library-for-python/wiki/Migrate-to-MSAL-Python) guide. + +## Roadmap + +You can follow the latest updates and plans for MSAL Python in the [Roadmap](https://github.com/AzureAD/microsoft-authentication-library-for-python/wiki/Roadmap) published on our Wiki. + +## Samples and Documentation + +MSAL Python supports multiple [application types and authentication scenarios](https://docs.microsoft.com/azure/active-directory/develop/authentication-flows-app-scenarios). +The generic documents on +[Auth Scenarios](https://docs.microsoft.com/azure/active-directory/develop/authentication-scenarios) +and +[Auth protocols](https://docs.microsoft.com/azure/active-directory/develop/active-directory-v2-protocols) +are recommended reading. + +We provide a [full suite of sample applications](https://aka.ms/aaddevsamplesv2) and [documentation](https://aka.ms/aaddevv2) to help you get started with learning the Microsoft identity platform. + +## Community Help and Support + +We leverage Stack Overflow to work with the community on supporting Azure Active Directory and its SDKs, including this one! +We highly recommend you ask your questions on Stack Overflow (we're all on there!) +Also browser existing issues to see if someone has had your question before. + +We recommend you use the "msal" tag so we can see it! +Here is the latest Q&A on Stack Overflow for MSAL: +[http://stackoverflow.com/questions/tagged/msal](http://stackoverflow.com/questions/tagged/msal) + +## Submit Feedback +We'd like your thoughts on this library. Please complete [this short survey.](https://forms.office.com/r/TMjZkDbzjY) + +## Security Reporting + +If you find a security issue with our libraries or services please report it to [secure@microsoft.com](mailto:secure@microsoft.com) with as much detail as possible. Your submission may be eligible for a bounty through the [Microsoft Bounty](http://aka.ms/bugbounty) program. Please do not post security issues to GitHub Issues or any other public site. We will contact you shortly upon receiving the information. We encourage you to get notifications of when security incidents occur by visiting [this page](https://technet.microsoft.com/security/dd252948) and subscribing to Security Advisory Alerts. + +## Contributing + +All code is licensed under the MIT license and we triage actively on GitHub. We enthusiastically welcome contributions and feedback. Please read the [contributing guide](./contributing.md) before starting. + +## We Value and Adhere to the Microsoft Open Source Code of Conduct + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. diff --git a/contributing.md b/contributing.md new file mode 100644 index 00000000..e78c1ce1 --- /dev/null +++ b/contributing.md @@ -0,0 +1,122 @@ +# CONTRIBUTING + +Azure Active Directory SDK projects welcomes new contributors. This document will guide you +through the process. + +### CONTRIBUTOR LICENSE AGREEMENT + +Please visit [https://cla.microsoft.com/](https://cla.microsoft.com/) and sign the Contributor License +Agreement. You only need to do that once. We can not look at your code until you've submitted this request. + + +### FORK + +Fork this project on GitHub and check out your copy. + +Example for Project Foo (which can be any ADAL or MSAL or just any library): + +``` +$ git clone git@github.com:username/project-foo.git +$ cd project-foo +$ git remote add upstream git@github.com:AzureAD/project-foo.git +``` + +No need to decide if you want your feature or bug fix to go into the dev branch +or the master branch. **All bug fixes and new features should go into the dev branch.** + +The master branch is effectively frozen; patches that change the SDKs +protocols or API surface area or affect the run-time behavior of the SDK will be rejected. + +Some of our SDKs have bundled dependencies that are not part of the project proper. +Any changes to files in those directories or its subdirectories should be sent to their respective projects. +Do not send your patch to us, we cannot accept it. + +In case of doubt, open an issue in the [issue tracker](issues). + +Especially do so if you plan to work on a major change in functionality. Nothing is more +frustrating than seeing your hard work go to waste because your vision +does not align with our goals for the SDK. + + +### BRANCH + +Okay, so you have decided on the proper branch. Create a feature branch +and start hacking: + +``` +$ git checkout -b my-feature-branch +``` + +### COMMIT + +Make sure git knows your name and email address: + +``` +$ git config --global user.name "J. Random User" +$ git config --global user.email "j.random.user@example.com" +``` + +Writing good commit logs is important. A commit log should describe what +changed and why. Follow these guidelines when writing one: + +1. The first line should be 50 characters or less and contain a short + description of the change prefixed with the name of the changed + subsystem (e.g. "net: add localAddress and localPort to Socket"). +2. Keep the second line blank. +3. Wrap all other lines at 72 columns. + +A good commit log looks like this: + +``` +fix: explaining the commit in one line + +Body of commit message is a few lines of text, explaining things +in more detail, possibly giving some background about the issue +being fixed, etc etc. + +The body of the commit message can be several paragraphs, and +please do proper word-wrap and keep columns shorter than about +72 characters or so. That way `git log` will show things +nicely even when it is indented. +``` + +The header line should be meaningful; it is what other people see when they +run `git shortlog` or `git log --oneline`. + +Check the output of `git log --oneline files_that_you_changed` to find out +what directories your changes touch. + + +### REBASE + +Use `git rebase` (not `git merge`) to sync your work from time to time. + +``` +$ git fetch upstream +$ git rebase upstream/v0.1 # or upstream/master +``` + + +### TEST + +Bug fixes and features should come with tests. Add your tests in the +test directory. This varies by repository but often follows the same convention of /src/test. Look at other tests to see how they should be +structured (license boilerplate, common includes, etc.). + + +Make sure that all tests pass. + + +### PUSH + +``` +$ git push origin my-feature-branch +``` + +Go to https://github.com/username/microsoft-authentication-library-for-***.git and select your feature branch. Click +the 'Pull Request' button and fill out the form. + +Pull requests are usually reviewed within a few days. If there are comments +to address, apply your changes in a separate commit and push that to your +feature branch. Post a comment in the pull request afterwards; GitHub does +not send out notifications when you add commits. diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 00000000..298ea9e2 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,19 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 00000000..810dfc02 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,185 @@ +# -*- coding: utf-8 -*- +# +# Configuration file for the Sphinx documentation builder. +# +# This file does only contain a selection of the most common options. For a +# full list see the documentation: +# http://www.sphinx-doc.org/en/master/config + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +from datetime import date +import os +import sys +sys.path.insert(0, os.path.abspath('..')) + + +# -- Project information ----------------------------------------------------- + +project = u'MSAL Python' +copyright = u'{0}, Microsoft'.format(date.today().year) +author = u'Microsoft' + +# The short X.Y version +from msal import __version__ as version +# The full version, including alpha/beta/rc tags +release = version + + +# -- General configuration --------------------------------------------------- + +# If your documentation needs a minimal Sphinx version, state it here. +# +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.githubpages', +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +# source_suffix = ['.rst', '.md'] +source_suffix = '.rst' + +# The master toctree document. +master_doc = 'index' + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = [u'_build', 'Thumbs.db', '.DS_Store'] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = None + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +# html_theme = 'alabaster' +html_theme = 'furo' + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +html_theme_options = { + "light_css_variables": { + "font-stack": "'Segoe UI', SegoeUI, 'Helvetica Neue', Helvetica, Arial, sans-serif", + "font-stack--monospace": "SFMono-Regular, Consolas, 'Liberation Mono', Menlo, Courier, monospace", + }, +} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +# Custom sidebar templates, must be a dictionary that maps document names +# to template names. +# +# The default sidebars (for documents that don't match any pattern) are +# defined by theme itself. Builtin themes are using these templates by +# default: ``['localtoc.html', 'relations.html', 'sourcelink.html', +# 'searchbox.html']``. +# +# html_sidebars = {} + + +# -- Options for HTMLHelp output --------------------------------------------- + +# Output file base name for HTML help builder. +htmlhelp_basename = 'MSALPythondoc' + + +# -- Options for LaTeX output ------------------------------------------------ + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + (master_doc, 'MSALPython.tex', u'MSAL Python Documentation', + u'Microsoft', 'manual'), +] + + +# -- Options for manual page output ------------------------------------------ + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + (master_doc, 'msalpython', u'MSAL Python Documentation', + [author], 1) +] + + +# -- Options for Texinfo output ---------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + (master_doc, 'MSALPython', u'MSAL Python Documentation', + author, 'MSALPython', 'One line description of project.', + 'Miscellaneous'), +] + + +# -- Options for Epub output ------------------------------------------------- + +# Bibliographic Dublin Core info. +epub_title = project + +# The unique identifier of the text. This can be a ISBN number +# or the project homepage. +# +# epub_identifier = '' + +# A unique identification for the text. +# +# epub_uid = '' + +# A list of files that should not be packed into the epub file. +epub_exclude_files = ['search.html'] + + +# -- Extension configuration ------------------------------------------------- \ No newline at end of file diff --git a/docs/daemon-app.svg b/docs/daemon-app.svg new file mode 100644 index 00000000..8f1af659 --- /dev/null +++ b/docs/daemon-app.svg @@ -0,0 +1,1074 @@ + + + +image/svg+xml + + + + + + + + + + + + + + + + + + Page-1 + + + + + Web app (Was Websites).1277 + Daemon Web app + + Sheet.1002 + + Sheet.1003 + + + + + Sheet.1004 + + Sheet.1005 + + Sheet.1006 + + Sheet.1007 + + + + + Sheet.1008 + + + + Sheet.1009 + + Sheet.1010 + + + + + Sheet.1011 + + Sheet.1012 + + + + + Sheet.1013 + + Sheet.1014 + + + + + Sheet.1015 + + Sheet.1016 + + + + + Sheet.1017 + + Sheet.1018 + + + + + Sheet.1019 + + Sheet.1020 + + + + + Sheet.1021 + + + + Sheet.1022 + + + + + Sheet.1023 + + Sheet.1024 + + + + + Sheet.1025 + + Sheet.1026 + + + + + Sheet.1027 + + Sheet.1028 + + + + + + + + DaemonWeb app + + + + API App.1305 + Daemon API App + + + + DaemonAPI App + + + Microsoft Enterprise desktop virtualization.1317 + Daemon Desktop App + + Sheet.1031 + + + + Sheet.1032 + + + + Sheet.1033 + + + + Sheet.1034 + + + + Sheet.1035 + + + + Sheet.1036 + + + + Sheet.1037 + + + + Sheet.1038 + + + + + + DaemonDesktop App + + + + Certificate.1337 + Secret + + Sheet.1040 + + + + Sheet.1041 + + Sheet.1042 + + Sheet.1043 + + + + + Sheet.1044 + + Sheet.1045 + + + + + Sheet.1046 + + Sheet.1047 + + + + + + + + Secret + + + + Arrow (Azure Poster Style).1346 + + + + Arrow (Azure Poster Style).1348 + + + + API App.1350 + Daemon Web API + + + + Daemon Web API + + + Certificate.1385 + Secret + + Sheet.1052 + + + + Sheet.1053 + + Sheet.1054 + + Sheet.1055 + + + + + Sheet.1056 + + Sheet.1057 + + + + + Sheet.1058 + + Sheet.1059 + + + + + + + + Secret + + + + Certificate.1416 + Secret + + Sheet.1061 + + + + Sheet.1062 + + Sheet.1063 + + Sheet.1064 + + + + + Sheet.1065 + + Sheet.1066 + + + + + Sheet.1067 + + Sheet.1068 + + + + + + + + Secret + + + + Arrow (Azure Poster Style).1507 + + + + Sheet.1215 + Client Credentials flow + + + + Client Credentials flow + + + \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 00000000..95b89b98 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,113 @@ +MSAL Python documentation +========================= + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + :hidden: + + MSAL Documentation + GitHub Repository + +You can find high level conceptual documentations in the project +`README `_. + +Scenarios +========= + +There are many `different application scenarios `_. +MSAL Python supports some of them. +**The following diagram serves as a map. Locate your application scenario on the map.** +**If the corresponding icon is clickable, it will bring you to an MSAL Python sample for that scenario.** + +* Most authentication scenarios acquire tokens on behalf of signed-in users. + + .. raw:: html + + + + + + Web app + Web app + Desktop App + + Browserless app + + +* There are also daemon apps. In these scenarios, applications acquire tokens on behalf of themselves with no user. + + .. raw:: html + + + + + + Daemon App acquires token for themselves + + +* There are other less common samples, such for ADAL-to-MSAL migration, + `available inside the project code base + `_. + + +API +=== + +The following section is the API Reference of MSAL Python. + +.. note:: + + Only APIs and their parameters documented in this section are part of public API, + with guaranteed backward compatibility for the entire 1.x series. + + Other modules in the source code are all considered as internal helpers, + which could change at anytime in the future, without prior notice. + +MSAL proposes a clean separation between +`public client applications and confidential client applications +`_. + +They are implemented as two separated classes, +with different methods for different authentication scenarios. + +PublicClientApplication +----------------------- + +.. autoclass:: msal.PublicClientApplication + :members: + :inherited-members: + + .. automethod:: __init__ + +ConfidentialClientApplication +----------------------------- + +.. autoclass:: msal.ConfidentialClientApplication + :members: + :inherited-members: + + .. automethod:: __init__ + +TokenCache +---------- + +One of the parameters accepted by +both `PublicClientApplication` and `ConfidentialClientApplication` +is the `TokenCache`. + +.. autoclass:: msal.TokenCache + :members: + +You can subclass it to add new behavior, such as, token serialization. +See `SerializableTokenCache` for example. + +.. autoclass:: msal.SerializableTokenCache + :members: diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 00000000..27f573b8 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% + +:end +popd diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 00000000..d5de57fe --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,2 @@ +furo +-r ../requirements.txt \ No newline at end of file diff --git a/docs/scenarios-with-users.svg b/docs/scenarios-with-users.svg new file mode 100644 index 00000000..fffdec47 --- /dev/null +++ b/docs/scenarios-with-users.svg @@ -0,0 +1,2789 @@ + + + +image/svg+xml + + + + + + + + + + + + + + + + + + Page-1 + + + + + Web app (Was Websites).1073 + Single Page Application + + Sheet.1074 + + Sheet.1075 + + + + + Sheet.1076 + + Sheet.1077 + + Sheet.1078 + + Sheet.1079 + + + + + Sheet.1080 + + + + Sheet.1081 + + Sheet.1082 + + + + + Sheet.1083 + + Sheet.1084 + + + + + Sheet.1085 + + Sheet.1086 + + + + + Sheet.1087 + + Sheet.1088 + + + + + Sheet.1089 + + Sheet.1090 + + + + + Sheet.1091 + + Sheet.1092 + + + + + Sheet.1093 + + + + Sheet.1094 + + + + + Sheet.1095 + + Sheet.1096 + + + + + Sheet.1097 + + Sheet.1098 + + + + + Sheet.1099 + + Sheet.1100 + + + + + + + + Single Page Application + + + + Web app (Was Websites).1101 + Web app + + Sheet.1102 + + Sheet.1103 + + + + + Sheet.1104 + + Sheet.1105 + + Sheet.1106 + + Sheet.1107 + + + + + Sheet.1108 + + + + Sheet.1109 + + Sheet.1110 + + + + + Sheet.1111 + + Sheet.1112 + + + + + Sheet.1113 + + Sheet.1114 + + + + + Sheet.1115 + + Sheet.1116 + + + + + Sheet.1117 + + Sheet.1118 + + + + + Sheet.1119 + + Sheet.1120 + + + + + Sheet.1121 + + + + Sheet.1122 + + + + + Sheet.1123 + + Sheet.1124 + + + + + Sheet.1125 + + Sheet.1126 + + + + + Sheet.1127 + + Sheet.1128 + + + + + + + + Web app + + + + API App.1129 + API App + + + + API App + + + IoT Hub.1130 + Browserless app + + Sheet.1131 + + Sheet.1132 + + Sheet.1133 + + Sheet.1134 + + + Sheet.1135 + + + + + Sheet.1136 + + + + + + + Browserlessapp + + + + Mobile App (Was Mobile Services).1137 + Mobile App + + Sheet.1138 + + + + Sheet.1139 + + + + + + MobileApp + + + + Arrow (Azure Poster Style).1140 + + + + Microsoft Enterprise desktop virtualization.1141 + Desktop App + + Sheet.1142 + + + + Sheet.1143 + + + + Sheet.1144 + + + + Sheet.1145 + + + + Sheet.1146 + + + + Sheet.1147 + + + + Sheet.1148 + + + + Sheet.1149 + + + + + + Desktop App + + + + User Permissions.1150 + + Sheet.1151 + + Sheet.1152 + + Sheet.1153 + + Sheet.1154 + + + + + Sheet.1155 + + Sheet.1156 + + + + + + + Sheet.1157 + + Sheet.1158 + + + + Sheet.1159 + + + + Sheet.1160 + + + + + + Arrow (Azure Poster Style).1170 + + + + Arrow (Azure Poster Style).1171 + + + + Arrow (Azure Poster Style).1172 + + + + Arrow (Azure Poster Style).1173 + + + + API App.1174 + API App + + + + API App + + + Arrow (Azure Poster Style).1175 + + + + User Permissions.1176 + + Sheet.1177 + + Sheet.1178 + + Sheet.1179 + + Sheet.1180 + + + + + Sheet.1181 + + Sheet.1182 + + + + + + + Sheet.1183 + + Sheet.1184 + + + + Sheet.1185 + + + + Sheet.1186 + + + + + + User Permissions.1187 + + Sheet.1188 + + Sheet.1189 + + Sheet.1190 + + Sheet.1191 + + + + + Sheet.1192 + + Sheet.1193 + + + + + + + Sheet.1194 + + Sheet.1195 + + + + Sheet.1196 + + + + Sheet.1197 + + + + + + User Permissions.1198 + + Sheet.1199 + + Sheet.1200 + + Sheet.1201 + + Sheet.1202 + + + + + Sheet.1203 + + Sheet.1204 + + + + + + + Sheet.1205 + + Sheet.1206 + + + + Sheet.1207 + + + + Sheet.1208 + + + + + + User Permissions.1218 + + Sheet.1219 + + Sheet.1220 + + Sheet.1221 + + Sheet.1222 + + + + + Sheet.1223 + + Sheet.1224 + + + + + + + Sheet.1225 + + Sheet.1226 + + + + Sheet.1227 + + + + Sheet.1228 + + + + + + Web app (Was Websites).1425 + Single Page Application + + Sheet.1426 + + Sheet.1427 + + + + + Sheet.1428 + + Sheet.1429 + + Sheet.1430 + + Sheet.1431 + + + + + Sheet.1432 + + + + Sheet.1433 + + Sheet.1434 + + + + + Sheet.1435 + + Sheet.1436 + + + + + Sheet.1437 + + Sheet.1438 + + + + + Sheet.1439 + + Sheet.1440 + + + + + Sheet.1441 + + Sheet.1442 + + + + + Sheet.1443 + + Sheet.1444 + + + + + Sheet.1445 + + + + Sheet.1446 + + + + + Sheet.1447 + + Sheet.1448 + + + + + Sheet.1449 + + Sheet.1450 + + + + + Sheet.1451 + + Sheet.1452 + + + + + + + + Single Page Application + + + + User Permissions.1453 + + Sheet.1454 + + Sheet.1455 + + Sheet.1456 + + Sheet.1457 + + + + + Sheet.1458 + + Sheet.1459 + + + + + + + Sheet.1460 + + Sheet.1461 + + + + Sheet.1462 + + + + Sheet.1463 + + + + + + Web app (Was Websites).1464 + Web app + + Sheet.1465 + + Sheet.1466 + + + + + Sheet.1467 + + Sheet.1468 + + Sheet.1469 + + Sheet.1470 + + + + + Sheet.1471 + + + + Sheet.1472 + + Sheet.1473 + + + + + Sheet.1474 + + Sheet.1475 + + + + + Sheet.1476 + + Sheet.1477 + + + + + Sheet.1478 + + Sheet.1479 + + + + + Sheet.1480 + + Sheet.1481 + + + + + Sheet.1482 + + Sheet.1483 + + + + + Sheet.1484 + + + + Sheet.1485 + + + + + Sheet.1486 + + Sheet.1487 + + + + + Sheet.1488 + + Sheet.1489 + + + + + Sheet.1490 + + Sheet.1491 + + + + + + + + Web app + + + + User Permissions.1492 + + Sheet.1493 + + Sheet.1494 + + Sheet.1495 + + Sheet.1496 + + + + + Sheet.1497 + + Sheet.1498 + + + + + + + Sheet.1499 + + Sheet.1500 + + + + Sheet.1501 + + + + Sheet.1502 + + + + + + Pentagon.1683 + + + + + + + User Permissions.1684 + + Sheet.1685 + + Sheet.1686 + + Sheet.1687 + + Sheet.1688 + + + + + Sheet.1689 + + Sheet.1690 + + + + + + + Sheet.1691 + + Sheet.1692 + + + + Sheet.1693 + + + + Sheet.1694 + + + + + + \ No newline at end of file diff --git a/docs/thumbnail.png b/docs/thumbnail.png new file mode 100644 index 00000000..e1606e91 Binary files /dev/null and b/docs/thumbnail.png differ diff --git a/msal/__init__.py b/msal/__init__.py new file mode 100644 index 00000000..4e2faaed --- /dev/null +++ b/msal/__init__.py @@ -0,0 +1,36 @@ +#------------------------------------------------------------------------------ +# +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files(the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions : +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +#------------------------------------------------------------------------------ + +from .application import ( + __version__, + ClientApplication, + ConfidentialClientApplication, + PublicClientApplication, + ) +from .oauth2cli.oidc import Prompt +from .token_cache import TokenCache, SerializableTokenCache + diff --git a/msal/application.py b/msal/application.py new file mode 100644 index 00000000..829c35b0 --- /dev/null +++ b/msal/application.py @@ -0,0 +1,1773 @@ +import functools +import json +import time +try: # Python 2 + from urlparse import urljoin +except: # Python 3 + from urllib.parse import urljoin +import logging +import sys +import warnings +from threading import Lock +import os + +from .oauth2cli import Client, JwtAssertionCreator +from .oauth2cli.oidc import decode_part +from .authority import Authority +from .mex import send_request as mex_send_request +from .wstrust_request import send_request as wst_send_request +from .wstrust_response import * +from .token_cache import TokenCache +import msal.telemetry +from .region import _detect_region +from .throttled_http_client import ThrottledHttpClient +from .cloudshell import _is_running_in_cloud_shell + + +# The __init__.py will import this. Not the other way around. +__version__ = "1.18.0" # When releasing, also check and bump our dependencies's versions if needed + +logger = logging.getLogger(__name__) +_AUTHORITY_TYPE_CLOUDSHELL = "CLOUDSHELL" + +def extract_certs(public_cert_content): + # Parses raw public certificate file contents and returns a list of strings + # Usage: headers = {"x5c": extract_certs(open("my_cert.pem").read())} + public_certificates = re.findall( + r'-----BEGIN CERTIFICATE-----(?P[^-]+)-----END CERTIFICATE-----', + public_cert_content, re.I) + if public_certificates: + return [cert.strip() for cert in public_certificates] + # The public cert tags are not found in the input, + # let's make best effort to exclude a private key pem file. + if "PRIVATE KEY" in public_cert_content: + raise ValueError( + "We expect your public key but detect a private key instead") + return [public_cert_content.strip()] + + +def _merge_claims_challenge_and_capabilities(capabilities, claims_challenge): + # Represent capabilities as {"access_token": {"xms_cc": {"values": capabilities}}} + # and then merge/add it into incoming claims + if not capabilities: + return claims_challenge + claims_dict = json.loads(claims_challenge) if claims_challenge else {} + for key in ["access_token"]: # We could add "id_token" if we'd decide to + claims_dict.setdefault(key, {}).update(xms_cc={"values": capabilities}) + return json.dumps(claims_dict) + + +def _str2bytes(raw): + # A conversion based on duck-typing rather than six.text_type + try: + return raw.encode(encoding="utf-8") + except: + return raw + + +def _clean_up(result): + if isinstance(result, dict): + result.pop("refresh_in", None) # MSAL handled refresh_in, customers need not + return result + + +def _preferred_browser(): + """Register Edge and return a name suitable for subsequent webbrowser.get(...) + when appropriate. Otherwise return None. + """ + # On Linux, only Edge will provide device-based Conditional Access support + if sys.platform != "linux": # On other platforms, we have no browser preference + return None + browser_path = "/usr/bin/microsoft-edge" # Use a full path owned by sys admin + # Note: /usr/bin/microsoft-edge, /usr/bin/microsoft-edge-stable, etc. + # are symlinks that point to the actual binaries which are found under + # /opt/microsoft/msedge/msedge or /opt/microsoft/msedge-beta/msedge. + # Either method can be used to detect an Edge installation. + user_has_no_preference = "BROWSER" not in os.environ + user_wont_mind_edge = "microsoft-edge" in os.environ.get("BROWSER", "") # Note: + # BROWSER could contain "microsoft-edge" or "/path/to/microsoft-edge". + # Python documentation (https://docs.python.org/3/library/webbrowser.html) + # does not document the name being implicitly register, + # so there is no public API to know whether the ENV VAR browser would work. + # Therefore, we would not bother examine the env var browser's type. + # We would just register our own Edge instance. + if (user_has_no_preference or user_wont_mind_edge) and os.path.exists(browser_path): + try: + import webbrowser # Lazy import. Some distro may not have this. + browser_name = "msal-edge" # Avoid popular name "microsoft-edge" + # otherwise `BROWSER="microsoft-edge"; webbrowser.get("microsoft-edge")` + # would return a GenericBrowser instance which won't work. + try: + registration_available = isinstance( + webbrowser.get(browser_name), webbrowser.BackgroundBrowser) + except webbrowser.Error: + registration_available = False + if not registration_available: + logger.debug("Register %s with %s", browser_name, browser_path) + # By registering our own browser instance with our own name, + # rather than populating a process-wide BROWSER enn var, + # this approach does not have side effect on non-MSAL code path. + webbrowser.register( # Even double-register happens to work fine + browser_name, None, webbrowser.BackgroundBrowser(browser_path)) + return browser_name + except ImportError: + pass # We may still proceed + return None + + +class _ClientWithCcsRoutingInfo(Client): + + def initiate_auth_code_flow(self, **kwargs): + if kwargs.get("login_hint"): # eSTS could have utilized this as-is, but nope + kwargs["X-AnchorMailbox"] = "UPN:%s" % kwargs["login_hint"] + return super(_ClientWithCcsRoutingInfo, self).initiate_auth_code_flow( + client_info=1, # To be used as CSS Routing info + **kwargs) + + def obtain_token_by_auth_code_flow( + self, auth_code_flow, auth_response, **kwargs): + # Note: the obtain_token_by_browser() is also covered by this + assert isinstance(auth_code_flow, dict) and isinstance(auth_response, dict) + headers = kwargs.pop("headers", {}) + client_info = json.loads( + decode_part(auth_response["client_info"]) + ) if auth_response.get("client_info") else {} + if "uid" in client_info and "utid" in client_info: + # Note: The value of X-AnchorMailbox is also case-insensitive + headers["X-AnchorMailbox"] = "Oid:{uid}@{utid}".format(**client_info) + return super(_ClientWithCcsRoutingInfo, self).obtain_token_by_auth_code_flow( + auth_code_flow, auth_response, headers=headers, **kwargs) + + def obtain_token_by_username_password(self, username, password, **kwargs): + headers = kwargs.pop("headers", {}) + headers["X-AnchorMailbox"] = "upn:{}".format(username) + return super(_ClientWithCcsRoutingInfo, self).obtain_token_by_username_password( + username, password, headers=headers, **kwargs) + + +class ClientApplication(object): + + ACQUIRE_TOKEN_SILENT_ID = "84" + ACQUIRE_TOKEN_BY_REFRESH_TOKEN = "85" + ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID = "301" + ACQUIRE_TOKEN_ON_BEHALF_OF_ID = "523" + ACQUIRE_TOKEN_BY_DEVICE_FLOW_ID = "622" + ACQUIRE_TOKEN_FOR_CLIENT_ID = "730" + ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID = "832" + ACQUIRE_TOKEN_INTERACTIVE = "169" + GET_ACCOUNTS_ID = "902" + REMOVE_ACCOUNT_ID = "903" + + ATTEMPT_REGION_DISCOVERY = True # "TryAutoDetect" + + def __init__( + self, client_id, + client_credential=None, authority=None, validate_authority=True, + token_cache=None, + http_client=None, + verify=True, proxies=None, timeout=None, + client_claims=None, app_name=None, app_version=None, + client_capabilities=None, + azure_region=None, # Note: We choose to add this param in this base class, + # despite it is currently only needed by ConfidentialClientApplication. + # This way, it holds the same positional param place for PCA, + # when we would eventually want to add this feature to PCA in future. + exclude_scopes=None, + http_cache=None, + ): + """Create an instance of application. + + :param str client_id: Your app has a client_id after you register it on AAD. + + :param Union[str, dict] client_credential: + For :class:`PublicClientApplication`, you simply use `None` here. + For :class:`ConfidentialClientApplication`, + it can be a string containing client secret, + or an X509 certificate container in this form:: + + { + "private_key": "...-----BEGIN PRIVATE KEY-----...", + "thumbprint": "A1B2C3D4E5F6...", + "public_certificate": "...-----BEGIN CERTIFICATE-----... (Optional. See below.)", + "passphrase": "Passphrase if the private_key is encrypted (Optional. Added in version 1.6.0)", + } + + *Added in version 0.5.0*: + public_certificate (optional) is public key certificate + which will be sent through 'x5c' JWT header only for + subject name and issuer authentication to support cert auto rolls. + + Per `specs `_, + "the certificate containing + the public key corresponding to the key used to digitally sign the + JWS MUST be the first certificate. This MAY be followed by + additional certificates, with each subsequent certificate being the + one used to certify the previous one." + However, your certificate's issuer may use a different order. + So, if your attempt ends up with an error AADSTS700027 - + "The provided signature value did not match the expected signature value", + you may try use only the leaf cert (in PEM/str format) instead. + + *Added in version 1.13.0*: + It can also be a completely pre-signed assertion that you've assembled yourself. + Simply pass a container containing only the key "client_assertion", like this:: + + { + "client_assertion": "...a JWT with claims aud, exp, iss, jti, nbf, and sub..." + } + + :param dict client_claims: + *Added in version 0.5.0*: + It is a dictionary of extra claims that would be signed by + by this :class:`ConfidentialClientApplication` 's private key. + For example, you can use {"client_ip": "x.x.x.x"}. + You may also override any of the following default claims:: + + { + "aud": the_token_endpoint, + "iss": self.client_id, + "sub": same_as_issuer, + "exp": now + 10_min, + "iat": now, + "jti": a_random_uuid + } + + :param str authority: + A URL that identifies a token authority. It should be of the format + ``https://login.microsoftonline.com/your_tenant`` + By default, we will use ``https://login.microsoftonline.com/common`` + + *Changed in version 1.17*: you can also use predefined constant + and a builder like this:: + + from msal.authority import ( + AuthorityBuilder, + AZURE_US_GOVERNMENT, AZURE_CHINA, AZURE_PUBLIC) + my_authority = AuthorityBuilder(AZURE_PUBLIC, "contoso.onmicrosoft.com") + # Now you get an equivalent of + # "https://login.microsoftonline.com/contoso.onmicrosoft.com" + + # You can feed such an authority to msal's ClientApplication + from msal import PublicClientApplication + app = PublicClientApplication("my_client_id", authority=my_authority, ...) + + :param bool validate_authority: (optional) Turns authority validation + on or off. This parameter default to true. + :param TokenCache cache: + Sets the token cache used by this ClientApplication instance. + By default, an in-memory cache will be created and used. + :param http_client: (optional) + Your implementation of abstract class HttpClient + Defaults to a requests session instance. + Since MSAL 1.11.0, the default session would be configured + to attempt one retry on connection error. + If you are providing your own http_client, + it will be your http_client's duty to decide whether to perform retry. + + :param verify: (optional) + It will be passed to the + `verify parameter in the underlying requests library + `_ + This does not apply if you have chosen to pass your own Http client + :param proxies: (optional) + It will be passed to the + `proxies parameter in the underlying requests library + `_ + This does not apply if you have chosen to pass your own Http client + :param timeout: (optional) + It will be passed to the + `timeout parameter in the underlying requests library + `_ + This does not apply if you have chosen to pass your own Http client + :param app_name: (optional) + You can provide your application name for Microsoft telemetry purposes. + Default value is None, means it will not be passed to Microsoft. + :param app_version: (optional) + You can provide your application version for Microsoft telemetry purposes. + Default value is None, means it will not be passed to Microsoft. + :param list[str] client_capabilities: (optional) + Allows configuration of one or more client capabilities, e.g. ["CP1"]. + + Client capability is meant to inform the Microsoft identity platform + (STS) what this client is capable for, + so STS can decide to turn on certain features. + For example, if client is capable to handle *claims challenge*, + STS can then issue CAE access tokens to resources + knowing when the resource emits *claims challenge* + the client will be capable to handle. + + Implementation details: + Client capability is implemented using "claims" parameter on the wire, + for now. + MSAL will combine them into + `claims parameter `_. + + 4. An app which already onboard to the region's allow-list. + + This parameter defaults to None, which means region behavior remains off. + + App developer can opt in to a regional endpoint, + by provide its region name, such as "westus", "eastus2". + You can find a full list of regions by running + ``az account list-locations -o table``, or referencing to + `this doc `_. + + An app running inside Azure Functions and Azure VM can use a special keyword + ``ClientApplication.ATTEMPT_REGION_DISCOVERY`` to auto-detect region. + + .. note:: + + Setting ``azure_region`` to non-``None`` for an app running + outside of Azure Function/VM could hang indefinitely. + + You should consider opting in/out region behavior on-demand, + by loading ``azure_region=None`` or ``azure_region="westus"`` + or ``azure_region=True`` (which means opt-in and auto-detect) + from your per-deployment configuration, and then do + ``app = ConfidentialClientApplication(..., azure_region=azure_region)``. + + Alternatively, you can configure a short timeout, + or provide a custom http_client which has a short timeout. + That way, the latency would be under your control, + but still less performant than opting out of region feature. + + New in version 1.12.0. + + :param list[str] exclude_scopes: (optional) + Historically MSAL hardcodes `offline_access` scope, + which would allow your app to have prolonged access to user's data. + If that is unnecessary or undesirable for your app, + now you can use this parameter to supply an exclusion list of scopes, + such as ``exclude_scopes = ["offline_access"]``. + + :param dict http_cache: + MSAL has long been caching tokens in the ``token_cache``. + Recently, MSAL also introduced a concept of ``http_cache``, + by automatically caching some finite amount of non-token http responses, + so that *long-lived* + ``PublicClientApplication`` and ``ConfidentialClientApplication`` + would be more performant and responsive in some situations. + + This ``http_cache`` parameter accepts any dict-like object. + If not provided, MSAL will use an in-memory dict. + + If your app is a command-line app (CLI), + you would want to persist your http_cache across different CLI runs. + The following recipe shows a way to do so:: + + # Just add the following lines at the beginning of your CLI script + import sys, atexit, pickle + http_cache_filename = sys.argv[0] + ".http_cache" + try: + with open(http_cache_filename, "rb") as f: + persisted_http_cache = pickle.load(f) # Take a snapshot + except ( + FileNotFoundError, # Or IOError in Python 2 + pickle.UnpicklingError, # A corrupted http cache file + ): + persisted_http_cache = {} # Recover by starting afresh + atexit.register(lambda: pickle.dump( + # When exit, flush it back to the file. + # It may occasionally overwrite another process's concurrent write, + # but that is fine. Subsequent runs will reach eventual consistency. + persisted_http_cache, open(http_cache_file, "wb"))) + + # And then you can implement your app as you normally would + app = msal.PublicClientApplication( + "your_client_id", + ..., + http_cache=persisted_http_cache, # Utilize persisted_http_cache + ..., + #token_cache=..., # You may combine the old token_cache trick + # Please refer to token_cache recipe at + # https://msal-python.readthedocs.io/en/latest/#msal.SerializableTokenCache + ) + app.acquire_token_interactive(["your", "scope"], ...) + + Content inside ``http_cache`` are cheap to obtain. + There is no need to share them among different apps. + + Content inside ``http_cache`` will contain no tokens nor + Personally Identifiable Information (PII). Encryption is unnecessary. + + New in version 1.16.0. + """ + self.client_id = client_id + self.client_credential = client_credential + self.client_claims = client_claims + self._client_capabilities = client_capabilities + + if exclude_scopes and not isinstance(exclude_scopes, list): + raise ValueError( + "Invalid exclude_scopes={}. It need to be a list of strings.".format( + repr(exclude_scopes))) + self._exclude_scopes = frozenset(exclude_scopes or []) + if "openid" in self._exclude_scopes: + raise ValueError( + 'Invalid exclude_scopes={}. You can not opt out "openid" scope'.format( + repr(exclude_scopes))) + + if http_client: + self.http_client = http_client + else: + import requests # Lazy load + + self.http_client = requests.Session() + self.http_client.verify = verify + self.http_client.proxies = proxies + # Requests, does not support session - wide timeout + # But you can patch that (https://github.com/psf/requests/issues/3341): + self.http_client.request = functools.partial( + self.http_client.request, timeout=timeout) + + # Enable a minimal retry. Better than nothing. + # https://github.com/psf/requests/blob/v2.25.1/requests/adapters.py#L94-L108 + a = requests.adapters.HTTPAdapter(max_retries=1) + self.http_client.mount("http://", a) + self.http_client.mount("https://", a) + self.http_client = ThrottledHttpClient( + self.http_client, + {} if http_cache is None else http_cache, # Default to an in-memory dict + ) + + self.app_name = app_name + self.app_version = app_version + + # Here the self.authority will not be the same type as authority in input + try: + self.authority = Authority( + authority or "https://login.microsoftonline.com/common/", + self.http_client, validate_authority=validate_authority) + except ValueError: # Those are explicit authority validation errors + raise + except Exception: # The rest are typically connection errors + if validate_authority and azure_region: + # Since caller opts in to use region, here we tolerate connection + # errors happened during authority validation at non-region endpoint + self.authority = Authority( + authority or "https://login.microsoftonline.com/common/", + self.http_client, validate_authority=False) + else: + raise + + self.token_cache = token_cache or TokenCache() + self._region_configured = azure_region + self._region_detected = None + self.client, self._regional_client = self._build_client( + client_credential, self.authority) + self.authority_groups = None + self._telemetry_buffer = {} + self._telemetry_lock = Lock() + + def _decorate_scope( + self, scopes, + reserved_scope=frozenset(['openid', 'profile', 'offline_access'])): + if not isinstance(scopes, (list, set, tuple)): + raise ValueError("The input scopes should be a list, tuple, or set") + scope_set = set(scopes) # Input scopes is typically a list. Copy it to a set. + if scope_set & reserved_scope: + # These scopes are reserved for the API to provide good experience. + # We could make the developer pass these and then if they do they will + # come back asking why they don't see refresh token or user information. + raise ValueError( + "API does not accept {} value as user-provided scopes".format( + reserved_scope)) + if self.client_id in scope_set: + if len(scope_set) > 1: + # We make developers pass their client id, so that they can express + # the intent that they want the token for themselves (their own + # app). + # If we do not restrict them to passing only client id then they + # could write code where they expect an id token but end up getting + # access_token. + raise ValueError("Client Id can only be provided as a single scope") + decorated = set(reserved_scope) # Make a writable copy + else: + decorated = scope_set | reserved_scope + decorated -= self._exclude_scopes + return list(decorated) + + def _build_telemetry_context( + self, api_id, correlation_id=None, refresh_reason=None): + return msal.telemetry._TelemetryContext( + self._telemetry_buffer, self._telemetry_lock, api_id, + correlation_id=correlation_id, refresh_reason=refresh_reason) + + def _get_regional_authority(self, central_authority): + self._region_detected = self._region_detected or _detect_region( + self.http_client if self._region_configured is not None else None) + if (self._region_configured != self.ATTEMPT_REGION_DISCOVERY + and self._region_configured != self._region_detected): + logger.warning('Region configured ({}) != region detected ({})'.format( + repr(self._region_configured), repr(self._region_detected))) + region_to_use = ( + self._region_detected + if self._region_configured == self.ATTEMPT_REGION_DISCOVERY + else self._region_configured) # It will retain the None i.e. opted out + logger.debug('Region to be used: {}'.format(repr(region_to_use))) + if region_to_use: + regional_host = ("{}.r.login.microsoftonline.com".format(region_to_use) + if central_authority.instance in ( + # The list came from https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/358/files#r629400328 + "login.microsoftonline.com", + "login.windows.net", + "sts.windows.net", + ) + else "{}.{}".format(region_to_use, central_authority.instance)) + return Authority( + "https://{}/{}".format(regional_host, central_authority.tenant), + self.http_client, + validate_authority=False) # The central_authority has already been validated + return None + + def _build_client(self, client_credential, authority, skip_regional_client=False): + client_assertion = None + client_assertion_type = None + default_headers = { + "x-client-sku": "MSAL.Python", "x-client-ver": __version__, + "x-client-os": sys.platform, + "x-client-cpu": "x64" if sys.maxsize > 2 ** 32 else "x86", + "x-ms-lib-capability": "retry-after, h429", + } + if self.app_name: + default_headers['x-app-name'] = self.app_name + if self.app_version: + default_headers['x-app-ver'] = self.app_version + default_body = {"client_info": 1} + if isinstance(client_credential, dict): + assert (("private_key" in client_credential + and "thumbprint" in client_credential) or + "client_assertion" in client_credential) + client_assertion_type = Client.CLIENT_ASSERTION_TYPE_JWT + if 'client_assertion' in client_credential: + client_assertion = client_credential['client_assertion'] + else: + headers = {} + if 'public_certificate' in client_credential: + headers["x5c"] = extract_certs(client_credential['public_certificate']) + if not client_credential.get("passphrase"): + unencrypted_private_key = client_credential['private_key'] + else: + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.backends import default_backend + unencrypted_private_key = serialization.load_pem_private_key( + _str2bytes(client_credential["private_key"]), + _str2bytes(client_credential["passphrase"]), + backend=default_backend(), # It was a required param until 2020 + ) + assertion = JwtAssertionCreator( + unencrypted_private_key, algorithm="RS256", + sha1_thumbprint=client_credential.get("thumbprint"), headers=headers) + client_assertion = assertion.create_regenerative_assertion( + audience=authority.token_endpoint, issuer=self.client_id, + additional_claims=self.client_claims or {}) + else: + default_body['client_secret'] = client_credential + central_configuration = { + "authorization_endpoint": authority.authorization_endpoint, + "token_endpoint": authority.token_endpoint, + "device_authorization_endpoint": + authority.device_authorization_endpoint or + urljoin(authority.token_endpoint, "devicecode"), + } + central_client = _ClientWithCcsRoutingInfo( + central_configuration, + self.client_id, + http_client=self.http_client, + default_headers=default_headers, + default_body=default_body, + client_assertion=client_assertion, + client_assertion_type=client_assertion_type, + on_obtaining_tokens=lambda event: self.token_cache.add(dict( + event, environment=authority.instance)), + on_removing_rt=self.token_cache.remove_rt, + on_updating_rt=self.token_cache.update_rt) + + regional_client = None + if (client_credential # Currently regional endpoint only serves some CCA flows + and not skip_regional_client): + regional_authority = self._get_regional_authority(authority) + if regional_authority: + regional_configuration = { + "authorization_endpoint": regional_authority.authorization_endpoint, + "token_endpoint": regional_authority.token_endpoint, + "device_authorization_endpoint": + regional_authority.device_authorization_endpoint or + urljoin(regional_authority.token_endpoint, "devicecode"), + } + regional_client = _ClientWithCcsRoutingInfo( + regional_configuration, + self.client_id, + http_client=self.http_client, + default_headers=default_headers, + default_body=default_body, + client_assertion=client_assertion, + client_assertion_type=client_assertion_type, + on_obtaining_tokens=lambda event: self.token_cache.add(dict( + event, environment=authority.instance)), + on_removing_rt=self.token_cache.remove_rt, + on_updating_rt=self.token_cache.update_rt) + return central_client, regional_client + + def initiate_auth_code_flow( + self, + scopes, # type: list[str] + redirect_uri=None, + state=None, # Recommended by OAuth2 for CSRF protection + prompt=None, + login_hint=None, # type: Optional[str] + domain_hint=None, # type: Optional[str] + claims_challenge=None, + max_age=None, + response_mode=None, # type: Optional[str] + ): + """Initiate an auth code flow. + + Later when the response reaches your redirect_uri, + you can use :func:`~acquire_token_by_auth_code_flow()` + to complete the authentication/authorization. + + :param list scopes: + It is a list of case-sensitive strings. + :param str redirect_uri: + Optional. If not specified, server will use the pre-registered one. + :param str state: + An opaque value used by the client to + maintain state between the request and callback. + If absent, this library will automatically generate one internally. + :param str prompt: + By default, no prompt value will be sent, not even "none". + You will have to specify a value explicitly. + Its valid values are defined in Open ID Connect specs + https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest + :param str login_hint: + Optional. Identifier of the user. Generally a User Principal Name (UPN). + :param domain_hint: + Can be one of "consumers" or "organizations" or your tenant domain "contoso.com". + If included, it will skip the email-based discovery process that user goes + through on the sign-in page, leading to a slightly more streamlined user experience. + More information on possible values + `here `_ and + `here `_. + + :param int max_age: + OPTIONAL. Maximum Authentication Age. + Specifies the allowable elapsed time in seconds + since the last time the End-User was actively authenticated. + If the elapsed time is greater than this value, + Microsoft identity platform will actively re-authenticate the End-User. + + MSAL Python will also automatically validate the auth_time in ID token. + + New in version 1.15. + + :param str response_mode: + OPTIONAL. Specifies the method with which response parameters should be returned. + The default value is equivalent to ``query``, which is still secure enough in MSAL Python + (because MSAL Python does not transfer tokens via query parameter in the first place). + For even better security, we recommend using the value ``form_post``. + In "form_post" mode, response parameters + will be encoded as HTML form values that are transmitted via the HTTP POST method and + encoded in the body using the application/x-www-form-urlencoded format. + Valid values can be either "form_post" for HTTP POST to callback URI or + "query" (the default) for HTTP GET with parameters encoded in query string. + More information on possible values + `here ` + and `here ` + + :return: + The auth code flow. It is a dict in this form:: + + { + "auth_uri": "https://...", // Guide user to visit this + "state": "...", // You may choose to verify it by yourself, + // or just let acquire_token_by_auth_code_flow() + // do that for you. + "...": "...", // Everything else are reserved and internal + } + + The caller is expected to:: + + 1. somehow store this content, typically inside the current session, + 2. guide the end user (i.e. resource owner) to visit that auth_uri, + 3. and then relay this dict and subsequent auth response to + :func:`~acquire_token_by_auth_code_flow()`. + """ + client = _ClientWithCcsRoutingInfo( + {"authorization_endpoint": self.authority.authorization_endpoint}, + self.client_id, + http_client=self.http_client) + flow = client.initiate_auth_code_flow( + redirect_uri=redirect_uri, state=state, login_hint=login_hint, + prompt=prompt, + scope=self._decorate_scope(scopes), + domain_hint=domain_hint, + claims=_merge_claims_challenge_and_capabilities( + self._client_capabilities, claims_challenge), + max_age=max_age, + response_mode=response_mode, + ) + flow["claims_challenge"] = claims_challenge + return flow + + def get_authorization_request_url( + self, + scopes, # type: list[str] + login_hint=None, # type: Optional[str] + state=None, # Recommended by OAuth2 for CSRF protection + redirect_uri=None, + response_type="code", # Could be "token" if you use Implicit Grant + prompt=None, + nonce=None, + domain_hint=None, # type: Optional[str] + claims_challenge=None, + **kwargs): + """Constructs a URL for you to start a Authorization Code Grant. + + :param list[str] scopes: (Required) + Scopes requested to access a protected API (a resource). + :param str state: Recommended by OAuth2 for CSRF protection. + :param str login_hint: + Identifier of the user. Generally a User Principal Name (UPN). + :param str redirect_uri: + Address to return to upon receiving a response from the authority. + :param str response_type: + Default value is "code" for an OAuth2 Authorization Code grant. + + You could use other content such as "id_token" or "token", + which would trigger an Implicit Grant, but that is + `not recommended `_. + + :param str prompt: + By default, no prompt value will be sent, not even "none". + You will have to specify a value explicitly. + Its valid values are defined in Open ID Connect specs + https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest + :param nonce: + A cryptographically random value used to mitigate replay attacks. See also + `OIDC specs `_. + :param domain_hint: + Can be one of "consumers" or "organizations" or your tenant domain "contoso.com". + If included, it will skip the email-based discovery process that user goes + through on the sign-in page, leading to a slightly more streamlined user experience. + More information on possible values + `here `_ and + `here `_. + :param claims_challenge: + The claims_challenge parameter requests specific claims requested by the resource provider + in the form of a claims_challenge directive in the www-authenticate header to be + returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. + It is a string of a JSON object which contains lists of claims being requested from these locations. + + :return: The authorization url as a string. + """ + authority = kwargs.pop("authority", None) # Historically we support this + if authority: + warnings.warn( + "We haven't decided if this method will accept authority parameter") + # The previous implementation is, it will use self.authority by default. + # Multi-tenant app can use new authority on demand + the_authority = Authority( + authority, + self.http_client + ) if authority else self.authority + + client = _ClientWithCcsRoutingInfo( + {"authorization_endpoint": the_authority.authorization_endpoint}, + self.client_id, + http_client=self.http_client) + warnings.warn( + "Change your get_authorization_request_url() " + "to initiate_auth_code_flow()", DeprecationWarning) + with warnings.catch_warnings(record=True): + return client.build_auth_request_uri( + response_type=response_type, + redirect_uri=redirect_uri, state=state, login_hint=login_hint, + prompt=prompt, + scope=self._decorate_scope(scopes), + nonce=nonce, + domain_hint=domain_hint, + claims=_merge_claims_challenge_and_capabilities( + self._client_capabilities, claims_challenge), + ) + + def acquire_token_by_auth_code_flow( + self, auth_code_flow, auth_response, scopes=None, **kwargs): + """Validate the auth response being redirected back, and obtain tokens. + + It automatically provides nonce protection. + + :param dict auth_code_flow: + The same dict returned by :func:`~initiate_auth_code_flow()`. + :param dict auth_response: + A dict of the query string received from auth server. + :param list[str] scopes: + Scopes requested to access a protected API (a resource). + + Most of the time, you can leave it empty. + + If you requested user consent for multiple resources, here you will + need to provide a subset of what you required in + :func:`~initiate_auth_code_flow()`. + + OAuth2 was designed mostly for singleton services, + where tokens are always meant for the same resource and the only + changes are in the scopes. + In AAD, tokens can be issued for multiple 3rd party resources. + You can ask authorization code for multiple resources, + but when you redeem it, the token is for only one intended + recipient, called audience. + So the developer need to specify a scope so that we can restrict the + token to be issued for the corresponding audience. + + :return: + * A dict containing "access_token" and/or "id_token", among others, + depends on what scope was used. + (See https://tools.ietf.org/html/rfc6749#section-5.1) + * A dict containing "error", optionally "error_description", "error_uri". + (It is either `this `_ + or `that `_) + * Most client-side data error would result in ValueError exception. + So the usage pattern could be without any protocol details:: + + def authorize(): # A controller in a web app + try: + result = msal_app.acquire_token_by_auth_code_flow( + session.get("flow", {}), request.args) + if "error" in result: + return render_template("error.html", result) + use(result) # Token(s) are available in result and cache + except ValueError: # Usually caused by CSRF + pass # Simply ignore them + return redirect(url_for("index")) + """ + self._validate_ssh_cert_input_data(kwargs.get("data", {})) + telemetry_context = self._build_telemetry_context( + self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID) + response =_clean_up(self.client.obtain_token_by_auth_code_flow( + auth_code_flow, + auth_response, + scope=self._decorate_scope(scopes) if scopes else None, + headers=telemetry_context.generate_headers(), + data=dict( + kwargs.pop("data", {}), + claims=_merge_claims_challenge_and_capabilities( + self._client_capabilities, + auth_code_flow.pop("claims_challenge", None))), + **kwargs)) + telemetry_context.update_telemetry(response) + return response + + def acquire_token_by_authorization_code( + self, + code, + scopes, # Syntactically required. STS accepts empty value though. + redirect_uri=None, + # REQUIRED, if the "redirect_uri" parameter was included in the + # authorization request as described in Section 4.1.1, and their + # values MUST be identical. + nonce=None, + claims_challenge=None, + **kwargs): + """The second half of the Authorization Code Grant. + + :param code: The authorization code returned from Authorization Server. + :param list[str] scopes: (Required) + Scopes requested to access a protected API (a resource). + + If you requested user consent for multiple resources, here you will + typically want to provide a subset of what you required in AuthCode. + + OAuth2 was designed mostly for singleton services, + where tokens are always meant for the same resource and the only + changes are in the scopes. + In AAD, tokens can be issued for multiple 3rd party resources. + You can ask authorization code for multiple resources, + but when you redeem it, the token is for only one intended + recipient, called audience. + So the developer need to specify a scope so that we can restrict the + token to be issued for the corresponding audience. + + :param nonce: + If you provided a nonce when calling :func:`get_authorization_request_url`, + same nonce should also be provided here, so that we'll validate it. + An exception will be raised if the nonce in id token mismatches. + + :param claims_challenge: + The claims_challenge parameter requests specific claims requested by the resource provider + in the form of a claims_challenge directive in the www-authenticate header to be + returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. + It is a string of a JSON object which contains lists of claims being requested from these locations. + + :return: A dict representing the json response from AAD: + + - A successful response would contain "access_token" key, + - an error response would contain "error" and usually "error_description". + """ + # If scope is absent on the wire, STS will give you a token associated + # to the FIRST scope sent during the authorization request. + # So in theory, you can omit scope here when you were working with only + # one scope. But, MSAL decorates your scope anyway, so they are never + # really empty. + assert isinstance(scopes, list), "Invalid parameter type" + self._validate_ssh_cert_input_data(kwargs.get("data", {})) + warnings.warn( + "Change your acquire_token_by_authorization_code() " + "to acquire_token_by_auth_code_flow()", DeprecationWarning) + with warnings.catch_warnings(record=True): + telemetry_context = self._build_telemetry_context( + self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID) + response = _clean_up(self.client.obtain_token_by_authorization_code( + code, redirect_uri=redirect_uri, + scope=self._decorate_scope(scopes), + headers=telemetry_context.generate_headers(), + data=dict( + kwargs.pop("data", {}), + claims=_merge_claims_challenge_and_capabilities( + self._client_capabilities, claims_challenge)), + nonce=nonce, + **kwargs)) + telemetry_context.update_telemetry(response) + return response + + def get_accounts(self, username=None): + """Get a list of accounts which previously signed in, i.e. exists in cache. + + An account can later be used in :func:`~acquire_token_silent` + to find its tokens. + + :param username: + Filter accounts with this username only. Case insensitive. + :return: A list of account objects. + Each account is a dict. For now, we only document its "username" field. + Your app can choose to display those information to end user, + and allow user to choose one of his/her accounts to proceed. + """ + accounts = self._find_msal_accounts(environment=self.authority.instance) + if not accounts: # Now try other aliases of this authority instance + for alias in self._get_authority_aliases(self.authority.instance): + accounts = self._find_msal_accounts(environment=alias) + if accounts: + break + if username: + # Federated account["username"] from AAD could contain mixed case + lowercase_username = username.lower() + accounts = [a for a in accounts + if a["username"].lower() == lowercase_username] + if not accounts: + logger.debug(( # This would also happen when the cache is empty + "get_accounts(username='{}') finds no account. " + "If tokens were acquired without 'profile' scope, " + "they would contain no username for filtering. " + "Consider calling get_accounts(username=None) instead." + ).format(username)) + # Does not further filter by existing RTs here. It probably won't matter. + # Because in most cases Accounts and RTs co-exist. + # Even in the rare case when an RT is revoked and then removed, + # acquire_token_silent() would then yield no result, + # apps would fall back to other acquire methods. This is the standard pattern. + return accounts + + def _find_msal_accounts(self, environment): + interested_authority_types = [ + TokenCache.AuthorityType.ADFS, TokenCache.AuthorityType.MSSTS] + if _is_running_in_cloud_shell(): + interested_authority_types.append(_AUTHORITY_TYPE_CLOUDSHELL) + grouped_accounts = { + a.get("home_account_id"): # Grouped by home tenant's id + { # These are minimal amount of non-tenant-specific account info + "home_account_id": a.get("home_account_id"), + "environment": a.get("environment"), + "username": a.get("username"), + + # The following fields for backward compatibility, for now + "authority_type": a.get("authority_type"), + "local_account_id": a.get("local_account_id"), # Tenant-specific + "realm": a.get("realm"), # Tenant-specific + } + for a in self.token_cache.find( + TokenCache.CredentialType.ACCOUNT, + query={"environment": environment}) + if a["authority_type"] in interested_authority_types + } + return list(grouped_accounts.values()) + + def _get_authority_aliases(self, instance): + if not self.authority_groups: + resp = self.http_client.get( + "https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize", + headers={'Accept': 'application/json'}) + resp.raise_for_status() + self.authority_groups = [ + set(group['aliases']) for group in json.loads(resp.text)['metadata']] + for group in self.authority_groups: + if instance in group: + return [alias for alias in group if alias != instance] + return [] + + def remove_account(self, account): + """Sign me out and forget me from token cache""" + self._forget_me(account) + + def _sign_out(self, home_account): + # Remove all relevant RTs and ATs from token cache + owned_by_home_account = { + "environment": home_account["environment"], + "home_account_id": home_account["home_account_id"],} # realm-independent + app_metadata = self._get_app_metadata(home_account["environment"]) + # Remove RTs/FRTs, and they are realm-independent + for rt in [rt for rt in self.token_cache.find( + TokenCache.CredentialType.REFRESH_TOKEN, query=owned_by_home_account) + # Do RT's app ownership check as a precaution, in case family apps + # and 3rd-party apps share same token cache, although they should not. + if rt["client_id"] == self.client_id or ( + app_metadata.get("family_id") # Now let's settle family business + and rt.get("family_id") == app_metadata["family_id"]) + ]: + self.token_cache.remove_rt(rt) + for at in self.token_cache.find( # Remove ATs + # Regardless of realm, b/c we've removed realm-independent RTs anyway + TokenCache.CredentialType.ACCESS_TOKEN, query=owned_by_home_account): + # To avoid the complexity of locating sibling family app's AT, + # we skip AT's app ownership check. + # It means ATs for other apps will also be removed, it is OK because: + # * non-family apps are not supposed to share token cache to begin with; + # * Even if it happens, we keep other app's RT already, so SSO still works + self.token_cache.remove_at(at) + + def _forget_me(self, home_account): + # It implies signout, and then also remove all relevant accounts and IDTs + self._sign_out(home_account) + owned_by_home_account = { + "environment": home_account["environment"], + "home_account_id": home_account["home_account_id"],} # realm-independent + for idt in self.token_cache.find( # Remove IDTs, regardless of realm + TokenCache.CredentialType.ID_TOKEN, query=owned_by_home_account): + self.token_cache.remove_idt(idt) + for a in self.token_cache.find( # Remove Accounts, regardless of realm + TokenCache.CredentialType.ACCOUNT, query=owned_by_home_account): + self.token_cache.remove_account(a) + + def _acquire_token_by_cloud_shell(self, scopes, data=None): + from .cloudshell import _obtain_token + response = _obtain_token( + self.http_client, scopes, client_id=self.client_id, data=data) + if "error" not in response: + self.token_cache.add(dict( + client_id=self.client_id, + scope=response["scope"].split() if "scope" in response else scopes, + token_endpoint=self.authority.token_endpoint, + response=response.copy(), + data=data or {}, + authority_type=_AUTHORITY_TYPE_CLOUDSHELL, + )) + return response + + def acquire_token_silent( + self, + scopes, # type: List[str] + account, # type: Optional[Account] + authority=None, # See get_authorization_request_url() + force_refresh=False, # type: Optional[boolean] + claims_challenge=None, + **kwargs): + """Acquire an access token for given account, without user interaction. + + It is done either by finding a valid access token from cache, + or by finding a valid refresh token from cache and then automatically + use it to redeem a new access token. + + This method will combine the cache empty and refresh error + into one return value, `None`. + If your app does not care about the exact token refresh error during + token cache look-up, then this method is easier and recommended. + + Internally, this method calls :func:`~acquire_token_silent_with_error`. + + :param claims_challenge: + The claims_challenge parameter requests specific claims requested by the resource provider + in the form of a claims_challenge directive in the www-authenticate header to be + returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. + It is a string of a JSON object which contains lists of claims being requested from these locations. + + :return: + - A dict containing no "error" key, + and typically contains an "access_token" key, + if cache lookup succeeded. + - None when cache lookup does not yield a token. + """ + result = self.acquire_token_silent_with_error( + scopes, account, authority=authority, force_refresh=force_refresh, + claims_challenge=claims_challenge, **kwargs) + return result if result and "error" not in result else None + + def acquire_token_silent_with_error( + self, + scopes, # type: List[str] + account, # type: Optional[Account] + authority=None, # See get_authorization_request_url() + force_refresh=False, # type: Optional[boolean] + claims_challenge=None, + **kwargs): + """Acquire an access token for given account, without user interaction. + + It is done either by finding a valid access token from cache, + or by finding a valid refresh token from cache and then automatically + use it to redeem a new access token. + + This method will differentiate cache empty from token refresh error. + If your app cares the exact token refresh error during + token cache look-up, then this method is suitable. + Otherwise, the other method :func:`~acquire_token_silent` is recommended. + + :param list[str] scopes: (Required) + Scopes requested to access a protected API (a resource). + :param account: + one of the account object returned by :func:`~get_accounts`, + or use None when you want to find an access token for this client. + :param force_refresh: + If True, it will skip Access Token look-up, + and try to find a Refresh Token to obtain a new Access Token. + :param claims_challenge: + The claims_challenge parameter requests specific claims requested by the resource provider + in the form of a claims_challenge directive in the www-authenticate header to be + returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. + It is a string of a JSON object which contains lists of claims being requested from these locations. + :return: + - A dict containing no "error" key, + and typically contains an "access_token" key, + if cache lookup succeeded. + - None when there is simply no token in the cache. + - A dict containing an "error" key, when token refresh failed. + """ + assert isinstance(scopes, list), "Invalid parameter type" + self._validate_ssh_cert_input_data(kwargs.get("data", {})) + correlation_id = msal.telemetry._get_new_correlation_id() + if authority: + warnings.warn("We haven't decided how/if this method will accept authority parameter") + # the_authority = Authority( + # authority, + # self.http_client, + # ) if authority else self.authority + result = self._acquire_token_silent_from_cache_and_possibly_refresh_it( + scopes, account, self.authority, force_refresh=force_refresh, + claims_challenge=claims_challenge, + correlation_id=correlation_id, + **kwargs) + if result and "error" not in result: + return result + final_result = result + for alias in self._get_authority_aliases(self.authority.instance): + if not self.token_cache.find( + self.token_cache.CredentialType.REFRESH_TOKEN, + # target=scopes, # MUST NOT filter by scopes, because: + # 1. AAD RTs are scope-independent; + # 2. therefore target is optional per schema; + query={"environment": alias}): + # Skip heavy weight logic when RT for this alias doesn't exist + continue + the_authority = Authority( + "https://" + alias + "/" + self.authority.tenant, + self.http_client, + validate_authority=False) + result = self._acquire_token_silent_from_cache_and_possibly_refresh_it( + scopes, account, the_authority, force_refresh=force_refresh, + claims_challenge=claims_challenge, + correlation_id=correlation_id, + **kwargs) + if result: + if "error" not in result: + return result + final_result = result + if final_result and final_result.get("suberror"): + final_result["classification"] = { # Suppress these suberrors, per #57 + "bad_token": "", + "token_expired": "", + "protection_policy_required": "", + "client_mismatch": "", + "device_authentication_failed": "", + }.get(final_result["suberror"], final_result["suberror"]) + return final_result + + def _acquire_token_silent_from_cache_and_possibly_refresh_it( + self, + scopes, # type: List[str] + account, # type: Optional[Account] + authority, # This can be different than self.authority + force_refresh=False, # type: Optional[boolean] + claims_challenge=None, + correlation_id=None, + **kwargs): + access_token_from_cache = None + if not (force_refresh or claims_challenge): # Bypass AT when desired or using claims + query={ + "client_id": self.client_id, + "environment": authority.instance, + "realm": authority.tenant, + "home_account_id": (account or {}).get("home_account_id"), + } + key_id = kwargs.get("data", {}).get("key_id") + if key_id: # Some token types (SSH-certs, POP) are bound to a key + query["key_id"] = key_id + matches = self.token_cache.find( + self.token_cache.CredentialType.ACCESS_TOKEN, + target=scopes, + query=query) + now = time.time() + refresh_reason = msal.telemetry.AT_ABSENT + for entry in matches: + expires_in = int(entry["expires_on"]) - now + if expires_in < 5*60: # Then consider it expired + refresh_reason = msal.telemetry.AT_EXPIRED + continue # Removal is not necessary, it will be overwritten + logger.debug("Cache hit an AT") + access_token_from_cache = { # Mimic a real response + "access_token": entry["secret"], + "token_type": entry.get("token_type", "Bearer"), + "expires_in": int(expires_in), # OAuth2 specs defines it as int + } + if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging + refresh_reason = msal.telemetry.AT_AGING + break # With a fallback in hand, we break here to go refresh + self._build_telemetry_context(-1).hit_an_access_token() + return access_token_from_cache # It is still good as new + else: + refresh_reason = msal.telemetry.FORCE_REFRESH # TODO: It could also mean claims_challenge + assert refresh_reason, "It should have been established at this point" + try: + if account and account.get("authority_type") == _AUTHORITY_TYPE_CLOUDSHELL: + return self._acquire_token_by_cloud_shell( + scopes, data=kwargs.get("data")) + result = _clean_up(self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( + authority, self._decorate_scope(scopes), account, + refresh_reason=refresh_reason, claims_challenge=claims_challenge, + correlation_id=correlation_id, + **kwargs)) + if (result and "error" not in result) or (not access_token_from_cache): + return result + except: # The exact HTTP exception is transportation-layer dependent + # Typically network error. Potential AAD outage? + if not access_token_from_cache: # It means there is no fall back option + raise # We choose to bubble up the exception + return access_token_from_cache + + def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( + self, authority, scopes, account, **kwargs): + query = { + "environment": authority.instance, + "home_account_id": (account or {}).get("home_account_id"), + # "realm": authority.tenant, # AAD RTs are tenant-independent + } + app_metadata = self._get_app_metadata(authority.instance) + if not app_metadata: # Meaning this app is now used for the first time. + # When/if we have a way to directly detect current app's family, + # we'll rewrite this block, to support multiple families. + # For now, we try existing RTs (*). If it works, we are in that family. + # (*) RTs of a different app/family are not supposed to be + # shared with or accessible by us in the first place. + at = self._acquire_token_silent_by_finding_specific_refresh_token( + authority, scopes, + dict(query, family_id="1"), # A hack, we have only 1 family for now + rt_remover=lambda rt_item: None, # NO-OP b/c RTs are likely not mine + break_condition=lambda response: # Break loop when app not in family + # Based on an AAD-only behavior mentioned in internal doc here + # https://msazure.visualstudio.com/One/_git/ESTS-Docs/pullrequest/1138595 + "client_mismatch" in response.get("error_additional_info", []), + **kwargs) + if at and "error" not in at: + return at + last_resp = None + if app_metadata.get("family_id"): # Meaning this app belongs to this family + last_resp = at = self._acquire_token_silent_by_finding_specific_refresh_token( + authority, scopes, dict(query, family_id=app_metadata["family_id"]), + **kwargs) + if at and "error" not in at: + return at + # Either this app is an orphan, so we will naturally use its own RT; + # or all attempts above have failed, so we fall back to non-foci behavior. + return self._acquire_token_silent_by_finding_specific_refresh_token( + authority, scopes, dict(query, client_id=self.client_id), + **kwargs) or last_resp + + def _get_app_metadata(self, environment): + apps = self.token_cache.find( # Use find(), rather than token_cache.get(...) + TokenCache.CredentialType.APP_METADATA, query={ + "environment": environment, "client_id": self.client_id}) + return apps[0] if apps else {} + + def _acquire_token_silent_by_finding_specific_refresh_token( + self, authority, scopes, query, + rt_remover=None, break_condition=lambda response: False, + refresh_reason=None, correlation_id=None, claims_challenge=None, + **kwargs): + matches = self.token_cache.find( + self.token_cache.CredentialType.REFRESH_TOKEN, + # target=scopes, # AAD RTs are scope-independent + query=query) + logger.debug("Found %d RTs matching %s", len(matches), query) + + response = None # A distinguishable value to mean cache is empty + if not matches: # Then exit early to avoid expensive operations + return response + client, _ = self._build_client( + # Potentially expensive if building regional client + self.client_credential, authority, skip_regional_client=True) + telemetry_context = self._build_telemetry_context( + self.ACQUIRE_TOKEN_SILENT_ID, + correlation_id=correlation_id, refresh_reason=refresh_reason) + for entry in sorted( # Since unfit RTs would not be aggressively removed, + # we start from newer RTs which are more likely fit. + matches, + key=lambda e: int(e.get("last_modification_time", "0")), + reverse=True): + logger.debug("Cache attempts an RT") + headers = telemetry_context.generate_headers() + if query.get("home_account_id"): # Then use it as CCS Routing info + headers["X-AnchorMailbox"] = "Oid:{}".format( # case-insensitive value + query["home_account_id"].replace(".", "@")) + response = client.obtain_token_by_refresh_token( + entry, rt_getter=lambda token_item: token_item["secret"], + on_removing_rt=lambda rt_item: None, # Disable RT removal, + # because an invalid_grant could be caused by new MFA policy, + # the RT could still be useful for other MFA-less scope or tenant + on_obtaining_tokens=lambda event: self.token_cache.add(dict( + event, + environment=authority.instance, + skip_account_creation=True, # To honor a concurrent remove_account() + )), + scope=scopes, + headers=headers, + data=dict( + kwargs.pop("data", {}), + claims=_merge_claims_challenge_and_capabilities( + self._client_capabilities, claims_challenge)), + **kwargs) + telemetry_context.update_telemetry(response) + if "error" not in response: + return response + logger.debug("Refresh failed. {error}: {error_description}".format( + error=response.get("error"), + error_description=response.get("error_description"), + )) + if break_condition(response): + break + return response # Returns the latest error (if any), or just None + + def _validate_ssh_cert_input_data(self, data): + if data.get("token_type") == "ssh-cert": + if not data.get("req_cnf"): + raise ValueError( + "When requesting an SSH certificate, " + "you must include a string parameter named 'req_cnf' " + "containing the public key in JWK format " + "(https://tools.ietf.org/html/rfc7517).") + if not data.get("key_id"): + raise ValueError( + "When requesting an SSH certificate, " + "you must include a string parameter named 'key_id' " + "which identifies the key in the 'req_cnf' argument.") + + def acquire_token_by_refresh_token(self, refresh_token, scopes, **kwargs): + """Acquire token(s) based on a refresh token (RT) obtained from elsewhere. + + You use this method only when you have old RTs from elsewhere, + and now you want to migrate them into MSAL. + Calling this method results in new tokens automatically storing into MSAL. + + You do NOT need to use this method if you are already using MSAL. + MSAL maintains RT automatically inside its token cache, + and an access token can be retrieved + when you call :func:`~acquire_token_silent`. + + :param str refresh_token: The old refresh token, as a string. + + :param list scopes: + The scopes associate with this old RT. + Each scope needs to be in the Microsoft identity platform (v2) format. + See `Scopes not resources `_. + + :return: + * A dict contains "error" and some other keys, when error happened. + * A dict contains no "error" key means migration was successful. + """ + self._validate_ssh_cert_input_data(kwargs.get("data", {})) + telemetry_context = self._build_telemetry_context( + self.ACQUIRE_TOKEN_BY_REFRESH_TOKEN, + refresh_reason=msal.telemetry.FORCE_REFRESH) + response = _clean_up(self.client.obtain_token_by_refresh_token( + refresh_token, + scope=self._decorate_scope(scopes), + headers=telemetry_context.generate_headers(), + rt_getter=lambda rt: rt, + on_updating_rt=False, + on_removing_rt=lambda rt_item: None, # No OP + **kwargs)) + telemetry_context.update_telemetry(response) + return response + + def acquire_token_by_username_password( + self, username, password, scopes, claims_challenge=None, **kwargs): + """Gets a token for a given resource via user credentials. + + See this page for constraints of Username Password Flow. + https://github.com/AzureAD/microsoft-authentication-library-for-python/wiki/Username-Password-Authentication + + :param str username: Typically a UPN in the form of an email address. + :param str password: The password. + :param list[str] scopes: + Scopes requested to access a protected API (a resource). + :param claims_challenge: + The claims_challenge parameter requests specific claims requested by the resource provider + in the form of a claims_challenge directive in the www-authenticate header to be + returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. + It is a string of a JSON object which contains lists of claims being requested from these locations. + + :return: A dict representing the json response from AAD: + + - A successful response would contain "access_token" key, + - an error response would contain "error" and usually "error_description". + """ + scopes = self._decorate_scope(scopes) + telemetry_context = self._build_telemetry_context( + self.ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID) + headers = telemetry_context.generate_headers() + data = dict( + kwargs.pop("data", {}), + claims=_merge_claims_challenge_and_capabilities( + self._client_capabilities, claims_challenge)) + if not self.authority.is_adfs: + user_realm_result = self.authority.user_realm_discovery( + username, correlation_id=headers[msal.telemetry.CLIENT_REQUEST_ID]) + if user_realm_result.get("account_type") == "Federated": + response = _clean_up(self._acquire_token_by_username_password_federated( + user_realm_result, username, password, scopes=scopes, + data=data, + headers=headers, **kwargs)) + telemetry_context.update_telemetry(response) + return response + response = _clean_up(self.client.obtain_token_by_username_password( + username, password, scope=scopes, + headers=headers, + data=data, + **kwargs)) + telemetry_context.update_telemetry(response) + return response + + def _acquire_token_by_username_password_federated( + self, user_realm_result, username, password, scopes=None, **kwargs): + wstrust_endpoint = {} + if user_realm_result.get("federation_metadata_url"): + wstrust_endpoint = mex_send_request( + user_realm_result["federation_metadata_url"], + self.http_client) + if wstrust_endpoint is None: + raise ValueError("Unable to find wstrust endpoint from MEX. " + "This typically happens when attempting MSA accounts. " + "More details available here. " + "https://github.com/AzureAD/microsoft-authentication-library-for-python/wiki/Username-Password-Authentication") + logger.debug("wstrust_endpoint = %s", wstrust_endpoint) + wstrust_result = wst_send_request( + username, password, + user_realm_result.get("cloud_audience_urn", "urn:federation:MicrosoftOnline"), + wstrust_endpoint.get("address", + # Fallback to an AAD supplied endpoint + user_realm_result.get("federation_active_auth_url")), + wstrust_endpoint.get("action"), self.http_client) + if not ("token" in wstrust_result and "type" in wstrust_result): + raise RuntimeError("Unsuccessful RSTR. %s" % wstrust_result) + GRANT_TYPE_SAML1_1 = 'urn:ietf:params:oauth:grant-type:saml1_1-bearer' + grant_type = { + SAML_TOKEN_TYPE_V1: GRANT_TYPE_SAML1_1, + SAML_TOKEN_TYPE_V2: self.client.GRANT_TYPE_SAML2, + WSS_SAML_TOKEN_PROFILE_V1_1: GRANT_TYPE_SAML1_1, + WSS_SAML_TOKEN_PROFILE_V2: self.client.GRANT_TYPE_SAML2 + }.get(wstrust_result.get("type")) + if not grant_type: + raise RuntimeError( + "RSTR returned unknown token type: %s", wstrust_result.get("type")) + self.client.grant_assertion_encoders.setdefault( # Register a non-standard type + grant_type, self.client.encode_saml_assertion) + return self.client.obtain_token_by_assertion( + wstrust_result["token"], grant_type, scope=scopes, + on_obtaining_tokens=lambda event: self.token_cache.add(dict( + event, + environment=self.authority.instance, + username=username, # Useful in case IDT contains no such info + )), + **kwargs) + + +class PublicClientApplication(ClientApplication): # browser app or mobile app + + DEVICE_FLOW_CORRELATION_ID = "_correlation_id" + + def __init__(self, client_id, client_credential=None, **kwargs): + if client_credential is not None: + raise ValueError("Public Client should not possess credentials") + super(PublicClientApplication, self).__init__( + client_id, client_credential=None, **kwargs) + + def acquire_token_interactive( + self, + scopes, # type: list[str] + prompt=None, + login_hint=None, # type: Optional[str] + domain_hint=None, # type: Optional[str] + claims_challenge=None, + timeout=None, + port=None, + extra_scopes_to_consent=None, + max_age=None, + **kwargs): + """Acquire token interactively i.e. via a local browser. + + Prerequisite: In Azure Portal, configure the Redirect URI of your + "Mobile and Desktop application" as ``http://localhost``. + + :param list scopes: + It is a list of case-sensitive strings. + :param str prompt: + By default, no prompt value will be sent, not even "none". + You will have to specify a value explicitly. + Its valid values are defined in Open ID Connect specs + https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest + :param str login_hint: + Optional. Identifier of the user. Generally a User Principal Name (UPN). + :param domain_hint: + Can be one of "consumers" or "organizations" or your tenant domain "contoso.com". + If included, it will skip the email-based discovery process that user goes + through on the sign-in page, leading to a slightly more streamlined user experience. + More information on possible values + `here `_ and + `here `_. + + :param claims_challenge: + The claims_challenge parameter requests specific claims requested by the resource provider + in the form of a claims_challenge directive in the www-authenticate header to be + returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. + It is a string of a JSON object which contains lists of claims being requested from these locations. + + :param int timeout: + This method will block the current thread. + This parameter specifies the timeout value in seconds. + Default value ``None`` means wait indefinitely. + + :param int port: + The port to be used to listen to an incoming auth response. + By default we will use a system-allocated port. + (The rest of the redirect_uri is hard coded as ``http://localhost``.) + + :param list extra_scopes_to_consent: + "Extra scopes to consent" is a concept only available in AAD. + It refers to other resources you might want to prompt to consent for, + in the same interaction, but for which you won't get back a + token for in this particular operation. + + :param int max_age: + OPTIONAL. Maximum Authentication Age. + Specifies the allowable elapsed time in seconds + since the last time the End-User was actively authenticated. + If the elapsed time is greater than this value, + Microsoft identity platform will actively re-authenticate the End-User. + + MSAL Python will also automatically validate the auth_time in ID token. + + New in version 1.15. + + :return: + - A dict containing no "error" key, + and typically contains an "access_token" key. + - A dict containing an "error" key, when token refresh failed. + """ + self._validate_ssh_cert_input_data(kwargs.get("data", {})) + if _is_running_in_cloud_shell() and prompt == "none": + return self._acquire_token_by_cloud_shell( + scopes, data=kwargs.pop("data", {})) + claims = _merge_claims_challenge_and_capabilities( + self._client_capabilities, claims_challenge) + telemetry_context = self._build_telemetry_context( + self.ACQUIRE_TOKEN_INTERACTIVE) + response = _clean_up(self.client.obtain_token_by_browser( + scope=self._decorate_scope(scopes) if scopes else None, + extra_scope_to_consent=extra_scopes_to_consent, + redirect_uri="http://localhost:{port}".format( + # Hardcode the host, for now. AAD portal rejects 127.0.0.1 anyway + port=port or 0), + prompt=prompt, + login_hint=login_hint, + max_age=max_age, + timeout=timeout, + auth_params={ + "claims": claims, + "domain_hint": domain_hint, + }, + data=dict(kwargs.pop("data", {}), claims=claims), + headers=telemetry_context.generate_headers(), + browser_name=_preferred_browser(), + **kwargs)) + telemetry_context.update_telemetry(response) + return response + + def initiate_device_flow(self, scopes=None, **kwargs): + """Initiate a Device Flow instance, + which will be used in :func:`~acquire_token_by_device_flow`. + + :param list[str] scopes: + Scopes requested to access a protected API (a resource). + :return: A dict representing a newly created Device Flow object. + + - A successful response would contain "user_code" key, among others + - an error response would contain some other readable key/value pairs. + """ + correlation_id = msal.telemetry._get_new_correlation_id() + flow = self.client.initiate_device_flow( + scope=self._decorate_scope(scopes or []), + headers={msal.telemetry.CLIENT_REQUEST_ID: correlation_id}, + **kwargs) + flow[self.DEVICE_FLOW_CORRELATION_ID] = correlation_id + return flow + + def acquire_token_by_device_flow(self, flow, claims_challenge=None, **kwargs): + """Obtain token by a device flow object, with customizable polling effect. + + :param dict flow: + A dict previously generated by :func:`~initiate_device_flow`. + By default, this method's polling effect will block current thread. + You can abort the polling loop at any time, + by changing the value of the flow's "expires_at" key to 0. + :param claims_challenge: + The claims_challenge parameter requests specific claims requested by the resource provider + in the form of a claims_challenge directive in the www-authenticate header to be + returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. + It is a string of a JSON object which contains lists of claims being requested from these locations. + + :return: A dict representing the json response from AAD: + + - A successful response would contain "access_token" key, + - an error response would contain "error" and usually "error_description". + """ + telemetry_context = self._build_telemetry_context( + self.ACQUIRE_TOKEN_BY_DEVICE_FLOW_ID, + correlation_id=flow.get(self.DEVICE_FLOW_CORRELATION_ID)) + response = _clean_up(self.client.obtain_token_by_device_flow( + flow, + data=dict( + kwargs.pop("data", {}), + code=flow["device_code"], # 2018-10-4 Hack: + # during transition period, + # service seemingly need both device_code and code parameter. + claims=_merge_claims_challenge_and_capabilities( + self._client_capabilities, claims_challenge), + ), + headers=telemetry_context.generate_headers(), + **kwargs)) + telemetry_context.update_telemetry(response) + return response + + +class ConfidentialClientApplication(ClientApplication): # server-side web app + + def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs): + """Acquires token for the current confidential client, not for an end user. + + :param list[str] scopes: (Required) + Scopes requested to access a protected API (a resource). + :param claims_challenge: + The claims_challenge parameter requests specific claims requested by the resource provider + in the form of a claims_challenge directive in the www-authenticate header to be + returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. + It is a string of a JSON object which contains lists of claims being requested from these locations. + + :return: A dict representing the json response from AAD: + + - A successful response would contain "access_token" key, + - an error response would contain "error" and usually "error_description". + """ + # TBD: force_refresh behavior + if self.authority.tenant.lower() in ["common", "organizations"]: + warnings.warn( + "Using /common or /organizations authority " + "in acquire_token_for_client() is unreliable. " + "Please use a specific tenant instead.", DeprecationWarning) + self._validate_ssh_cert_input_data(kwargs.get("data", {})) + telemetry_context = self._build_telemetry_context( + self.ACQUIRE_TOKEN_FOR_CLIENT_ID) + client = self._regional_client or self.client + response = _clean_up(client.obtain_token_for_client( + scope=scopes, # This grant flow requires no scope decoration + headers=telemetry_context.generate_headers(), + data=dict( + kwargs.pop("data", {}), + claims=_merge_claims_challenge_and_capabilities( + self._client_capabilities, claims_challenge)), + **kwargs)) + telemetry_context.update_telemetry(response) + return response + + def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=None, **kwargs): + """Acquires token using on-behalf-of (OBO) flow. + + The current app is a middle-tier service which was called with a token + representing an end user. + The current app can use such token (a.k.a. a user assertion) to request + another token to access downstream web API, on behalf of that user. + See `detail docs here `_ . + + The current middle-tier app has no user interaction to obtain consent. + See how to gain consent upfront for your middle-tier app from this article. + https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-on-behalf-of-flow#gaining-consent-for-the-middle-tier-application + + :param str user_assertion: The incoming token already received by this app + :param list[str] scopes: Scopes required by downstream API (a resource). + :param claims_challenge: + The claims_challenge parameter requests specific claims requested by the resource provider + in the form of a claims_challenge directive in the www-authenticate header to be + returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. + It is a string of a JSON object which contains lists of claims being requested from these locations. + + :return: A dict representing the json response from AAD: + + - A successful response would contain "access_token" key, + - an error response would contain "error" and usually "error_description". + """ + telemetry_context = self._build_telemetry_context( + self.ACQUIRE_TOKEN_ON_BEHALF_OF_ID) + # The implementation is NOT based on Token Exchange + # https://tools.ietf.org/html/draft-ietf-oauth-token-exchange-16 + response = _clean_up(self.client.obtain_token_by_assertion( # bases on assertion RFC 7521 + user_assertion, + self.client.GRANT_TYPE_JWT, # IDTs and AAD ATs are all JWTs + scope=self._decorate_scope(scopes), # Decoration is used for: + # 1. Explicitly requesting an RT, without relying on AAD default + # behavior, even though it currently still issues an RT. + # 2. Requesting an IDT (which would otherwise be unavailable) + # so that the calling app could use id_token_claims to implement + # their own cache mapping, which is likely needed in web apps. + data=dict( + kwargs.pop("data", {}), + requested_token_use="on_behalf_of", + claims=_merge_claims_challenge_and_capabilities( + self._client_capabilities, claims_challenge)), + headers=telemetry_context.generate_headers(), + # TBD: Expose a login_hint (or ccs_routing_hint) param for web app + **kwargs)) + telemetry_context.update_telemetry(response) + return response diff --git a/msal/authority.py b/msal/authority.py new file mode 100644 index 00000000..4fb6e829 --- /dev/null +++ b/msal/authority.py @@ -0,0 +1,177 @@ +import json +try: + from urllib.parse import urlparse +except ImportError: # Fall back to Python 2 + from urlparse import urlparse +import logging + +from .exceptions import MsalServiceError + + +logger = logging.getLogger(__name__) + +# Endpoints were copied from here +# https://docs.microsoft.com/en-us/azure/active-directory/develop/authentication-national-cloud#azure-ad-authentication-endpoints +AZURE_US_GOVERNMENT = "login.microsoftonline.us" +AZURE_CHINA = "login.chinacloudapi.cn" +AZURE_PUBLIC = "login.microsoftonline.com" + +WORLD_WIDE = 'login.microsoftonline.com' # There was an alias login.windows.net +WELL_KNOWN_AUTHORITY_HOSTS = set([ + WORLD_WIDE, + AZURE_CHINA, + 'login-us.microsoftonline.com', + AZURE_US_GOVERNMENT, + ]) +WELL_KNOWN_B2C_HOSTS = [ + "b2clogin.com", + "b2clogin.cn", + "b2clogin.us", + "b2clogin.de", + ] + + +class AuthorityBuilder(object): + def __init__(self, instance, tenant): + """A helper to save caller from doing string concatenation. + + Usage is documented in :func:`application.ClientApplication.__init__`. + """ + self._instance = instance.rstrip("/") + self._tenant = tenant.strip("/") + + def __str__(self): + return "https://{}/{}".format(self._instance, self._tenant) + + +class Authority(object): + """This class represents an (already-validated) authority. + + Once constructed, it contains members named "*_endpoint" for this instance. + TODO: It will also cache the previously-validated authority instances. + """ + _domains_without_user_realm_discovery = set([]) + + @property + def http_client(self): # Obsolete. We will remove this eventually + warnings.warn( + "authority.http_client might be removed in MSAL Python 1.21+", DeprecationWarning) + return self._http_client + + def __init__(self, authority_url, http_client, validate_authority=True): + """Creates an authority instance, and also validates it. + + :param validate_authority: + The Authority validation process actually checks two parts: + instance (a.k.a. host) and tenant. We always do a tenant discovery. + This parameter only controls whether an instance discovery will be + performed. + """ + self._http_client = http_client + if isinstance(authority_url, AuthorityBuilder): + authority_url = str(authority_url) + authority, self.instance, tenant = canonicalize(authority_url) + parts = authority.path.split('/') + is_b2c = any(self.instance.endswith("." + d) for d in WELL_KNOWN_B2C_HOSTS) or ( + len(parts) == 3 and parts[2].lower().startswith("b2c_")) + if (tenant != "adfs" and (not is_b2c) and validate_authority + and self.instance not in WELL_KNOWN_AUTHORITY_HOSTS): + payload = instance_discovery( + "https://{}{}/oauth2/v2.0/authorize".format( + self.instance, authority.path), + self._http_client) + if payload.get("error") == "invalid_instance": + raise ValueError( + "invalid_instance: " + "The authority you provided, %s, is not whitelisted. " + "If it is indeed your legit customized domain name, " + "you can turn off this check by passing in " + "validate_authority=False" + % authority_url) + tenant_discovery_endpoint = payload['tenant_discovery_endpoint'] + else: + tenant_discovery_endpoint = ( + 'https://{}{}{}/.well-known/openid-configuration'.format( + self.instance, + authority.path, # In B2C scenario, it is "/tenant/policy" + "" if tenant == "adfs" else "/v2.0" # the AAD v2 endpoint + )) + try: + openid_config = tenant_discovery( + tenant_discovery_endpoint, + self._http_client) + except ValueError: + raise ValueError( + "Unable to get authority configuration for {}. " + "Authority would typically be in a format of " + "https://login.microsoftonline.com/your_tenant " + "Also please double check your tenant name or GUID is correct.".format( + authority_url)) + logger.debug("openid_config = %s", openid_config) + self.authorization_endpoint = openid_config['authorization_endpoint'] + self.token_endpoint = openid_config['token_endpoint'] + self.device_authorization_endpoint = openid_config.get('device_authorization_endpoint') + _, _, self.tenant = canonicalize(self.token_endpoint) # Usually a GUID + self.is_adfs = self.tenant.lower() == 'adfs' + + def user_realm_discovery(self, username, correlation_id=None, response=None): + # It will typically return a dict containing "ver", "account_type", + # "federation_protocol", "cloud_audience_urn", + # "federation_metadata_url", "federation_active_auth_url", etc. + if self.instance not in self.__class__._domains_without_user_realm_discovery: + resp = response or self._http_client.get( + "https://{netloc}/common/userrealm/{username}?api-version=1.0".format( + netloc=self.instance, username=username), + headers={'Accept': 'application/json', + 'client-request-id': correlation_id},) + if resp.status_code != 404: + resp.raise_for_status() + return json.loads(resp.text) + self.__class__._domains_without_user_realm_discovery.add(self.instance) + return {} # This can guide the caller to fall back normal ROPC flow + + +def canonicalize(authority_url): + # Returns (url_parsed_result, hostname_in_lowercase, tenant) + authority = urlparse(authority_url) + parts = authority.path.split("/") + if authority.scheme != "https" or len(parts) < 2 or not parts[1]: + raise ValueError( + "Your given address (%s) should consist of " + "an https url with a minimum of one segment in a path: e.g. " + "https://login.microsoftonline.com/ " + "or https://.b2clogin.com/.onmicrosoft.com/policy" + % authority_url) + return authority, authority.hostname, parts[1] + +def instance_discovery(url, http_client, **kwargs): + resp = http_client.get( # Note: This URL seemingly returns V1 endpoint only + 'https://{}/common/discovery/instance'.format( + WORLD_WIDE # Historically using WORLD_WIDE. Could use self.instance too + # See https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadInstanceDiscovery.cs#L101-L103 + # and https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadAuthority.cs#L19-L33 + ), + params={'authorization_endpoint': url, 'api-version': '1.0'}, + **kwargs) + return json.loads(resp.text) + +def tenant_discovery(tenant_discovery_endpoint, http_client, **kwargs): + # Returns Openid Configuration + resp = http_client.get(tenant_discovery_endpoint, **kwargs) + if resp.status_code == 200: + payload = json.loads(resp.text) # It could raise ValueError + if 'authorization_endpoint' in payload and 'token_endpoint' in payload: + return payload # Happy path + raise ValueError("OIDC Discovery does not provide enough information") + if 400 <= resp.status_code < 500: + # Nonexist tenant would hit this path + # e.g. https://login.microsoftonline.com/nonexist_tenant/v2.0/.well-known/openid-configuration + raise ValueError( + "OIDC Discovery endpoint rejects our request. Error: {}".format( + resp.text # Expose it as-is b/c OIDC defines no error response format + )) + # Transient network error would hit this path + resp.raise_for_status() + raise RuntimeError( # A fallback here, in case resp.raise_for_status() is no-op + "Unable to complete OIDC Discovery: %d, %s" % (resp.status_code, resp.text)) + diff --git a/msal/cloudshell.py b/msal/cloudshell.py new file mode 100644 index 00000000..f4feaf44 --- /dev/null +++ b/msal/cloudshell.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. + +"""This module wraps Cloud Shell's IMDS-like interface inside an OAuth2-like helper""" +import base64 +import json +import logging +import os +import time +try: # Python 2 + from urlparse import urlparse +except: # Python 3 + from urllib.parse import urlparse +from .oauth2cli.oidc import decode_part + + +logger = logging.getLogger(__name__) + + +def _is_running_in_cloud_shell(): + return os.environ.get("AZUREPS_HOST_ENVIRONMENT", "").startswith("cloud-shell") + + +def _scope_to_resource(scope): # This is an experimental reasonable-effort approach + cloud_shell_supported_audiences = [ + "https://analysis.windows.net/powerbi/api", # Came from https://msazure.visualstudio.com/One/_git/compute-CloudShell?path=/src/images/agent/env/envconfig.PROD.json + "https://pas.windows.net/CheckMyAccess/Linux/.default", # Cloud Shell accepts it as-is + ] + for a in cloud_shell_supported_audiences: + if scope.startswith(a): + return a + u = urlparse(scope) + if u.scheme: + return "{}://{}".format(u.scheme, u.netloc) + return scope # There is no much else we can do here + + +def _obtain_token(http_client, scopes, client_id=None, data=None): + resp = http_client.post( + "http://localhost:50342/oauth2/token", + data=dict( + data or {}, + resource=" ".join(map(_scope_to_resource, scopes))), + headers={"Metadata": "true"}, + ) + if resp.status_code >= 300: + logger.debug("Cloud Shell IMDS error: %s", resp.text) + cs_error = json.loads(resp.text).get("error", {}) + return {k: v for k, v in { + "error": cs_error.get("code"), + "error_description": cs_error.get("message"), + }.items() if v} + imds_payload = json.loads(resp.text) + BEARER = "Bearer" + oauth2_response = { + "access_token": imds_payload["access_token"], + "expires_in": int(imds_payload["expires_in"]), + "token_type": imds_payload.get("token_type", BEARER), + } + expected_token_type = (data or {}).get("token_type", BEARER) + if oauth2_response["token_type"] != expected_token_type: + return { # Generate a normal error (rather than an intrusive exception) + "error": "broker_error", + "error_description": "token_type {} is not supported by this version of Azure Portal".format( + expected_token_type), + } + parts = imds_payload["access_token"].split(".") + + # The following default values are useful in SSH Cert scenario + client_info = { # Default value, in case the real value will be unavailable + "uid": "user", + "utid": "cloudshell", + } + now = time.time() + preferred_username = "currentuser@cloudshell" + oauth2_response["id_token_claims"] = { # First 5 claims are required per OIDC + "iss": "cloudshell", + "sub": "user", + "aud": client_id, + "exp": now + 3600, + "iat": now, + "preferred_username": preferred_username, # Useful as MSAL account's username + } + + if len(parts) == 3: # Probably a JWT. Use it to derive client_info and id token. + try: + # Data defined in https://docs.microsoft.com/en-us/azure/active-directory/develop/access-tokens#payload-claims + jwt_payload = json.loads(decode_part(parts[1])) + client_info = { + # Mimic a real home_account_id, + # so that this pseudo account and a real account would interop. + "uid": jwt_payload.get("oid", "user"), + "utid": jwt_payload.get("tid", "cloudshell"), + } + oauth2_response["id_token_claims"] = { + "iss": jwt_payload["iss"], + "sub": jwt_payload["sub"], # Could use oid instead + "aud": client_id, + "exp": jwt_payload["exp"], + "iat": jwt_payload["iat"], + "preferred_username": jwt_payload.get("preferred_username") # V2 + or jwt_payload.get("unique_name") # V1 + or preferred_username, + } + except ValueError: + logger.debug("Unable to decode jwt payload: %s", parts[1]) + oauth2_response["client_info"] = base64.b64encode( + # Mimic a client_info, so that MSAL would create an account + json.dumps(client_info).encode("utf-8")).decode("utf-8") + oauth2_response["id_token_claims"]["tid"] = client_info["utid"] # TBD + + ## Note: Decided to not surface resource back as scope, + ## because they would cause the downstream OAuth2 code path to + ## cache the token with a different scope and won't hit them later. + #if imds_payload.get("resource"): + # oauth2_response["scope"] = imds_payload["resource"] + if imds_payload.get("refresh_token"): + oauth2_response["refresh_token"] = imds_payload["refresh_token"] + return oauth2_response + diff --git a/msal/exceptions.py b/msal/exceptions.py new file mode 100644 index 00000000..5e9ee151 --- /dev/null +++ b/msal/exceptions.py @@ -0,0 +1,38 @@ +#------------------------------------------------------------------------------ +# +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files(the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions : +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +#------------------------------------------------------------------------------ + +class MsalError(Exception): + # Define the template in Unicode to accommodate possible Unicode variables + msg = u'An unspecified error' + + def __init__(self, *args, **kwargs): + super(MsalError, self).__init__(self.msg.format(**kwargs), *args) + self.kwargs = kwargs + +class MsalServiceError(MsalError): + msg = u"{error}: {error_description}" + diff --git a/msal/individual_cache.py b/msal/individual_cache.py new file mode 100644 index 00000000..4c6fa00e --- /dev/null +++ b/msal/individual_cache.py @@ -0,0 +1,286 @@ +from functools import wraps +import time +try: + from collections.abc import MutableMapping # Python 3.3+ +except ImportError: + from collections import MutableMapping # Python 2.7+ +import heapq +from threading import Lock + + +class _ExpiringMapping(MutableMapping): + _INDEX = "_index_" + + def __init__(self, mapping=None, capacity=None, expires_in=None, lock=None, + *args, **kwargs): + """Items in this mapping can have individual shelf life, + just like food items in your refrigerator have their different shelf life + determined by each food, not by the refrigerator. + + Expired items will be automatically evicted. + The clean-up will be done at each time when adding a new item, + or when looping or counting the entire mapping. + (This is better than being done indecisively by a background thread, + which might not always happen before your accessing the mapping.) + + This implementation uses no dependency other than Python standard library. + + :param MutableMapping mapping: + A dict-like key-value mapping, which needs to support __setitem__(), + __getitem__(), __delitem__(), get(), pop(). + + The default mapping is an in-memory dict. + + You could potentially supply a file-based dict-like object, too. + This implementation deliberately avoid mapping.__iter__(), + which could be slow on a file-based mapping. + + :param int capacity: + How many items this mapping will hold. + When you attempt to add new item into a full mapping, + it will automatically delete the item that is expiring soonest. + + The default value is None, which means there is no capacity limit. + + :param int expires_in: + How many seconds an item would expire and be purged from this mapping. + Also known as time-to-live (TTL). + You can also use :func:`~set()` to provide per-item expires_in value. + + :param Lock lock: + A locking mechanism with context manager interface. + If no lock is provided, a threading.Lock will be used. + But you may want to supply a different lock, + if your customized mapping is being shared differently. + """ + super(_ExpiringMapping, self).__init__(*args, **kwargs) + self._mapping = mapping if mapping is not None else {} + self._capacity = capacity + self._expires_in = expires_in + self._lock = Lock() if lock is None else lock + + def _validate_key(self, key): + if key == self._INDEX: + raise ValueError("key {} is a reserved keyword in {}".format( + key, self.__class__.__name__)) + + def set(self, key, value, expires_in): + # This method's name was chosen so that it matches its cousin __setitem__(), + # and it also complements the counterpart get(). + # The downside is such a name shadows the built-in type set in this file, + # but you can overcome that by defining a global alias for set. + """It sets the key-value pair into this mapping, with its per-item expires_in. + + It will take O(logN) time, because it will run some maintenance. + This worse-than-constant time is acceptable, because in a cache scenario, + __setitem__() would only be called during a cache miss, + which would already incur an expensive target function call anyway. + + By the way, most other methods of this mapping still have O(1) constant time. + """ + with self._lock: + self._set(key, value, expires_in) + + def _set(self, key, value, expires_in): + # This internal implementation powers both set() and __setitem__(), + # so that they don't depend on each other. + self._validate_key(key) + sequence, timestamps = self._mapping.get(self._INDEX, ([], {})) + self._maintenance(sequence, timestamps) # O(logN) + now = int(time.time()) + expires_at = now + expires_in + entry = [expires_at, now, key] + is_new_item = key not in timestamps + is_beyond_capacity = self._capacity and len(timestamps) >= self._capacity + if is_new_item and is_beyond_capacity: + self._drop_indexed_entry(timestamps, heapq.heappushpop(sequence, entry)) + else: # Simply add new entry. The old one would become a harmless orphan. + heapq.heappush(sequence, entry) + timestamps[key] = [expires_at, now] # It overwrites existing key, if any + self._mapping[key] = value + self._mapping[self._INDEX] = sequence, timestamps + + def _maintenance(self, sequence, timestamps): # O(logN) + """It will modify input sequence and timestamps in-place""" + now = int(time.time()) + while sequence: # Clean up expired items + expires_at, created_at, key = sequence[0] + if created_at <= now < expires_at: # Then all remaining items are fresh + break + self._drop_indexed_entry(timestamps, sequence[0]) # It could error out + heapq.heappop(sequence) # Only pop it after a successful _drop_indexed_entry() + while self._capacity is not None and len(timestamps) > self._capacity: + self._drop_indexed_entry(timestamps, sequence[0]) # It could error out + heapq.heappop(sequence) # Only pop it after a successful _drop_indexed_entry() + + def _drop_indexed_entry(self, timestamps, entry): + """For an entry came from index, drop it from timestamps and self._mapping""" + expires_at, created_at, key = entry + if [expires_at, created_at] == timestamps.get(key): # So it is not an orphan + self._mapping.pop(key, None) # It could raise exception + timestamps.pop(key, None) # This would probably always succeed + + def __setitem__(self, key, value): + """Implements the __setitem__(). + + Same characteristic as :func:`~set()`, + but use class-wide expires_in which was specified by :func:`~__init__()`. + """ + if self._expires_in is None: + raise ValueError("Need a numeric value for expires_in during __init__()") + with self._lock: + self._set(key, value, self._expires_in) + + def __getitem__(self, key): # O(1) + """If the item you requested already expires, KeyError will be raised.""" + self._validate_key(key) + with self._lock: + # Skip self._maintenance(), because it would need O(logN) time + sequence, timestamps = self._mapping.get(self._INDEX, ([], {})) + expires_at, created_at = timestamps[key] # Would raise KeyError accordingly + now = int(time.time()) + if not created_at <= now < expires_at: + self._mapping.pop(key, None) + timestamps.pop(key, None) + self._mapping[self._INDEX] = sequence, timestamps + raise KeyError("{} {}".format( + key, + "expired" if now >= expires_at else "created in the future?", + )) + return self._mapping[key] # O(1) + + def __delitem__(self, key): # O(1) + """If the item you requested already expires, KeyError will be raised.""" + self._validate_key(key) + with self._lock: + # Skip self._maintenance(), because it would need O(logN) time + self._mapping.pop(key, None) # O(1) + sequence, timestamps = self._mapping.get(self._INDEX, ([], {})) + del timestamps[key] # O(1) + self._mapping[self._INDEX] = sequence, timestamps + + def __len__(self): # O(logN) + """Drop all expired items and return the remaining length""" + with self._lock: + sequence, timestamps = self._mapping.get(self._INDEX, ([], {})) + self._maintenance(sequence, timestamps) # O(logN) + self._mapping[self._INDEX] = sequence, timestamps + return len(timestamps) # Faster than iter(self._mapping) when it is on disk + + def __iter__(self): + """Drop all expired items and return an iterator of the remaining items""" + with self._lock: + sequence, timestamps = self._mapping.get(self._INDEX, ([], {})) + self._maintenance(sequence, timestamps) # O(logN) + self._mapping[self._INDEX] = sequence, timestamps + return iter(timestamps) # Faster than iter(self._mapping) when it is on disk + + +class _IndividualCache(object): + # The code structure below can decorate both function and method. + # It is inspired by https://stackoverflow.com/a/9417088 + # We may potentially switch to build upon + # https://github.com/micheles/decorator/blob/master/docs/documentation.md#statement-of-the-problem + def __init__(self, mapping=None, key_maker=None, expires_in=None): + """Constructs a cache decorator that allows item-by-item control on + how to cache the return value of the decorated function. + + :param MutableMapping mapping: + The cached items will be stored inside. + You'd want to use a ExpiringMapping + if you plan to utilize the ``expires_in`` behavior. + + If nothing is provided, an in-memory dict will be used, + but it will provide no expiry functionality. + + .. note:: + + When using this class as a decorator, + your mapping needs to be available at "compile" time, + so it would typically be a global-, module- or class-level mapping:: + + module_mapping = {} + + @IndividualCache(mapping=module_mapping, ...) + def foo(): + ... + + If you want to use a mapping available only at run-time, + you have to manually decorate your function at run-time, too:: + + def foo(): + ... + + def bar(runtime_mapping): + foo = IndividualCache(mapping=runtime_mapping...)(foo) + + :param callable key_maker: + A callable which should have signature as + ``lambda function, args, kwargs: "return a string as key"``. + + If key_maker happens to return ``None``, the cache will be bypassed, + the underlying function will be invoked directly, + and the invoke result will not be cached either. + + :param callable expires_in: + The default value is ``None``, + which means the content being cached has no per-item expiry, + and will subject to the underlying mapping's global expiry time. + + It can be an integer indicating + how many seconds the result will be cached. + In particular, if the value is 0, + it means the result expires after zero second (i.e. immediately), + therefore the result will *not* be cached. + (Mind the difference between ``expires_in=0`` and ``expires_in=None``.) + + Or it can be a callable with the signature as + ``lambda function=function, args=args, kwargs=kwargs, result=result: 123`` + to calculate the expiry on the fly. + Its return value will be interpreted in the same way as above. + """ + self._mapping = mapping if mapping is not None else {} + self._key_maker = key_maker or (lambda function, args, kwargs: ( + function, # This default implementation uses function as part of key, + # so that the cache is partitioned by function. + # However, you could have many functions to use same namespace, + # so different decorators could share same cache. + args, + tuple(kwargs.items()), # raw kwargs is not hashable + )) + self._expires_in = expires_in + + def __call__(self, function): + + @wraps(function) + def wrapper(*args, **kwargs): + key = self._key_maker(function, args, kwargs) + if key is None: # Then bypass the cache + return function(*args, **kwargs) + + now = int(time.time()) + try: + return self._mapping[key] + except KeyError: + # We choose to NOT call function(...) in this block, otherwise + # potential exception from function(...) would become a confusing + # "During handling of the above exception, another exception occurred" + pass + value = function(*args, **kwargs) + + expires_in = self._expires_in( + function=function, + args=args, + kwargs=kwargs, + result=value, + ) if callable(self._expires_in) else self._expires_in + if expires_in == 0: + return value + if expires_in is None: + self._mapping[key] = value + else: + self._mapping.set(key, value, expires_in) + return value + + return wrapper + diff --git a/msal/mex.py b/msal/mex.py new file mode 100644 index 00000000..edecba37 --- /dev/null +++ b/msal/mex.py @@ -0,0 +1,137 @@ +#------------------------------------------------------------------------------ +# +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files(the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions : +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +#------------------------------------------------------------------------------ + +try: + from urllib.parse import urlparse +except: + from urlparse import urlparse +try: + from xml.etree import cElementTree as ET +except ImportError: + from xml.etree import ElementTree as ET +import logging + + +logger = logging.getLogger(__name__) + +def _xpath_of_root(route_to_leaf): + # Construct an xpath suitable to find a root node which has a specified leaf + return '/'.join(route_to_leaf + ['..'] * (len(route_to_leaf)-1)) + + +def send_request(mex_endpoint, http_client, **kwargs): + mex_resp = http_client.get(mex_endpoint, **kwargs) + mex_resp.raise_for_status() + try: + return Mex(mex_resp.text).get_wstrust_username_password_endpoint() + except ET.ParseError: + logger.exception( + "Malformed MEX document: %s, %s", mex_resp.status_code, mex_resp.text) + raise + + +class Mex(object): + + NS = { # Also used by wstrust_*.py + 'wsdl': 'http://schemas.xmlsoap.org/wsdl/', + 'sp': 'http://docs.oasis-open.org/ws-sx/ws-securitypolicy/200702', + 'sp2005': 'http://schemas.xmlsoap.org/ws/2005/07/securitypolicy', + 'wsu': 'http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-wssecurity-utility-1.0.xsd', + 'wsa': 'http://www.w3.org/2005/08/addressing', # Duplicate? + 'wsa10': 'http://www.w3.org/2005/08/addressing', + 'http': 'http://schemas.microsoft.com/ws/06/2004/policy/http', + 'soap12': 'http://schemas.xmlsoap.org/wsdl/soap12/', + 'wsp': 'http://schemas.xmlsoap.org/ws/2004/09/policy', + 's': 'http://www.w3.org/2003/05/soap-envelope', + 'wst': 'http://docs.oasis-open.org/ws-sx/ws-trust/200512', + 'trust': "http://docs.oasis-open.org/ws-sx/ws-trust/200512", # Duplicate? + 'saml': "urn:oasis:names:tc:SAML:1.0:assertion", + 'wst2005': 'http://schemas.xmlsoap.org/ws/2005/02/trust', # was named "t" + } + ACTION_13 = 'http://docs.oasis-open.org/ws-sx/ws-trust/200512/RST/Issue' + ACTION_2005 = 'http://schemas.xmlsoap.org/ws/2005/02/trust/RST/Issue' + + def __init__(self, mex_document): + self.dom = ET.fromstring(mex_document) + + def _get_policy_ids(self, components_to_leaf, binding_xpath): + id_attr = '{%s}Id' % self.NS['wsu'] + return set(["#{}".format(policy.get(id_attr)) + for policy in self.dom.findall(_xpath_of_root(components_to_leaf), self.NS) + # If we did not find any binding, this is potentially bad. + if policy.find(binding_xpath, self.NS) is not None]) + + def _get_username_password_policy_ids(self): + path = ['wsp:Policy', 'wsp:ExactlyOne', 'wsp:All', + 'sp:SignedEncryptedSupportingTokens', 'wsp:Policy', + 'sp:UsernameToken', 'wsp:Policy', 'sp:WssUsernameToken10'] + policies = self._get_policy_ids(path, './/sp:TransportBinding') + path2005 = ['wsp:Policy', 'wsp:ExactlyOne', 'wsp:All', + 'sp2005:SignedSupportingTokens', 'wsp:Policy', + 'sp2005:UsernameToken', 'wsp:Policy', 'sp2005:WssUsernameToken10'] + policies.update(self._get_policy_ids(path2005, './/sp2005:TransportBinding')) + return policies + + def _get_iwa_policy_ids(self): + return self._get_policy_ids( + ['wsp:Policy', 'wsp:ExactlyOne', 'wsp:All', 'http:NegotiateAuthentication'], + './/sp2005:TransportBinding') + + def _get_bindings(self): + bindings = {} # {binding_name: {"policy_uri": "...", "version": "..."}} + for binding in self.dom.findall("wsdl:binding", self.NS): + if (binding.find('soap12:binding', self.NS).get("transport") != + 'http://schemas.xmlsoap.org/soap/http'): + continue + action = binding.find( + 'wsdl:operation/soap12:operation', self.NS).get("soapAction") + for pr in binding.findall("wsp:PolicyReference", self.NS): + bindings[binding.get("name")] = { + "policy_uri": pr.get("URI"), "action": action} + return bindings + + def _get_endpoints(self, bindings, policy_ids): + endpoints = [] + for port in self.dom.findall('wsdl:service/wsdl:port', self.NS): + binding_name = port.get("binding").split(':')[-1] # Should have 2 parts + binding = bindings.get(binding_name) + if binding and binding["policy_uri"] in policy_ids: + address = port.find('wsa10:EndpointReference/wsa10:Address', self.NS) + if address is not None and address.text.lower().startswith("https://"): + endpoints.append( + {"address": address.text, "action": binding["action"]}) + return endpoints + + def get_wstrust_username_password_endpoint(self): + """Returns {"address": "https://...", "action": "the soapAction value"}""" + endpoints = self._get_endpoints( + self._get_bindings(), self._get_username_password_policy_ids()) + for e in endpoints: + if e["action"] == self.ACTION_13: + return e # Historically, we prefer ACTION_13 a.k.a. WsTrust13 + return endpoints[0] if endpoints else None + diff --git a/oauth2cli/__init__.py b/msal/oauth2cli/__init__.py similarity index 100% rename from oauth2cli/__init__.py rename to msal/oauth2cli/__init__.py diff --git a/oauth2cli/assertion.py b/msal/oauth2cli/assertion.py similarity index 99% rename from oauth2cli/assertion.py rename to msal/oauth2cli/assertion.py index f01bb2d0..419bb14e 100644 --- a/oauth2cli/assertion.py +++ b/msal/oauth2cli/assertion.py @@ -4,8 +4,6 @@ import uuid import logging -import jwt - logger = logging.getLogger(__name__) @@ -99,6 +97,7 @@ def create_normal_assertion( Parameters are defined in https://tools.ietf.org/html/rfc7523#section-3 Key-value pairs in additional_claims will be added into payload as-is. """ + import jwt # Lazy loading now = time.time() payload = { 'aud': audience, diff --git a/oauth2cli/authcode.py b/msal/oauth2cli/authcode.py similarity index 100% rename from oauth2cli/authcode.py rename to msal/oauth2cli/authcode.py diff --git a/oauth2cli/http.py b/msal/oauth2cli/http.py similarity index 100% rename from oauth2cli/http.py rename to msal/oauth2cli/http.py diff --git a/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py similarity index 99% rename from oauth2cli/oauth2.py rename to msal/oauth2cli/oauth2.py index 6cb31bbb..90f576af 100644 --- a/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -17,8 +17,6 @@ import string import hashlib -import requests - from .authcode import AuthCodeReceiver as _AuthCodeReceiver try: @@ -158,6 +156,8 @@ def __init__( "when http_client is in use") self._http_client = http_client else: + import requests # Lazy loading + self._http_client = requests.Session() self._http_client.verify = True if verify is None else verify self._http_client.proxies = proxies diff --git a/oauth2cli/oidc.py b/msal/oauth2cli/oidc.py similarity index 100% rename from oauth2cli/oidc.py rename to msal/oauth2cli/oidc.py diff --git a/msal/region.py b/msal/region.py new file mode 100644 index 00000000..c540dc71 --- /dev/null +++ b/msal/region.py @@ -0,0 +1,45 @@ +import os +import logging + +logger = logging.getLogger(__name__) + + +def _detect_region(http_client=None): + region = os.environ.get("REGION_NAME", "").replace(" ", "").lower() # e.g. westus2 + if region: + return region + if http_client: + return _detect_region_of_azure_vm(http_client) # It could hang for minutes + return None + + +def _detect_region_of_azure_vm(http_client): + url = ( + "http://169.254.169.254/metadata/instance" + + # Utilize the "route parameters" feature to obtain region as a string + # https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service?tabs=linux#route-parameters + "/compute/location?format=text" + + # Location info is available since API version 2017-04-02 + # https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service?tabs=linux#response-1 + "&api-version=2021-01-01" + ) + logger.info( + "Connecting to IMDS {}. " + "It may take a while if you are running outside of Azure. " + "You should consider opting in/out region behavior on-demand, " + 'by loading a boolean flag "is_deployed_in_azure" ' + 'from your per-deployment config and then do ' + '"app = ConfidentialClientApplication(..., ' + 'azure_region=is_deployed_in_azure)"'.format(url)) + try: + # https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service?tabs=linux#instance-metadata + resp = http_client.get(url, headers={"Metadata": "true"}) + except: + logger.info( + "IMDS {} unavailable. Perhaps not running in Azure VM?".format(url)) + return None + else: + return resp.text.strip() + diff --git a/msal/telemetry.py b/msal/telemetry.py new file mode 100644 index 00000000..b07ab3ed --- /dev/null +++ b/msal/telemetry.py @@ -0,0 +1,78 @@ +import uuid +import logging + + +logger = logging.getLogger(__name__) + +CLIENT_REQUEST_ID = 'client-request-id' +CLIENT_CURRENT_TELEMETRY = "x-client-current-telemetry" +CLIENT_LAST_TELEMETRY = "x-client-last-telemetry" +NON_SILENT_CALL = 0 +FORCE_REFRESH = 1 +AT_ABSENT = 2 +AT_EXPIRED = 3 +AT_AGING = 4 +RESERVED = 5 + + +def _get_new_correlation_id(): + return str(uuid.uuid4()) + + +class _TelemetryContext(object): + """It is used for handling the telemetry context for current OAuth2 "exchange".""" + # https://identitydivision.visualstudio.com/DevEx/_git/AuthLibrariesApiReview?path=%2FTelemetry%2FMSALServerSideTelemetry.md&_a=preview + _SUCCEEDED = "succeeded" + _FAILED = "failed" + _FAILURE_SIZE = "failure_size" + _CURRENT_HEADER_SIZE_LIMIT = 100 + _LAST_HEADER_SIZE_LIMIT = 350 + + def __init__(self, buffer, lock, api_id, correlation_id=None, refresh_reason=None): + self._buffer = buffer + self._lock = lock + self._api_id = api_id + self._correlation_id = correlation_id or _get_new_correlation_id() + self._refresh_reason = refresh_reason or NON_SILENT_CALL + logger.debug("Generate or reuse correlation_id: %s", self._correlation_id) + + def generate_headers(self): + with self._lock: + current = "4|{api_id},{cache_refresh}|".format( + api_id=self._api_id, cache_refresh=self._refresh_reason) + if len(current) > self._CURRENT_HEADER_SIZE_LIMIT: + logger.warning( + "Telemetry header greater than {} will be truncated by AAD".format( + self._CURRENT_HEADER_SIZE_LIMIT)) + failures = self._buffer.get(self._FAILED, []) + return { + CLIENT_REQUEST_ID: self._correlation_id, + CLIENT_CURRENT_TELEMETRY: current, + CLIENT_LAST_TELEMETRY: "4|{succeeded}|{failed_requests}|{errors}|".format( + succeeded=self._buffer.get(self._SUCCEEDED, 0), + failed_requests=",".join("{a},{c}".format(**f) for f in failures), + errors=",".join(f["e"] for f in failures), + ) + } + + def hit_an_access_token(self): + with self._lock: + self._buffer[self._SUCCEEDED] = self._buffer.get(self._SUCCEEDED, 0) + 1 + + def update_telemetry(self, auth_result): + if auth_result: + with self._lock: + if "error" in auth_result: + self._record_failure(auth_result["error"]) + else: # Telemetry sent successfully. Reset buffer + self._buffer.clear() # This won't work: self._buffer = {} + + def _record_failure(self, error): + simulation = len(",{api_id},{correlation_id},{error}".format( + api_id=self._api_id, correlation_id=self._correlation_id, error=error)) + if self._buffer.get(self._FAILURE_SIZE, 0) + simulation < self._LAST_HEADER_SIZE_LIMIT: + self._buffer[self._FAILURE_SIZE] = self._buffer.get( + self._FAILURE_SIZE, 0) + simulation + self._buffer.setdefault(self._FAILED, []).append({ + "a": self._api_id, "c": self._correlation_id, "e": error}) + diff --git a/msal/throttled_http_client.py b/msal/throttled_http_client.py new file mode 100644 index 00000000..378cd3df --- /dev/null +++ b/msal/throttled_http_client.py @@ -0,0 +1,141 @@ +from threading import Lock +from hashlib import sha256 + +from .individual_cache import _IndividualCache as IndividualCache +from .individual_cache import _ExpiringMapping as ExpiringMapping + + +# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4 +DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code" + + +def _hash(raw): + return sha256(repr(raw).encode("utf-8")).hexdigest() + + +def _parse_http_429_5xx_retry_after(result=None, **ignored): + """Return seconds to throttle""" + assert result is not None, """ + The signature defines it with a default value None, + only because the its shape is already decided by the + IndividualCache's.__call__(). + In actual code path, the result parameter here won't be None. + """ + response = result + lowercase_headers = {k.lower(): v for k, v in getattr( + # Historically, MSAL's HttpResponse does not always have headers + response, "headers", {}).items()} + if not (response.status_code == 429 or response.status_code >= 500 + or "retry-after" in lowercase_headers): + return 0 # Quick exit + default = 60 # Recommended at the end of + # https://identitydivision.visualstudio.com/devex/_git/AuthLibrariesApiReview?version=GBdev&path=%2FService%20protection%2FIntial%20set%20of%20protection%20measures.md&_a=preview + retry_after = int(lowercase_headers.get("retry-after", default)) + try: + # AAD's retry_after uses integer format only + # https://stackoverflow.microsoft.com/questions/264931/264932 + delay_seconds = int(retry_after) + except ValueError: + delay_seconds = default + return min(3600, delay_seconds) + + +def _extract_data(kwargs, key, default=None): + data = kwargs.get("data", {}) # data is usually a dict, but occasionally a string + return data.get(key) if isinstance(data, dict) else default + + +class ThrottledHttpClient(object): + def __init__(self, http_client, http_cache): + """Throttle the given http_client by storing and retrieving data from cache. + + This wrapper exists so that our patching post() and get() would prevent + re-patching side effect when/if same http_client being reused. + """ + expiring_mapping = ExpiringMapping( # It will automatically clean up + mapping=http_cache if http_cache is not None else {}, + capacity=1024, # To prevent cache blowing up especially for CCA + lock=Lock(), # TODO: This should ideally also allow customization + ) + + _post = http_client.post # We'll patch _post, and keep original post() intact + + _post = IndividualCache( + # Internal specs requires throttling on at least token endpoint, + # here we have a generic patch for POST on all endpoints. + mapping=expiring_mapping, + key_maker=lambda func, args, kwargs: + "POST {} client_id={} scope={} hash={} 429/5xx/Retry-After".format( + args[0], # It is the url, typically containing authority and tenant + _extract_data(kwargs, "client_id"), # Per internal specs + _extract_data(kwargs, "scope"), # Per internal specs + _hash( + # The followings are all approximations of the "account" concept + # to support per-account throttling. + # TODO: We may want to disable it for confidential client, though + _extract_data(kwargs, "refresh_token", # "account" during refresh + _extract_data(kwargs, "code", # "account" of auth code grant + _extract_data(kwargs, "username")))), # "account" of ROPC + ), + expires_in=_parse_http_429_5xx_retry_after, + )(_post) + + _post = IndividualCache( # It covers the "UI required cache" + mapping=expiring_mapping, + key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format( + args[0], # It is the url, typically containing authority and tenant + _hash( + # Here we use literally all parameters, even those short-lived + # parameters containing timestamps (WS-Trust or POP assertion), + # because they will automatically be cleaned up by ExpiringMapping. + # + # Furthermore, there is no need to implement + # "interactive requests would reset the cache", + # because acquire_token_silent()'s would be automatically unblocked + # due to token cache layer operates on top of http cache layer. + # + # And, acquire_token_silent(..., force_refresh=True) will NOT + # bypass http cache, because there is no real gain from that. + # We won't bother implement it, nor do we want to encourage + # acquire_token_silent(..., force_refresh=True) pattern. + str(kwargs.get("params")) + str(kwargs.get("data"))), + ), + expires_in=lambda result=None, kwargs=None, **ignored: + 60 + if result.status_code == 400 + # Here we choose to cache exact HTTP 400 errors only (rather than 4xx) + # because they are the ones defined in OAuth2 + # (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2) + # Other 4xx errors might have different requirements e.g. + # "407 Proxy auth required" would need a key including http headers. + and not( # Exclude Device Flow whose retry is expected and regulated + isinstance(kwargs.get("data"), dict) + and kwargs["data"].get("grant_type") == DEVICE_AUTH_GRANT + ) + and "retry-after" not in set( # Leave it to the Retry-After decorator + h.lower() for h in getattr(result, "headers", {}).keys()) + else 0, + )(_post) + + self.post = _post + + self.get = IndividualCache( # Typically those discovery GETs + mapping=expiring_mapping, + key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format( + args[0], # It is the url, sometimes containing inline params + _hash(kwargs.get("params", "")), + ), + expires_in=lambda result=None, **ignored: + 3600*24 if 200 <= result.status_code < 300 else 0, + )(http_client.get) + + self._http_client = http_client + + # The following 2 methods have been defined dynamically by __init__() + #def post(self, *args, **kwargs): pass + #def get(self, *args, **kwargs): pass + + def close(self): + """MSAL won't need this. But we allow throttled_http_client.close() anyway""" + return self._http_client.close() + diff --git a/msal/token_cache.py b/msal/token_cache.py new file mode 100644 index 00000000..f7d9f955 --- /dev/null +++ b/msal/token_cache.py @@ -0,0 +1,331 @@ +import json +import threading +import time +import logging + +from .authority import canonicalize +from .oauth2cli.oidc import decode_part, decode_id_token + + +logger = logging.getLogger(__name__) + +def is_subdict_of(small, big): + return dict(big, **small) == big + + +class TokenCache(object): + """This is considered as a base class containing minimal cache behavior. + + Although it maintains tokens using unified schema across all MSAL libraries, + this class does not serialize/persist them. + See subclass :class:`SerializableTokenCache` for details on serialization. + """ + + class CredentialType: + ACCESS_TOKEN = "AccessToken" + REFRESH_TOKEN = "RefreshToken" + ACCOUNT = "Account" # Not exactly a credential type, but we put it here + ID_TOKEN = "IdToken" + APP_METADATA = "AppMetadata" + + class AuthorityType: + ADFS = "ADFS" + MSSTS = "MSSTS" # MSSTS means AAD v2 for both AAD & MSA + + def __init__(self): + self._lock = threading.RLock() + self._cache = {} + self.key_makers = { + self.CredentialType.REFRESH_TOKEN: + lambda home_account_id=None, environment=None, client_id=None, + target=None, **ignored_payload_from_a_real_token: + "-".join([ + home_account_id or "", + environment or "", + self.CredentialType.REFRESH_TOKEN, + client_id or "", + "", # RT is cross-tenant in AAD + target or "", # raw value could be None if deserialized from other SDK + ]).lower(), + self.CredentialType.ACCESS_TOKEN: + lambda home_account_id=None, environment=None, client_id=None, + realm=None, target=None, **ignored_payload_from_a_real_token: + "-".join([ + home_account_id or "", + environment or "", + self.CredentialType.ACCESS_TOKEN, + client_id or "", + realm or "", + target or "", + ]).lower(), + self.CredentialType.ID_TOKEN: + lambda home_account_id=None, environment=None, client_id=None, + realm=None, **ignored_payload_from_a_real_token: + "-".join([ + home_account_id or "", + environment or "", + self.CredentialType.ID_TOKEN, + client_id or "", + realm or "", + "" # Albeit irrelevant, schema requires an empty scope here + ]).lower(), + self.CredentialType.ACCOUNT: + lambda home_account_id=None, environment=None, realm=None, + **ignored_payload_from_a_real_entry: + "-".join([ + home_account_id or "", + environment or "", + realm or "", + ]).lower(), + self.CredentialType.APP_METADATA: + lambda environment=None, client_id=None, **kwargs: + "appmetadata-{}-{}".format(environment or "", client_id or ""), + } + + def find(self, credential_type, target=None, query=None): + target = target or [] + assert isinstance(target, list), "Invalid parameter type" + target_set = set(target) + with self._lock: + # Since the target inside token cache key is (per schema) unsorted, + # there is no point to attempt an O(1) key-value search here. + # So we always do an O(n) in-memory search. + return [entry + for entry in self._cache.get(credential_type, {}).values() + if is_subdict_of(query or {}, entry) + and (target_set <= set(entry.get("target", "").split()) + if target else True) + ] + + def add(self, event, now=None): + # type: (dict) -> None + """Handle a token obtaining event, and add tokens into cache. + + Known side effects: This function modifies the input event in place. + """ + def wipe(dictionary, sensitive_fields): # Masks sensitive info + for sensitive in sensitive_fields: + if sensitive in dictionary: + dictionary[sensitive] = "********" + wipe(event.get("data", {}), + ("password", "client_secret", "refresh_token", "assertion")) + try: + return self.__add(event, now=now) + finally: + wipe(event.get("response", {}), ( # These claims were useful during __add() + "id_token_claims", # Provided by broker + "access_token", "refresh_token", "id_token", "username")) + wipe(event, ["username"]) # Needed for federated ROPC + logger.debug("event=%s", json.dumps( + # We examined and concluded that this log won't have Log Injection risk, + # because the event payload is already in JSON so CR/LF will be escaped. + event, indent=4, sort_keys=True, + default=str, # A workaround when assertion is in bytes in Python 3 + )) + + def __parse_account(self, response, id_token_claims): + """Return client_info and home_account_id""" + if "client_info" in response: # It happens when client_info and profile are in request + client_info = json.loads(decode_part(response["client_info"])) + if "uid" in client_info and "utid" in client_info: + return client_info, "{uid}.{utid}".format(**client_info) + # https://github.com/AzureAD/microsoft-authentication-library-for-python/issues/387 + if id_token_claims: # This would be an end user on ADFS-direct scenario + sub = id_token_claims["sub"] # "sub" always exists, per OIDC specs + return {"uid": sub}, sub + # client_credentials flow will reach this code path + return {}, None + + def __add(self, event, now=None): + # event typically contains: client_id, scope, token_endpoint, + # response, params, data, grant_type + environment = realm = None + if "token_endpoint" in event: + _, environment, realm = canonicalize(event["token_endpoint"]) + if "environment" in event: # Always available unless in legacy test cases + environment = event["environment"] # Set by application.py + response = event.get("response", {}) + data = event.get("data", {}) + access_token = response.get("access_token") + refresh_token = response.get("refresh_token") + id_token = response.get("id_token") + id_token_claims = ( + decode_id_token(id_token, client_id=event["client_id"]) + if id_token + else response.get("id_token_claims", {})) # Broker would provide id_token_claims + client_info, home_account_id = self.__parse_account(response, id_token_claims) + + target = ' '.join(event.get("scope") or []) # Per schema, we don't sort it + + with self._lock: + now = int(time.time() if now is None else now) + + if access_token: + expires_in = int( # AADv1-like endpoint returns a string + response.get("expires_in", 3599)) + ext_expires_in = int( # AADv1-like endpoint returns a string + response.get("ext_expires_in", expires_in)) + at = { + "credential_type": self.CredentialType.ACCESS_TOKEN, + "secret": access_token, + "home_account_id": home_account_id, + "environment": environment, + "client_id": event.get("client_id"), + "target": target, + "realm": realm, + "token_type": response.get("token_type", "Bearer"), + "cached_at": str(now), # Schema defines it as a string + "expires_on": str(now + expires_in), # Same here + "extended_expires_on": str(now + ext_expires_in) # Same here + } + if data.get("key_id"): # It happens in SSH-cert or POP scenario + at["key_id"] = data.get("key_id") + if "refresh_in" in response: + refresh_in = response["refresh_in"] # It is an integer + at["refresh_on"] = str(now + refresh_in) # Schema wants a string + self.modify(self.CredentialType.ACCESS_TOKEN, at, at) + + if client_info and not event.get("skip_account_creation"): + account = { + "home_account_id": home_account_id, + "environment": environment, + "realm": realm, + "local_account_id": id_token_claims.get( + "oid", id_token_claims.get("sub")), + "username": id_token_claims.get("preferred_username") # AAD + or id_token_claims.get("upn") # ADFS 2019 + or data.get("username") # Falls back to ROPC username + or event.get("username") # Falls back to Federated ROPC username + or "", # The schema does not like null + "authority_type": event.get( + "authority_type", # Honor caller's choice of authority_type + self.AuthorityType.ADFS if realm == "adfs" + else self.AuthorityType.MSSTS), + # "client_info": response.get("client_info"), # Optional + } + self.modify(self.CredentialType.ACCOUNT, account, account) + + if id_token: + idt = { + "credential_type": self.CredentialType.ID_TOKEN, + "secret": id_token, + "home_account_id": home_account_id, + "environment": environment, + "realm": realm, + "client_id": event.get("client_id"), + # "authority": "it is optional", + } + self.modify(self.CredentialType.ID_TOKEN, idt, idt) + + if refresh_token: + rt = { + "credential_type": self.CredentialType.REFRESH_TOKEN, + "secret": refresh_token, + "home_account_id": home_account_id, + "environment": environment, + "client_id": event.get("client_id"), + "target": target, # Optional per schema though + "last_modification_time": str(now), # Optional. Schema defines it as a string. + } + if "foci" in response: + rt["family_id"] = response["foci"] + self.modify(self.CredentialType.REFRESH_TOKEN, rt, rt) + + app_metadata = { + "client_id": event.get("client_id"), + "environment": environment, + } + if "foci" in response: + app_metadata["family_id"] = response.get("foci") + self.modify(self.CredentialType.APP_METADATA, app_metadata, app_metadata) + + def modify(self, credential_type, old_entry, new_key_value_pairs=None): + # Modify the specified old_entry with new_key_value_pairs, + # or remove the old_entry if the new_key_value_pairs is None. + + # This helper exists to consolidate all token add/modify/remove behaviors, + # so that the sub-classes will have only one method to work on, + # instead of patching a pair of update_xx() and remove_xx() per type. + # You can monkeypatch self.key_makers to support more types on-the-fly. + key = self.key_makers[credential_type](**old_entry) + with self._lock: + if new_key_value_pairs: # Update with them + entries = self._cache.setdefault(credential_type, {}) + entries[key] = dict( + old_entry, # Do not use entries[key] b/c it might not exist + **new_key_value_pairs) + else: # Remove old_entry + self._cache.setdefault(credential_type, {}).pop(key, None) + + def remove_rt(self, rt_item): + assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN + return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item) + + def update_rt(self, rt_item, new_rt): + assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN + return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item, { + "secret": new_rt, + "last_modification_time": str(int(time.time())), # Optional. Schema defines it as a string. + }) + + def remove_at(self, at_item): + assert at_item.get("credential_type") == self.CredentialType.ACCESS_TOKEN + return self.modify(self.CredentialType.ACCESS_TOKEN, at_item) + + def remove_idt(self, idt_item): + assert idt_item.get("credential_type") == self.CredentialType.ID_TOKEN + return self.modify(self.CredentialType.ID_TOKEN, idt_item) + + def remove_account(self, account_item): + assert "authority_type" in account_item + return self.modify(self.CredentialType.ACCOUNT, account_item) + + +class SerializableTokenCache(TokenCache): + """This serialization can be a starting point to implement your own persistence. + + This class does NOT actually persist the cache on disk/db/etc.. + Depending on your need, + the following simple recipe for file-based persistence may be sufficient:: + + import os, atexit, msal + cache = msal.SerializableTokenCache() + if os.path.exists("my_cache.bin"): + cache.deserialize(open("my_cache.bin", "r").read()) + atexit.register(lambda: + open("my_cache.bin", "w").write(cache.serialize()) + # Hint: The following optional line persists only when state changed + if cache.has_state_changed else None + ) + app = msal.ClientApplication(..., token_cache=cache) + ... + + :var bool has_state_changed: + Indicates whether the cache state in the memory has changed since last + :func:`~serialize` or :func:`~deserialize` call. + """ + has_state_changed = False + + def add(self, event, **kwargs): + super(SerializableTokenCache, self).add(event, **kwargs) + self.has_state_changed = True + + def modify(self, credential_type, old_entry, new_key_value_pairs=None): + super(SerializableTokenCache, self).modify( + credential_type, old_entry, new_key_value_pairs) + self.has_state_changed = True + + def deserialize(self, state): + # type: (Optional[str]) -> None + """Deserialize the cache from a state previously obtained by serialize()""" + with self._lock: + self._cache = json.loads(state) if state else {} + self.has_state_changed = False # reset + + def serialize(self): + # type: () -> str + """Serialize the current cache state into a string.""" + with self._lock: + self.has_state_changed = False + return json.dumps(self._cache, indent=4) + diff --git a/msal/wstrust_request.py b/msal/wstrust_request.py new file mode 100644 index 00000000..43a2804f --- /dev/null +++ b/msal/wstrust_request.py @@ -0,0 +1,129 @@ +#------------------------------------------------------------------------------ +# +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files(the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions : +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +#------------------------------------------------------------------------------ + +import uuid +from datetime import datetime, timedelta +import logging + +from .mex import Mex +from .wstrust_response import parse_response + +logger = logging.getLogger(__name__) + +def send_request( + username, password, cloud_audience_urn, endpoint_address, soap_action, http_client, + **kwargs): + if not endpoint_address: + raise ValueError("WsTrust endpoint address can not be empty") + if soap_action is None: + if '/trust/2005/usernamemixed' in endpoint_address: + soap_action = Mex.ACTION_2005 + elif '/trust/13/usernamemixed' in endpoint_address: + soap_action = Mex.ACTION_13 + if soap_action not in (Mex.ACTION_13, Mex.ACTION_2005): + raise ValueError("Unsupported soap action: %s. " + "Contact your administrator to check your ADFS's MEX settings." % soap_action) + data = _build_rst( + username, password, cloud_audience_urn, endpoint_address, soap_action) + resp = http_client.post(endpoint_address, data=data, headers={ + 'Content-type':'application/soap+xml; charset=utf-8', + 'SOAPAction': soap_action, + }, **kwargs) + if resp.status_code >= 400: + logger.debug("Unsuccessful WsTrust request receives: %s", resp.text) + # It turns out ADFS uses 5xx status code even with client-side incorrect password error + # resp.raise_for_status() + return parse_response(resp.text) + + +def escape_password(password): + return (password.replace('&', '&').replace('"', '"') + .replace("'", ''') # the only one not provided by cgi.escape(s, True) + .replace('<', '<').replace('>', '>')) + + +def wsu_time_format(datetime_obj): + # WsTrust (http://docs.oasis-open.org/ws-sx/ws-trust/v1.4/ws-trust.html) + # does not seem to define timestamp format, but we see YYYY-mm-ddTHH:MM:SSZ + # here (https://www.ibm.com/developerworks/websphere/library/techarticles/1003_chades/1003_chades.html) + # It avoids the uncertainty of the optional ".ssssss" in datetime.isoformat() + # https://docs.python.org/2/library/datetime.html#datetime.datetime.isoformat + return datetime_obj.strftime('%Y-%m-%dT%H:%M:%SZ') + + +def _build_rst(username, password, cloud_audience_urn, endpoint_address, soap_action): + now = datetime.utcnow() + return """ + + {soap_action} + urn:uuid:{message_id} + + http://www.w3.org/2005/08/addressing/anonymous + + {endpoint_address} + + + + {time_now} + {time_expire} + + + {username} + {password} + + + + + + + + + {applies_to} + + + {key_type} + {request_type} + + + """.format( + s=Mex.NS["s"], wsu=Mex.NS["wsu"], wsa=Mex.NS["wsa10"], + soap_action=soap_action, message_id=str(uuid.uuid4()), + endpoint_address=endpoint_address, + time_now=wsu_time_format(now), + time_expire=wsu_time_format(now + timedelta(minutes=10)), + username=username, password=escape_password(password), + wst=Mex.NS["wst"] if soap_action == Mex.ACTION_13 else Mex.NS["wst2005"], + applies_to=cloud_audience_urn, + key_type='http://docs.oasis-open.org/ws-sx/ws-trust/200512/Bearer' + if soap_action == Mex.ACTION_13 else + 'http://schemas.xmlsoap.org/ws/2005/05/identity/NoProofKey', + request_type='http://docs.oasis-open.org/ws-sx/ws-trust/200512/Issue' + if soap_action == Mex.ACTION_13 else + 'http://schemas.xmlsoap.org/ws/2005/02/trust/Issue', + ) + diff --git a/msal/wstrust_response.py b/msal/wstrust_response.py new file mode 100644 index 00000000..9c58af23 --- /dev/null +++ b/msal/wstrust_response.py @@ -0,0 +1,94 @@ +#------------------------------------------------------------------------------ +# +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files(the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions : +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +#------------------------------------------------------------------------------ + +try: + from xml.etree import cElementTree as ET +except ImportError: + from xml.etree import ElementTree as ET +import re + +from .mex import Mex + + +SAML_TOKEN_TYPE_V1 = 'urn:oasis:names:tc:SAML:1.0:assertion' +SAML_TOKEN_TYPE_V2 = 'urn:oasis:names:tc:SAML:2.0:assertion' + +# http://docs.oasis-open.org/wss-m/wss/v1.1.1/os/wss-SAMLTokenProfile-v1.1.1-os.html#_Toc307397288 +WSS_SAML_TOKEN_PROFILE_V1_1 = "http://docs.oasis-open.org/wss/oasis-wss-saml-token-profile-1.1#SAMLV1.1" +WSS_SAML_TOKEN_PROFILE_V2 = "http://docs.oasis-open.org/wss/oasis-wss-saml-token-profile-1.1#SAMLV2.0" + +def parse_response(body): # Returns {"token": "", "type": "..."} + token = parse_token_by_re(body) + if token: + return token + error = parse_error(body) + raise RuntimeError("WsTrust server returned error in RSTR: %s" % (error or body)) + +def parse_error(body): # Returns error as a dict. See unit test case for an example. + dom = ET.fromstring(body) + reason_text_node = dom.find('s:Body/s:Fault/s:Reason/s:Text', Mex.NS) + subcode_value_node = dom.find('s:Body/s:Fault/s:Code/s:Subcode/s:Value', Mex.NS) + if reason_text_node is not None or subcode_value_node is not None: + return {"reason": reason_text_node.text, "code": subcode_value_node.text} + +def findall_content(xml_string, tag): + """ + Given a tag name without any prefix, + this function returns a list of the raw content inside this tag as-is. + + >>> findall_content(" what ever content ", "foo") + [" what ever content "] + + Motivation: + + Usually we would use XML parser to extract the data by xpath. + However the ElementTree in Python will implicitly normalize the output + by "hoisting" the inner inline namespaces into the outmost element. + The result will be a semantically equivalent XML snippet, + but not fully identical to the original one. + While this effect shouldn't become a problem in all other cases, + it does not seem to fully comply with Exclusive XML Canonicalization spec + (https://www.w3.org/TR/xml-exc-c14n/), and void the SAML token signature. + SAML signature algo needs the "XML -> C14N(XML) -> Signed(C14N(Xml))" order. + + The binary extention lxml is probably the canonical way to solve this + (https://stackoverflow.com/questions/22959577/python-exclusive-xml-canonicalization-xml-exc-c14n) + but here we use this workaround, based on Regex, to return raw content as-is. + """ + # \w+ is good enough for https://www.w3.org/TR/REC-xml/#NT-NameChar + pattern = r"<(?:\w+:)?%(tag)s(?:[^>]*)>(.*)=2.0.0,<3', - 'PyJWT>=1.0.0,<3', + 'PyJWT[crypto]>=1.0.0,<3', # MSAL does not use jwt.decode(), therefore is insusceptible to CVE-2022-29217 so no need to bump to PyJWT 2.4+ + + 'cryptography>=0.6,<40', + # load_pem_private_key() is available since 0.6 + # https://github.com/pyca/cryptography/blob/master/CHANGELOG.rst#06---2014-09-29 + # + # And we will use the cryptography (X+3).0.0 as the upper bound, + # based on their latest deprecation policy + # https://cryptography.io/en/latest/api-stability/#deprecation + + "mock;python_version<'3.3'", ] ) + diff --git a/tests/__init__.py b/tests/__init__.py index 8c717de0..741b1e08 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,5 +1,6 @@ import sys import logging + if sys.version_info[:2] < (2, 7): # The unittest module got a significant overhaul in Python 2.7, # so if we're in 2.6 we can use the backported version unittest2. diff --git a/tests/archan.us.mex.xml b/tests/archan.us.mex.xml new file mode 100644 index 00000000..00bc35eb --- /dev/null +++ b/tests/archan.us.mex.xml @@ -0,0 +1,915 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + http://schemas.xmlsoap.org/ws/2005/02/trust/PublicKey + http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p + http://www.w3.org/2000/09/xmldsig#rsa-sha1 + http://www.w3.org/2001/10/xml-exc-c14n# + http://www.w3.org/2001/04/xmlenc#aes256-cbc + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + http://schemas.xmlsoap.org/ws/2005/02/trust/SymmetricKey + 256 + http://www.w3.org/2001/04/xmlenc#aes256-cbc + http://www.w3.org/2000/09/xmldsig#hmac-sha1 + http://www.w3.org/2001/10/xml-exc-c14n# + http://www.w3.org/2001/04/xmlenc#aes256-cbc + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + http://docs.oasis-open.org/ws-sx/ws-trust/200512/PublicKey + http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p + http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p + http://www.w3.org/2000/09/xmldsig#rsa-sha1 + http://www.w3.org/2001/10/xml-exc-c14n# + http://www.w3.org/2001/04/xmlenc#aes256-cbc + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + http://docs.oasis-open.org/ws-sx/ws-trust/200512/SymmetricKey + 256 + http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p + http://www.w3.org/2001/04/xmlenc#aes256-cbc + http://www.w3.org/2000/09/xmldsig#hmac-sha1 + http://www.w3.org/2001/10/xml-exc-c14n# + http://www.w3.org/2001/04/xmlenc#aes256-cbc + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + https://arvmserver2012.archan.us/adfs/services/trust/2005/windowstransport + + host/ARVMServer2012.archan.us + + + + + + + https://arvmserver2012.archan.us/adfs/services/trust/2005/certificatemixed + + + + + + https://arvmserver2012.archan.us/adfs/services/trust/2005/certificatetransport + + + + + + https://arvmserver2012.archan.us/adfs/services/trust/2005/usernamemixed + + + + + + https://arvmserver2012.archan.us/adfs/services/trust/2005/kerberosmixed + + + + + + https://arvmserver2012.archan.us/adfs/services/trust/2005/issuedtokenmixedasymmetricbasic256 + + + + + + https://arvmserver2012.archan.us/adfs/services/trust/2005/issuedtokenmixedsymmetricbasic256 + + + + + + https://arvmserver2012.archan.us/adfs/services/trust/13/kerberosmixed + + + + + + https://arvmserver2012.archan.us/adfs/services/trust/13/certificatemixed + + + + + + https://arvmserver2012.archan.us/adfs/services/trust/13/usernamemixed + + + + + + https://arvmserver2012.archan.us/adfs/services/trust/13/issuedtokenmixedasymmetricbasic256 + + + + + + https://arvmserver2012.archan.us/adfs/services/trust/13/issuedtokenmixedsymmetricbasic256 + + + + + + https://arvmserver2012.archan.us/adfs/services/trust/13/windowstransport + + host/ARVMServer2012.archan.us + + + + + \ No newline at end of file diff --git a/tests/arupela.mex.xml b/tests/arupela.mex.xml new file mode 100644 index 00000000..03dd84c1 --- /dev/null +++ b/tests/arupela.mex.xml @@ -0,0 +1,866 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + http://schemas.xmlsoap.org/ws/2005/02/trust/PublicKey + http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p + http://www.w3.org/2000/09/xmldsig#rsa-sha1 + http://www.w3.org/2001/10/xml-exc-c14n# + http://www.w3.org/2001/04/xmlenc#aes256-cbc + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + http://schemas.xmlsoap.org/ws/2005/02/trust/SymmetricKey + 256 + http://www.w3.org/2001/04/xmlenc#aes256-cbc + http://www.w3.org/2000/09/xmldsig#hmac-sha1 + http://www.w3.org/2001/10/xml-exc-c14n# + http://www.w3.org/2001/04/xmlenc#aes256-cbc + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + http://docs.oasis-open.org/ws-sx/ws-trust/200512/PublicKey + http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p + http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p + http://www.w3.org/2000/09/xmldsig#rsa-sha1 + http://www.w3.org/2001/10/xml-exc-c14n# + http://www.w3.org/2001/04/xmlenc#aes256-cbc + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + http://docs.oasis-open.org/ws-sx/ws-trust/200512/SymmetricKey + 256 + http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p + http://www.w3.org/2001/04/xmlenc#aes256-cbc + http://www.w3.org/2000/09/xmldsig#hmac-sha1 + http://www.w3.org/2001/10/xml-exc-c14n# + http://www.w3.org/2001/04/xmlenc#aes256-cbc + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + https://fs.arupela.com/adfs/services/trust/2005/windowstransport + + host/fs.arupela.com + + + + + + + https://fs.arupela.com/adfs/services/trust/2005/certificatemixed + + + + + + https://fs.arupela.com:49443/adfs/services/trust/2005/certificatetransport + + + + + + https://fs.arupela.com/adfs/services/trust/2005/usernamemixed + + + + + + https://fs.arupela.com/adfs/services/trust/2005/kerberosmixed + + + + + + https://fs.arupela.com/adfs/services/trust/2005/issuedtokenmixedasymmetricbasic256 + + + + + + https://fs.arupela.com/adfs/services/trust/2005/issuedtokenmixedsymmetricbasic256 + + + + + + https://fs.arupela.com/adfs/services/trust/13/kerberosmixed + + + + + + https://fs.arupela.com/adfs/services/trust/13/certificatemixed + + + + + + https://fs.arupela.com/adfs/services/trust/13/usernamemixed + + + + + + https://fs.arupela.com/adfs/services/trust/13/issuedtokenmixedasymmetricbasic256 + + + + + + https://fs.arupela.com/adfs/services/trust/13/issuedtokenmixedsymmetricbasic256 + + + + \ No newline at end of file diff --git a/tests/http_client.py b/tests/http_client.py index 4bff9b45..5adbbded 100644 --- a/tests/http_client.py +++ b/tests/http_client.py @@ -10,14 +10,19 @@ def __init__(self, verify=True, proxies=None, timeout=None): self.timeout = timeout def post(self, url, params=None, data=None, headers=None, **kwargs): + assert not kwargs, "Our stack shouldn't leak extra kwargs: %s" % kwargs return MinimalResponse(requests_resp=self.session.post( url, params=params, data=data, headers=headers, timeout=self.timeout)) def get(self, url, params=None, headers=None, **kwargs): + assert not kwargs, "Our stack shouldn't leak extra kwargs: %s" % kwargs return MinimalResponse(requests_resp=self.session.get( url, params=params, headers=headers, timeout=self.timeout)) + def close(self): # Not required, but we use it to avoid a warning in unit test + self.session.close() + class MinimalResponse(object): # Not for production use def __init__(self, requests_resp=None, status_code=None, text=None): @@ -26,5 +31,6 @@ def __init__(self, requests_resp=None, status_code=None, text=None): self._raw_resp = requests_resp def raise_for_status(self): - if self._raw_resp: + if self._raw_resp is not None: # Turns out `if requests.response` won't work + # cause it would be True when 200<=status<400 self._raw_resp.raise_for_status() diff --git a/tests/microsoft.mex.xml b/tests/microsoft.mex.xml new file mode 100644 index 00000000..e68a8bef --- /dev/null +++ b/tests/microsoft.mex.xml @@ -0,0 +1,916 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + http://schemas.xmlsoap.org/ws/2005/02/trust/PublicKey + http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p + http://www.w3.org/2000/09/xmldsig#rsa-sha1 + http://www.w3.org/2001/10/xml-exc-c14n# + http://www.w3.org/2001/04/xmlenc#aes256-cbc + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + http://schemas.xmlsoap.org/ws/2005/02/trust/SymmetricKey + 256 + http://www.w3.org/2001/04/xmlenc#aes256-cbc + http://www.w3.org/2000/09/xmldsig#hmac-sha1 + http://www.w3.org/2001/10/xml-exc-c14n# + http://www.w3.org/2001/04/xmlenc#aes256-cbc + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + http://docs.oasis-open.org/ws-sx/ws-trust/200512/PublicKey + http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p + http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p + http://www.w3.org/2000/09/xmldsig#rsa-sha1 + http://www.w3.org/2001/10/xml-exc-c14n# + http://www.w3.org/2001/04/xmlenc#aes256-cbc + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + http://docs.oasis-open.org/ws-sx/ws-trust/200512/SymmetricKey + 256 + http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p + http://www.w3.org/2001/04/xmlenc#aes256-cbc + http://www.w3.org/2000/09/xmldsig#hmac-sha1 + http://www.w3.org/2001/10/xml-exc-c14n# + http://www.w3.org/2001/04/xmlenc#aes256-cbc + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + https://corp.sts.microsoft.com/adfs/services/trust/2005/windowstransport + + iamfed@redmond.corp.microsoft.com + + + + + + + https://corp.sts.microsoft.com/adfs/services/trust/2005/certificatemixed + + + + + + https://corp.sts.microsoft.com/adfs/services/trust/2005/usernamemixed + + + + + + https://corp.sts.microsoft.com/adfs/services/trust/2005/kerberosmixed + + + + + + https://corp.sts.microsoft.com/adfs/services/trust/2005/issuedtokenmixedasymmetricbasic256 + + + + + + https://corp.sts.microsoft.com/adfs/services/trust/2005/issuedtokenmixedsymmetricbasic256 + + + + + + https://corp.sts.microsoft.com/adfs/services/trust/13/kerberosmixed + + + + + + https://corp.sts.microsoft.com/adfs/services/trust/13/certificatemixed + + + + + + https://corp.sts.microsoft.com/adfs/services/trust/13/usernamemixed + + + + + + https://corp.sts.microsoft.com/adfs/services/trust/13/issuedtokenmixedasymmetricbasic256 + + + + + + https://corp.sts.microsoft.com/adfs/services/trust/13/issuedtokenmixedsymmetricbasic256 + + + + + + https://corp.sts.microsoft.com/adfs/services/trust/13/windowstransport + + iamfed@redmond.corp.microsoft.com + + + + + \ No newline at end of file diff --git a/tests/msaltest.py b/tests/msaltest.py new file mode 100644 index 00000000..c1ef1e7c --- /dev/null +++ b/tests/msaltest.py @@ -0,0 +1,178 @@ +import getpass, logging, pprint, sys, msal + + +def _input_boolean(message): + return input( + "{} (N/n/F/f or empty means False, otherwise it is True): ".format(message) + ) not in ('N', 'n', 'F', 'f', '') + +def _input(message, default=None): + return input(message.format(default=default)).strip() or default + +def _select_options( + options, header="Your options:", footer=" Your choice? ", option_renderer=str, + accept_nonempty_string=False, + ): + assert options, "options must not be empty" + if header: + print(header) + for i, o in enumerate(options, start=1): + print(" {}: {}".format(i, option_renderer(o))) + if accept_nonempty_string: + print(" Or you can just type in your input.") + while True: + raw_data = input(footer) + try: + choice = int(raw_data) + if 1 <= choice <= len(options): + return options[choice - 1] + except ValueError: + if raw_data and accept_nonempty_string: + return raw_data + +def _input_scopes(): + return _select_options([ + "https://graph.microsoft.com/.default", + "https://management.azure.com/.default", + "User.Read", + "User.ReadBasic.All", + ], + header="Select a scope (multiple scopes can only be input by manually typing them, delimited by space):", + accept_nonempty_string=True, + ).split() + +def _select_account(app): + accounts = app.get_accounts() + if accounts: + return _select_options( + accounts, + option_renderer=lambda a: a["username"], + header="Account(s) already signed in inside MSAL Python:", + ) + else: + print("No account available inside MSAL Python. Use other methods to acquire token first.") + +def acquire_token_silent(app): + """acquire_token_silent() - with an account already signed into MSAL Python.""" + account = _select_account(app) + if account: + pprint.pprint(app.acquire_token_silent( + _input_scopes(), + account=account, + force_refresh=_input_boolean("Bypass MSAL Python's token cache?"), + )) + +def _acquire_token_interactive(app, scopes, data=None): + prompt = _select_options([ + {"value": None, "description": "Unspecified. Proceed silently with a default account (if any), fallback to prompt."}, + {"value": "none", "description": "none. Proceed silently with a default account (if any), or error out."}, + {"value": "select_account", "description": "select_account. Prompt with an account picker."}, + ], + option_renderer=lambda o: o["description"], + header="Prompt behavior?")["value"] + raw_login_hint = _select_options( + # login_hint is unnecessary when prompt=select_account, + # but we still let tester input login_hint, just for testing purpose. + [None] + [a["username"] for a in app.get_accounts()], + header="login_hint? (If you have multiple signed-in sessions in browser, and you specify a login_hint to match one of them, you will bypass the account picker.)", + accept_nonempty_string=True, + ) + login_hint = raw_login_hint["username"] if isinstance(raw_login_hint, dict) else raw_login_hint + result = app.acquire_token_interactive( + scopes, prompt=prompt, login_hint=login_hint, data=data or {}) + if login_hint and "id_token_claims" in result: + signed_in_user = result.get("id_token_claims", {}).get("preferred_username") + if signed_in_user != login_hint: + logging.warning('Signed-in user "%s" does not match login_hint', signed_in_user) + return result + +def acquire_token_interactive(app): + """acquire_token_interactive() - User will be prompted if app opts to do select_account.""" + pprint.pprint(_acquire_token_interactive(app, _input_scopes())) + +def acquire_token_by_username_password(app): + """acquire_token_by_username_password() - See constraints here: https://docs.microsoft.com/en-us/azure/active-directory/develop/msal-authentication-flows#constraints-for-ropc""" + pprint.pprint(app.acquire_token_by_username_password( + _input("username: "), getpass.getpass("password: "), scopes=_input_scopes())) + +_JWK1 = """{"kty":"RSA", "n":"2tNr73xwcj6lH7bqRZrFzgSLj7OeLfbn8216uOMDHuaZ6TEUBDN8Uz0ve8jAlKsP9CQFCSVoSNovdE-fs7c15MxEGHjDcNKLWonznximj8pDGZQjVdfK-7mG6P6z-lgVcLuYu5JcWU_PeEqIKg5llOaz-qeQ4LEDS4T1D2qWRGpAra4rJX1-kmrWmX_XIamq30C9EIO0gGuT4rc2hJBWQ-4-FnE1NXmy125wfT3NdotAJGq5lMIfhjfglDbJCwhc8Oe17ORjO3FsB5CLuBRpYmP7Nzn66lRY3Fe11Xz8AEBl3anKFSJcTvlMnFtu3EpD-eiaHfTgRBU7CztGQqVbiQ", "e":"AQAB"}""" +SSH_CERT_DATA = {"token_type": "ssh-cert", "key_id": "key1", "req_cnf": _JWK1} +SSH_CERT_SCOPE = ["https://pas.windows.net/CheckMyAccess/Linux/.default"] + +def acquire_ssh_cert_silently(app): + """Acquire an SSH Cert silently- This typically only works with Azure CLI""" + account = _select_account(app) + if account: + result = app.acquire_token_silent( + SSH_CERT_SCOPE, + account, + data=SSH_CERT_DATA, + force_refresh=_input_boolean("Bypass MSAL Python's token cache?"), + ) + pprint.pprint(result) + if result and result.get("token_type") != "ssh-cert": + logging.error("Unable to acquire an ssh-cert.") + +def acquire_ssh_cert_interactive(app): + """Acquire an SSH Cert interactively - This typically only works with Azure CLI""" + result = _acquire_token_interactive(app, SSH_CERT_SCOPE, data=SSH_CERT_DATA) + pprint.pprint(result) + if result.get("token_type") != "ssh-cert": + logging.error("Unable to acquire an ssh-cert") + +def remove_account(app): + """remove_account() - Invalidate account and/or token(s) from cache, so that acquire_token_silent() would be reset""" + account = _select_account(app) + if account: + app.remove_account(account) + print('Account "{}" and/or its token(s) are signed out from MSAL Python'.format(account["username"])) + +def exit(_): + """Exit""" + bug_link = "https://github.com/AzureAD/microsoft-authentication-library-for-python/issues/new/choose" + print("Bye. If you found a bug, please report it here: {}".format(bug_link)) + sys.exit() + +def main(): + print("Welcome to the Msal Python Console Test App, committed at 2022-5-2\n") + chosen_app = _select_options([ + {"client_id": "04b07795-8ddb-461a-bbee-02f9e1bf7b46", "name": "Azure CLI (Correctly configured for MSA-PT)"}, + {"client_id": "04f0c124-f2bc-4f59-8241-bf6df9866bbd", "name": "Visual Studio (Correctly configured for MSA-PT)"}, + {"client_id": "95de633a-083e-42f5-b444-a4295d8e9314", "name": "Whiteboard Services (Non MSA-PT app. Accepts AAD & MSA accounts.)"}, + ], + option_renderer=lambda a: a["name"], + header="Impersonate this app (or you can type in the client_id of your own app)", + accept_nonempty_string=True) + app = msal.PublicClientApplication( + chosen_app["client_id"] if isinstance(chosen_app, dict) else chosen_app, + authority=_select_options([ + "https://login.microsoftonline.com/common", + "https://login.microsoftonline.com/organizations", + "https://login.microsoftonline.com/microsoft.onmicrosoft.com", + "https://login.microsoftonline.com/msidlab4.onmicrosoft.com", + "https://login.microsoftonline.com/consumers", + ], + header="Input authority (Note that MSA-PT apps would NOT use the /common authority)", + accept_nonempty_string=True, + ), + ) + if _input_boolean("Enable MSAL Python's DEBUG log?"): + logging.basicConfig(level=logging.DEBUG) + while True: + func = _select_options([ + acquire_token_silent, + acquire_token_interactive, + acquire_token_by_username_password, + acquire_ssh_cert_silently, + acquire_ssh_cert_interactive, + remove_account, + exit, + ], option_renderer=lambda f: f.__doc__, header="MSAL Python APIs:") + try: + func(app) + except KeyboardInterrupt: # Useful for bailing out a stuck interactive flow + print("Aborted") + +if __name__ == "__main__": + main() + diff --git a/tests/rst_response.xml b/tests/rst_response.xml new file mode 100644 index 00000000..c0c06206 --- /dev/null +++ b/tests/rst_response.xml @@ -0,0 +1,90 @@ + + + http://docs.oasis-open.org/ws-sx/ws-trust/200512/RSTRC/IssueFinal + + + 2013-11-15T03:08:25.221Z + 2013-11-15T03:13:25.221Z + + + + + + + + 2013-11-15T03:08:25.205Z + 2013-11-15T04:08:25.205Z + + + + https://login.microsoftonline.com/extSTS.srf + + + + + + + https://login.microsoftonline.com/extSTS.srf + + + + + 1TIu064jGEmmf+hnI+F0Jg== + + urn:oasis:names:tc:SAML:1.0:cm:bearer + + + + frizzo@richard-randall.com + + + 1TIu064jGEmmf+hnI+F0Jg== + + + + + 1TIu064jGEmmf+hnI+F0Jg== + + urn:oasis:names:tc:SAML:1.0:cm:bearer + + + + + + + + + + + + + + 3i95D+nRbsyRitSPeT7ZtEr5vbM= + + + aVNmmKLNdAlBxxcNciWVfxynZUPR9ql8ZZSZt/qpqL/GB3HX/cL/QnfG2OOKrmhgEaR0Ul4grZhGJxlxMPDL0fhnBz+VJ5HwztMFgMYs3Md8A2sZd9n4dfu7+CByAna06lCwwfdFWlNV1MBFvlWvYtCLNkpYVr/aglmb9zpMkNxEOmHe/cwxUtYlzH4RpIsIT5pruoJtUxKcqTRDEeeYdzjBAiJuguQTChLmHNoMPdX1RmtJlPsrZ1s9R/IJky7fHLjB7jiTDceRCS5QUbgUqYbLG1MjFXthY2Hr7K9kpYjxxIk6xmM7mFQE3Hts3bj6UU7ElUvHpX9bxxk3pqzlhg== + + + MIIC6DCCAdCgAwIBAgIQaztYF2TpvZZG6yreA3NRpzANBgkqhkiG9w0BAQsFADAwMS4wLAYDVQQDEyVBREZTIFNpZ25pbmcgLSBmcy5yaWNoYXJkLXJhbmRhbGwuY29tMB4XDTEzMTExMTAzNTMwMFoXDTE0MTExMTAzNTMwMFowMDEuMCwGA1UEAxMlQURGUyBTaWduaW5nIC0gZnMucmljaGFyZC1yYW5kYWxsLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAO+1VWY/sYDdN3hdsvT+mWHTcOwjp2G9e0AEZdmgh7bS54WUJw9y0cMxJmGB0jAAW40zomzIbS8/o3iuxcJyFgBVtMFfXwFjVQJnZJ7IMXFs1V/pJHrwWHxePz/WzXFtMaqEIe8QummJ07UBg9UsYZUYTGO9NDGw1Yr/oRNsl7bLA0S/QlW6yryf6l3snHzIgtO2xiWn6q3vCJTTVNMROkI2YKNKdYiD5fFD77kFACfJmOwP8MN9u+HM2IN6g0Nv5s7rMyw077Co/xKefamWQCB0jLpv89jo3hLgkwIgWX4cMVgHSNmdzXSgC3owG8ivRuJDATh83GiqI6jzA1+x4rkCAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAxA5MQZHw9lJYDpU4f45EYrWPEaAPnncaoxIeLE9fG14gA01frajRfdyoO0AKqb+ZG6sePKngsuq4QHA2EnEI4Di5uWKsXy1Id0AXUSUhLpe63alZ8OwiNKDKn71nwpXnlGwKqljnG3xBMniGtGKrFS4WM+joEHzaKpvgtGRGoDdtXF4UXZJcn2maw6d/kiHrQ3kWoQcQcJ9hVIo8bC0BPvxV0Qh4TF3Nb3tKhaXsY68eMxMGbHok9trVHQ3Vew35FuTg1JzsfCFSDF8sxu7FJ4iZ7VLM8MQLnvIMcubLJvc57EHSsNyeiqBFQIYkdg7MSf+Ot2qJjfExgo+NOtWN+g== + + + + + + + + _9bd2b280-f153-471a-9b73-c1df0d555075 + + + + + _9bd2b280-f153-471a-9b73-c1df0d555075 + + + urn:oasis:names:tc:SAML:1.0:assertion + http://docs.oasis-open.org/ws-sx/ws-trust/200512/Issue + http://docs.oasis-open.org/ws-sx/ws-trust/200512/Bearer + + + + \ No newline at end of file diff --git a/tests/test_application.py b/tests/test_application.py new file mode 100644 index 00000000..804ccb82 --- /dev/null +++ b/tests/test_application.py @@ -0,0 +1,627 @@ +# Note: Since Aug 2019 we move all e2e tests into test_e2e.py, +# so this test_application file contains only unit tests without dependency. +import sys +from msal.application import * +from msal.application import _str2bytes +import msal +from msal.application import _merge_claims_challenge_and_capabilities +from tests import unittest +from tests.test_token_cache import build_id_token, build_response +from tests.http_client import MinimalHttpClient, MinimalResponse +from msal.telemetry import CLIENT_CURRENT_TELEMETRY, CLIENT_LAST_TELEMETRY + + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.DEBUG) + + +class TestHelperExtractCerts(unittest.TestCase): # It is used by SNI scenario + + def test_extract_a_tag_less_public_cert(self): + pem = "my_cert" + self.assertEqual(["my_cert"], extract_certs(pem)) + + def test_extract_a_tag_enclosed_cert(self): + pem = """ + -----BEGIN CERTIFICATE----- + my_cert + -----END CERTIFICATE----- + """ + self.assertEqual(["my_cert"], extract_certs(pem)) + + def test_extract_multiple_tag_enclosed_certs(self): + pem = """ + -----BEGIN CERTIFICATE----- + my_cert1 + -----END CERTIFICATE----- + + -----BEGIN CERTIFICATE----- + my_cert2 + -----END CERTIFICATE----- + """ + self.assertEqual(["my_cert1", "my_cert2"], extract_certs(pem)) + + +class TestBytesConversion(unittest.TestCase): + def test_string_to_bytes(self): + self.assertEqual(type(_str2bytes("some string")), type(b"bytes")) + + def test_bytes_to_bytes(self): + self.assertEqual(type(_str2bytes(b"some bytes")), type(b"bytes")) + + +class TestClientApplicationAcquireTokenSilentErrorBehaviors(unittest.TestCase): + + def setUp(self): + self.authority_url = "https://login.microsoftonline.com/common" + self.authority = msal.authority.Authority( + self.authority_url, MinimalHttpClient()) + self.scopes = ["s1", "s2"] + self.uid = "my_uid" + self.utid = "my_utid" + self.account = {"home_account_id": "{}.{}".format(self.uid, self.utid)} + self.rt = "this is a rt" + self.cache = msal.SerializableTokenCache() + self.client_id = "my_app" + self.cache.add({ # Pre-populate the cache + "client_id": self.client_id, + "scope": self.scopes, + "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), + "response": build_response( + access_token="an expired AT to trigger refresh", expires_in=-99, + uid=self.uid, utid=self.utid, refresh_token=self.rt), + }) # The add(...) helper populates correct home_account_id for future searching + self.app = ClientApplication( + self.client_id, authority=self.authority_url, token_cache=self.cache) + + def test_cache_empty_will_be_returned_as_None(self): + self.app.token_cache = msal.SerializableTokenCache() # Reset it to empty + self.assertEqual( + None, self.app.acquire_token_silent_with_error(['cache_miss'], self.account)) + + def test_acquire_token_silent_will_suppress_error(self): + error_response = '{"error": "invalid_grant", "suberror": "xyz"}' + def tester(url, **kwargs): + return MinimalResponse(status_code=400, text=error_response) + self.assertEqual(None, self.app.acquire_token_silent( + self.scopes, self.account, post=tester)) + + def test_acquire_token_silent_with_error_will_return_error(self): + error_response = '{"error": "invalid_grant", "error_description": "xyz"}' + def tester(url, **kwargs): + return MinimalResponse(status_code=400, text=error_response) + self.assertEqual(json.loads(error_response), self.app.acquire_token_silent_with_error( + self.scopes, self.account, post=tester)) + + def test_atswe_will_map_some_suberror_to_classification_as_is(self): + error_response = '{"error": "invalid_grant", "suberror": "basic_action"}' + def tester(url, **kwargs): + return MinimalResponse(status_code=400, text=error_response) + result = self.app.acquire_token_silent_with_error( + self.scopes, self.account, post=tester) + self.assertEqual("basic_action", result.get("classification")) + + def test_atswe_will_map_some_suberror_to_classification_to_empty_string(self): + error_response = '{"error": "invalid_grant", "suberror": "client_mismatch"}' + def tester(url, **kwargs): + return MinimalResponse(status_code=400, text=error_response) + result = self.app.acquire_token_silent_with_error( + self.scopes, self.account, post=tester) + self.assertEqual("", result.get("classification")) + +class TestClientApplicationAcquireTokenSilentFociBehaviors(unittest.TestCase): + + def setUp(self): + self.authority_url = "https://login.microsoftonline.com/common" + self.authority = msal.authority.Authority( + self.authority_url, MinimalHttpClient()) + self.scopes = ["s1", "s2"] + self.uid = "my_uid" + self.utid = "my_utid" + self.account = {"home_account_id": "{}.{}".format(self.uid, self.utid)} + self.frt = "what the frt" + self.cache = msal.SerializableTokenCache() + self.preexisting_family_app_id = "preexisting_family_app" + self.cache.add({ # Pre-populate a FRT + "client_id": self.preexisting_family_app_id, + "scope": self.scopes, + "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), + "response": build_response( + access_token="Siblings won't share AT. test_remove_account() will.", + id_token=build_id_token(aud=self.preexisting_family_app_id), + uid=self.uid, utid=self.utid, refresh_token=self.frt, foci="1"), + }) # The add(...) helper populates correct home_account_id for future searching + + def test_unknown_orphan_app_will_attempt_frt_and_not_remove_it(self): + app = ClientApplication( + "unknown_orphan", authority=self.authority_url, token_cache=self.cache) + logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) + error_response = '{"error": "invalid_grant","error_description": "Was issued to another client"}' + def tester(url, data=None, **kwargs): + self.assertEqual(self.frt, data.get("refresh_token"), "Should attempt the FRT") + return MinimalResponse(status_code=400, text=error_response) + app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( + self.authority, self.scopes, self.account, post=tester) + self.assertNotEqual([], app.token_cache.find( + msal.TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": self.frt}), + "The FRT should not be removed from the cache") + + def test_known_orphan_app_will_skip_frt_and_only_use_its_own_rt(self): + app = ClientApplication( + "known_orphan", authority=self.authority_url, token_cache=self.cache) + rt = "RT for this orphan app. We will check it being used by this test case." + self.cache.add({ # Populate its RT and AppMetadata, so it becomes a known orphan app + "client_id": app.client_id, + "scope": self.scopes, + "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), + "response": build_response(uid=self.uid, utid=self.utid, refresh_token=rt), + }) + logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) + def tester(url, data=None, **kwargs): + self.assertEqual(rt, data.get("refresh_token"), "Should attempt the RT") + return MinimalResponse(status_code=200, text='{}') + app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( + self.authority, self.scopes, self.account, post=tester) + + def test_unknown_family_app_will_attempt_frt_and_join_family(self): + def tester(url, data=None, **kwargs): + self.assertEqual( + self.frt, data.get("refresh_token"), "Should attempt the FRT") + return MinimalResponse( + status_code=200, text=json.dumps(build_response( + uid=self.uid, utid=self.utid, foci="1", access_token="at"))) + app = ClientApplication( + "unknown_family_app", authority=self.authority_url, token_cache=self.cache) + at = app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( + self.authority, self.scopes, self.account, post=tester) + logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) + self.assertEqual("at", at.get("access_token"), "New app should get a new AT") + app_metadata = app.token_cache.find( + msal.TokenCache.CredentialType.APP_METADATA, + query={"client_id": app.client_id}) + self.assertNotEqual([], app_metadata, "Should record new app's metadata") + self.assertEqual("1", app_metadata[0].get("family_id"), + "The new family app should be recorded as in the same family") + # Known family app will simply use FRT, which is largely the same as this one + + # Will not test scenario of app leaving family. Per specs, it won't happen. + + def test_preexisting_family_app_will_attempt_frt_and_return_error(self): + error_response = '{"error": "invalid_grant", "error_description": "xyz"}' + def tester(url, data=None, **kwargs): + self.assertEqual( + self.frt, data.get("refresh_token"), "Should attempt the FRT") + return MinimalResponse(status_code=400, text=error_response) + app = ClientApplication( + "preexisting_family_app", authority=self.authority_url, token_cache=self.cache) + resp = app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( + self.authority, self.scopes, self.account, post=tester) + logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) + self.assertEqual(json.loads(error_response), resp, "Error raised will be returned") + + def test_family_app_remove_account(self): + logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) + app = ClientApplication( + self.preexisting_family_app_id, + authority=self.authority_url, token_cache=self.cache) + account = app.get_accounts()[0] + mine = {"home_account_id": account["home_account_id"]} + + self.assertNotEqual([], self.cache.find( + self.cache.CredentialType.ACCESS_TOKEN, query=mine)) + self.assertNotEqual([], self.cache.find( + self.cache.CredentialType.REFRESH_TOKEN, query=mine)) + self.assertNotEqual([], self.cache.find( + self.cache.CredentialType.ID_TOKEN, query=mine)) + self.assertNotEqual([], self.cache.find( + self.cache.CredentialType.ACCOUNT, query=mine)) + + app.remove_account(account) + + self.assertEqual([], self.cache.find( + self.cache.CredentialType.ACCESS_TOKEN, query=mine)) + self.assertEqual([], self.cache.find( + self.cache.CredentialType.REFRESH_TOKEN, query=mine)) + self.assertEqual([], self.cache.find( + self.cache.CredentialType.ID_TOKEN, query=mine)) + self.assertEqual([], self.cache.find( + self.cache.CredentialType.ACCOUNT, query=mine)) + + +class TestClientApplicationForAuthorityMigration(unittest.TestCase): + + @classmethod + def setUp(self): + self.environment_in_cache = "sts.windows.net" + self.authority_url_in_app = "https://login.microsoftonline.com/common" + self.scopes = ["s1", "s2"] + uid = "uid" + utid = "utid" + self.account = {"home_account_id": "{}.{}".format(uid, utid)} + self.client_id = "my_app" + self.access_token = "access token for testing authority aliases" + self.cache = msal.SerializableTokenCache() + self.cache.add({ + "client_id": self.client_id, + "scope": self.scopes, + "token_endpoint": "https://{}/common/oauth2/v2.0/token".format( + self.environment_in_cache), + "response": build_response( + uid=uid, utid=utid, + access_token=self.access_token, refresh_token="some refresh token"), + }) # The add(...) helper populates correct home_account_id for future searching + self.app = ClientApplication( + self.client_id, + authority=self.authority_url_in_app, token_cache=self.cache) + + def test_get_accounts_should_find_accounts_under_different_alias(self): + accounts = self.app.get_accounts() + self.assertNotEqual([], accounts) + self.assertEqual(self.environment_in_cache, accounts[0].get("environment"), + "We should be able to find an account under an authority alias") + + def test_acquire_token_silent_should_find_at_under_different_alias(self): + result = self.app.acquire_token_silent(self.scopes, self.account) + self.assertNotEqual(None, result) + self.assertEqual(self.access_token, result.get('access_token')) + + def test_acquire_token_silent_should_find_rt_under_different_alias(self): + self.cache._cache["AccessToken"] = {} # A hacky way to clear ATs + class ExpectedBehavior(Exception): + pass + def helper(scopes, account, authority, *args, **kwargs): + if authority.instance == self.environment_in_cache: + raise ExpectedBehavior("RT of different alias being attempted") + self.app._acquire_token_silent_from_cache_and_possibly_refresh_it = helper + + with self.assertRaises(ExpectedBehavior): + self.app.acquire_token_silent(["different scope"], self.account) + + +class TestApplicationForClientCapabilities(unittest.TestCase): + + def test_capabilities_and_id_token_claims_merge(self): + client_capabilities = ["foo", "bar"] + claims_challenge = '''{"id_token": {"auth_time": {"essential": true}}}''' + merged_claims = '''{"id_token": {"auth_time": {"essential": true}}, + "access_token": {"xms_cc": {"values": ["foo", "bar"]}}}''' + # Comparing dictionaries as JSON object order differs based on python version + self.assertEqual( + json.loads(merged_claims), + json.loads(_merge_claims_challenge_and_capabilities( + client_capabilities, claims_challenge))) + + def test_capabilities_and_id_token_claims_and_access_token_claims_merge(self): + client_capabilities = ["foo", "bar"] + claims_challenge = '''{"id_token": {"auth_time": {"essential": true}}, + "access_token": {"nbf":{"essential":true, "value":"1563308371"}}}''' + merged_claims = '''{"id_token": {"auth_time": {"essential": true}}, + "access_token": {"nbf": {"essential": true, "value": "1563308371"}, + "xms_cc": {"values": ["foo", "bar"]}}}''' + # Comparing dictionaries as JSON object order differs based on python version + self.assertEqual( + json.loads(merged_claims), + json.loads(_merge_claims_challenge_and_capabilities( + client_capabilities, claims_challenge))) + + def test_no_capabilities_only_claims_merge(self): + claims_challenge = '''{"id_token": {"auth_time": {"essential": true}}}''' + self.assertEqual( + json.loads(claims_challenge), + json.loads(_merge_claims_challenge_and_capabilities(None, claims_challenge))) + + def test_only_client_capabilities_no_claims_merge(self): + client_capabilities = ["foo", "bar"] + merged_claims = '''{"access_token": {"xms_cc": {"values": ["foo", "bar"]}}}''' + self.assertEqual( + json.loads(merged_claims), + json.loads(_merge_claims_challenge_and_capabilities(client_capabilities, None))) + + def test_both_claims_and_capabilities_none(self): + self.assertEqual(_merge_claims_challenge_and_capabilities(None, None), None) + + +class TestApplicationForRefreshInBehaviors(unittest.TestCase): + """The following test cases were based on design doc here + https://identitydivision.visualstudio.com/DevEx/_git/AuthLibrariesApiReview?path=%2FRefreshAtExpirationPercentage%2Foverview.md&version=GBdev&_a=preview&anchor=scenarios + """ + authority_url = "https://login.microsoftonline.com/common" + scopes = ["s1", "s2"] + uid = "my_uid" + utid = "my_utid" + account = {"home_account_id": "{}.{}".format(uid, utid)} + rt = "this is a rt" + client_id = "my_app" + + @classmethod + def setUpClass(cls): # Initialization at runtime, not interpret-time + cls.app = ClientApplication(cls.client_id, authority=cls.authority_url) + + def setUp(self): + self.app.token_cache = self.cache = msal.SerializableTokenCache() + + def populate_cache(self, access_token="at", expires_in=86400, refresh_in=43200): + self.cache.add({ + "client_id": self.client_id, + "scope": self.scopes, + "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), + "response": build_response( + access_token=access_token, + expires_in=expires_in, refresh_in=refresh_in, + uid=self.uid, utid=self.utid, refresh_token=self.rt), + }) + + def test_fresh_token_should_be_returned_from_cache(self): + # a.k.a. Return unexpired token that is not above token refresh expiration threshold + access_token = "An access token prepopulated into cache" + self.populate_cache(access_token=access_token, expires_in=900, refresh_in=450) + result = self.app.acquire_token_silent( + ['s1'], self.account, + post=lambda url, *args, **kwargs: # Utilize the undocumented test feature + self.fail("I/O shouldn't happen in cache hit AT scenario") + ) + self.assertEqual(access_token, result.get("access_token")) + self.assertNotIn("refresh_in", result, "Customers need not know refresh_in") + + def test_aging_token_and_available_aad_should_return_new_token(self): + # a.k.a. Attempt to refresh unexpired token when AAD available + self.populate_cache(access_token="old AT", expires_in=3599, refresh_in=-1) + new_access_token = "new AT" + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|84,4|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": new_access_token, + "refresh_in": 123, + })) + result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post) + self.assertEqual(new_access_token, result.get("access_token")) + self.assertNotIn("refresh_in", result, "Customers need not know refresh_in") + + def test_aging_token_and_unavailable_aad_should_return_old_token(self): + # a.k.a. Attempt refresh unexpired token when AAD unavailable + old_at = "old AT" + self.populate_cache(access_token=old_at, expires_in=3599, refresh_in=-1) + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|84,2|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + return MinimalResponse(status_code=400, text=json.dumps({"error": error})) + result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post) + self.assertEqual(old_at, result.get("access_token")) + + def test_expired_token_and_unavailable_aad_should_return_error(self): + # a.k.a. Attempt refresh expired token when AAD unavailable + self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900) + error = "something went wrong" + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|84,3|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + return MinimalResponse(status_code=400, text=json.dumps({"error": error})) + result = self.app.acquire_token_silent_with_error( + ['s1'], self.account, post=mock_post) + self.assertEqual(error, result.get("error"), "Error should be returned") + + def test_expired_token_and_available_aad_should_return_new_token(self): + # a.k.a. Attempt refresh expired token when AAD available + self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900) + new_access_token = "new AT" + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|84,3|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": new_access_token, + "refresh_in": 123, + })) + result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post) + self.assertEqual(new_access_token, result.get("access_token")) + self.assertNotIn("refresh_in", result, "Customers need not know refresh_in") + + +class TestTelemetryMaintainingOfflineState(unittest.TestCase): + authority_url = "https://login.microsoftonline.com/common" + scopes = ["s1", "s2"] + uid = "my_uid" + utid = "my_utid" + account = {"home_account_id": "{}.{}".format(uid, utid)} + rt = "this is a rt" + client_id = "my_app" + + def populate_cache(self, cache, access_token="at"): + cache.add({ + "client_id": self.client_id, + "scope": self.scopes, + "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), + "response": build_response( + access_token=access_token, + uid=self.uid, utid=self.utid, refresh_token=self.rt), + }) + + def test_maintaining_offline_state_and_sending_them(self): + app = PublicClientApplication( + self.client_id, + authority=self.authority_url, token_cache=msal.SerializableTokenCache()) + cached_access_token = "cached_at" + self.populate_cache(app.token_cache, access_token=cached_access_token) + + result = app.acquire_token_silent( + self.scopes, self.account, + post=lambda url, *args, **kwargs: # Utilize the undocumented test feature + self.fail("I/O shouldn't happen in cache hit AT scenario") + ) + self.assertEqual(cached_access_token, result.get("access_token")) + + error1 = "error_1" + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|622,0|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + self.assertEqual("4|1|||", (headers or {}).get(CLIENT_LAST_TELEMETRY), + "The previous cache hit should result in success counter value as 1") + return MinimalResponse(status_code=400, text=json.dumps({"error": error1})) + result = app.acquire_token_by_device_flow({ # It allows customizing correlation_id + "device_code": "123", + PublicClientApplication.DEVICE_FLOW_CORRELATION_ID: "id_1", + }, post=mock_post) + self.assertEqual(error1, result.get("error")) + + error2 = "error_2" + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|622,0|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + self.assertEqual("4|1|622,id_1|error_1|", (headers or {}).get(CLIENT_LAST_TELEMETRY), + "The previous error should result in same success counter plus latest error info") + return MinimalResponse(status_code=400, text=json.dumps({"error": error2})) + result = app.acquire_token_by_device_flow({ + "device_code": "123", + PublicClientApplication.DEVICE_FLOW_CORRELATION_ID: "id_2", + }, post=mock_post) + self.assertEqual(error2, result.get("error")) + + at = "ensures the successful path (which includes the mock) been used" + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|622,0|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + self.assertEqual("4|1|622,id_1,622,id_2|error_1,error_2|", (headers or {}).get(CLIENT_LAST_TELEMETRY), + "The previous error should result in same success counter plus latest error info") + return MinimalResponse(status_code=200, text=json.dumps({"access_token": at})) + result = app.acquire_token_by_device_flow({"device_code": "123"}, post=mock_post) + self.assertEqual(at, result.get("access_token")) + + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|622,0|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + self.assertEqual("4|0|||", (headers or {}).get(CLIENT_LAST_TELEMETRY), + "The previous success should reset all offline telemetry counters") + return MinimalResponse(status_code=200, text=json.dumps({"access_token": at})) + result = app.acquire_token_by_device_flow({"device_code": "123"}, post=mock_post) + self.assertEqual(at, result.get("access_token")) + + +class TestTelemetryOnClientApplication(unittest.TestCase): + @classmethod + def setUpClass(cls): # Initialization at runtime, not interpret-time + cls.app = ClientApplication( + "client_id", authority="https://login.microsoftonline.com/common") + + def test_acquire_token_by_auth_code_flow(self): + at = "this is an access token" + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|832,0|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + return MinimalResponse(status_code=200, text=json.dumps({"access_token": at})) + state = "foo" + result = self.app.acquire_token_by_auth_code_flow( + {"state": state, "code_verifier": "bar"}, {"state": state, "code": "012"}, + post=mock_post) + self.assertEqual(at, result.get("access_token")) + + def test_acquire_token_by_refresh_token(self): + at = "this is an access token" + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|85,1|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + return MinimalResponse(status_code=200, text=json.dumps({"access_token": at})) + result = self.app.acquire_token_by_refresh_token("rt", ["s"], post=mock_post) + self.assertEqual(at, result.get("access_token")) + + +class TestTelemetryOnPublicClientApplication(unittest.TestCase): + @classmethod + def setUpClass(cls): # Initialization at runtime, not interpret-time + cls.app = PublicClientApplication( + "client_id", authority="https://login.microsoftonline.com/common") + + # For now, acquire_token_interactive() is verified by code review. + + def test_acquire_token_by_device_flow(self): + at = "this is an access token" + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|622,0|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + return MinimalResponse(status_code=200, text=json.dumps({"access_token": at})) + result = self.app.acquire_token_by_device_flow( + {"device_code": "123"}, post=mock_post) + self.assertEqual(at, result.get("access_token")) + + def test_acquire_token_by_username_password(self): + at = "this is an access token" + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|301,0|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + return MinimalResponse(status_code=200, text=json.dumps({"access_token": at})) + result = self.app.acquire_token_by_username_password( + "username", "password", ["scope"], post=mock_post) + self.assertEqual(at, result.get("access_token")) + + +class TestTelemetryOnConfidentialClientApplication(unittest.TestCase): + @classmethod + def setUpClass(cls): # Initialization at runtime, not interpret-time + cls.app = ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/common") + + def test_acquire_token_for_client(self): + at = "this is an access token" + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|730,0|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + return MinimalResponse(status_code=200, text=json.dumps({"access_token": at})) + result = self.app.acquire_token_for_client(["scope"], post=mock_post) + self.assertEqual(at, result.get("access_token")) + + def test_acquire_token_on_behalf_of(self): + at = "this is an access token" + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|523,0|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + return MinimalResponse(status_code=200, text=json.dumps({"access_token": at})) + result = self.app.acquire_token_on_behalf_of("assertion", ["s"], post=mock_post) + self.assertEqual(at, result.get("access_token")) + + +class TestClientApplicationWillGroupAccounts(unittest.TestCase): + def test_get_accounts(self): + client_id = "my_app" + scopes = ["scope_1", "scope_2"] + environment = "login.microsoftonline.com" + uid = "home_oid" + utid = "home_tenant_guid" + username = "Jane Doe" + cache = msal.SerializableTokenCache() + for tenant in ["contoso", "fabrikam"]: + cache.add({ + "client_id": client_id, + "scope": scopes, + "token_endpoint": + "https://{}/{}/oauth2/v2.0/token".format(environment, tenant), + "response": build_response( + uid=uid, utid=utid, access_token="at", refresh_token="rt", + id_token=build_id_token( + aud=client_id, + sub="oid_in_" + tenant, + preferred_username=username, + ), + ), + }) + app = ClientApplication( + client_id, + authority="https://{}/common".format(environment), + token_cache=cache) + accounts = app.get_accounts() + self.assertEqual(1, len(accounts), "Should return one grouped account") + account = accounts[0] + self.assertEqual("{}.{}".format(uid, utid), account["home_account_id"]) + self.assertEqual(environment, account["environment"]) + self.assertEqual(username, account["username"]) + self.assertIn("authority_type", account, "Backward compatibility") + self.assertIn("local_account_id", account, "Backward compatibility") + self.assertIn("realm", account, "Backward compatibility") + + +@unittest.skipUnless( + sys.version_info[0] >= 3 and sys.version_info[1] >= 2, + "assertWarns() is only available in Python 3.2+") +class TestClientCredentialGrant(unittest.TestCase): + def _test_certain_authority_should_emit_warnning(self, authority): + app = ConfidentialClientApplication( + "client_id", client_credential="secret", authority=authority) + def mock_post(url, headers=None, *args, **kwargs): + return MinimalResponse( + status_code=200, text=json.dumps({"access_token": "an AT"})) + with self.assertWarns(DeprecationWarning): + app.acquire_token_for_client(["scope"], post=mock_post) + + def test_common_authority_should_emit_warnning(self): + self._test_certain_authority_should_emit_warnning( + authority="https://login.microsoftonline.com/common") + + def test_organizations_authority_should_emit_warnning(self): + self._test_certain_authority_should_emit_warnning( + authority="https://login.microsoftonline.com/organizations") + diff --git a/tests/test_assertion.py b/tests/test_assertion.py new file mode 100644 index 00000000..7885afe8 --- /dev/null +++ b/tests/test_assertion.py @@ -0,0 +1,15 @@ +import json + +from msal.oauth2cli import JwtAssertionCreator +from msal.oauth2cli.oidc import decode_part + +from tests import unittest + + +class AssertionTestCase(unittest.TestCase): + def test_extra_claims(self): + assertion = JwtAssertionCreator(key=None, algorithm="none").sign_assertion( + "audience", "issuer", additional_claims={"client_ip": "1.2.3.4"}) + payload = json.loads(decode_part(assertion.split(b'.')[1].decode('utf-8'))) + self.assertEqual("1.2.3.4", payload.get("client_ip")) + diff --git a/tests/test_authcode.py b/tests/test_authcode.py index 385100fd..c7e7565f 100644 --- a/tests/test_authcode.py +++ b/tests/test_authcode.py @@ -2,7 +2,7 @@ import socket import sys -from oauth2cli.authcode import AuthCodeReceiver +from msal.oauth2cli.authcode import AuthCodeReceiver class TestAuthCodeReceiver(unittest.TestCase): diff --git a/tests/test_authority.py b/tests/test_authority.py new file mode 100644 index 00000000..ee81c15e --- /dev/null +++ b/tests/test_authority.py @@ -0,0 +1,124 @@ +import os + +from msal.authority import * +from tests import unittest +from tests.http_client import MinimalHttpClient + + +@unittest.skipIf(os.getenv("TRAVIS_TAG"), "Skip network io during tagged release") +class TestAuthority(unittest.TestCase): + + def _test_given_host_and_tenant(self, host, tenant): + c = MinimalHttpClient() + a = Authority('https://{}/{}'.format(host, tenant), c) + self.assertEqual( + a.authorization_endpoint, + 'https://{}/{}/oauth2/v2.0/authorize'.format(host, tenant)) + self.assertEqual( + a.token_endpoint, + 'https://{}/{}/oauth2/v2.0/token'.format(host, tenant)) + c.close() + + def _test_authority_builder(self, host, tenant): + c = MinimalHttpClient() + a = Authority(AuthorityBuilder(host, tenant), c) + self.assertEqual( + a.authorization_endpoint, + 'https://{}/{}/oauth2/v2.0/authorize'.format(host, tenant)) + self.assertEqual( + a.token_endpoint, + 'https://{}/{}/oauth2/v2.0/token'.format(host, tenant)) + c.close() + + def test_wellknown_host_and_tenant(self): + # Assert all well known authority hosts are using their own "common" tenant + for host in WELL_KNOWN_AUTHORITY_HOSTS: + self._test_given_host_and_tenant(host, "common") + + def test_wellknown_host_and_tenant_using_new_authority_builder(self): + self._test_authority_builder(AZURE_PUBLIC, "consumers") + self._test_authority_builder(AZURE_US_GOVERNMENT, "common") + ## AZURE_CHINA is prone to some ConnectionError. We skip it to speed up our tests. + # self._test_authority_builder(AZURE_CHINA, "organizations") + + @unittest.skip("As of Jan 2017, the server no longer returns V1 endpoint") + def test_lessknown_host_will_return_a_set_of_v1_endpoints(self): + # This is an observation for current (2016-10) server-side behavior. + # It is probably not a strict API contract. I simply mention it here. + less_known = 'login.windows.net' # less.known.host/ + v1_token_endpoint = 'https://{}/common/oauth2/token'.format(less_known) + a = Authority( + 'https://{}/common'.format(less_known), MinimalHttpClient()) + self.assertEqual(a.token_endpoint, v1_token_endpoint) + self.assertNotIn('v2.0', a.token_endpoint) + + def test_unknown_host_wont_pass_instance_discovery(self): + _assert = getattr(self, "assertRaisesRegex", self.assertRaisesRegexp) # Hack + with _assert(ValueError, "invalid_instance"): + Authority('https://example.com/tenant_doesnt_matter_in_this_case', + MinimalHttpClient()) + + def test_invalid_host_skipping_validation_can_be_turned_off(self): + try: + Authority( + 'https://example.com/invalid', + MinimalHttpClient(), validate_authority=False) + except ValueError as e: + if "invalid_instance" in str(e): # Imprecise but good enough + self.fail("validate_authority=False should turn off validation") + except: # Could be requests...RequestException, json...JSONDecodeError, etc. + pass # Those are expected for this unittest case + + +class TestAuthorityInternalHelperCanonicalize(unittest.TestCase): + + def test_canonicalize_tenant_followed_by_extra_paths(self): + _, i, t = canonicalize("https://example.com/tenant/subpath?foo=bar#fragment") + self.assertEqual("example.com", i) + self.assertEqual("tenant", t) + + def test_canonicalize_tenant_followed_by_extra_query(self): + _, i, t = canonicalize("https://example.com/tenant?foo=bar#fragment") + self.assertEqual("example.com", i) + self.assertEqual("tenant", t) + + def test_canonicalize_tenant_followed_by_extra_fragment(self): + _, i, t = canonicalize("https://example.com/tenant#fragment") + self.assertEqual("example.com", i) + self.assertEqual("tenant", t) + + def test_canonicalize_rejects_non_https(self): + with self.assertRaises(ValueError): + canonicalize("http://non.https.example.com/tenant") + + def test_canonicalize_rejects_tenantless(self): + with self.assertRaises(ValueError): + canonicalize("https://no.tenant.example.com") + + def test_canonicalize_rejects_tenantless_host_with_trailing_slash(self): + with self.assertRaises(ValueError): + canonicalize("https://no.tenant.example.com/") + + +@unittest.skipIf(os.getenv("TRAVIS_TAG"), "Skip network io during tagged release") +class TestAuthorityInternalHelperUserRealmDiscovery(unittest.TestCase): + def test_memorize(self): + # We use a real authority so the constructor can finish tenant discovery + authority = "https://login.microsoftonline.com/common" + self.assertNotIn(authority, Authority._domains_without_user_realm_discovery) + a = Authority(authority, MinimalHttpClient(), validate_authority=False) + + try: + # We now pretend this authority supports no User Realm Discovery + class MockResponse(object): + status_code = 404 + a.user_realm_discovery("john.doe@example.com", response=MockResponse()) + self.assertIn( + "login.microsoftonline.com", + Authority._domains_without_user_realm_discovery, + "user_realm_discovery() should memorize domains not supporting URD") + a.user_realm_discovery("john.doe@example.com", + response="This would cause exception if memorization did not work") + finally: # MUST NOT let the previous test changes affect other test cases + Authority._domains_without_user_realm_discovery = set([]) + diff --git a/tests/test_ccs.py b/tests/test_ccs.py new file mode 100644 index 00000000..8b801773 --- /dev/null +++ b/tests/test_ccs.py @@ -0,0 +1,73 @@ +import unittest +try: + from unittest.mock import patch, ANY +except: + from mock import patch, ANY + +from tests.http_client import MinimalResponse +from tests.test_token_cache import build_response + +import msal + + +class TestCcsRoutingInfoTestCase(unittest.TestCase): + + def test_acquire_token_by_auth_code_flow(self): + app = msal.ClientApplication("client_id") + state = "foo" + flow = app.initiate_auth_code_flow( + ["some", "scope"], login_hint="johndoe@contoso.com", state=state) + with patch.object(app.http_client, "post", return_value=MinimalResponse( + status_code=400, text='{"error": "mock"}')) as mocked_method: + app.acquire_token_by_auth_code_flow(flow, { + "state": state, + "code": "bar", + "client_info": # MSAL asks for client_info, so it would be available + "eyJ1aWQiOiJhYTkwNTk0OS1hMmI4LTRlMGEtOGFlYS1iMzJlNTNjY2RiNDEiLCJ1dGlkIjoiNzJmOTg4YmYtODZmMS00MWFmLTkxYWItMmQ3Y2QwMTFkYjQ3In0", + }) + self.assertEqual( + "Oid:aa905949-a2b8-4e0a-8aea-b32e53ccdb41@72f988bf-86f1-41af-91ab-2d7cd011db47", + mocked_method.call_args[1].get("headers", {}).get('X-AnchorMailbox'), + "CSS routing info should be derived from client_info") + + # I've manually tested acquire_token_interactive. No need to automate it, + # because it and acquire_token_by_auth_code_flow() share same code path. + + def test_acquire_token_silent(self): + uid = "foo" + utid = "bar" + client_id = "my_client_id" + scopes = ["some", "scope"] + authority_url = "https://login.microsoftonline.com/common" + token_cache = msal.TokenCache() + token_cache.add({ # Pre-populate the cache + "client_id": client_id, + "scope": scopes, + "token_endpoint": "{}/oauth2/v2.0/token".format(authority_url), + "response": build_response( + access_token="an expired AT to trigger refresh", expires_in=-99, + uid=uid, utid=utid, refresh_token="this is a RT"), + }) # The add(...) helper populates correct home_account_id for future searching + app = msal.ClientApplication( + client_id, authority=authority_url, token_cache=token_cache) + with patch.object(app.http_client, "post", return_value=MinimalResponse( + status_code=400, text='{"error": "mock"}')) as mocked_method: + account = {"home_account_id": "{}.{}".format(uid, utid)} + app.acquire_token_silent(["scope"], account) + self.assertEqual( + "Oid:{}@{}".format( # Server accepts case-insensitive value + uid, utid), # It would look like "Oid:foo@bar" + mocked_method.call_args[1].get("headers", {}).get('X-AnchorMailbox'), + "CSS routing info should be derived from home_account_id") + + def test_acquire_token_by_username_password(self): + app = msal.ClientApplication("client_id") + username = "johndoe@contoso.com" + with patch.object(app.http_client, "post", return_value=MinimalResponse( + status_code=400, text='{"error": "mock"}')) as mocked_method: + app.acquire_token_by_username_password(username, "password", ["scope"]) + self.assertEqual( + "upn:" + username, + mocked_method.call_args[1].get("headers", {}).get('X-AnchorMailbox'), + "CSS routing info should be derived from client_info") + diff --git a/tests/test_client.py b/tests/test_client.py index 4cbcc8cf..b180c6b8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -9,9 +9,8 @@ import requests -from oauth2cli.oidc import Client -from oauth2cli.authcode import obtain_auth_code, AuthCodeReceiver -from oauth2cli.assertion import JwtSigner +from msal.oauth2cli import Client, JwtSigner, AuthCodeReceiver +from msal.oauth2cli.authcode import obtain_auth_code from tests import unittest, Oauth2TestCase from tests.http_client import MinimalHttpClient, MinimalResponse @@ -80,12 +79,21 @@ def load_conf(filename): # Since the OAuth2 specs uses snake_case, this test config also uses snake_case @unittest.skipUnless("client_id" in CONFIG, "client_id missing") +@unittest.skipUnless(CONFIG.get("openid_configuration"), "openid_configuration missing") class TestClient(Oauth2TestCase): @classmethod def setUpClass(cls): http_client = MinimalHttpClient() - if "client_certificate" in CONFIG: + if "client_assertion" in CONFIG: + cls.client = Client( + CONFIG["openid_configuration"], + CONFIG['client_id'], + http_client=http_client, + client_assertion=CONFIG["client_assertion"], + client_assertion_type=Client.CLIENT_ASSERTION_TYPE_JWT, + ) + elif "client_certificate" in CONFIG: private_key_path = CONFIG["client_certificate"]["private_key_path"] with open(os.path.join(THIS_FOLDER, private_key_path)) as f: private_key = f.read() # Expecting PEM format diff --git a/tests/test_e2e.py b/tests/test_e2e.py new file mode 100644 index 00000000..f0fb226d --- /dev/null +++ b/tests/test_e2e.py @@ -0,0 +1,1022 @@ +"""If the following ENV VAR are available, many end-to-end test cases would run. +LAB_APP_CLIENT_SECRET=... +LAB_OBO_CLIENT_SECRET=... +LAB_APP_CLIENT_ID=... +LAB_OBO_PUBLIC_CLIENT_ID=... +LAB_OBO_CONFIDENTIAL_CLIENT_ID=... +""" +try: + from dotenv import load_dotenv # Use this only in local dev machine + load_dotenv() # take environment variables from .env. +except: + pass + +import logging +import os +import json +import time +import unittest +import sys +try: + from unittest.mock import patch, ANY +except: + from mock import patch, ANY + +import requests + +import msal +from tests.http_client import MinimalHttpClient, MinimalResponse +from msal.oauth2cli import AuthCodeReceiver + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.DEBUG if "-v" in sys.argv else logging.INFO) + + +def _get_app_and_auth_code( + client_id, + client_secret=None, + authority="https://login.microsoftonline.com/common", + port=44331, + scopes=["https://graph.microsoft.com/.default"], # Microsoft Graph + **kwargs): + from msal.oauth2cli.authcode import obtain_auth_code + if client_secret: + app = msal.ConfidentialClientApplication( + client_id, + client_credential=client_secret, + authority=authority, http_client=MinimalHttpClient()) + else: + app = msal.PublicClientApplication( + client_id, authority=authority, http_client=MinimalHttpClient()) + redirect_uri = "http://localhost:%d" % port + ac = obtain_auth_code(port, auth_uri=app.get_authorization_request_url( + scopes, redirect_uri=redirect_uri, **kwargs)) + assert ac is not None + return (app, ac, redirect_uri) + +@unittest.skipIf(os.getenv("TRAVIS_TAG"), "Skip e2e tests during tagged release") +class E2eTestCase(unittest.TestCase): + + def assertLoosely(self, response, assertion=None, + skippable_errors=("invalid_grant", "interaction_required")): + if response.get("error") in skippable_errors: + logger.debug("Response = %s", response) + # Some of these errors are configuration issues, not library issues + raise unittest.SkipTest(response.get("error_description")) + else: + if assertion is None: + assertion = lambda: self.assertIn( + "access_token", response, + "{error}: {error_description}".format( + # Do explicit response.get(...) rather than **response + error=response.get("error"), + error_description=response.get("error_description"))) + assertion() + + def assertCacheWorksForUser( + self, result_from_wire, scope, username=None, data=None): + logger.debug( + "%s: cache = %s, id_token_claims = %s", + self.id(), + json.dumps(self.app.token_cache._cache, indent=4), + json.dumps(result_from_wire.get("id_token_claims"), indent=4), + ) + # You can filter by predefined username, or let end user to choose one + accounts = self.app.get_accounts(username=username) + self.assertNotEqual(0, len(accounts)) + account = accounts[0] + if ("scope" not in result_from_wire # This is the usual case + or # Authority server could return different set of scopes + set(scope) <= set(result_from_wire["scope"].split(" ")) + ): + # Going to test acquire_token_silent(...) to locate an AT from cache + result_from_cache = self.app.acquire_token_silent( + scope, account=account, data=data or {}) + self.assertIsNotNone(result_from_cache) + self.assertIsNone( + result_from_cache.get("refresh_token"), "A cache hit returns no RT") + self.assertEqual( + result_from_wire['access_token'], result_from_cache['access_token'], + "We should get a cached AT") + + if "refresh_token" in result_from_wire: + # Going to test acquire_token_silent(...) to obtain an AT by a RT from cache + self.app.token_cache._cache["AccessToken"] = {} # A hacky way to clear ATs + result_from_cache = self.app.acquire_token_silent( + scope, account=account, data=data or {}) + if "refresh_token" not in result_from_wire: + self.assertEqual( + result_from_cache["access_token"], result_from_wire["access_token"], + "The previously cached AT should be returned") + self.assertIsNotNone(result_from_cache, + "We should get a result from acquire_token_silent(...) call") + self.assertIsNotNone( + # We used to assert it this way: + # result_from_wire['access_token'] != result_from_cache['access_token'] + # but ROPC in B2C tends to return the same AT we obtained seconds ago. + # Now looking back, "refresh_token grant would return a brand new AT" + # was just an empirical observation but never a commitment in specs, + # so we adjust our way to assert here. + (result_from_cache or {}).get("access_token"), + "We should get an AT from acquire_token_silent(...) call") + + def assertCacheWorksForApp(self, result_from_wire, scope): + logger.debug( + "%s: cache = %s, id_token_claims = %s", + self.id(), + json.dumps(self.app.token_cache._cache, indent=4), + json.dumps(result_from_wire.get("id_token_claims"), indent=4), + ) + # Going to test acquire_token_silent(...) to locate an AT from cache + result_from_cache = self.app.acquire_token_silent(scope, account=None) + self.assertIsNotNone(result_from_cache) + self.assertEqual( + result_from_wire['access_token'], result_from_cache['access_token'], + "We should get a cached AT") + + def _test_username_password(self, + authority=None, client_id=None, username=None, password=None, scope=None, + client_secret=None, # Since MSAL 1.11, confidential client has ROPC too + azure_region=None, + http_client=None, + **ignored): + assert authority and client_id and username and password and scope + self.app = msal.ClientApplication( + client_id, authority=authority, + http_client=http_client or MinimalHttpClient(), + azure_region=azure_region, # Regional endpoint does not support ROPC. + # Here we just use it to test a regional app won't break ROPC. + client_credential=client_secret) + result = self.app.acquire_token_by_username_password( + username, password, scopes=scope) + self.assertLoosely(result) + self.assertCacheWorksForUser( + result, scope, + username=username, # Our implementation works even when "profile" scope was not requested, or when profile claims is unavailable in B2C + ) + + def _test_device_flow( + self, client_id=None, authority=None, scope=None, **ignored): + assert client_id and authority and scope + self.app = msal.PublicClientApplication( + client_id, authority=authority, http_client=MinimalHttpClient()) + flow = self.app.initiate_device_flow(scopes=scope) + assert "user_code" in flow, "DF does not seem to be provisioned: %s".format( + json.dumps(flow, indent=4)) + logger.info(flow["message"]) + + duration = 60 + logger.info("We will wait up to %d seconds for you to sign in" % duration) + flow["expires_at"] = min( # Shorten the time for quick test + flow["expires_at"], time.time() + duration) + result = self.app.acquire_token_by_device_flow(flow) + self.assertLoosely( # It will skip this test if there is no user interaction + result, + assertion=lambda: self.assertIn('access_token', result), + skippable_errors=self.app.client.DEVICE_FLOW_RETRIABLE_ERRORS) + if "access_token" not in result: + self.skipTest("End user did not complete Device Flow in time") + self.assertCacheWorksForUser(result, scope, username=None) + result["access_token"] = result["refresh_token"] = "************" + logger.info( + "%s obtained tokens: %s", self.id(), json.dumps(result, indent=4)) + + def _test_acquire_token_interactive( + self, client_id=None, authority=None, scope=None, port=None, + username_uri="", # But you would want to provide one + data=None, # Needed by ssh-cert feature + prompt=None, + **ignored): + assert client_id and authority and scope + self.app = msal.PublicClientApplication( + client_id, authority=authority, http_client=MinimalHttpClient()) + result = self.app.acquire_token_interactive( + scope, + prompt=prompt, + timeout=120, + port=port, + welcome_template= # This is an undocumented feature for testing + """

{id}

    +
  1. Get a username from the upn shown at here
  2. +
  3. Get its password from https://aka.ms/GetLabUserSecret?Secret=msidlabXYZ + (replace the lab name with the labName from the link above).
  4. +
  5. Sign In or Abort
  6. +
""".format(id=self.id(), username_uri=username_uri), + data=data or {}, + ) + self.assertIn( + "access_token", result, + "{error}: {error_description}".format( + # Note: No interpolation here, cause error won't always present + error=result.get("error"), + error_description=result.get("error_description"))) + self.assertCacheWorksForUser(result, scope, username=None, data=data or {}) + return result # For further testing + + +class SshCertTestCase(E2eTestCase): + _JWK1 = """{"kty":"RSA", "n":"2tNr73xwcj6lH7bqRZrFzgSLj7OeLfbn8216uOMDHuaZ6TEUBDN8Uz0ve8jAlKsP9CQFCSVoSNovdE-fs7c15MxEGHjDcNKLWonznximj8pDGZQjVdfK-7mG6P6z-lgVcLuYu5JcWU_PeEqIKg5llOaz-qeQ4LEDS4T1D2qWRGpAra4rJX1-kmrWmX_XIamq30C9EIO0gGuT4rc2hJBWQ-4-FnE1NXmy125wfT3NdotAJGq5lMIfhjfglDbJCwhc8Oe17ORjO3FsB5CLuBRpYmP7Nzn66lRY3Fe11Xz8AEBl3anKFSJcTvlMnFtu3EpD-eiaHfTgRBU7CztGQqVbiQ", "e":"AQAB"}""" + _JWK2 = """{"kty":"RSA", "n":"72u07mew8rw-ssw3tUs9clKstGO2lvD7ZNxJU7OPNKz5PGYx3gjkhUmtNah4I4FP0DuF1ogb_qSS5eD86w10Wb1ftjWcoY8zjNO9V3ph-Q2tMQWdDW5kLdeU3-EDzc0HQeou9E0udqmfQoPbuXFQcOkdcbh3eeYejs8sWn3TQprXRwGh_TRYi-CAurXXLxQ8rp-pltUVRIr1B63fXmXhMeCAGwCPEFX9FRRs-YHUszUJl9F9-E0nmdOitiAkKfCC9LhwB9_xKtjmHUM9VaEC9jWOcdvXZutwEoW2XPMOg0Ky-s197F9rfpgHle2gBrXsbvVMvS0D-wXg6vsq6BAHzQ", "e":"AQAB"}""" + DATA1 = {"token_type": "ssh-cert", "key_id": "key1", "req_cnf": _JWK1} + DATA2 = {"token_type": "ssh-cert", "key_id": "key2", "req_cnf": _JWK2} + _SCOPE_USER = ["https://pas.windows.net/CheckMyAccess/Linux/user_impersonation"] + _SCOPE_SP = ["https://pas.windows.net/CheckMyAccess/Linux/.default"] + SCOPE = _SCOPE_SP # Historically there was a separation, at 2021 it is unified + + def test_ssh_cert_for_service_principal(self): + # Any SP can obtain an ssh-cert. Here we use the lab app. + result = get_lab_app().acquire_token_for_client(self.SCOPE, data=self.DATA1) + self.assertIsNotNone(result.get("access_token"), "Encountered {}: {}".format( + result.get("error"), result.get("error_description"))) + self.assertEqual("ssh-cert", result["token_type"]) + + @unittest.skipIf(os.getenv("TRAVIS"), "Browser automation is not yet implemented") + def test_ssh_cert_for_user(self): + result = self._test_acquire_token_interactive( + client_id="04b07795-8ddb-461a-bbee-02f9e1bf7b46", # Azure CLI is one + # of the only 2 clients that are PreAuthz to use ssh cert feature + authority="https://login.microsoftonline.com/common", + scope=self.SCOPE, + data=self.DATA1, + username_uri="https://msidlab.com/api/user?usertype=cloud", + prompt="none" if msal.application._is_running_in_cloud_shell() else None, + ) # It already tests reading AT from cache, and using RT to refresh + # acquire_token_silent() would work because we pass in the same key + self.assertIsNotNone(result.get("access_token"), "Encountered {}: {}".format( + result.get("error"), result.get("error_description"))) + self.assertEqual("ssh-cert", result["token_type"]) + logger.debug("%s.cache = %s", + self.id(), json.dumps(self.app.token_cache._cache, indent=4)) + + # refresh_token grant can fetch an ssh-cert bound to a different key + account = self.app.get_accounts()[0] + refreshed_ssh_cert = self.app.acquire_token_silent( + self.SCOPE, account=account, data=self.DATA2) + self.assertIsNotNone(refreshed_ssh_cert) + self.assertEqual(refreshed_ssh_cert["token_type"], "ssh-cert") + self.assertNotEqual(result["access_token"], refreshed_ssh_cert['access_token']) + + +@unittest.skipUnless( + msal.application._is_running_in_cloud_shell(), + "Manually run this test case from inside Cloud Shell") +class CloudShellTestCase(E2eTestCase): + app = msal.PublicClientApplication("client_id") + scope_that_requires_no_managed_device = "https://management.core.windows.net/" # Scopes came from https://msazure.visualstudio.com/One/_git/compute-CloudShell?path=/src/images/agent/env/envconfig.PROD.json&version=GBmaster&_a=contents + def test_access_token_should_be_obtained_for_a_supported_scope(self): + result = self.app.acquire_token_interactive( + [self.scope_that_requires_no_managed_device], prompt="none") + self.assertEqual( + "Bearer", result.get("token_type"), "Unexpected result: %s" % result) + self.assertIsNotNone(result.get("access_token")) + + +THIS_FOLDER = os.path.dirname(__file__) +CONFIG = os.path.join(THIS_FOLDER, "config.json") +@unittest.skipUnless(os.path.exists(CONFIG), "Optional %s not found" % CONFIG) +class FileBasedTestCase(E2eTestCase): + # This covers scenarios that are not currently available for test automation. + # So they mean to be run on maintainer's machine for semi-automated tests. + + @classmethod + def setUpClass(cls): + with open(CONFIG) as f: + cls.config = json.load(f) + + def skipUnlessWithConfig(self, fields): + for field in fields: + if field not in self.config: + self.skipTest('Skipping due to lack of configuration "%s"' % field) + + def test_username_password(self): + self.skipUnlessWithConfig(["client_id", "username", "password", "scope"]) + self._test_username_password(**self.config) + + def _get_app_and_auth_code(self, scopes=None, **kwargs): + return _get_app_and_auth_code( + self.config["client_id"], + client_secret=self.config.get("client_secret"), + authority=self.config.get("authority"), + port=self.config.get("listen_port", 44331), + scopes=scopes or self.config["scope"], + **kwargs) + + def _test_auth_code(self, auth_kwargs, token_kwargs): + self.skipUnlessWithConfig(["client_id", "scope"]) + (self.app, ac, redirect_uri) = self._get_app_and_auth_code(**auth_kwargs) + result = self.app.acquire_token_by_authorization_code( + ac, self.config["scope"], redirect_uri=redirect_uri, **token_kwargs) + logger.debug("%s.cache = %s", + self.id(), json.dumps(self.app.token_cache._cache, indent=4)) + self.assertIn( + "access_token", result, + "{error}: {error_description}".format( + # Note: No interpolation here, cause error won't always present + error=result.get("error"), + error_description=result.get("error_description"))) + self.assertCacheWorksForUser(result, self.config["scope"], username=None) + + def test_auth_code(self): + self._test_auth_code({}, {}) + + def test_auth_code_with_matching_nonce(self): + self._test_auth_code({"nonce": "foo"}, {"nonce": "foo"}) + + def test_auth_code_with_mismatching_nonce(self): + self.skipUnlessWithConfig(["client_id", "scope"]) + (self.app, ac, redirect_uri) = self._get_app_and_auth_code(nonce="foo") + with self.assertRaises(ValueError): + self.app.acquire_token_by_authorization_code( + ac, self.config["scope"], redirect_uri=redirect_uri, nonce="bar") + + def test_client_secret(self): + self.skipUnlessWithConfig(["client_id", "client_secret"]) + self.app = msal.ConfidentialClientApplication( + self.config["client_id"], + client_credential=self.config.get("client_secret"), + authority=self.config.get("authority"), + http_client=MinimalHttpClient()) + scope = self.config.get("scope", []) + result = self.app.acquire_token_for_client(scope) + self.assertIn('access_token', result) + self.assertCacheWorksForApp(result, scope) + + def test_client_certificate(self): + self.skipUnlessWithConfig(["client_id", "client_certificate"]) + client_cert = self.config["client_certificate"] + assert "private_key_path" in client_cert and "thumbprint" in client_cert + with open(os.path.join(THIS_FOLDER, client_cert['private_key_path'])) as f: + private_key = f.read() # Should be in PEM format + self.app = msal.ConfidentialClientApplication( + self.config['client_id'], + {"private_key": private_key, "thumbprint": client_cert["thumbprint"]}, + http_client=MinimalHttpClient()) + scope = self.config.get("scope", []) + result = self.app.acquire_token_for_client(scope) + self.assertIn('access_token', result) + self.assertCacheWorksForApp(result, scope) + + def test_subject_name_issuer_authentication(self): + self.skipUnlessWithConfig(["client_id", "client_certificate"]) + client_cert = self.config["client_certificate"] + assert "private_key_path" in client_cert and "thumbprint" in client_cert + if not "public_certificate" in client_cert: + self.skipTest("Skipping SNI test due to lack of public_certificate") + with open(os.path.join(THIS_FOLDER, client_cert['private_key_path'])) as f: + private_key = f.read() # Should be in PEM format + with open(os.path.join(THIS_FOLDER, client_cert['public_certificate'])) as f: + public_certificate = f.read() + self.app = msal.ConfidentialClientApplication( + self.config['client_id'], authority=self.config["authority"], + client_credential={ + "private_key": private_key, + "thumbprint": self.config["thumbprint"], + "public_certificate": public_certificate, + }, + http_client=MinimalHttpClient()) + scope = self.config.get("scope", []) + result = self.app.acquire_token_for_client(scope) + self.assertIn('access_token', result) + self.assertCacheWorksForApp(result, scope) + + def test_client_assertion(self): + self.skipUnlessWithConfig(["client_id", "client_assertion"]) + self.app = msal.ConfidentialClientApplication( + self.config['client_id'], authority=self.config["authority"], + client_credential={"client_assertion": self.config["client_assertion"]}, + http_client=MinimalHttpClient()) + scope = self.config.get("scope", []) + result = self.app.acquire_token_for_client(scope) + self.assertIn('access_token', result) + self.assertCacheWorksForApp(result, scope) + +@unittest.skipUnless(os.path.exists(CONFIG), "Optional %s not found" % CONFIG) +class DeviceFlowTestCase(E2eTestCase): # A leaf class so it will be run only once + @classmethod + def setUpClass(cls): + with open(CONFIG) as f: + cls.config = json.load(f) + + def test_device_flow(self): + self._test_device_flow(**self.config) + + +def get_lab_app( + env_client_id="LAB_APP_CLIENT_ID", + env_client_secret="LAB_APP_CLIENT_SECRET", + authority="https://login.microsoftonline.com/" + "72f988bf-86f1-41af-91ab-2d7cd011db47", # Microsoft tenant ID + timeout=None, + **kwargs): + """Returns the lab app as an MSAL confidential client. + + Get it from environment variables if defined, otherwise fall back to use MSI. + """ + logger.info( + "Reading ENV variables %s and %s for lab app defined at " + "https://docs.msidlab.com/accounts/confidentialclient.html", + env_client_id, env_client_secret) + if os.getenv(env_client_id) and os.getenv(env_client_secret): + # A shortcut mainly for running tests on developer's local development machine + # or it could be setup on Travis CI + # https://docs.travis-ci.com/user/environment-variables/#defining-variables-in-repository-settings + # Data came from here + # https://docs.msidlab.com/accounts/confidentialclient.html + client_id = os.getenv(env_client_id) + client_secret = os.getenv(env_client_secret) + else: + logger.info("ENV variables %s and/or %s are not defined. Fall back to MSI.", + env_client_id, env_client_secret) + # See also https://microsoft.sharepoint-df.com/teams/MSIDLABSExtended/SitePages/Programmatically-accessing-LAB-API's.aspx + raise unittest.SkipTest("MSI-based mechanism has not been implemented yet") + return msal.ConfidentialClientApplication( + client_id, + client_credential=client_secret, + authority=authority, + http_client=MinimalHttpClient(timeout=timeout), + **kwargs) + +def get_session(lab_app, scopes): # BTW, this infrastructure tests the confidential client flow + logger.info("Creating session") + result = lab_app.acquire_token_for_client(scopes) + assert result.get("access_token"), \ + "Unable to obtain token for lab. Encountered {}: {}".format( + result.get("error"), result.get("error_description")) + session = requests.Session() + session.headers.update({"Authorization": "Bearer %s" % result["access_token"]}) + session.hooks["response"].append(lambda r, *args, **kwargs: r.raise_for_status()) + return session + + +class LabBasedTestCase(E2eTestCase): + _secrets = {} + adfs2019_scopes = ["placeholder"] # Need this to satisfy MSAL API surface. + # Internally, MSAL will also append more scopes like "openid" etc.. + # ADFS 2019 will issue tokens for valid scope only, by default "openid". + # https://docs.microsoft.com/en-us/windows-server/identity/ad-fs/overview/ad-fs-faq#what-permitted-scopes-are-supported-by-ad-fs + + @classmethod + def setUpClass(cls): + # https://docs.msidlab.com/accounts/apiaccess.html#code-snippet + cls.session = get_session(get_lab_app(), ["https://msidlab.com/.default"]) + + @classmethod + def tearDownClass(cls): + cls.session.close() + + @classmethod + def get_lab_app_object(cls, **query): # https://msidlab.com/swagger/index.html + url = "https://msidlab.com/api/app" + resp = cls.session.get(url, params=query) + result = resp.json()[0] + result["scopes"] = [ # Raw data has extra space, such as "s1, s2" + s.strip() for s in result["defaultScopes"].split(',')] + return result + + @classmethod + def get_lab_user_secret(cls, lab_name="msidlab4"): + lab_name = lab_name.lower() + if lab_name not in cls._secrets: + logger.info("Querying lab user password for %s", lab_name) + url = "https://msidlab.com/api/LabUserSecret?secret=%s" % lab_name + resp = cls.session.get(url) + cls._secrets[lab_name] = resp.json()["value"] + return cls._secrets[lab_name] + + @classmethod + def get_lab_user(cls, **query): # https://docs.msidlab.com/labapi/userapi.html + resp = cls.session.get("https://msidlab.com/api/user", params=query) + result = resp.json()[0] + _env = query.get("azureenvironment", "").lower() + authority_base = { + "azureusgovernment": "https://login.microsoftonline.us/" + }.get(_env, "https://login.microsoftonline.com/") + scope = { + "azureusgovernment": ["https://graph.microsoft.us/.default"], + }.get(_env, ["https://graph.microsoft.com/.default"]) + return { # Mapping lab API response to our simplified configuration format + "authority": authority_base + result["tenantID"], + "client_id": result["appId"], + "username": result["upn"], + "lab_name": result["labName"], + "scope": scope, + } + + def _test_acquire_token_by_auth_code( + self, client_id=None, authority=None, port=None, scope=None, + **ignored): + assert client_id and authority and port and scope + (self.app, ac, redirect_uri) = _get_app_and_auth_code( + client_id, authority=authority, port=port, scopes=scope) + result = self.app.acquire_token_by_authorization_code( + ac, scope, redirect_uri=redirect_uri) + logger.debug( + "%s: cache = %s, id_token_claims = %s", + self.id(), + json.dumps(self.app.token_cache._cache, indent=4), + json.dumps(result.get("id_token_claims"), indent=4), + ) + self.assertIn( + "access_token", result, + "{error}: {error_description}".format( + # Note: No interpolation here, cause error won't always present + error=result.get("error"), + error_description=result.get("error_description"))) + self.assertCacheWorksForUser(result, scope, username=None) + + def _test_acquire_token_by_auth_code_flow( + self, client_id=None, authority=None, port=None, scope=None, + username_uri="", # But you would want to provide one + **ignored): + assert client_id and authority and scope + self.app = msal.ClientApplication( + client_id, authority=authority, http_client=MinimalHttpClient()) + with AuthCodeReceiver(port=port) as receiver: + flow = self.app.initiate_auth_code_flow( + scope, + redirect_uri="http://localhost:%d" % receiver.get_port(), + ) + auth_response = receiver.get_auth_response( + auth_uri=flow["auth_uri"], state=flow["state"], timeout=60, + welcome_template="""

{id}

    +
  1. Get a username from the upn shown at here
  2. +
  3. Get its password from https://aka.ms/GetLabUserSecret?Secret=msidlabXYZ + (replace the lab name with the labName from the link above).
  4. +
  5. Sign In or Abort
  6. +
""".format(id=self.id(), username_uri=username_uri), + ) + if auth_response is None: + self.skipTest("Timed out. Did not have test settings in hand? Prepare and retry.") + self.assertIsNotNone( + auth_response.get("code"), "Error: {}, Detail: {}".format( + auth_response.get("error"), auth_response)) + result = self.app.acquire_token_by_auth_code_flow(flow, auth_response) + logger.debug( + "%s: cache = %s, id_token_claims = %s", + self.id(), + json.dumps(self.app.token_cache._cache, indent=4), + json.dumps(result.get("id_token_claims"), indent=4), + ) + self.assertIn( + "access_token", result, + "{error}: {error_description}".format( + # Note: No interpolation here, cause error won't always present + error=result.get("error"), + error_description=result.get("error_description"))) + self.assertCacheWorksForUser(result, scope, username=None) + + def _test_acquire_token_obo(self, config_pca, config_cca, + azure_region=None, # Regional endpoint does not really support OBO. + # Here we just test regional apps won't adversely break OBO + http_client=None, + ): + if "client_secret" not in config_pca: + # 1.a An app obtains a token representing a user, for our mid-tier service + result = msal.PublicClientApplication( + config_pca["client_id"], authority=config_pca["authority"], + azure_region=azure_region, + http_client=http_client or MinimalHttpClient(), + ).acquire_token_by_username_password( + config_pca["username"], config_pca["password"], + scopes=config_pca["scope"], + ) + else: # We repurpose the config_pca to contain client_secret for cca app 1 + # 1.b An app obtains a token representing itself, for our mid-tier service + result = msal.ConfidentialClientApplication( + config_pca["client_id"], authority=config_pca["authority"], + client_credential=config_pca["client_secret"], + azure_region=azure_region, + http_client=http_client or MinimalHttpClient(), + ).acquire_token_for_client(scopes=config_pca["scope"]) + assertion = result.get("access_token") + self.assertIsNotNone(assertion, "First app failed to get AT. {}".format( + json.dumps(result, indent=2))) + + # 2. Our mid-tier service uses OBO to obtain a token for downstream service + cca = msal.ConfidentialClientApplication( + config_cca["client_id"], + client_credential=config_cca["client_secret"], + authority=config_cca["authority"], + azure_region=azure_region, + http_client=http_client or MinimalHttpClient(), + # token_cache= ..., # Default token cache is all-tokens-store-in-memory. + # That's fine if OBO app uses short-lived msal instance per session. + # Otherwise, the OBO app need to implement a one-cache-per-user setup. + ) + cca_result = cca.acquire_token_on_behalf_of(assertion, config_cca["scope"]) + self.assertIsNotNone(cca_result.get("access_token"), "OBO call failed: {}".format( + json.dumps(cca_result, indent=2))) + + # 3. Now the OBO app can simply store downstream token(s) in same session. + # Alternatively, if you want to persist the downstream AT, and possibly + # the RT (if any) for prolonged access even after your own AT expires, + # now it is the time to persist current cache state for current user. + # Assuming you already did that (which is not shown in this test case), + # the following part shows one of the ways to obtain an AT from cache. + username = cca_result.get("id_token_claims", {}).get("preferred_username") + accounts = cca.get_accounts(username=username) + if username is not None: # It means CCA have requested an IDT w/ "profile" scope + assert config_cca["username"] == username, "Incorrect test case configuration" + self.assertEqual(1, len(accounts), "App is supposed to partition token cache per user") + account = accounts[0] # Alternatively, cca app could just loop through each account + result = cca.acquire_token_silent(config_cca["scope"], account) + self.assertTrue( + result and result.get("access_token") == cca_result["access_token"], + "CCA should hit an access token from cache: {}".format( + json.dumps(cca.token_cache._cache, indent=2))) + if "refresh_token" in cca_result: + result = cca.acquire_token_silent( + config_cca["scope"], account=account, force_refresh=True) + self.assertTrue( + result and "access_token" in result, + "CCA should get an AT silently, but we got this instead: {}".format(result)) + self.assertNotEqual( + result["access_token"], cca_result["access_token"], + "CCA should get a new AT") + else: + logger.info("AAD did not issue a RT for OBO flow") + + def _test_acquire_token_by_client_secret( + self, client_id=None, client_secret=None, authority=None, scope=None, + **ignored): + assert client_id and client_secret and authority and scope + app = msal.ConfidentialClientApplication( + client_id, client_credential=client_secret, authority=authority, + http_client=MinimalHttpClient()) + result = app.acquire_token_for_client(scope) + self.assertIsNotNone(result.get("access_token"), "Got %s instead" % result) + result2 = app.acquire_token_silent(scope, account=None) + self.assertEqual( + result2.get("access_token"), result["access_token"], + "CCA should hit an access token from cache: {}".format( + json.dumps(app.token_cache._cache, indent=2)) + ) + if "refresh_token" in result: # Empirically, RT is unavailable, but just in case... + result3 = app.acquire_token_silent(scope, account=None, force_refresh=True) + error_message = "CCA should get a new AT via RT in cache: {}".format( + json.dumps(app.token_cache._cache, indent=2)) + self.assertIsNotNone(result3, error_message) + self.assertNotEqual(result3.get("access_token"), result["access_token"], error_message) + + +class WorldWideTestCase(LabBasedTestCase): + + def test_aad_managed_user(self): # Pure cloud + config = self.get_lab_user(usertype="cloud") + config["password"] = self.get_lab_user_secret(config["lab_name"]) + self._test_username_password(**config) + + def test_adfs4_fed_user(self): + config = self.get_lab_user(usertype="federated", federationProvider="ADFSv4") + config["password"] = self.get_lab_user_secret(config["lab_name"]) + self._test_username_password(**config) + + def test_adfs3_fed_user(self): + config = self.get_lab_user(usertype="federated", federationProvider="ADFSv3") + config["password"] = self.get_lab_user_secret(config["lab_name"]) + self._test_username_password(**config) + + def test_adfs2_fed_user(self): + config = self.get_lab_user(usertype="federated", federationProvider="ADFSv2") + config["password"] = self.get_lab_user_secret(config["lab_name"]) + self._test_username_password(**config) + + def test_adfs2019_fed_user(self): + try: + config = self.get_lab_user(usertype="federated", federationProvider="ADFSv2019") + config["password"] = self.get_lab_user_secret(config["lab_name"]) + self._test_username_password(**config) + except requests.exceptions.HTTPError: + if os.getenv("TRAVIS"): + self.skipTest("MEX endpoint in our test environment tends to fail") + raise + + @unittest.skipIf(os.getenv("TRAVIS"), "Browser automation is not yet implemented") + def test_cloud_acquire_token_interactive(self): + config = self.get_lab_user(usertype="cloud") + self._test_acquire_token_interactive( + username_uri="https://msidlab.com/api/user?usertype=cloud", + **config) + + def test_ropc_adfs2019_onprem(self): + # Configuration is derived from https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.7.0/tests/Microsoft.Identity.Test.Common/TestConstants.cs#L250-L259 + config = self.get_lab_user(usertype="onprem", federationProvider="ADFSv2019") + config["authority"] = "https://fs.%s.com/adfs" % config["lab_name"] + config["scope"] = self.adfs2019_scopes + config["password"] = self.get_lab_user_secret(config["lab_name"]) + self._test_username_password(**config) + + @unittest.skipIf(os.getenv("TRAVIS"), "Browser automation is not yet implemented") + def test_adfs2019_onprem_acquire_token_by_auth_code(self): + """When prompted, you can manually login using this account: + + # https://msidlab.com/api/user?usertype=onprem&federationprovider=ADFSv2019 + username = "..." # The upn from the link above + password="***" # From https://aka.ms/GetLabUserSecret?Secret=msidlabXYZ + """ + config = self.get_lab_user(usertype="onprem", federationProvider="ADFSv2019") + config["authority"] = "https://fs.%s.com/adfs" % config["lab_name"] + config["scope"] = self.adfs2019_scopes + config["port"] = 8080 + self._test_acquire_token_by_auth_code(**config) + + @unittest.skipIf(os.getenv("TRAVIS"), "Browser automation is not yet implemented") + def test_adfs2019_onprem_acquire_token_by_auth_code_flow(self): + config = self.get_lab_user(usertype="onprem", federationProvider="ADFSv2019") + config["authority"] = "https://fs.%s.com/adfs" % config["lab_name"] + config["scope"] = self.adfs2019_scopes + config["port"] = 8080 + self._test_acquire_token_by_auth_code_flow( + username_uri="https://msidlab.com/api/user?usertype=onprem&federationprovider=ADFSv2019", + **config) + + @unittest.skipIf(os.getenv("TRAVIS"), "Browser automation is not yet implemented") + def test_adfs2019_onprem_acquire_token_interactive(self): + config = self.get_lab_user(usertype="onprem", federationProvider="ADFSv2019") + config["authority"] = "https://fs.%s.com/adfs" % config["lab_name"] + config["scope"] = self.adfs2019_scopes + config["port"] = 8080 + self._test_acquire_token_interactive( + username_uri="https://msidlab.com/api/user?usertype=onprem&federationprovider=ADFSv2019", + **config) + + @unittest.skipUnless( + os.getenv("LAB_OBO_CLIENT_SECRET"), + "Need LAB_OBO_CLIENT_SECRET from https://aka.ms/GetLabSecret?Secret=TodoListServiceV2-OBO") + @unittest.skipUnless( + os.getenv("LAB_OBO_CONFIDENTIAL_CLIENT_ID"), + "Need LAB_OBO_CONFIDENTIAL_CLIENT_ID from https://docs.msidlab.com/flows/onbehalfofflow.html") + @unittest.skipUnless( + os.getenv("LAB_OBO_PUBLIC_CLIENT_ID"), + "Need LAB_OBO_PUBLIC_CLIENT_ID from https://docs.msidlab.com/flows/onbehalfofflow.html") + def test_acquire_token_obo(self): + config = self.get_lab_user(usertype="cloud") + + config_cca = {} + config_cca.update(config) + config_cca["client_id"] = os.getenv("LAB_OBO_CONFIDENTIAL_CLIENT_ID") + config_cca["scope"] = ["https://graph.microsoft.com/.default"] + config_cca["client_secret"] = os.getenv("LAB_OBO_CLIENT_SECRET") + + config_pca = {} + config_pca.update(config) + config_pca["client_id"] = os.getenv("LAB_OBO_PUBLIC_CLIENT_ID") + config_pca["password"] = self.get_lab_user_secret(config_pca["lab_name"]) + config_pca["scope"] = ["api://%s/read" % config_cca["client_id"]] + + self._test_acquire_token_obo(config_pca, config_cca) + + @unittest.skipUnless( + os.path.exists("tests/sp_obo.pem"), + "Need a 'tests/sp_obo.pem' private to run OBO for SP test") + def test_acquire_token_obo_for_sp(self): + authority = "https://login.windows-ppe.net/f686d426-8d16-42db-81b7-ab578e110ccd" + with open("tests/sp_obo.pem") as pem: + client_secret = { + "private_key": pem.read(), + "thumbprint": "378938210C976692D7F523B8C4FFBB645D17CE92", + } + midtier_app = { + "authority": authority, + "client_id": "c84e9c32-0bc9-4a73-af05-9efe9982a322", + "client_secret": client_secret, + "scope": ["23d08a1e-1249-4f7c-b5a5-cb11f29b6923/.default"], + #"username": "OBO-Client-PPE", # We do NOT attempt locating initial_app by name + } + initial_app = { + "authority": authority, + "client_id": "9793041b-9078-4942-b1d2-babdc472cc0c", + "client_secret": client_secret, + "scope": [midtier_app["client_id"] + "/.default"], + } + self._test_acquire_token_obo(initial_app, midtier_app) + + def test_acquire_token_by_client_secret(self): + # Vastly different than ArlingtonCloudTestCase.test_acquire_token_by_client_secret() + _app = self.get_lab_app_object( + publicClient="no", signinAudience="AzureAdMyOrg") + self._test_acquire_token_by_client_secret( + client_id=_app["appId"], + client_secret=self.get_lab_user_secret( + _app["clientSecret"].split("/")[-1]), + authority="{}{}.onmicrosoft.com".format( + _app["authority"], _app["labName"].lower().rstrip(".com")), + scope=["https://graph.microsoft.com/.default"], + ) + + @unittest.skipUnless( + os.getenv("LAB_OBO_CLIENT_SECRET"), + "Need LAB_OBO_CLIENT_SECRET from https://aka.ms/GetLabSecret?Secret=TodoListServiceV2-OBO") + @unittest.skipUnless( + os.getenv("LAB_OBO_CONFIDENTIAL_CLIENT_ID"), + "Need LAB_OBO_CONFIDENTIAL_CLIENT_ID from https://docs.msidlab.com/flows/onbehalfofflow.html") + def test_confidential_client_acquire_token_by_username_password(self): + # This approach won't work: + # config = self.get_lab_user(usertype="cloud", publicClient="no") + # so we repurpose the obo confidential app to test ROPC + config = self.get_lab_user(usertype="cloud") + config["password"] = self.get_lab_user_secret(config["lab_name"]) + # Swap in the OBO confidential app + config["client_id"] = os.getenv("LAB_OBO_CONFIDENTIAL_CLIENT_ID") + config["scope"] = ["https://graph.microsoft.com/.default"] + config["client_secret"] = os.getenv("LAB_OBO_CLIENT_SECRET") + self._test_username_password(**config) + + def _build_b2c_authority(self, policy): + base = "https://msidlabb2c.b2clogin.com/msidlabb2c.onmicrosoft.com" + return base + "/" + policy # We do not support base + "?p=" + policy + + @unittest.skipIf(os.getenv("TRAVIS"), "Browser automation is not yet implemented") + def test_b2c_acquire_token_by_auth_code(self): + """ + When prompted, you can manually login using this account: + + username="b2clocal@msidlabb2c.onmicrosoft.com" + # This won't work https://msidlab.com/api/user?usertype=b2c + password="***" # From https://aka.ms/GetLabUserSecret?Secret=msidlabb2c + """ + config = self.get_lab_app_object(azureenvironment="azureb2ccloud") + self._test_acquire_token_by_auth_code( + authority=self._build_b2c_authority("B2C_1_SignInPolicy"), + client_id=config["appId"], + port=3843, # Lab defines 4 of them: [3843, 4584, 4843, 60000] + scope=config["scopes"], + ) + + @unittest.skipIf(os.getenv("TRAVIS"), "Browser automation is not yet implemented") + def test_b2c_acquire_token_by_auth_code_flow(self): + config = self.get_lab_app_object(azureenvironment="azureb2ccloud") + self._test_acquire_token_by_auth_code_flow( + authority=self._build_b2c_authority("B2C_1_SignInPolicy"), + client_id=config["appId"], + port=3843, # Lab defines 4 of them: [3843, 4584, 4843, 60000] + scope=config["scopes"], + username_uri="https://msidlab.com/api/user?usertype=b2c&b2cprovider=local", + ) + + def test_b2c_acquire_token_by_ropc(self): + config = self.get_lab_app_object(azureenvironment="azureb2ccloud") + self._test_username_password( + authority=self._build_b2c_authority("B2C_1_ROPC_Auth"), + client_id=config["appId"], + username="b2clocal@msidlabb2c.onmicrosoft.com", + password=self.get_lab_user_secret("msidlabb2c"), + scope=config["scopes"], + ) + + +class WorldWideRegionalEndpointTestCase(LabBasedTestCase): + region = "westus" + timeout = 2 # Short timeout makes this test case responsive on non-VM + + def _test_acquire_token_for_client(self, configured_region, expected_region): + """This is the only grant supported by regional endpoint, for now""" + self.app = get_lab_app( # Regional endpoint only supports confidential client + + ## FWIW, the MSAL<1.12 versions could use this to achieve similar result + #authority="https://westus.login.microsoft.com/microsoft.onmicrosoft.com", + #validate_authority=False, + authority="https://login.microsoftonline.com/microsoft.onmicrosoft.com", + azure_region=configured_region, + timeout=2, # Short timeout makes this test case responsive on non-VM + ) + scopes = ["https://graph.microsoft.com/.default"] + + with patch.object( # Test the request hit the regional endpoint + self.app.http_client, "post", return_value=MinimalResponse( + status_code=400, text='{"error": "mock"}')) as mocked_method: + self.app.acquire_token_for_client(scopes) + expected_host = '{}.r.login.microsoftonline.com'.format( + expected_region) if expected_region else 'login.microsoftonline.com' + mocked_method.assert_called_with( + 'https://{}/{}/oauth2/v2.0/token'.format( + expected_host, self.app.authority.tenant), + params=ANY, data=ANY, headers=ANY) + result = self.app.acquire_token_for_client( + scopes, + params={"AllowEstsRNonMsi": "true"}, # For testing regional endpoint. It will be removed once MSAL Python 1.12+ has been onboard to ESTS-R + ) + self.assertIn('access_token', result) + self.assertCacheWorksForApp(result, scopes) + + def test_acquire_token_for_client_should_hit_global_endpoint_by_default(self): + self._test_acquire_token_for_client(None, None) + + def test_acquire_token_for_client_should_ignore_env_var_by_default(self): + os.environ["REGION_NAME"] = "eastus" + self._test_acquire_token_for_client(None, None) + del os.environ["REGION_NAME"] + + def test_acquire_token_for_client_should_use_a_specified_region(self): + self._test_acquire_token_for_client("westus", "westus") + + def test_acquire_token_for_client_should_use_an_env_var_with_short_region_name(self): + os.environ["REGION_NAME"] = "eastus" + self._test_acquire_token_for_client( + msal.ConfidentialClientApplication.ATTEMPT_REGION_DISCOVERY, "eastus") + del os.environ["REGION_NAME"] + + def test_acquire_token_for_client_should_use_an_env_var_with_long_region_name(self): + os.environ["REGION_NAME"] = "East Us 2" + self._test_acquire_token_for_client( + msal.ConfidentialClientApplication.ATTEMPT_REGION_DISCOVERY, "eastus2") + del os.environ["REGION_NAME"] + + @unittest.skipUnless( + os.getenv("LAB_OBO_CLIENT_SECRET"), + "Need LAB_OBO_CLIENT_SECRET from https://aka.ms/GetLabSecret?Secret=TodoListServiceV2-OBO") + @unittest.skipUnless( + os.getenv("LAB_OBO_CONFIDENTIAL_CLIENT_ID"), + "Need LAB_OBO_CONFIDENTIAL_CLIENT_ID from https://docs.msidlab.com/flows/onbehalfofflow.html") + @unittest.skipUnless( + os.getenv("LAB_OBO_PUBLIC_CLIENT_ID"), + "Need LAB_OBO_PUBLIC_CLIENT_ID from https://docs.msidlab.com/flows/onbehalfofflow.html") + def test_cca_obo_should_bypass_regional_endpoint_therefore_still_work(self): + """We test OBO because it is implemented in sub class ConfidentialClientApplication""" + config = self.get_lab_user(usertype="cloud") + + config_cca = {} + config_cca.update(config) + config_cca["client_id"] = os.getenv("LAB_OBO_CONFIDENTIAL_CLIENT_ID") + config_cca["scope"] = ["https://graph.microsoft.com/.default"] + config_cca["client_secret"] = os.getenv("LAB_OBO_CLIENT_SECRET") + + config_pca = {} + config_pca.update(config) + config_pca["client_id"] = os.getenv("LAB_OBO_PUBLIC_CLIENT_ID") + config_pca["password"] = self.get_lab_user_secret(config_pca["lab_name"]) + config_pca["scope"] = ["api://%s/read" % config_cca["client_id"]] + + self._test_acquire_token_obo( + config_pca, config_cca, + azure_region=self.region, + http_client=MinimalHttpClient(timeout=self.timeout), + ) + + @unittest.skipUnless( + os.getenv("LAB_OBO_CLIENT_SECRET"), + "Need LAB_OBO_CLIENT_SECRET from https://aka.ms/GetLabSecret?Secret=TodoListServiceV2-OBO") + @unittest.skipUnless( + os.getenv("LAB_OBO_CONFIDENTIAL_CLIENT_ID"), + "Need LAB_OBO_CONFIDENTIAL_CLIENT_ID from https://docs.msidlab.com/flows/onbehalfofflow.html") + def test_cca_ropc_should_bypass_regional_endpoint_therefore_still_work(self): + """We test ROPC because it is implemented in base class ClientApplication""" + config = self.get_lab_user(usertype="cloud") + config["password"] = self.get_lab_user_secret(config["lab_name"]) + # We repurpose the obo confidential app to test ROPC + # Swap in the OBO confidential app + config["client_id"] = os.getenv("LAB_OBO_CONFIDENTIAL_CLIENT_ID") + config["scope"] = ["https://graph.microsoft.com/.default"] + config["client_secret"] = os.getenv("LAB_OBO_CLIENT_SECRET") + self._test_username_password( + azure_region=self.region, + http_client=MinimalHttpClient(timeout=self.timeout), + **config) + + +class ArlingtonCloudTestCase(LabBasedTestCase): + environment = "azureusgovernment" + + def test_acquire_token_by_ropc(self): + config = self.get_lab_user(azureenvironment=self.environment) + config["password"] = self.get_lab_user_secret(config["lab_name"]) + self._test_username_password(**config) + + def test_acquire_token_by_client_secret(self): + config = self.get_lab_user(usertype="cloud", azureenvironment=self.environment, publicClient="no") + config["client_secret"] = self.get_lab_user_secret("ARLMSIDLAB1-IDLASBS-App-CC-Secret") + self._test_acquire_token_by_client_secret(**config) + + def test_acquire_token_obo(self): + config_cca = self.get_lab_user( + usertype="cloud", azureenvironment=self.environment, publicClient="no") + config_cca["scope"] = ["https://graph.microsoft.us/.default"] + config_cca["client_secret"] = self.get_lab_user_secret("ARLMSIDLAB1-IDLASBS-App-CC-Secret") + + config_pca = self.get_lab_user(usertype="cloud", azureenvironment=self.environment, publicClient="yes") + obo_app_object = self.get_lab_app_object( + usertype="cloud", azureenvironment=self.environment, publicClient="no") + config_pca["password"] = self.get_lab_user_secret(config_pca["lab_name"]) + config_pca["scope"] = ["{app_uri}/files.read".format(app_uri=obo_app_object.get("identifierUris"))] + + self._test_acquire_token_obo(config_pca, config_cca) + + def test_acquire_token_device_flow(self): + config = self.get_lab_user(usertype="cloud", azureenvironment=self.environment, publicClient="yes") + config["scope"] = ["user.read"] + self._test_device_flow(**config) + + def test_acquire_token_silent_with_an_empty_cache_should_return_none(self): + config = self.get_lab_user( + usertype="cloud", azureenvironment=self.environment, publicClient="no") + app = msal.ConfidentialClientApplication( + config['client_id'], authority=config['authority'], + http_client=MinimalHttpClient()) + result = app.acquire_token_silent(scopes=config['scope'], account=None) + self.assertEqual(result, None) + # Note: An alias in this region is no longer accepting HTTPS traffic. + # If this test case passes without exception, + # it means MSAL Python is not affected by that. + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_individual_cache.py b/tests/test_individual_cache.py new file mode 100644 index 00000000..38bd572d --- /dev/null +++ b/tests/test_individual_cache.py @@ -0,0 +1,93 @@ +from time import sleep +from random import random +import unittest +from msal.individual_cache import _ExpiringMapping as ExpiringMapping +from msal.individual_cache import _IndividualCache as IndividualCache + + +class TestExpiringMapping(unittest.TestCase): + def setUp(self): + self.mapping = {} + self.m = ExpiringMapping(mapping=self.mapping, capacity=2, expires_in=1) + + def test_should_disallow_accessing_reserved_keyword(self): + with self.assertRaises(ValueError): + self.m.get(ExpiringMapping._INDEX) + + def test_setitem(self): + self.assertEqual(0, len(self.m)) + self.m["thing one"] = "one" + self.assertIn(ExpiringMapping._INDEX, self.mapping, "Index created") + self.assertEqual(1, len(self.m), "It contains one item (excluding index)") + self.assertEqual("one", self.m["thing one"]) + self.assertEqual(["thing one"], list(self.m)) + + def test_set(self): + self.assertEqual(0, len(self.m)) + self.m.set("thing two", "two", 2) + self.assertIn(ExpiringMapping._INDEX, self.mapping, "Index created") + self.assertEqual(1, len(self.m), "It contains one item (excluding index)") + self.assertEqual("two", self.m["thing two"]) + self.assertEqual(["thing two"], list(self.m)) + + def test_len_should_purge(self): + self.m["thing one"] = "one" + sleep(1) + self.assertEqual(0, len(self.m)) + + def test_iter_should_purge(self): + self.m["thing one"] = "one" + sleep(1) + self.assertEqual([], list(self.m)) + + def test_get_should_purge(self): + self.m["thing one"] = "one" + sleep(1) + with self.assertRaises(KeyError): + self.m["thing one"] + + def test_various_expiring_time(self): + self.assertEqual(0, len(self.m)) + self.m["thing one"] = "one" + self.m.set("thing two", "two", 2) + self.assertEqual(2, len(self.m), "It contains 2 items") + sleep(1) + self.assertEqual(["thing two"], list(self.m), "One expires, another remains") + + def test_old_item_can_be_updated_with_new_expiry_time(self): + self.assertEqual(0, len(self.m)) + self.m["thing"] = "one" + self.m.set("thing", "two", 2) + self.assertEqual(1, len(self.m), "It contains 1 item") + self.assertEqual("two", self.m["thing"], 'Already been updated to "two"') + sleep(1) + self.assertEqual("two", self.m["thing"], "Not yet expires") + sleep(1) + self.assertEqual(0, len(self.m)) + + def test_oversized_input_should_purge_most_aging_item(self): + self.assertEqual(0, len(self.m)) + self.m["thing one"] = "one" + self.m.set("thing two", "two", 2) + self.assertEqual(2, len(self.m), "It contains 2 items") + self.m["thing three"] = "three" + self.assertEqual(2, len(self.m), "It contains 2 items") + self.assertNotIn("thing one", self.m) + + +class TestIndividualCache(unittest.TestCase): + mapping = {} + + @IndividualCache(mapping=mapping) + def foo(self, a, b, c=None, d=None): + return random() # So that we'd know whether a new response is received + + def test_memorize_a_function_call(self): + self.assertNotEqual(self.foo(1, 1), self.foo(2, 2)) + self.assertEqual( + self.foo(1, 2, c=3, d=4), + self.foo(1, 2, c=3, d=4), + "Subsequent run should obtain same result from cache") + # Note: In Python 3.7+, dict is ordered, so the following is typically True: + #self.assertNotEqual(self.foo(a=1, b=2), self.foo(b=2, a=1)) + diff --git a/tests/test_mex.py b/tests/test_mex.py new file mode 100644 index 00000000..fe330f71 --- /dev/null +++ b/tests/test_mex.py @@ -0,0 +1,28 @@ +import os + +from tests import unittest +from msal.mex import * + + +THIS_FOLDER = os.path.dirname(__file__) + +class TestMex(unittest.TestCase): + + def _test_parser(self, sample, expected_endpoint): + with open(os.path.join(THIS_FOLDER, sample)) as sample_file: + endpoint = Mex(mex_document=sample_file.read() + ).get_wstrust_username_password_endpoint()["address"] + self.assertEqual(expected_endpoint, endpoint) + + def test_happy_path_1(self): + self._test_parser("microsoft.mex.xml", + 'https://corp.sts.microsoft.com/adfs/services/trust/13/usernamemixed') + + def test_happy_path_2(self): + self._test_parser('arupela.mex.xml', + 'https://fs.arupela.com/adfs/services/trust/13/usernamemixed') + + def test_happy_path_3(self): + self._test_parser('archan.us.mex.xml', + 'https://arvmserver2012.archan.us/adfs/services/trust/13/usernamemixed') + diff --git a/tests/test_throttled_http_client.py b/tests/test_throttled_http_client.py new file mode 100644 index 00000000..93820505 --- /dev/null +++ b/tests/test_throttled_http_client.py @@ -0,0 +1,179 @@ +# Test cases for https://identitydivision.visualstudio.com/devex/_git/AuthLibrariesApiReview?version=GBdev&path=%2FService%20protection%2FIntial%20set%20of%20protection%20measures.md&_a=preview&anchor=common-test-cases +from time import sleep +from random import random +import logging +from msal.throttled_http_client import ThrottledHttpClient +from tests import unittest +from tests.http_client import MinimalResponse + + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.DEBUG) + + +class DummyHttpResponse(MinimalResponse): + def __init__(self, headers=None, **kwargs): + self.headers = {} if headers is None else headers + super(DummyHttpResponse, self).__init__(**kwargs) + + +class DummyHttpClient(object): + def __init__(self, status_code=None, response_headers=None): + self._status_code = status_code + self._response_headers = response_headers + + def _build_dummy_response(self): + return DummyHttpResponse( + status_code=self._status_code, + headers=self._response_headers, + text=random(), # So that we'd know whether a new response is received + ) + + def post(self, url, params=None, data=None, headers=None, **kwargs): + return self._build_dummy_response() + + def get(self, url, params=None, headers=None, **kwargs): + return self._build_dummy_response() + + def close(self): + raise CloseMethodCalled("Not used by MSAL, but our customers may use it") + + +class CloseMethodCalled(Exception): + pass + + +class TestHttpDecoration(unittest.TestCase): + + def test_throttled_http_client_should_not_alter_original_http_client(self): + http_cache = {} + original_http_client = DummyHttpClient() + original_get = original_http_client.get + original_post = original_http_client.post + throttled_http_client = ThrottledHttpClient(original_http_client, http_cache) + goal = """The implementation should wrap original http_client + and keep it intact, instead of monkey-patching it""" + self.assertNotEqual(throttled_http_client, original_http_client, goal) + self.assertEqual(original_post, original_http_client.post) + self.assertEqual(original_get, original_http_client.get) + + def _test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds( + self, http_client, retry_after): + http_cache = {} + http_client = ThrottledHttpClient(http_client, http_cache) + resp1 = http_client.post("https://example.com") # We implemented POST only + resp2 = http_client.post("https://example.com") # We implemented POST only + logger.debug(http_cache) + self.assertEqual(resp1.text, resp2.text, "Should return a cached response") + sleep(retry_after + 1) + resp3 = http_client.post("https://example.com") # We implemented POST only + self.assertNotEqual(resp1.text, resp3.text, "Should return a new response") + + def test_429_with_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(self): + retry_after = 1 + self._test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds( + DummyHttpClient( + status_code=429, response_headers={"Retry-After": retry_after}), + retry_after) + + def test_5xx_with_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(self): + retry_after = 1 + self._test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds( + DummyHttpClient( + status_code=503, response_headers={"Retry-After": retry_after}), + retry_after) + + def test_400_with_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(self): + """Retry-After is supposed to only shown in http 429/5xx, + but we choose to support Retry-After for arbitrary http response.""" + retry_after = 1 + self._test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds( + DummyHttpClient( + status_code=400, response_headers={"Retry-After": retry_after}), + retry_after) + + def test_one_RetryAfter_request_should_block_a_similar_request(self): + http_cache = {} + http_client = DummyHttpClient( + status_code=429, response_headers={"Retry-After": 2}) + http_client = ThrottledHttpClient(http_client, http_cache) + resp1 = http_client.post("https://example.com", data={ + "scope": "one", "claims": "bar", "grant_type": "authorization_code"}) + resp2 = http_client.post("https://example.com", data={ + "scope": "one", "claims": "foo", "grant_type": "password"}) + logger.debug(http_cache) + self.assertEqual(resp1.text, resp2.text, "Should return a cached response") + + def test_one_RetryAfter_request_should_not_block_a_different_request(self): + http_cache = {} + http_client = DummyHttpClient( + status_code=429, response_headers={"Retry-After": 2}) + http_client = ThrottledHttpClient(http_client, http_cache) + resp1 = http_client.post("https://example.com", data={"scope": "one"}) + resp2 = http_client.post("https://example.com", data={"scope": "two"}) + logger.debug(http_cache) + self.assertNotEqual(resp1.text, resp2.text, "Should return a new response") + + def test_one_invalid_grant_should_block_a_similar_request(self): + http_cache = {} + http_client = DummyHttpClient( + status_code=400) # It covers invalid_grant and interaction_required + http_client = ThrottledHttpClient(http_client, http_cache) + resp1 = http_client.post("https://example.com", data={"claims": "foo"}) + logger.debug(http_cache) + resp1_again = http_client.post("https://example.com", data={"claims": "foo"}) + self.assertEqual(resp1.text, resp1_again.text, "Should return a cached response") + resp2 = http_client.post("https://example.com", data={"claims": "bar"}) + self.assertNotEqual(resp1.text, resp2.text, "Should return a new response") + resp2_again = http_client.post("https://example.com", data={"claims": "bar"}) + self.assertEqual(resp2.text, resp2_again.text, "Should return a cached response") + + def test_one_foci_app_recovering_from_invalid_grant_should_also_unblock_another(self): + """ + Need not test multiple FOCI app's acquire_token_silent() here. By design, + one FOCI app's successful populating token cache would result in another + FOCI app's acquire_token_silent() to hit a token without invoking http request. + """ + + def test_forcefresh_behavior(self): + """ + The implementation let token cache and http cache operate in different + layers. They do not couple with each other. + Therefore, acquire_token_silent(..., force_refresh=True) + would bypass the token cache yet technically still hit the http cache. + + But that is OK, cause the customer need no force_refresh in the first place. + After a successful AT/RT acquisition, AT/RT will be in the token cache, + and a normal acquire_token_silent(...) without force_refresh would just work. + This was discussed in https://identitydivision.visualstudio.com/DevEx/_git/AuthLibrariesApiReview/pullrequest/3618?_a=files + """ + + def test_http_get_200_should_be_cached(self): + http_cache = {} + http_client = DummyHttpClient( + status_code=200) # It covers UserRealm discovery and OIDC discovery + http_client = ThrottledHttpClient(http_client, http_cache) + resp1 = http_client.get("https://example.com?foo=bar") + resp2 = http_client.get("https://example.com?foo=bar") + logger.debug(http_cache) + self.assertEqual(resp1.text, resp2.text, "Should return a cached response") + + def test_device_flow_retry_should_not_be_cached(self): + DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code" + http_cache = {} + http_client = DummyHttpClient(status_code=400) + http_client = ThrottledHttpClient(http_client, http_cache) + resp1 = http_client.post( + "https://example.com", data={"grant_type": DEVICE_AUTH_GRANT}) + resp2 = http_client.post( + "https://example.com", data={"grant_type": DEVICE_AUTH_GRANT}) + logger.debug(http_cache) + self.assertNotEqual(resp1.text, resp2.text, "Should return a new response") + + def test_throttled_http_client_should_provide_close(self): + http_cache = {} + http_client = DummyHttpClient(status_code=200) + http_client = ThrottledHttpClient(http_client, http_cache) + with self.assertRaises(CloseMethodCalled): + http_client.close() + diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py new file mode 100644 index 00000000..2fe486c2 --- /dev/null +++ b/tests/test_token_cache.py @@ -0,0 +1,288 @@ +import logging +import base64 +import json +import time + +from msal.token_cache import * +from tests import unittest + + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.DEBUG) + + +# NOTE: These helpers were once implemented as static methods in TokenCacheTestCase. +# That would cause other test files' "from ... import TokenCacheTestCase" +# to re-run all test cases in this file. +# Now we avoid that, by defining these helpers in module level. +def build_id_token( + iss="issuer", sub="subject", aud="my_client_id", exp=None, iat=None, + **claims): # AAD issues "preferred_username", ADFS issues "upn" + return "header.%s.signature" % base64.b64encode(json.dumps(dict({ + "iss": iss, + "sub": sub, + "aud": aud, + "exp": exp or (time.time() + 100), + "iat": iat or time.time(), + }, **claims)).encode()).decode('utf-8') + + +def build_response( # simulate a response from AAD + uid=None, utid=None, # If present, they will form client_info + access_token=None, expires_in=3600, token_type="some type", + **kwargs # Pass-through: refresh_token, foci, id_token, error, refresh_in, ... + ): + response = {} + if uid and utid: # Mimic the AAD behavior for "client_info=1" request + response["client_info"] = base64.b64encode(json.dumps({ + "uid": uid, "utid": utid, + }).encode()).decode('utf-8') + if access_token: + response.update({ + "access_token": access_token, + "expires_in": expires_in, + "token_type": token_type, + }) + response.update(kwargs) # Pass-through key-value pairs as top-level fields + return response + + +class TokenCacheTestCase(unittest.TestCase): + + def setUp(self): + self.cache = TokenCache() + + def testAddByAad(self): + client_id = "my_client_id" + id_token = build_id_token( + oid="object1234", preferred_username="John Doe", aud=client_id) + self.cache.add({ + "client_id": client_id, + "scope": ["s2", "s1", "s3"], # Not in particular order + "token_endpoint": "https://login.example.com/contoso/v2/token", + "response": build_response( + uid="uid", utid="utid", # client_info + expires_in=3600, access_token="an access token", + id_token=id_token, refresh_token="a refresh token"), + }, now=1000) + self.assertEqual( + { + 'cached_at': "1000", + 'client_id': 'my_client_id', + 'credential_type': 'AccessToken', + 'environment': 'login.example.com', + 'expires_on': "4600", + 'extended_expires_on': "4600", + 'home_account_id': "uid.utid", + 'realm': 'contoso', + 'secret': 'an access token', + 'target': 's2 s1 s3', + 'token_type': 'some type', + }, + self.cache._cache["AccessToken"].get( + 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s2 s1 s3') + ) + self.assertEqual( + { + 'client_id': 'my_client_id', + 'credential_type': 'RefreshToken', + 'environment': 'login.example.com', + 'home_account_id': "uid.utid", + 'last_modification_time': '1000', + 'secret': 'a refresh token', + 'target': 's2 s1 s3', + }, + self.cache._cache["RefreshToken"].get( + 'uid.utid-login.example.com-refreshtoken-my_client_id--s2 s1 s3') + ) + self.assertEqual( + { + 'home_account_id': "uid.utid", + 'environment': 'login.example.com', + 'realm': 'contoso', + 'local_account_id': "object1234", + 'username': "John Doe", + 'authority_type': "MSSTS", + }, + self.cache._cache["Account"].get('uid.utid-login.example.com-contoso') + ) + self.assertEqual( + { + 'credential_type': 'IdToken', + 'secret': id_token, + 'home_account_id': "uid.utid", + 'environment': 'login.example.com', + 'realm': 'contoso', + 'client_id': 'my_client_id', + }, + self.cache._cache["IdToken"].get( + 'uid.utid-login.example.com-idtoken-my_client_id-contoso-') + ) + self.assertEqual( + { + "client_id": "my_client_id", + 'environment': 'login.example.com', + }, + self.cache._cache.get("AppMetadata", {}).get( + "appmetadata-login.example.com-my_client_id") + ) + + def testAddByAdfs(self): + client_id = "my_client_id" + id_token = build_id_token(aud=client_id, upn="JaneDoe@example.com") + self.cache.add({ + "client_id": client_id, + "scope": ["s2", "s1", "s3"], # Not in particular order + "token_endpoint": "https://fs.msidlab8.com/adfs/oauth2/token", + "response": build_response( + uid=None, utid=None, # ADFS will provide no client_info + expires_in=3600, access_token="an access token", + id_token=id_token, refresh_token="a refresh token"), + }, now=1000) + self.assertEqual( + { + 'cached_at': "1000", + 'client_id': 'my_client_id', + 'credential_type': 'AccessToken', + 'environment': 'fs.msidlab8.com', + 'expires_on': "4600", + 'extended_expires_on': "4600", + 'home_account_id': "subject", + 'realm': 'adfs', + 'secret': 'an access token', + 'target': 's2 s1 s3', + 'token_type': 'some type', + }, + self.cache._cache["AccessToken"].get( + 'subject-fs.msidlab8.com-accesstoken-my_client_id-adfs-s2 s1 s3') + ) + self.assertEqual( + { + 'client_id': 'my_client_id', + 'credential_type': 'RefreshToken', + 'environment': 'fs.msidlab8.com', + 'home_account_id': "subject", + 'last_modification_time': "1000", + 'secret': 'a refresh token', + 'target': 's2 s1 s3', + }, + self.cache._cache["RefreshToken"].get( + 'subject-fs.msidlab8.com-refreshtoken-my_client_id--s2 s1 s3') + ) + self.assertEqual( + { + 'home_account_id': "subject", + 'environment': 'fs.msidlab8.com', + 'realm': 'adfs', + 'local_account_id': "subject", + 'username': "JaneDoe@example.com", + 'authority_type': "ADFS", + }, + self.cache._cache["Account"].get('subject-fs.msidlab8.com-adfs') + ) + self.assertEqual( + { + 'credential_type': 'IdToken', + 'secret': id_token, + 'home_account_id': "subject", + 'environment': 'fs.msidlab8.com', + 'realm': 'adfs', + 'client_id': 'my_client_id', + }, + self.cache._cache["IdToken"].get( + 'subject-fs.msidlab8.com-idtoken-my_client_id-adfs-') + ) + self.assertEqual( + { + "client_id": "my_client_id", + 'environment': 'fs.msidlab8.com', + }, + self.cache._cache.get("AppMetadata", {}).get( + "appmetadata-fs.msidlab8.com-my_client_id") + ) + + def test_key_id_is_also_recorded(self): + my_key_id = "some_key_id_123" + self.cache.add({ + "data": {"key_id": my_key_id}, + "client_id": "my_client_id", + "scope": ["s2", "s1", "s3"], # Not in particular order + "token_endpoint": "https://login.example.com/contoso/v2/token", + "response": build_response( + uid="uid", utid="utid", # client_info + expires_in=3600, access_token="an access token", + refresh_token="a refresh token"), + }, now=1000) + cached_key_id = self.cache._cache["AccessToken"].get( + 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s2 s1 s3', + {}).get("key_id") + self.assertEqual(my_key_id, cached_key_id, "AT should be bound to the key") + + def test_refresh_in_should_be_recorded_as_refresh_on(self): # Sounds weird. Yep. + self.cache.add({ + "client_id": "my_client_id", + "scope": ["s2", "s1", "s3"], # Not in particular order + "token_endpoint": "https://login.example.com/contoso/v2/token", + "response": build_response( + uid="uid", utid="utid", # client_info + expires_in=3600, refresh_in=1800, access_token="an access token", + ), #refresh_token="a refresh token"), + }, now=1000) + refresh_on = self.cache._cache["AccessToken"].get( + 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s2 s1 s3', + {}).get("refresh_on") + self.assertEqual("2800", refresh_on, "Should save refresh_on") + + def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self): + sample = { + 'client_id': 'my_client_id', + 'credential_type': 'RefreshToken', + 'environment': 'login.example.com', + 'home_account_id': "uid.utid", + 'secret': 'a refresh token', + 'target': 's2 s1 s3', + } + new_rt = "this is a new RT" + self.cache._cache["RefreshToken"] = {"wrong-key": sample} + self.cache.modify( + self.cache.CredentialType.REFRESH_TOKEN, sample, {"secret": new_rt}) + self.assertEqual( + dict(sample, secret=new_rt), + self.cache._cache["RefreshToken"].get( + 'uid.utid-login.example.com-refreshtoken-my_client_id--s2 s1 s3') + ) + + +class SerializableTokenCacheTestCase(TokenCacheTestCase): + # Run all inherited test methods, and have extra check in tearDown() + + def setUp(self): + self.cache = SerializableTokenCache() + self.cache.deserialize(""" + { + "AccessToken": { + "an-entry": { + "foo": "bar" + } + }, + "customized": "whatever" + } + """) + + def test_has_state_changed(self): + cache = SerializableTokenCache() + self.assertFalse(cache.has_state_changed) + cache.add({}) # An NO-OP add() still counts as a state change. Good enough. + self.assertTrue(cache.has_state_changed) + + def tearDown(self): + state = self.cache.serialize() + logger.debug("serialize() = %s", state) + # Now assert all extended content are kept intact + output = json.loads(state) + self.assertEqual(output.get("customized"), "whatever", + "Undefined cache keys and their values should be intact") + self.assertEqual( + output.get("AccessToken", {}).get("an-entry"), {"foo": "bar"}, + "Undefined token keys and their values should be intact") + diff --git a/tests/test_wstrust.py b/tests/test_wstrust.py new file mode 100644 index 00000000..1d909585 --- /dev/null +++ b/tests/test_wstrust.py @@ -0,0 +1,98 @@ +#------------------------------------------------------------------------------ +# +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files(the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions : +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +#------------------------------------------------------------------------------ + +try: + from xml.etree import cElementTree as ET +except ImportError: + from xml.etree import ElementTree as ET +import os + +from msal.wstrust_response import * + +from tests import unittest + + +class Test_WsTrustResponse(unittest.TestCase): + + def test_findall_content_with_comparison(self): + content = """ + + + foo + + """ + sample = ('' + + content + + '') + + # Demonstrating how XML-based parser won't give you the raw content as-is + element = ET.fromstring(sample).findall('{SAML:assertion}Assertion')[0] + assertion_via_xml_parser = ET.tostring(element) + self.assertNotEqual(content, assertion_via_xml_parser) + self.assertNotIn(b"", assertion_via_xml_parser) + + # The findall_content() helper, based on Regex, will return content as-is. + self.assertEqual([content], findall_content(sample, "Wrapper")) + + def test_parse_error(self): + error_response = ''' + + + http://www.w3.org/2005/08/addressing/soap/fault + + + 2013-07-30T00:32:21.989Z + 2013-07-30T00:37:21.989Z + + + + + + + s:Sender + + a:RequestFailed + + + + MSIS3127: The specified request failed. + + + + ''' + self.assertEqual({ + "reason": "MSIS3127: The specified request failed.", + "code": "a:RequestFailed", + }, parse_error(error_response)) + + def test_token_parsing_happy_path(self): + with open(os.path.join(os.path.dirname(__file__), "rst_response.xml")) as f: + rst_body = f.read() + result = parse_token_by_re(rst_body) + self.assertEqual(result.get("type"), SAML_TOKEN_TYPE_V1) + self.assertIn(b"