Skip to content

Commit

Permalink
Merge pull request #73 from AzureAD/release-0.3.0
Browse files Browse the repository at this point in the history
Release 0.3.0
  • Loading branch information
abhidnya13 authored Sep 1, 2020
2 parents 3df9da0 + 6d2efab commit b90d20e
Show file tree
Hide file tree
Showing 8 changed files with 302 additions and 19 deletions.
46 changes: 46 additions & 0 deletions .github/workflows/codeql.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: "Code Scanning - Action"

on:
push:
schedule:
- cron: '0 0 * * 0'

jobs:
CodeQL-Build:

strategy:
fail-fast: false


# CodeQL runs on ubuntu-latest, windows-latest, and macos-latest
runs-on: ubuntu-latest

steps:
- name: Checkout repository
uses: actions/checkout@v2

# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v1
# Override language selection by uncommenting this and choosing your languages
# with:
# languages: go, javascript, csharp, python, cpp, java

# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
# If this step fails, then you should remove it and run the build manually (see below).
- name: Autobuild
uses: github/codeql-action/autobuild@v1

# ℹ️ Command-line programs to run using the OS shell.
# 📚 https://git.io/JvXDl

# ✏️ If the Autobuild fails above, remove it and uncomment the following three lines
# and modify them (or add more) to build your code if your project
# uses a compiled language

#- run: |
# make bootstrap
# make release

- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v1
2 changes: 1 addition & 1 deletion msal_extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Provides auxiliary functionality to the `msal` package."""
__version__ = "0.2.2"
__version__ = "0.3.0"

import sys

Expand Down
105 changes: 93 additions & 12 deletions msal_extensions/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import abc
import os
import errno
import logging
try:
from pathlib import Path # Built-in in Python 3
except:
Expand All @@ -21,6 +22,9 @@
ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore


logger = logging.getLogger(__name__)


def _mkdir_p(path):
"""Creates a directory, and any necessary parents.
Expand All @@ -41,6 +45,20 @@ def _mkdir_p(path):
raise


# We do not aim to wrap every os-specific exception.
# Here we define only the most common one,
# otherwise caller would need to catch os-specific persistence exceptions.
class PersistenceNotFound(IOError): # Use IOError rather than OSError as base,
# because historically an IOError was bubbled up and expected.
# https://github.com/AzureAD/microsoft-authentication-extensions-for-python/blob/0.2.2/msal_extensions/token_cache.py#L38
# Now we want to maintain backward compatibility even when using Python 2.x
# It makes no difference in Python 3.3+ where IOError is an alias of OSError.
def __init__(
self,
err_no=errno.ENOENT, message="Persistence not found", location=None):
super(PersistenceNotFound, self).__init__(err_no, message, location)


class BasePersistence(ABC):
"""An abstract persistence defining the common interface of this family"""

Expand All @@ -55,12 +73,18 @@ def save(self, content):
@abc.abstractmethod
def load(self):
# type: () -> str
"""Load content from this persistence"""
"""Load content from this persistence.
Could raise PersistenceNotFound if no save() was called before.
"""
raise NotImplementedError

@abc.abstractmethod
def time_last_modified(self):
"""Get the last time when this persistence has been modified"""
"""Get the last time when this persistence has been modified.
Could raise PersistenceNotFound if no save() was called before.
"""
raise NotImplementedError

@abc.abstractmethod
Expand All @@ -87,11 +111,32 @@ def save(self, content):
def load(self):
# type: () -> str
"""Load content from this persistence"""
with open(self._location, 'r') as handle:
return handle.read()
try:
with open(self._location, 'r') as handle:
return handle.read()
except EnvironmentError as exp: # EnvironmentError in Py 2.7 works across platform
if exp.errno == errno.ENOENT:
raise PersistenceNotFound(
message=(
"Persistence not initialized. "
"You can recover by calling a save() first."),
location=self._location,
)
raise


def time_last_modified(self):
return os.path.getmtime(self._location)
try:
return os.path.getmtime(self._location)
except EnvironmentError as exp: # EnvironmentError in Py 2.7 works across platform
if exp.errno == errno.ENOENT:
raise PersistenceNotFound(
message=(
"Persistence not initialized. "
"You can recover by calling a save() first."),
location=self._location,
)
raise

def touch(self):
"""To touch this file-based persistence without writing content into it"""
Expand All @@ -115,13 +160,28 @@ def __init__(self, location, entropy=''):

def save(self, content):
# type: (str) -> None
data = self._dp_agent.protect(content)
with open(self._location, 'wb+') as handle:
handle.write(self._dp_agent.protect(content))
handle.write(data)

def load(self):
# type: () -> str
with open(self._location, 'rb') as handle:
return self._dp_agent.unprotect(handle.read())
try:
with open(self._location, 'rb') as handle:
data = handle.read()
return self._dp_agent.unprotect(data)
except EnvironmentError as exp: # EnvironmentError in Py 2.7 works across platform
if exp.errno == errno.ENOENT:
raise PersistenceNotFound(
message=(
"Persistence not initialized. "
"You can recover by calling a save() first."),
location=self._location,
)
logger.exception(
"DPAPI error likely caused by file content not previously encrypted. "
"App developer should migrate by calling save(plaintext) first.")
raise


class KeychainPersistence(BasePersistence):
Expand All @@ -136,9 +196,10 @@ def __init__(self, signal_location, service_name, account_name):
"""
if not (service_name and account_name): # It would hang on OSX
raise ValueError("service_name and account_name are required")
from .osx import Keychain # pylint: disable=import-outside-toplevel
from .osx import Keychain, KeychainError # pylint: disable=import-outside-toplevel
self._file_persistence = FilePersistence(signal_location) # Favor composition
self._Keychain = Keychain # pylint: disable=invalid-name
self._KeychainError = KeychainError # pylint: disable=invalid-name
self._service_name = service_name
self._account_name = account_name

Expand All @@ -150,8 +211,21 @@ def save(self, content):

def load(self):
with self._Keychain() as locker:
return locker.get_generic_password(
self._service_name, self._account_name)
try:
return locker.get_generic_password(
self._service_name, self._account_name)
except self._KeychainError as ex:
if ex.exit_status == self._KeychainError.ITEM_NOT_FOUND:
# This happens when a load() is called before a save().
# We map it into cross-platform error for unified catching.
raise PersistenceNotFound(
location="Service:{} Account:{}".format(
self._service_name, self._account_name),
message=(
"Keychain persistence not initialized. "
"You can recover by call a save() first."),
)
raise # We do not intend to hide any other underlying exceptions

def time_last_modified(self):
return self._file_persistence.time_last_modified()
Expand Down Expand Up @@ -188,7 +262,14 @@ def save(self, content):
self._file_persistence.touch() # For time_last_modified()

def load(self):
return self._agent.load()
data = self._agent.load()
if data is None:
# Lower level libsecret would return None when found nothing. Here
# in persistence layer, we convert it to a unified error for consistence.
raise PersistenceNotFound(message=(
"Keyring persistence not initialized. "
"You can recover by call a save() first."))
return data

def time_last_modified(self):
return self._file_persistence.time_last_modified()
Expand Down
11 changes: 5 additions & 6 deletions msal_extensions/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
import os
import warnings
import time
import errno
import logging

import msal

from .cache_lock import CrossPlatLock
from .persistence import (
_mkdir_p, FilePersistence,
_mkdir_p, PersistenceNotFound, FilePersistence,
FilePersistenceWithDataProtection, KeychainPersistence)


Expand All @@ -35,10 +34,10 @@ def _reload_if_necessary(self):
if self._last_sync < self._persistence.time_last_modified():
self.deserialize(self._persistence.load())
self._last_sync = time.time()
except IOError as exp:
if exp.errno != errno.ENOENT:
raise
# Otherwise, from cache's perspective, a nonexistent file is a NO-OP
except PersistenceNotFound:
# From cache's perspective, a nonexistent persistence is a NO-OP.
pass
# However, existing data unable to be decrypted will still be bubbled up.

def modify(self, credential_type, old_entry, new_key_value_pairs=None):
with CrossPlatLock(self._lock_location):
Expand Down
43 changes: 43 additions & 0 deletions tests/cache_file_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Usage: cache_file_generator.py cache_file_path sleep_interval
This is a console application which is to be used for cross-platform lock performance testing.
The app will acquire lock for the cache file, log the process id and then release the lock.
It takes in two arguments - cache file path and the sleep interval.
The cache file path is the path of cache file.
The sleep interval is the time in seconds for which the lock is held by a process.
"""

import logging
import os
import sys
import time

from portalocker import exceptions

from msal_extensions import FilePersistence, CrossPlatLock


def _acquire_lock_and_write_to_cache(cache_location, sleep_interval):
cache_accessor = FilePersistence(cache_location)
lock_file_path = cache_accessor.get_location() + ".lockfile"
try:
with CrossPlatLock(lock_file_path):
data = cache_accessor.load()
if data is None:
data = ""
data += "< " + str(os.getpid()) + "\n"
time.sleep(sleep_interval)
data += "> " + str(os.getpid()) + "\n"
cache_accessor.save(data)
except exceptions.LockException as e:
logging.warning("Unable to acquire lock %s", e)


if __name__ == "__main__":
if len(sys.argv) < 3:
print(__doc__)
sys.exit(0)
_acquire_lock_and_write_to_cache(sys.argv[1], float(sys.argv[2]))

5 changes: 5 additions & 0 deletions tests/test_agnostic_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,8 @@ def test_current_platform_cache_roundtrip_with_alias_class(temp_location):
def test_persisted_token_cache(temp_location):
_test_token_cache_roundtrip(PersistedTokenCache(FilePersistence(temp_location)))

def test_file_not_found_error_is_not_raised():
persistence = FilePersistence('non_existing_file')
cache = PersistedTokenCache(persistence=persistence)
# An exception raised here will fail the test case as it is supposed to be a NO-OP
cache.find('')
75 changes: 75 additions & 0 deletions tests/test_cache_lock_file_perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import multiprocessing
import os
import shutil
import tempfile

import pytest

from cache_file_generator import _acquire_lock_and_write_to_cache


@pytest.fixture
def temp_location():
test_folder = tempfile.mkdtemp(prefix="test_persistence_roundtrip")
yield os.path.join(test_folder, 'persistence.bin')
shutil.rmtree(test_folder, ignore_errors=True)


def _validate_result_in_cache(cache_location):
with open(cache_location) as handle:
data = handle.read()
prev_process_id = None
count = 0
for line in data.split("\n"):
if line:
count += 1
tag, process_id = line.split(" ")
if prev_process_id is not None:
assert process_id == prev_process_id, "Process overlap found"
assert tag == '>', "Process overlap_found"
prev_process_id = None
else:
assert tag == '<', "Opening bracket not found"
prev_process_id = process_id
return count


def _run_multiple_processes(no_of_processes, cache_location, sleep_interval):
open(cache_location, "w+")
processes = []
for i in range(no_of_processes):
process = multiprocessing.Process(
target=_acquire_lock_and_write_to_cache,
args=(cache_location, sleep_interval))
processes.append(process)

for process in processes:
process.start()

for process in processes:
process.join()


def test_lock_for_normal_workload(temp_location):
num_of_processes = 4
sleep_interval = 0.1
_run_multiple_processes(num_of_processes, temp_location, sleep_interval)
count = _validate_result_in_cache(temp_location)
assert count == num_of_processes * 2, "Should not observe starvation"


def test_lock_for_high_workload(temp_location):
num_of_processes = 80
sleep_interval = 0
_run_multiple_processes(num_of_processes, temp_location, sleep_interval)
count = _validate_result_in_cache(temp_location)
assert count <= num_of_processes * 2, "Starvation or not, we should not observe garbled payload"


def test_lock_for_timeout(temp_location):
num_of_processes = 10
sleep_interval = 1
_run_multiple_processes(num_of_processes, temp_location, sleep_interval)
count = _validate_result_in_cache(temp_location)
assert count < num_of_processes * 2, "Should observe starvation"

Loading

0 comments on commit b90d20e

Please sign in to comment.