Commit befe3052 by Saleem Latif

Update saml --pull command to raise error when it fails.

parent a15be622
......@@ -25,9 +25,16 @@ class Command(BaseCommand):
log = logging.getLogger('third_party_auth.tasks')
log.propagate = False
log.addHandler(log_handler)
num_changed, num_failed, num_total = fetch_saml_metadata()
num_changed, num_failed, num_total, failure_messages = fetch_saml_metadata()
self.stdout.write(
"\nDone. Fetched {num_total} total. {num_changed} were updated and {num_failed} failed.\n".format(
num_changed=num_changed, num_failed=num_failed, num_total=num_total
)
)
if num_failed > 0:
raise CommandError(
"Command finished with the following exceptions:\n\n{failures}".format(
failures="\n\n".join(failure_messages)
)
)
......@@ -12,6 +12,7 @@ from django.core.management.base import CommandError
from django.conf import settings
from django.utils.six import StringIO
from requests import exceptions
from requests.models import Response
from third_party_auth.tests.factories import SAMLConfigurationFactory, SAMLProviderConfigFactory
......@@ -119,10 +120,11 @@ class TestSAMLCommand(TestCase):
# Create enabled configurations
self.__create_saml_configurations__()
# Capture command output log for testing.
call_command("saml", pull=True, stdout=self.stdout)
with self.assertRaisesRegexp(CommandError, r"HTTPError: 404 Client Error"):
# Capture command output log for testing.
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
@mock.patch("requests.get", mock_get(status_code=200))
def test_fetch_multiple_providers_data(self):
......@@ -160,7 +162,85 @@ class TestSAMLCommand(TestCase):
}
)
# Capture command output log for testing.
call_command("saml", pull=True, stdout=self.stdout)
with self.assertRaisesRegexp(CommandError, r"MetadataParseError: Can't find EntityDescriptor for entityID"):
# Capture command output log for testing.
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn('Done. Fetched 3 total. 2 were updated and 1 failed.', self.stdout.getvalue())
@mock.patch("requests.get")
def test_saml_request_exceptions(self, mocked_get):
"""
Test that management command errors out in case of fatal exceptions instead of failing silently.
"""
# Create enabled configurations
self.__create_saml_configurations__()
mocked_get.side_effect = exceptions.SSLError
with self.assertRaisesRegexp(CommandError, "SSLError:"):
# Capture command output log for testing.
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
mocked_get.side_effect = exceptions.ConnectionError
with self.assertRaisesRegexp(CommandError, "ConnectionError:"):
# Capture command output log for testing.
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
mocked_get.side_effect = exceptions.HTTPError
with self.assertRaisesRegexp(CommandError, "HTTPError:"):
# Capture command output log for testing.
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
@mock.patch("requests.get", mock_get(status_code=200))
def test_saml_parse_exceptions(self):
"""
Test that management command errors out in case of fatal exceptions instead of failing silently.
"""
# Create enabled configurations, this configuration will raise MetadataParseError.
self.__create_saml_configurations__(
saml_config={
"site__domain": "third.testserver.fake",
},
saml_provider_config={
"site__domain": "third.testserver.fake",
"idp_slug": "third-test-shib",
# Note: This entity id will not be present in returned response and will cause failed update.
"entity_id": "https://idp.testshib.org/idp/non-existent-shibboleth",
"metadata_source": "https://www.testshib.org/metadata/third/testshib-providers.xml",
}
)
with self.assertRaisesRegexp(CommandError, "MetadataParseError: Can't find EntityDescriptor for entityID"):
# Capture command output log for testing.
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
@mock.patch("requests.get")
def test_xml_parse_exceptions(self, mocked_get):
"""
Test that management command errors out in case of fatal exceptions instead of failing silently.
"""
response = Response()
response._content = "" # pylint: disable=protected-access
response.status_code = 200
mocked_get.return_value = response
# create enabled configuration
self.__create_saml_configurations__()
with self.assertRaisesRegexp(CommandError, "XMLSyntaxError:"):
# Capture command output log for testing.
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn('Done. Fetched 3 total. 2 were updated and 1 failed.', self.stdout.getvalue())
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
......@@ -10,6 +10,7 @@ import pytz
import logging
from lxml import etree
import requests
from requests import exceptions
from onelogin.saml2.utils import OneLogin_Saml2_Utils
from third_party_auth.models import SAMLConfiguration, SAMLProviderConfig, SAMLProviderData
......@@ -32,12 +33,14 @@ def fetch_saml_metadata():
It's OK to run this whether or not SAML is enabled.
Return value:
tuple(num_changed, num_failed, num_total)
tuple(num_changed, num_failed, num_total, failure_messages)
num_changed: Number of providers that are either new or whose metadata has changed
num_failed: Number of providers that could not be updated
num_total: Total number of providers whose metadata was fetched
failure_messages: List of error messages for the providers that could not be updated
"""
num_changed, num_failed = 0, 0
num_changed = 0
failure_messages = []
# First make a list of all the metadata XML URLs:
url_map = {}
......@@ -75,10 +78,38 @@ def fetch_saml_metadata():
num_changed += 1
else:
log.info(u"→ Updated existing SAMLProviderData. Nothing has changed.")
except Exception as err: # pylint: disable=broad-except
log.exception(err.message)
num_failed += 1
return (num_changed, num_failed, len(url_map))
except (exceptions.SSLError, exceptions.HTTPError, exceptions.RequestException, MetadataParseError) as error:
# Catch and process exception in case of errors during fetching and processing saml metadata.
# Here is a description of each exception.
# SSLError is raised in case of errors caused by SSL (e.g. SSL cer verification failure etc.)
# HTTPError is raised in case of unexpected status code (e.g. 500 error etc.)
# RequestException is the base exception for any request related error that "requests" lib raises.
# MetadataParseError is raised if there is error in the fetched meta data (e.g. missing @entityID etc.)
log.exception(error.message)
failure_messages.append(
"{error_type}: {error_message}\nMetadata Source: {url}\nEntity IDs: \n{entity_ids}.".format(
error_type=type(error).__name__,
error_message=error.message,
url=url,
entity_ids="\n".join(
["\t{}: {}".format(count, item) for count, item in enumerate(entity_ids, start=1)],
)
)
)
except etree.XMLSyntaxError as error:
log.exception(error.message)
failure_messages.append(
"XMLSyntaxError: {error_message}\nMetadata Source: {url}\nEntity IDs: \n{entity_ids}.".format(
error_message=str(error.error_log),
url=url,
entity_ids="\n".join(
["\t{}: {}".format(count, item) for count, item in enumerate(entity_ids, start=1)],
)
)
)
return num_changed, len(failure_messages), len(url_map), failure_messages
def _parse_metadata_xml(xml, entity_id):
......
......@@ -149,8 +149,9 @@ class TestShibIntegrationTest(IntegrationTestMixin, testutil.SAMLTestCase):
kwargs.setdefault('attr_email', 'urn:oid:1.3.6.1.4.1.5923.1.1.1.6') # eduPersonPrincipalName
self.configure_saml_provider(**kwargs)
self.assertTrue(httpretty.is_enabled())
num_changed, num_failed, num_total = fetch_saml_metadata()
num_changed, num_failed, num_total, failure_messages = fetch_saml_metadata()
self.assertEqual(num_failed, 0)
self.assertEqual(len(failure_messages), 0)
self.assertEqual(num_changed, 1)
self.assertEqual(num_total, 1)
......@@ -176,9 +177,10 @@ class TestShibIntegrationTest(IntegrationTestMixin, testutil.SAMLTestCase):
if fetch_metadata:
self.assertTrue(httpretty.is_enabled())
num_changed, num_failed, num_total = fetch_saml_metadata()
num_changed, num_failed, num_total, failure_messages = fetch_saml_metadata()
if assert_metadata_updates:
self.assertEqual(num_failed, 0)
self.assertEqual(len(failure_messages), 0)
self.assertEqual(num_changed, 1)
self.assertEqual(num_total, 1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment