Commit b6450384 by Matt Drayer

Add SAML metadata refresh control flag

mattdrayer: Change model fieldname, revise code, fix bad tests.
parent 4ef009cf
......@@ -25,14 +25,19 @@ class Command(BaseCommand):
log = logging.getLogger('third_party_auth.tasks')
log.propagate = False
log.addHandler(log_handler)
num_changed, num_failed, num_total, failure_messages = fetch_saml_metadata()
total, skipped, attempted, updated, failed, failure_messages = fetch_saml_metadata()
self.stdout.write(
"\nDone. Fetched {num_total} total. {num_changed} were updated and {num_failed} failed.\n".format(
num_changed=num_changed, num_failed=num_failed, num_total=num_total
"\nDone."
"\n{total} provider(s) found in database."
"\n{skipped} skipped and {attempted} attempted."
"\n{updated} updated and {failed} failed.\n".format(
total=total,
skipped=skipped, attempted=attempted,
updated=updated, failed=failed,
)
)
if num_failed > 0:
if failed > 0:
raise CommandError(
"Command finished with the following exceptions:\n\n{failures}".format(
failures="\n\n".join(failure_messages)
......
......@@ -92,10 +92,9 @@ class TestSAMLCommand(TestCase):
Test that management command completes without errors and logs correct information when no
saml configurations are enabled/present.
"""
# Capture command output log for testing.
expected = "\nDone.\n1 provider(s) found in database.\n1 skipped and 0 attempted.\n0 updated and 0 failed.\n"
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn('Done. Fetched 0 total. 0 were updated and 0 failed.', self.stdout.getvalue())
self.assertIn(expected, self.stdout.getvalue())
@mock.patch("requests.get", mock_get())
def test_fetch_saml_metadata(self):
......@@ -106,10 +105,9 @@ class TestSAMLCommand(TestCase):
# Create enabled configurations
self.__create_saml_configurations__()
# Capture command output log for testing.
expected = "\nDone.\n1 provider(s) found in database.\n0 skipped and 1 attempted.\n1 updated and 0 failed.\n"
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn('Done. Fetched 1 total. 1 were updated and 0 failed.', self.stdout.getvalue())
self.assertIn(expected, self.stdout.getvalue())
@mock.patch("requests.get", mock_get(status_code=404))
def test_fetch_saml_metadata_failure(self):
......@@ -120,11 +118,11 @@ class TestSAMLCommand(TestCase):
# Create enabled configurations
self.__create_saml_configurations__()
expected = "\nDone.\n1 provider(s) found in database.\n0 skipped and 1 attempted.\n0 updated and 1 failed.\n"
with self.assertRaisesRegexp(CommandError, r"HTTPError: 404 Client Error"):
# Capture command output log for testing.
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
self.assertIn(expected, self.stdout.getvalue())
@mock.patch("requests.get", mock_get(status_code=200))
def test_fetch_multiple_providers_data(self):
......@@ -162,11 +160,31 @@ class TestSAMLCommand(TestCase):
}
)
expected = '\n3 provider(s) found in database.\n0 skipped and 3 attempted.\n2 updated and 1 failed.\n'
with self.assertRaisesRegexp(CommandError, r"MetadataParseError: Can't find EntityDescriptor for entityID"):
# Capture command output log for testing.
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn(expected, self.stdout.getvalue())
self.assertIn('Done. Fetched 3 total. 2 were updated and 1 failed.', self.stdout.getvalue())
# Now add a fourth configuration, and indicate that it should not be included in the update
self.__create_saml_configurations__(
saml_config={
"site__domain": "fourth.testserver.fake",
},
saml_provider_config={
"site__domain": "fourth.testserver.fake",
"idp_slug": "fourth-test-shib",
"automatic_refresh_enabled": False,
# Note: This invalid entity id will not be present in the refresh set
"entity_id": "https://idp.testshib.org/idp/fourth-shibboleth",
"metadata_source": "https://www.testshib.org/metadata/fourth/testshib-providers.xml",
}
)
# Four configurations -- one will be skipped and three attempted, with similar results.
expected = '\nDone.\n4 provider(s) found in database.\n1 skipped and 3 attempted.\n0 updated and 1 failed.\n'
with self.assertRaisesRegexp(CommandError, r"MetadataParseError: Can't find EntityDescriptor for entityID"):
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn(expected, self.stdout.getvalue())
@mock.patch("requests.get")
def test_saml_request_exceptions(self, mocked_get):
......@@ -178,27 +196,23 @@ class TestSAMLCommand(TestCase):
mocked_get.side_effect = exceptions.SSLError
expected = "\nDone.\n1 provider(s) found in database.\n0 skipped and 1 attempted.\n0 updated and 1 failed.\n"
with self.assertRaisesRegexp(CommandError, "SSLError:"):
# Capture command output log for testing.
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
self.assertIn(expected, self.stdout.getvalue())
mocked_get.side_effect = exceptions.ConnectionError
with self.assertRaisesRegexp(CommandError, "ConnectionError:"):
# Capture command output log for testing.
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
self.assertIn(expected, self.stdout.getvalue())
mocked_get.side_effect = exceptions.HTTPError
with self.assertRaisesRegexp(CommandError, "HTTPError:"):
# Capture command output log for testing.
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
self.assertIn(expected, self.stdout.getvalue())
@mock.patch("requests.get", mock_get(status_code=200))
def test_saml_parse_exceptions(self):
......@@ -219,11 +233,11 @@ class TestSAMLCommand(TestCase):
}
)
expected = "\nDone.\n2 provider(s) found in database.\n1 skipped and 1 attempted.\n0 updated and 1 failed.\n"
with self.assertRaisesRegexp(CommandError, "MetadataParseError: Can't find EntityDescriptor for entityID"):
# Capture command output log for testing.
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
self.assertIn(expected, self.stdout.getvalue())
@mock.patch("requests.get")
def test_xml_parse_exceptions(self, mocked_get):
......@@ -239,8 +253,8 @@ class TestSAMLCommand(TestCase):
# create enabled configuration
self.__create_saml_configurations__()
expected = "\nDone.\n1 provider(s) found in database.\n0 skipped and 1 attempted.\n0 updated and 1 failed.\n"
with self.assertRaisesRegexp(CommandError, "XMLSyntaxError:"):
# Capture command output log for testing.
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
self.assertIn(expected, self.stdout.getvalue())
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('third_party_auth', '0005_add_site_field'),
]
operations = [
migrations.AddField(
model_name='samlproviderconfig',
name='automatic_refresh_enabled',
field=models.BooleanField(default=True, help_text=b"When checked, the SAML provider's metadata will be included in the automatic refresh job, if configured.", verbose_name=b'Enable automatic metadata refresh'),
),
]
......@@ -347,6 +347,9 @@ class SAMLProviderConfig(ProviderConfig):
attr_email = models.CharField(
max_length=128, blank=True, verbose_name="Email Attribute",
help_text="URN of SAML attribute containing the user's email address[es]. Leave blank for default.")
automatic_refresh_enabled = models.BooleanField(
default=True, verbose_name="Enable automatic metadata refresh",
help_text="When checked, the SAML provider's metadata will be included in the automatic refresh job, if configured.")
debug_mode = models.BooleanField(
default=False, verbose_name="Debug Mode",
help_text=(
......
......@@ -33,27 +33,42 @@ def fetch_saml_metadata():
It's OK to run this whether or not SAML is enabled.
Return value:
tuple(num_changed, num_failed, num_total, failure_messages)
num_changed: Number of providers that are either new or whose metadata has changed
tuple(num_skipped, num_attempted, num_updated, num_failed, failure_messages)
num_total: Total number of providers found in the database
num_skipped: Number of providers skipped for various reasons (see L52)
num_attempted: Number of providers whose metadata was fetched
num_updated: Number of providers that are either new or whose metadata has changed
num_failed: Number of providers that could not be updated
num_total: Total number of providers whose metadata was fetched
failure_messages: List of error messages for the providers that could not be updated
"""
num_changed = 0
failure_messages = []
# First make a list of all the metadata XML URLs:
saml_providers = SAMLProviderConfig.key_values('idp_slug', flat=True)
num_total = len(saml_providers)
num_skipped = 0
url_map = {}
for idp_slug in SAMLProviderConfig.key_values('idp_slug', flat=True):
for idp_slug in saml_providers:
config = SAMLProviderConfig.current(idp_slug)
if not config.enabled or not SAMLConfiguration.is_enabled(config.site):
# Skip SAML provider configurations which do not qualify for fetching
if any([
not config.enabled,
not config.automatic_refresh_enabled,
not SAMLConfiguration.is_enabled(config.site)
]):
num_skipped += 1
continue
url = config.metadata_source
if url not in url_map:
url_map[url] = []
if config.entity_id not in url_map[url]:
url_map[url].append(config.entity_id)
# Now fetch the metadata:
# Now attempt to fetch the metadata for the remaining SAML providers:
num_attempted = len(url_map)
num_updated = 0
failure_messages = [] # We return the length of this array for num_failed
for url, entity_ids in url_map.items():
try:
log.info("Fetching %s", url)
......@@ -75,7 +90,7 @@ def fetch_saml_metadata():
changed = _update_data(entity_id, public_key, sso_url, expires_at)
if changed:
log.info(u"→ Created new record for SAMLProviderData")
num_changed += 1
num_updated += 1
else:
log.info(u"→ Updated existing SAMLProviderData. Nothing has changed.")
except (exceptions.SSLError, exceptions.HTTPError, exceptions.RequestException, MetadataParseError) as error:
......@@ -109,7 +124,8 @@ def fetch_saml_metadata():
)
)
return num_changed, len(failure_messages), len(url_map), failure_messages
# Return counts for total, skipped, attempted, updated, and failed, along with any failure messages
return num_total, num_skipped, num_attempted, num_updated, len(failure_messages), failure_messages
def _parse_metadata_xml(xml, entity_id):
......
......@@ -149,11 +149,13 @@ class TestShibIntegrationTest(IntegrationTestMixin, testutil.SAMLTestCase):
kwargs.setdefault('attr_email', 'urn:oid:1.3.6.1.4.1.5923.1.1.1.6') # eduPersonPrincipalName
self.configure_saml_provider(**kwargs)
self.assertTrue(httpretty.is_enabled())
num_changed, num_failed, num_total, failure_messages = fetch_saml_metadata()
num_total, num_skipped, num_attempted, num_updated, num_failed, failure_messages = fetch_saml_metadata()
self.assertEqual(num_total, 1)
self.assertEqual(num_skipped, 0)
self.assertEqual(num_attempted, 1)
self.assertEqual(num_updated, 1)
self.assertEqual(num_failed, 0)
self.assertEqual(len(failure_messages), 0)
self.assertEqual(num_changed, 1)
self.assertEqual(num_total, 1)
def _freeze_time(self, timestamp):
""" Mock the current time for SAML, so we can replay canned requests/responses """
......@@ -177,12 +179,14 @@ class TestShibIntegrationTest(IntegrationTestMixin, testutil.SAMLTestCase):
if fetch_metadata:
self.assertTrue(httpretty.is_enabled())
num_changed, num_failed, num_total, failure_messages = fetch_saml_metadata()
num_total, num_skipped, num_attempted, num_updated, num_failed, failure_messages = fetch_saml_metadata()
if assert_metadata_updates:
self.assertEqual(num_total, 1)
self.assertEqual(num_skipped, 0)
self.assertEqual(num_attempted, 1)
self.assertEqual(num_updated, 1)
self.assertEqual(num_failed, 0)
self.assertEqual(len(failure_messages), 0)
self.assertEqual(num_changed, 1)
self.assertEqual(num_total, 1)
def do_provider_login(self, provider_redirect_url):
""" Mocked: the user logs in to TestShib and then gets redirected back """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment