Commit 00226bf3 by Braden MacDonald

Asynchronous metadata fetching using celery beat - PR 8518

parent cd941ead
......@@ -91,6 +91,9 @@ logs
chromedriver.log
ghostdriver.log
### Celery artifacts ###
celerybeat-schedule
### Unknown artifacts
database.sqlite
courseware/static/js/mathjax/*
......
......@@ -7,6 +7,7 @@ from django.contrib import admin
from config_models.admin import ConfigurationModelAdmin, KeyedConfigurationModelAdmin
from .models import OAuth2ProviderConfig, SAMLProviderConfig, SAMLConfiguration, SAMLProviderData
from .tasks import fetch_saml_metadata
admin.site.register(OAuth2ProviderConfig, KeyedConfigurationModelAdmin)
......@@ -29,6 +30,17 @@ class SAMLProviderConfigAdmin(KeyedConfigurationModelAdmin):
has_data.short_description = u'Metadata Ready'
has_data.boolean = True
def save_model(self, request, obj, form, change):
"""
Post save: Queue an asynchronous metadata fetch to update SAMLProviderData.
We only want to do this for manual edits done using the admin interface.
Note: This only works if the celery worker and the app worker are using the
same 'configuration' cache.
"""
super(SAMLProviderConfigAdmin, self).save_model(request, obj, form, change)
fetch_saml_metadata.apply_async((), countdown=2)
admin.site.register(SAMLProviderConfig, SAMLProviderConfigAdmin)
......@@ -54,7 +66,7 @@ admin.site.register(SAMLConfiguration, SAMLConfigurationAdmin)
class SAMLProviderDataAdmin(admin.ModelAdmin):
""" Django Admin class for SAMLProviderData """
""" Django Admin class for SAMLProviderData (Read Only) """
list_display = ('entity_id', 'is_valid', 'fetched_at', 'expires_at', 'sso_url')
readonly_fields = ('is_valid', )
......
......@@ -2,20 +2,10 @@
"""
Management commands for third_party_auth
"""
import datetime
import dateutil.parser
from django.core.management.base import BaseCommand, CommandError
from lxml import etree
import requests
from onelogin.saml2.utils import OneLogin_Saml2_Utils
from third_party_auth.models import SAMLConfiguration, SAMLProviderConfig, SAMLProviderData
#pylint: disable=superfluous-parens,no-member
class MetadataParseError(Exception):
""" An error occurred while parsing the SAML metadata from an IdP """
pass
import logging
from third_party_auth.models import SAMLConfiguration
from third_party_auth.tasks import fetch_saml_metadata
class Command(BaseCommand):
......@@ -27,120 +17,21 @@ class Command(BaseCommand):
raise CommandError("saml requires one argument: pull")
if not SAMLConfiguration.is_enabled():
self.stdout.write("Warning: SAML support is disabled via SAMLConfiguration.\n")
raise CommandError("SAML support is disabled via SAMLConfiguration.")
subcommand = args[0]
if subcommand == "pull":
self.cmd_pull()
log_handler = logging.StreamHandler(self.stdout)
log_handler.setLevel(logging.DEBUG)
log = logging.getLogger('third_party_auth.tasks')
log.propagate = False
log.addHandler(log_handler)
num_changed, num_failed, num_total = fetch_saml_metadata()
self.stdout.write(
"\nDone. Fetched {num_total} total. {num_changed} were updated and {num_failed} failed.\n".format(
num_changed=num_changed, num_failed=num_failed, num_total=num_total
)
)
else:
raise CommandError("Unknown argment: {}".format(subcommand))
@staticmethod
def tag_name(tag_name):
""" Get the namespaced-qualified name for an XML tag """
return '{urn:oasis:names:tc:SAML:2.0:metadata}' + tag_name
def cmd_pull(self):
""" Fetch the metadata for each provider and update the DB """
# First make a list of all the metadata XML URLs:
url_map = {}
for idp_slug in SAMLProviderConfig.key_values('idp_slug', flat=True):
config = SAMLProviderConfig.current(idp_slug)
if not config.enabled:
continue
url = config.metadata_source
if url not in url_map:
url_map[url] = []
if config.entity_id not in url_map[url]:
url_map[url].append(config.entity_id)
# Now fetch the metadata:
for url, entity_ids in url_map.items():
try:
self.stdout.write("\n→ Fetching {}\n".format(url))
if not url.lower().startswith('https'):
self.stdout.write("→ WARNING: This URL is not secure! It should use HTTPS.\n")
response = requests.get(url, verify=True) # May raise HTTPError or SSLError or ConnectionError
response.raise_for_status() # May raise an HTTPError
try:
parser = etree.XMLParser(remove_comments=True)
xml = etree.fromstring(response.text, parser)
except etree.XMLSyntaxError:
raise
# TODO: Can use OneLogin_Saml2_Utils to validate signed XML if anyone is using that
for entity_id in entity_ids:
self.stdout.write("→ Processing IdP with entityID {}\n".format(entity_id))
public_key, sso_url, expires_at = self._parse_metadata_xml(xml, entity_id)
self._update_data(entity_id, public_key, sso_url, expires_at)
except Exception as err: # pylint: disable=broad-except
self.stderr.write(u"→ ERROR: {}\n\n".format(err.message))
@classmethod
def _parse_metadata_xml(cls, xml, entity_id):
"""
Given an XML document containing SAML 2.0 metadata, parse it and return a tuple of
(public_key, sso_url, expires_at) for the specified entityID.
Raises MetadataParseError if anything is wrong.
"""
if xml.tag == cls.tag_name('EntityDescriptor'):
entity_desc = xml
else:
if xml.tag != cls.tag_name('EntitiesDescriptor'):
raise MetadataParseError("Expected root element to be <EntitiesDescriptor>, not {}".format(xml.tag))
entity_desc = xml.find(".//{}[@entityID='{}']".format(cls.tag_name('EntityDescriptor'), entity_id))
if not entity_desc:
raise MetadataParseError("Can't find EntityDescriptor for entityID {}".format(entity_id))
expires_at = None
if "validUntil" in xml.attrib:
expires_at = dateutil.parser.parse(xml.attrib["validUntil"])
if "cacheDuration" in xml.attrib:
cache_expires = OneLogin_Saml2_Utils.parse_duration(xml.attrib["cacheDuration"])
if expires_at is None or cache_expires < expires_at:
expires_at = cache_expires
sso_desc = entity_desc.find(cls.tag_name("IDPSSODescriptor"))
if not sso_desc:
raise MetadataParseError("IDPSSODescriptor missing")
if 'urn:oasis:names:tc:SAML:2.0:protocol' not in sso_desc.get("protocolSupportEnumeration"):
raise MetadataParseError("This IdP does not support SAML 2.0")
# Now we just need to get the public_key and sso_url
public_key = sso_desc.findtext("./{}//{}".format(
cls.tag_name("KeyDescriptor"), "{http://www.w3.org/2000/09/xmldsig#}X509Certificate"
))
if not public_key:
raise MetadataParseError("Public Key missing. Expected an <X509Certificate>")
public_key = public_key.replace(" ", "")
binding_elements = sso_desc.iterfind("./{}".format(cls.tag_name("SingleSignOnService")))
sso_bindings = {element.get('Binding'): element.get('Location') for element in binding_elements}
try:
# The only binding supported by python-saml and python-social-auth is HTTP-Redirect:
sso_url = sso_bindings['urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect']
except KeyError:
raise MetadataParseError("Unable to find SSO URL with HTTP-Redirect binding.")
return public_key, sso_url, expires_at
def _update_data(self, entity_id, public_key, sso_url, expires_at):
"""
Update/Create the SAMLProviderData for the given entity ID.
"""
data_obj = SAMLProviderData.current(entity_id)
fetched_at = datetime.datetime.now()
if data_obj and (data_obj.public_key == public_key and data_obj.sso_url == sso_url):
data_obj.expires_at = expires_at
data_obj.fetched_at = fetched_at
data_obj.save()
self.stdout.write("→ Updated existing SAMLProviderData. Nothing has changed.\n")
else:
SAMLProviderData.objects.create(
entity_id=entity_id,
fetched_at=fetched_at,
expires_at=expires_at,
sso_url=sso_url,
public_key=public_key,
)
self.stdout.write("→ Created new record for SAMLProviderData\n")
# -*- coding: utf-8 -*-
"""
Code to manage fetching and storing the metadata of IdPs.
"""
#pylint: disable=no-member
from celery.task import task # pylint: disable=import-error,no-name-in-module
import datetime
import dateutil.parser
import logging
from lxml import etree
import requests
from onelogin.saml2.utils import OneLogin_Saml2_Utils
from third_party_auth.models import SAMLConfiguration, SAMLProviderConfig, SAMLProviderData
log = logging.getLogger(__name__)
SAML_XML_NS = 'urn:oasis:names:tc:SAML:2.0:metadata' # The SAML Metadata XML namespace
class MetadataParseError(Exception):
""" An error occurred while parsing the SAML metadata from an IdP """
pass
@task(name='third_party_auth.fetch_saml_metadata')
def fetch_saml_metadata():
"""
Fetch and store/update the metadata of all IdPs
This task should be run on a daily basis.
It's OK to run this whether or not SAML is enabled.
Return value:
tuple(num_changed, num_failed, num_total)
num_changed: Number of providers that are either new or whose metadata has changed
num_failed: Number of providers that could not be updated
num_total: Total number of providers whose metadata was fetched
"""
if not SAMLConfiguration.is_enabled():
return (0, 0, 0) # Nothing to do until SAML is enabled.
num_changed, num_failed = 0, 0
# First make a list of all the metadata XML URLs:
url_map = {}
for idp_slug in SAMLProviderConfig.key_values('idp_slug', flat=True):
config = SAMLProviderConfig.current(idp_slug)
if not config.enabled:
continue
url = config.metadata_source
if url not in url_map:
url_map[url] = []
if config.entity_id not in url_map[url]:
url_map[url].append(config.entity_id)
# Now fetch the metadata:
for url, entity_ids in url_map.items():
try:
log.info("Fetching %s", url)
if not url.lower().startswith('https'):
log.warning("This SAML metadata URL is not secure! It should use HTTPS. (%s)", url)
response = requests.get(url, verify=True) # May raise HTTPError or SSLError or ConnectionError
response.raise_for_status() # May raise an HTTPError
try:
parser = etree.XMLParser(remove_comments=True)
xml = etree.fromstring(response.text, parser)
except etree.XMLSyntaxError:
raise
# TODO: Can use OneLogin_Saml2_Utils to validate signed XML if anyone is using that
for entity_id in entity_ids:
log.info(u"Processing IdP with entityID %s", entity_id)
public_key, sso_url, expires_at = _parse_metadata_xml(xml, entity_id)
changed = _update_data(entity_id, public_key, sso_url, expires_at)
if changed:
log.info(u"→ Created new record for SAMLProviderData")
num_changed += 1
else:
log.info(u"→ Updated existing SAMLProviderData. Nothing has changed.")
except Exception as err: # pylint: disable=broad-except
log.exception(err.message)
num_failed += 1
return (num_changed, num_failed, len(url_map))
def _parse_metadata_xml(xml, entity_id):
"""
Given an XML document containing SAML 2.0 metadata, parse it and return a tuple of
(public_key, sso_url, expires_at) for the specified entityID.
Raises MetadataParseError if anything is wrong.
"""
if xml.tag == etree.QName(SAML_XML_NS, 'EntityDescriptor'):
entity_desc = xml
else:
if xml.tag != etree.QName(SAML_XML_NS, 'EntitiesDescriptor'):
raise MetadataParseError("Expected root element to be <EntitiesDescriptor>, not {}".format(xml.tag))
entity_desc = xml.find(
".//{}[@entityID='{}']".format(etree.QName(SAML_XML_NS, 'EntityDescriptor'), entity_id)
)
if not entity_desc:
raise MetadataParseError("Can't find EntityDescriptor for entityID {}".format(entity_id))
expires_at = None
if "validUntil" in xml.attrib:
expires_at = dateutil.parser.parse(xml.attrib["validUntil"])
if "cacheDuration" in xml.attrib:
cache_expires = OneLogin_Saml2_Utils.parse_duration(xml.attrib["cacheDuration"])
if expires_at is None or cache_expires < expires_at:
expires_at = cache_expires
sso_desc = entity_desc.find(etree.QName(SAML_XML_NS, "IDPSSODescriptor"))
if not sso_desc:
raise MetadataParseError("IDPSSODescriptor missing")
if 'urn:oasis:names:tc:SAML:2.0:protocol' not in sso_desc.get("protocolSupportEnumeration"):
raise MetadataParseError("This IdP does not support SAML 2.0")
# Now we just need to get the public_key and sso_url
public_key = sso_desc.findtext("./{}//{}".format(
etree.QName(SAML_XML_NS, "KeyDescriptor"), "{http://www.w3.org/2000/09/xmldsig#}X509Certificate"
))
if not public_key:
raise MetadataParseError("Public Key missing. Expected an <X509Certificate>")
public_key = public_key.replace(" ", "")
binding_elements = sso_desc.iterfind("./{}".format(etree.QName(SAML_XML_NS, "SingleSignOnService")))
sso_bindings = {element.get('Binding'): element.get('Location') for element in binding_elements}
try:
# The only binding supported by python-saml and python-social-auth is HTTP-Redirect:
sso_url = sso_bindings['urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect']
except KeyError:
raise MetadataParseError("Unable to find SSO URL with HTTP-Redirect binding.")
return public_key, sso_url, expires_at
def _update_data(entity_id, public_key, sso_url, expires_at):
"""
Update/Create the SAMLProviderData for the given entity ID.
Return value:
False if nothing has changed and existing data's "fetched at" timestamp is just updated.
True if a new record was created. (Either this is a new provider or something changed.)
"""
data_obj = SAMLProviderData.current(entity_id)
fetched_at = datetime.datetime.now()
if data_obj and (data_obj.public_key == public_key and data_obj.sso_url == sso_url):
data_obj.expires_at = expires_at
data_obj.fetched_at = fetched_at
data_obj.save()
return False
else:
SAMLProviderData.objects.create(
entity_id=entity_id,
fetched_at=fetched_at,
expires_at=expires_at,
sso_url=sso_url,
public_key=public_key,
)
return True
"""
Third_party_auth integration tests using a mock version of the TestShib provider
"""
from django.core.management import call_command
from django.core.urlresolvers import reverse
import httpretty
from mock import patch
import StringIO
from student.tests.factories import UserFactory
from third_party_auth.tasks import fetch_saml_metadata
from third_party_auth.tests import testutil
import unittest
......@@ -209,15 +208,11 @@ class TestShibIntegrationTest(testutil.SAMLTestCase):
self.configure_saml_provider(**kwargs)
if fetch_metadata:
stdout = StringIO.StringIO()
stderr = StringIO.StringIO()
self.assertTrue(httpretty.is_enabled())
call_command('saml', 'pull', stdout=stdout, stderr=stderr)
stdout = stdout.getvalue().decode('utf-8')
stderr = stderr.getvalue().decode('utf-8')
self.assertEqual(stderr, '')
self.assertIn(u'Fetching {}'.format(TESTSHIB_METADATA_URL), stdout)
self.assertIn(u'Created new record for SAMLProviderData', stdout)
num_changed, num_failed, num_total = fetch_saml_metadata()
self.assertEqual(num_failed, 0)
self.assertEqual(num_changed, 1)
self.assertEqual(num_total, 1)
def _fake_testshib_login_and_return(self):
""" Mocked: the user logs in to TestShib and then gets redirected back """
......
......@@ -8,7 +8,7 @@ import unittest
from .testutil import AUTH_FEATURE_ENABLED, SAMLTestCase
# Define some XML namespaces:
SAML_XML_NS = 'urn:oasis:names:tc:SAML:2.0:metadata'
from third_party_auth.tasks import SAML_XML_NS
XMLDSIG_XML_NS = 'http://www.w3.org/2000/09/xmldsig#'
......
......@@ -16,6 +16,7 @@ Common traits:
# and throws spurious errors. Therefore, we disable invalid-name checking.
# pylint: disable=invalid-name
import datetime
import json
from .common import *
......@@ -107,6 +108,7 @@ CELERY_QUEUES = {
if os.environ.get('QUEUE') == 'high_mem':
CELERYD_MAX_TASKS_PER_CHILD = 1
CELERYBEAT_SCHEDULE = {} # For scheduling tasks, entries can be added to this dict
########################## NON-SECURE ENV CONFIG ##############################
# Things like server locations, ports, etc.
......@@ -552,6 +554,12 @@ if FEATURES.get('ENABLE_THIRD_PARTY_AUTH'):
# third_party_auth config moved to ConfigurationModels. This is for data migration only:
THIRD_PARTY_AUTH_OLD_CONFIG = AUTH_TOKENS.get('THIRD_PARTY_AUTH', None)
if ENV_TOKENS.get('THIRD_PARTY_AUTH_SAML_FETCH_PERIOD_HOURS', 24) is not None:
CELERYBEAT_SCHEDULE['refresh-saml-metadata'] = {
'task': 'third_party_auth.fetch_saml_metadata',
'schedule': datetime.timedelta(hours=ENV_TOKENS.get('THIRD_PARTY_AUTH_SAML_FETCH_PERIOD_HOURS', 24)),
}
##### OAUTH2 Provider ##############
if FEATURES.get('ENABLE_OAUTH2_PROVIDER'):
OAUTH_OIDC_ISSUER = ENV_TOKENS['OAUTH_OIDC_ISSUER']
......
......@@ -109,7 +109,7 @@ def celery(options):
Runs Celery workers.
"""
settings = getattr(options, 'settings', 'dev_with_worker')
run_process(django_cmd('lms', settings, 'celery', 'worker', '--loglevel=INFO', '--pythonpath=.'))
run_process(django_cmd('lms', settings, 'celery', 'worker', '--beat', '--loglevel=INFO', '--pythonpath=.'))
@task
......@@ -142,7 +142,7 @@ def run_all_servers(options):
run_multi_processes([
django_cmd('lms', settings_lms, 'runserver', '--traceback', '--pythonpath=.', "0.0.0.0:{}".format(DEFAULT_PORT['lms'])),
django_cmd('studio', settings_cms, 'runserver', '--traceback', '--pythonpath=.', "0.0.0.0:{}".format(DEFAULT_PORT['studio'])),
django_cmd('lms', worker_settings, 'celery', 'worker', '--loglevel=INFO', '--pythonpath=.')
django_cmd('lms', worker_settings, 'celery', 'worker', '--beat', '--loglevel=INFO', '--pythonpath=.')
])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment