testutil.py 11.5 KB
Newer Older
1 2 3 4 5 6
"""
Utilities for writing third_party_auth tests.

Used by Django and non-Django tests; must not have Django deps.
"""

7
from contextlib import contextmanager
8
from django.conf import settings
9
from django.contrib.auth.models import User
10
from django.contrib.sites.models import Site
Pan Luo committed
11 12
from provider.oauth2.models import Client as OAuth2Client
from provider import constants
13
import django.test
14
from mako.template import Template
15
import mock
16
import os.path
17
from storages.backends.overwrite import OverwriteStorage
18

19 20 21 22 23 24
from third_party_auth.models import (
    OAuth2ProviderConfig,
    SAMLProviderConfig,
    SAMLConfiguration,
    LTIProviderConfig,
    cache as config_cache,
Pan Luo committed
25
    ProviderApiPermissions,
26
)
27

28 29
from third_party_auth.saml import get_saml_idp_class, SAMLIdentityProvider

30

31
AUTH_FEATURES_KEY = 'ENABLE_THIRD_PARTY_AUTH'
32
AUTH_FEATURE_ENABLED = AUTH_FEATURES_KEY in settings.FEATURES
33 34


35 36 37 38 39 40 41 42 43 44 45 46
def patch_mako_templates():
    """ Patch mako so the django test client can access template context """
    orig_render = Template.render_unicode

    def wrapped_render(*args, **kwargs):
        """ Render the template and send the context info to any listeners that want it """
        django.test.signals.template_rendered.send(sender=None, template=None, context=kwargs)
        return orig_render(*args, **kwargs)

    return mock.patch.multiple(Template, render_unicode=wrapped_render, render=wrapped_render)


47 48 49 50
class FakeDjangoSettings(object):
    """A fake for Django settings."""

    def __init__(self, mappings):
51
        """Initializes the fake from mappings dict."""
52 53 54 55
        for key, value in mappings.iteritems():
            setattr(self, key, value)


56 57
class ThirdPartyAuthTestMixin(object):
    """ Helper methods useful for testing third party auth functionality """
58

59 60 61 62 63 64 65 66 67 68 69
    def setUp(self, *args, **kwargs):
        # Django's FileSystemStorage will rename files if they already exist.
        # This storage backend overwrites files instead, which makes it easier
        # to make assertions about filenames.
        icon_image_field = OAuth2ProviderConfig._meta.get_field('icon_image')  # pylint: disable=protected-access
        patch = mock.patch.object(icon_image_field, 'storage', OverwriteStorage())
        patch.start()
        self.addCleanup(patch.stop)

        super(ThirdPartyAuthTestMixin, self).setUp(*args, **kwargs)

70
    def tearDown(self):
71 72 73 74 75 76 77 78 79 80 81
        config_cache.clear()
        super(ThirdPartyAuthTestMixin, self).tearDown()

    def enable_saml(self, **kwargs):
        """ Enable SAML support (via SAMLConfiguration, not for any particular provider) """
        kwargs.setdefault('enabled', True)
        SAMLConfiguration(**kwargs).save()

    @staticmethod
    def configure_oauth_provider(**kwargs):
        """ Update the settings for an OAuth2-based third party auth provider """
82
        kwargs.setdefault('provider_slug', kwargs['backend_name'])
83 84 85 86 87 88
        obj = OAuth2ProviderConfig(**kwargs)
        obj.save()
        return obj

    def configure_saml_provider(self, **kwargs):
        """ Update the settings for a SAML-based third party auth provider """
89 90 91 92
        self.assertTrue(
            SAMLConfiguration.is_enabled(Site.objects.get_current()),
            "SAML Provider Configuration only works if SAML is enabled."
        )
93 94 95 96
        obj = SAMLProviderConfig(**kwargs)
        obj.save()
        return obj

97 98 99 100 101 102 103
    @staticmethod
    def configure_lti_provider(**kwargs):
        """ Update the settings for a LTI Tool Consumer third party auth provider """
        obj = LTIProviderConfig(**kwargs)
        obj.save()
        return obj

104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
    @classmethod
    def configure_google_provider(cls, **kwargs):
        """ Update the settings for the Google third party auth provider/backend """
        kwargs.setdefault("name", "Google")
        kwargs.setdefault("backend_name", "google-oauth2")
        kwargs.setdefault("icon_class", "fa-google-plus")
        kwargs.setdefault("key", "test-fake-key.apps.googleusercontent.com")
        kwargs.setdefault("secret", "opensesame")
        return cls.configure_oauth_provider(**kwargs)

    @classmethod
    def configure_facebook_provider(cls, **kwargs):
        """ Update the settings for the Facebook third party auth provider/backend """
        kwargs.setdefault("name", "Facebook")
        kwargs.setdefault("backend_name", "facebook")
        kwargs.setdefault("icon_class", "fa-facebook")
        kwargs.setdefault("key", "FB_TEST_APP")
        kwargs.setdefault("secret", "opensesame")
        return cls.configure_oauth_provider(**kwargs)

    @classmethod
    def configure_linkedin_provider(cls, **kwargs):
        """ Update the settings for the LinkedIn third party auth provider/backend """
        kwargs.setdefault("name", "LinkedIn")
        kwargs.setdefault("backend_name", "linkedin-oauth2")
        kwargs.setdefault("icon_class", "fa-linkedin")
        kwargs.setdefault("key", "test")
        kwargs.setdefault("secret", "test")
        return cls.configure_oauth_provider(**kwargs)

134
    @classmethod
135 136 137 138 139 140 141 142 143 144
    def configure_azure_ad_provider(cls, **kwargs):
        """ Update the settings for the Azure AD third party auth provider/backend """
        kwargs.setdefault("name", "Azure AD")
        kwargs.setdefault("backend_name", "azuread-oauth2")
        kwargs.setdefault("icon_class", "fa-azuread")
        kwargs.setdefault("key", "test")
        kwargs.setdefault("secret", "test")
        return cls.configure_oauth_provider(**kwargs)

    @classmethod
145 146 147 148 149 150 151 152 153
    def configure_twitter_provider(cls, **kwargs):
        """ Update the settings for the Twitter third party auth provider/backend """
        kwargs.setdefault("name", "Twitter")
        kwargs.setdefault("backend_name", "twitter")
        kwargs.setdefault("icon_class", "fa-twitter")
        kwargs.setdefault("key", "test")
        kwargs.setdefault("secret", "test")
        return cls.configure_oauth_provider(**kwargs)

154
    @classmethod
155
    def configure_dummy_provider(cls, **kwargs):
156
        """ Update the settings for the Dummy third party auth provider/backend """
157 158 159 160 161
        kwargs.setdefault("name", "Dummy")
        kwargs.setdefault("backend_name", "dummy")
        return cls.configure_oauth_provider(**kwargs)

    @classmethod
162 163 164 165 166 167 168
    def verify_user_email(cls, email):
        """ Mark the user with the given email as verified """
        user = User.objects.get(email=email)
        user.is_active = True
        user.save()

    @staticmethod
Pan Luo committed
169 170 171 172 173 174 175 176 177 178
    def configure_oauth_client():
        """ Configure a oauth client for testing """
        return OAuth2Client.objects.create(client_type=constants.CONFIDENTIAL)

    @staticmethod
    def configure_api_permission(client, provider_id):
        """ Configure the client and provider_id pair. This will give the access to a client for that provider. """
        return ProviderApiPermissions.objects.create(client=client, provider_id=provider_id)

    @staticmethod
179 180 181 182 183
    def read_data_file(filename):
        """ Read the contents of a file in the data folder """
        with open(os.path.join(os.path.dirname(__file__), 'data', filename)) as f:
            return f.read()

184 185 186

class TestCase(ThirdPartyAuthTestMixin, django.test.TestCase):
    """Base class for auth test cases."""
187 188 189 190 191 192
    def setUp(self):
        super(TestCase, self).setUp()
        # Explicitly set a server name that is compatible with all our providers:
        # (The SAML lib we use doesn't like the default 'testserver' as a domain)
        self.client.defaults['SERVER_NAME'] = 'example.none'
        self.url_prefix = 'http://example.none'
193 194


195 196 197 198 199 200 201
class SAMLTestCase(TestCase):
    """
    Base class for SAML-related third_party_auth tests
    """
    @classmethod
    def _get_public_key(cls, key_name='saml_key'):
        """ Get a public key for use in the test. """
202
        return cls.read_data_file('{}.pub'.format(key_name))
203 204 205 206

    @classmethod
    def _get_private_key(cls, key_name='saml_key'):
        """ Get a private key for use in the test. """
207
        return cls.read_data_file('{}.key'.format(key_name))
208

209 210 211 212 213 214 215 216 217
    def enable_saml(self, **kwargs):
        """ Enable SAML support (via SAMLConfiguration, not for any particular provider) """
        if 'private_key' not in kwargs:
            kwargs['private_key'] = self._get_private_key()
        if 'public_key' not in kwargs:
            kwargs['public_key'] = self._get_public_key()
        kwargs.setdefault('entity_id', "https://saml.example.none")
        super(SAMLTestCase, self).enable_saml(**kwargs)

218 219 220 221 222 223 224 225 226 227
    @mock.patch('third_party_auth.saml.log')
    def test_get_saml_idp_class_with_fake_identifier(self, log_mock):
        error_mock = log_mock.error
        idp_class = get_saml_idp_class('fake_idp_class_option')
        error_mock.assert_called_once_with(
            '%s is not a valid SAMLIdentityProvider subclass; using SAMLIdentityProvider base class.',
            'fake_idp_class_option'
        )
        self.assertIs(idp_class, SAMLIdentityProvider)

228

229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
@contextmanager
def simulate_running_pipeline(pipeline_target, backend, email=None, fullname=None, username=None):
    """Simulate that a pipeline is currently running.

    You can use this context manager to test packages that rely on third party auth.

    This uses `mock.patch` to override some calls in `third_party_auth.pipeline`,
    so you will need to provide the "target" module *as it is imported*
    in the software under test.  For example, if `foo/bar.py` does this:

    >>> from third_party_auth import pipeline

    then you will need to do something like this:

    >>> with simulate_running_pipeline("foo.bar.pipeline", "google-oauth2"):
    >>>    bar.do_something_with_the_pipeline()

    If, on the other hand, `foo/bar.py` had done this:

    >>> import third_party_auth

    then you would use the target "foo.bar.third_party_auth.pipeline" instead.

    Arguments:

        pipeline_target (string): The path to `third_party_auth.pipeline` as it is imported
            in the software under test.

        backend (string): The name of the backend currently running, for example "google-oauth2".
            Note that this is NOT the same as the name of the *provider*.  See the Python
            social auth documentation for the names of the backends.

    Keyword Arguments:
        email (string): If provided, simulate that the current provider has
            included the user's email address (useful for filling in the registration form).

        fullname (string): If provided, simulate that the current provider has
            included the user's full name (useful for filling in the registration form).

        username (string): If provided, simulate that the pipeline has provided
            this suggested username.  This is something that the `third_party_auth`
            app generates itself and should be available by the time the user
            is authenticating with a third-party provider.

    Returns:
        None

    """
    pipeline_data = {
        "backend": backend,
        "kwargs": {
            "details": {}
        }
    }
    if email is not None:
        pipeline_data["kwargs"]["details"]["email"] = email
    if fullname is not None:
        pipeline_data["kwargs"]["details"]["fullname"] = fullname
    if username is not None:
        pipeline_data["kwargs"]["username"] = username

    pipeline_get = mock.patch("{pipeline}.get".format(pipeline=pipeline_target), spec=True)
    pipeline_running = mock.patch("{pipeline}.running".format(pipeline=pipeline_target), spec=True)

    mock_get = pipeline_get.start()
    mock_running = pipeline_running.start()

    mock_get.return_value = pipeline_data
    mock_running.return_value = True

    try:
        yield

    finally:
        pipeline_get.stop()
        pipeline_running.stop()