Commit befe3052 by Saleem Latif

Update saml --pull command to raise error when it fails.

parent a15be622
...@@ -25,9 +25,16 @@ class Command(BaseCommand): ...@@ -25,9 +25,16 @@ class Command(BaseCommand):
log = logging.getLogger('third_party_auth.tasks') log = logging.getLogger('third_party_auth.tasks')
log.propagate = False log.propagate = False
log.addHandler(log_handler) log.addHandler(log_handler)
num_changed, num_failed, num_total = fetch_saml_metadata() num_changed, num_failed, num_total, failure_messages = fetch_saml_metadata()
self.stdout.write( self.stdout.write(
"\nDone. Fetched {num_total} total. {num_changed} were updated and {num_failed} failed.\n".format( "\nDone. Fetched {num_total} total. {num_changed} were updated and {num_failed} failed.\n".format(
num_changed=num_changed, num_failed=num_failed, num_total=num_total num_changed=num_changed, num_failed=num_failed, num_total=num_total
) )
) )
if num_failed > 0:
raise CommandError(
"Command finished with the following exceptions:\n\n{failures}".format(
failures="\n\n".join(failure_messages)
)
)
...@@ -12,6 +12,7 @@ from django.core.management.base import CommandError ...@@ -12,6 +12,7 @@ from django.core.management.base import CommandError
from django.conf import settings from django.conf import settings
from django.utils.six import StringIO from django.utils.six import StringIO
from requests import exceptions
from requests.models import Response from requests.models import Response
from third_party_auth.tests.factories import SAMLConfigurationFactory, SAMLProviderConfigFactory from third_party_auth.tests.factories import SAMLConfigurationFactory, SAMLProviderConfigFactory
...@@ -119,10 +120,11 @@ class TestSAMLCommand(TestCase): ...@@ -119,10 +120,11 @@ class TestSAMLCommand(TestCase):
# Create enabled configurations # Create enabled configurations
self.__create_saml_configurations__() self.__create_saml_configurations__()
# Capture command output log for testing. with self.assertRaisesRegexp(CommandError, r"HTTPError: 404 Client Error"):
call_command("saml", pull=True, stdout=self.stdout) # Capture command output log for testing.
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue()) self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
@mock.patch("requests.get", mock_get(status_code=200)) @mock.patch("requests.get", mock_get(status_code=200))
def test_fetch_multiple_providers_data(self): def test_fetch_multiple_providers_data(self):
...@@ -160,7 +162,85 @@ class TestSAMLCommand(TestCase): ...@@ -160,7 +162,85 @@ class TestSAMLCommand(TestCase):
} }
) )
# Capture command output log for testing. with self.assertRaisesRegexp(CommandError, r"MetadataParseError: Can't find EntityDescriptor for entityID"):
call_command("saml", pull=True, stdout=self.stdout) # Capture command output log for testing.
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn('Done. Fetched 3 total. 2 were updated and 1 failed.', self.stdout.getvalue())
@mock.patch("requests.get")
def test_saml_request_exceptions(self, mocked_get):
"""
Test that management command errors out in case of fatal exceptions instead of failing silently.
"""
# Create enabled configurations
self.__create_saml_configurations__()
mocked_get.side_effect = exceptions.SSLError
with self.assertRaisesRegexp(CommandError, "SSLError:"):
# Capture command output log for testing.
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
mocked_get.side_effect = exceptions.ConnectionError
with self.assertRaisesRegexp(CommandError, "ConnectionError:"):
# Capture command output log for testing.
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
mocked_get.side_effect = exceptions.HTTPError
with self.assertRaisesRegexp(CommandError, "HTTPError:"):
# Capture command output log for testing.
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
@mock.patch("requests.get", mock_get(status_code=200))
def test_saml_parse_exceptions(self):
"""
Test that management command errors out in case of fatal exceptions instead of failing silently.
"""
# Create enabled configurations, this configuration will raise MetadataParseError.
self.__create_saml_configurations__(
saml_config={
"site__domain": "third.testserver.fake",
},
saml_provider_config={
"site__domain": "third.testserver.fake",
"idp_slug": "third-test-shib",
# Note: This entity id will not be present in returned response and will cause failed update.
"entity_id": "https://idp.testshib.org/idp/non-existent-shibboleth",
"metadata_source": "https://www.testshib.org/metadata/third/testshib-providers.xml",
}
)
with self.assertRaisesRegexp(CommandError, "MetadataParseError: Can't find EntityDescriptor for entityID"):
# Capture command output log for testing.
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
@mock.patch("requests.get")
def test_xml_parse_exceptions(self, mocked_get):
"""
Test that management command errors out in case of fatal exceptions instead of failing silently.
"""
response = Response()
response._content = "" # pylint: disable=protected-access
response.status_code = 200
mocked_get.return_value = response
# create enabled configuration
self.__create_saml_configurations__()
with self.assertRaisesRegexp(CommandError, "XMLSyntaxError:"):
# Capture command output log for testing.
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn('Done. Fetched 3 total. 2 were updated and 1 failed.', self.stdout.getvalue()) self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
...@@ -10,6 +10,7 @@ import pytz ...@@ -10,6 +10,7 @@ import pytz
import logging import logging
from lxml import etree from lxml import etree
import requests import requests
from requests import exceptions
from onelogin.saml2.utils import OneLogin_Saml2_Utils from onelogin.saml2.utils import OneLogin_Saml2_Utils
from third_party_auth.models import SAMLConfiguration, SAMLProviderConfig, SAMLProviderData from third_party_auth.models import SAMLConfiguration, SAMLProviderConfig, SAMLProviderData
...@@ -32,12 +33,14 @@ def fetch_saml_metadata(): ...@@ -32,12 +33,14 @@ def fetch_saml_metadata():
It's OK to run this whether or not SAML is enabled. It's OK to run this whether or not SAML is enabled.
Return value: Return value:
tuple(num_changed, num_failed, num_total) tuple(num_changed, num_failed, num_total, failure_messages)
num_changed: Number of providers that are either new or whose metadata has changed num_changed: Number of providers that are either new or whose metadata has changed
num_failed: Number of providers that could not be updated num_failed: Number of providers that could not be updated
num_total: Total number of providers whose metadata was fetched num_total: Total number of providers whose metadata was fetched
failure_messages: List of error messages for the providers that could not be updated
""" """
num_changed, num_failed = 0, 0 num_changed = 0
failure_messages = []
# First make a list of all the metadata XML URLs: # First make a list of all the metadata XML URLs:
url_map = {} url_map = {}
...@@ -75,10 +78,38 @@ def fetch_saml_metadata(): ...@@ -75,10 +78,38 @@ def fetch_saml_metadata():
num_changed += 1 num_changed += 1
else: else:
log.info(u"→ Updated existing SAMLProviderData. Nothing has changed.") log.info(u"→ Updated existing SAMLProviderData. Nothing has changed.")
except Exception as err: # pylint: disable=broad-except except (exceptions.SSLError, exceptions.HTTPError, exceptions.RequestException, MetadataParseError) as error:
log.exception(err.message) # Catch and process exception in case of errors during fetching and processing saml metadata.
num_failed += 1 # Here is a description of each exception.
return (num_changed, num_failed, len(url_map)) # SSLError is raised in case of errors caused by SSL (e.g. SSL cer verification failure etc.)
# HTTPError is raised in case of unexpected status code (e.g. 500 error etc.)
# RequestException is the base exception for any request related error that "requests" lib raises.
# MetadataParseError is raised if there is error in the fetched meta data (e.g. missing @entityID etc.)
log.exception(error.message)
failure_messages.append(
"{error_type}: {error_message}\nMetadata Source: {url}\nEntity IDs: \n{entity_ids}.".format(
error_type=type(error).__name__,
error_message=error.message,
url=url,
entity_ids="\n".join(
["\t{}: {}".format(count, item) for count, item in enumerate(entity_ids, start=1)],
)
)
)
except etree.XMLSyntaxError as error:
log.exception(error.message)
failure_messages.append(
"XMLSyntaxError: {error_message}\nMetadata Source: {url}\nEntity IDs: \n{entity_ids}.".format(
error_message=str(error.error_log),
url=url,
entity_ids="\n".join(
["\t{}: {}".format(count, item) for count, item in enumerate(entity_ids, start=1)],
)
)
)
return num_changed, len(failure_messages), len(url_map), failure_messages
def _parse_metadata_xml(xml, entity_id): def _parse_metadata_xml(xml, entity_id):
......
...@@ -149,8 +149,9 @@ class TestShibIntegrationTest(IntegrationTestMixin, testutil.SAMLTestCase): ...@@ -149,8 +149,9 @@ class TestShibIntegrationTest(IntegrationTestMixin, testutil.SAMLTestCase):
kwargs.setdefault('attr_email', 'urn:oid:1.3.6.1.4.1.5923.1.1.1.6') # eduPersonPrincipalName kwargs.setdefault('attr_email', 'urn:oid:1.3.6.1.4.1.5923.1.1.1.6') # eduPersonPrincipalName
self.configure_saml_provider(**kwargs) self.configure_saml_provider(**kwargs)
self.assertTrue(httpretty.is_enabled()) self.assertTrue(httpretty.is_enabled())
num_changed, num_failed, num_total = fetch_saml_metadata() num_changed, num_failed, num_total, failure_messages = fetch_saml_metadata()
self.assertEqual(num_failed, 0) self.assertEqual(num_failed, 0)
self.assertEqual(len(failure_messages), 0)
self.assertEqual(num_changed, 1) self.assertEqual(num_changed, 1)
self.assertEqual(num_total, 1) self.assertEqual(num_total, 1)
...@@ -176,9 +177,10 @@ class TestShibIntegrationTest(IntegrationTestMixin, testutil.SAMLTestCase): ...@@ -176,9 +177,10 @@ class TestShibIntegrationTest(IntegrationTestMixin, testutil.SAMLTestCase):
if fetch_metadata: if fetch_metadata:
self.assertTrue(httpretty.is_enabled()) self.assertTrue(httpretty.is_enabled())
num_changed, num_failed, num_total = fetch_saml_metadata() num_changed, num_failed, num_total, failure_messages = fetch_saml_metadata()
if assert_metadata_updates: if assert_metadata_updates:
self.assertEqual(num_failed, 0) self.assertEqual(num_failed, 0)
self.assertEqual(len(failure_messages), 0)
self.assertEqual(num_changed, 1) self.assertEqual(num_changed, 1)
self.assertEqual(num_total, 1) self.assertEqual(num_total, 1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment