Commit 00226bf3 by Braden MacDonald

Asynchronous metadata fetching using celery beat - PR 8518

parent cd941ead
...@@ -91,6 +91,9 @@ logs ...@@ -91,6 +91,9 @@ logs
chromedriver.log chromedriver.log
ghostdriver.log ghostdriver.log
### Celery artifacts ###
celerybeat-schedule
### Unknown artifacts ### Unknown artifacts
database.sqlite database.sqlite
courseware/static/js/mathjax/* courseware/static/js/mathjax/*
......
...@@ -7,6 +7,7 @@ from django.contrib import admin ...@@ -7,6 +7,7 @@ from django.contrib import admin
from config_models.admin import ConfigurationModelAdmin, KeyedConfigurationModelAdmin from config_models.admin import ConfigurationModelAdmin, KeyedConfigurationModelAdmin
from .models import OAuth2ProviderConfig, SAMLProviderConfig, SAMLConfiguration, SAMLProviderData from .models import OAuth2ProviderConfig, SAMLProviderConfig, SAMLConfiguration, SAMLProviderData
from .tasks import fetch_saml_metadata
admin.site.register(OAuth2ProviderConfig, KeyedConfigurationModelAdmin) admin.site.register(OAuth2ProviderConfig, KeyedConfigurationModelAdmin)
...@@ -29,6 +30,17 @@ class SAMLProviderConfigAdmin(KeyedConfigurationModelAdmin): ...@@ -29,6 +30,17 @@ class SAMLProviderConfigAdmin(KeyedConfigurationModelAdmin):
has_data.short_description = u'Metadata Ready' has_data.short_description = u'Metadata Ready'
has_data.boolean = True has_data.boolean = True
def save_model(self, request, obj, form, change):
"""
Post save: Queue an asynchronous metadata fetch to update SAMLProviderData.
We only want to do this for manual edits done using the admin interface.
Note: This only works if the celery worker and the app worker are using the
same 'configuration' cache.
"""
super(SAMLProviderConfigAdmin, self).save_model(request, obj, form, change)
fetch_saml_metadata.apply_async((), countdown=2)
admin.site.register(SAMLProviderConfig, SAMLProviderConfigAdmin) admin.site.register(SAMLProviderConfig, SAMLProviderConfigAdmin)
...@@ -54,7 +66,7 @@ admin.site.register(SAMLConfiguration, SAMLConfigurationAdmin) ...@@ -54,7 +66,7 @@ admin.site.register(SAMLConfiguration, SAMLConfigurationAdmin)
class SAMLProviderDataAdmin(admin.ModelAdmin): class SAMLProviderDataAdmin(admin.ModelAdmin):
""" Django Admin class for SAMLProviderData """ """ Django Admin class for SAMLProviderData (Read Only) """
list_display = ('entity_id', 'is_valid', 'fetched_at', 'expires_at', 'sso_url') list_display = ('entity_id', 'is_valid', 'fetched_at', 'expires_at', 'sso_url')
readonly_fields = ('is_valid', ) readonly_fields = ('is_valid', )
......
...@@ -2,20 +2,10 @@ ...@@ -2,20 +2,10 @@
""" """
Management commands for third_party_auth Management commands for third_party_auth
""" """
import datetime
import dateutil.parser
from django.core.management.base import BaseCommand, CommandError from django.core.management.base import BaseCommand, CommandError
from lxml import etree import logging
import requests from third_party_auth.models import SAMLConfiguration
from onelogin.saml2.utils import OneLogin_Saml2_Utils from third_party_auth.tasks import fetch_saml_metadata
from third_party_auth.models import SAMLConfiguration, SAMLProviderConfig, SAMLProviderData
#pylint: disable=superfluous-parens,no-member
class MetadataParseError(Exception):
""" An error occurred while parsing the SAML metadata from an IdP """
pass
class Command(BaseCommand): class Command(BaseCommand):
...@@ -27,120 +17,21 @@ class Command(BaseCommand): ...@@ -27,120 +17,21 @@ class Command(BaseCommand):
raise CommandError("saml requires one argument: pull") raise CommandError("saml requires one argument: pull")
if not SAMLConfiguration.is_enabled(): if not SAMLConfiguration.is_enabled():
self.stdout.write("Warning: SAML support is disabled via SAMLConfiguration.\n") raise CommandError("SAML support is disabled via SAMLConfiguration.")
subcommand = args[0] subcommand = args[0]
if subcommand == "pull": if subcommand == "pull":
self.cmd_pull() log_handler = logging.StreamHandler(self.stdout)
log_handler.setLevel(logging.DEBUG)
log = logging.getLogger('third_party_auth.tasks')
log.propagate = False
log.addHandler(log_handler)
num_changed, num_failed, num_total = fetch_saml_metadata()
self.stdout.write(
"\nDone. Fetched {num_total} total. {num_changed} were updated and {num_failed} failed.\n".format(
num_changed=num_changed, num_failed=num_failed, num_total=num_total
)
)
else: else:
raise CommandError("Unknown argment: {}".format(subcommand)) raise CommandError("Unknown argment: {}".format(subcommand))
@staticmethod
def tag_name(tag_name):
""" Get the namespaced-qualified name for an XML tag """
return '{urn:oasis:names:tc:SAML:2.0:metadata}' + tag_name
def cmd_pull(self):
""" Fetch the metadata for each provider and update the DB """
# First make a list of all the metadata XML URLs:
url_map = {}
for idp_slug in SAMLProviderConfig.key_values('idp_slug', flat=True):
config = SAMLProviderConfig.current(idp_slug)
if not config.enabled:
continue
url = config.metadata_source
if url not in url_map:
url_map[url] = []
if config.entity_id not in url_map[url]:
url_map[url].append(config.entity_id)
# Now fetch the metadata:
for url, entity_ids in url_map.items():
try:
self.stdout.write("\n→ Fetching {}\n".format(url))
if not url.lower().startswith('https'):
self.stdout.write("→ WARNING: This URL is not secure! It should use HTTPS.\n")
response = requests.get(url, verify=True) # May raise HTTPError or SSLError or ConnectionError
response.raise_for_status() # May raise an HTTPError
try:
parser = etree.XMLParser(remove_comments=True)
xml = etree.fromstring(response.text, parser)
except etree.XMLSyntaxError:
raise
# TODO: Can use OneLogin_Saml2_Utils to validate signed XML if anyone is using that
for entity_id in entity_ids:
self.stdout.write("→ Processing IdP with entityID {}\n".format(entity_id))
public_key, sso_url, expires_at = self._parse_metadata_xml(xml, entity_id)
self._update_data(entity_id, public_key, sso_url, expires_at)
except Exception as err: # pylint: disable=broad-except
self.stderr.write(u"→ ERROR: {}\n\n".format(err.message))
@classmethod
def _parse_metadata_xml(cls, xml, entity_id):
"""
Given an XML document containing SAML 2.0 metadata, parse it and return a tuple of
(public_key, sso_url, expires_at) for the specified entityID.
Raises MetadataParseError if anything is wrong.
"""
if xml.tag == cls.tag_name('EntityDescriptor'):
entity_desc = xml
else:
if xml.tag != cls.tag_name('EntitiesDescriptor'):
raise MetadataParseError("Expected root element to be <EntitiesDescriptor>, not {}".format(xml.tag))
entity_desc = xml.find(".//{}[@entityID='{}']".format(cls.tag_name('EntityDescriptor'), entity_id))
if not entity_desc:
raise MetadataParseError("Can't find EntityDescriptor for entityID {}".format(entity_id))
expires_at = None
if "validUntil" in xml.attrib:
expires_at = dateutil.parser.parse(xml.attrib["validUntil"])
if "cacheDuration" in xml.attrib:
cache_expires = OneLogin_Saml2_Utils.parse_duration(xml.attrib["cacheDuration"])
if expires_at is None or cache_expires < expires_at:
expires_at = cache_expires
sso_desc = entity_desc.find(cls.tag_name("IDPSSODescriptor"))
if not sso_desc:
raise MetadataParseError("IDPSSODescriptor missing")
if 'urn:oasis:names:tc:SAML:2.0:protocol' not in sso_desc.get("protocolSupportEnumeration"):
raise MetadataParseError("This IdP does not support SAML 2.0")
# Now we just need to get the public_key and sso_url
public_key = sso_desc.findtext("./{}//{}".format(
cls.tag_name("KeyDescriptor"), "{http://www.w3.org/2000/09/xmldsig#}X509Certificate"
))
if not public_key:
raise MetadataParseError("Public Key missing. Expected an <X509Certificate>")
public_key = public_key.replace(" ", "")
binding_elements = sso_desc.iterfind("./{}".format(cls.tag_name("SingleSignOnService")))
sso_bindings = {element.get('Binding'): element.get('Location') for element in binding_elements}
try:
# The only binding supported by python-saml and python-social-auth is HTTP-Redirect:
sso_url = sso_bindings['urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect']
except KeyError:
raise MetadataParseError("Unable to find SSO URL with HTTP-Redirect binding.")
return public_key, sso_url, expires_at
def _update_data(self, entity_id, public_key, sso_url, expires_at):
"""
Update/Create the SAMLProviderData for the given entity ID.
"""
data_obj = SAMLProviderData.current(entity_id)
fetched_at = datetime.datetime.now()
if data_obj and (data_obj.public_key == public_key and data_obj.sso_url == sso_url):
data_obj.expires_at = expires_at
data_obj.fetched_at = fetched_at
data_obj.save()
self.stdout.write("→ Updated existing SAMLProviderData. Nothing has changed.\n")
else:
SAMLProviderData.objects.create(
entity_id=entity_id,
fetched_at=fetched_at,
expires_at=expires_at,
sso_url=sso_url,
public_key=public_key,
)
self.stdout.write("→ Created new record for SAMLProviderData\n")
# -*- coding: utf-8 -*-
"""
Code to manage fetching and storing the metadata of IdPs.
"""
#pylint: disable=no-member
from celery.task import task # pylint: disable=import-error,no-name-in-module
import datetime
import dateutil.parser
import logging
from lxml import etree
import requests
from onelogin.saml2.utils import OneLogin_Saml2_Utils
from third_party_auth.models import SAMLConfiguration, SAMLProviderConfig, SAMLProviderData
log = logging.getLogger(__name__)
SAML_XML_NS = 'urn:oasis:names:tc:SAML:2.0:metadata' # The SAML Metadata XML namespace
class MetadataParseError(Exception):
""" An error occurred while parsing the SAML metadata from an IdP """
pass
@task(name='third_party_auth.fetch_saml_metadata')
def fetch_saml_metadata():
"""
Fetch and store/update the metadata of all IdPs
This task should be run on a daily basis.
It's OK to run this whether or not SAML is enabled.
Return value:
tuple(num_changed, num_failed, num_total)
num_changed: Number of providers that are either new or whose metadata has changed
num_failed: Number of providers that could not be updated
num_total: Total number of providers whose metadata was fetched
"""
if not SAMLConfiguration.is_enabled():
return (0, 0, 0) # Nothing to do until SAML is enabled.
num_changed, num_failed = 0, 0
# First make a list of all the metadata XML URLs:
url_map = {}
for idp_slug in SAMLProviderConfig.key_values('idp_slug', flat=True):
config = SAMLProviderConfig.current(idp_slug)
if not config.enabled:
continue
url = config.metadata_source
if url not in url_map:
url_map[url] = []
if config.entity_id not in url_map[url]:
url_map[url].append(config.entity_id)
# Now fetch the metadata:
for url, entity_ids in url_map.items():
try:
log.info("Fetching %s", url)
if not url.lower().startswith('https'):
log.warning("This SAML metadata URL is not secure! It should use HTTPS. (%s)", url)
response = requests.get(url, verify=True) # May raise HTTPError or SSLError or ConnectionError
response.raise_for_status() # May raise an HTTPError
try:
parser = etree.XMLParser(remove_comments=True)
xml = etree.fromstring(response.text, parser)
except etree.XMLSyntaxError:
raise
# TODO: Can use OneLogin_Saml2_Utils to validate signed XML if anyone is using that
for entity_id in entity_ids:
log.info(u"Processing IdP with entityID %s", entity_id)
public_key, sso_url, expires_at = _parse_metadata_xml(xml, entity_id)
changed = _update_data(entity_id, public_key, sso_url, expires_at)
if changed:
log.info(u"→ Created new record for SAMLProviderData")
num_changed += 1
else:
log.info(u"→ Updated existing SAMLProviderData. Nothing has changed.")
except Exception as err: # pylint: disable=broad-except
log.exception(err.message)
num_failed += 1
return (num_changed, num_failed, len(url_map))
def _parse_metadata_xml(xml, entity_id):
"""
Given an XML document containing SAML 2.0 metadata, parse it and return a tuple of
(public_key, sso_url, expires_at) for the specified entityID.
Raises MetadataParseError if anything is wrong.
"""
if xml.tag == etree.QName(SAML_XML_NS, 'EntityDescriptor'):
entity_desc = xml
else:
if xml.tag != etree.QName(SAML_XML_NS, 'EntitiesDescriptor'):
raise MetadataParseError("Expected root element to be <EntitiesDescriptor>, not {}".format(xml.tag))
entity_desc = xml.find(
".//{}[@entityID='{}']".format(etree.QName(SAML_XML_NS, 'EntityDescriptor'), entity_id)
)
if not entity_desc:
raise MetadataParseError("Can't find EntityDescriptor for entityID {}".format(entity_id))
expires_at = None
if "validUntil" in xml.attrib:
expires_at = dateutil.parser.parse(xml.attrib["validUntil"])
if "cacheDuration" in xml.attrib:
cache_expires = OneLogin_Saml2_Utils.parse_duration(xml.attrib["cacheDuration"])
if expires_at is None or cache_expires < expires_at:
expires_at = cache_expires
sso_desc = entity_desc.find(etree.QName(SAML_XML_NS, "IDPSSODescriptor"))
if not sso_desc:
raise MetadataParseError("IDPSSODescriptor missing")
if 'urn:oasis:names:tc:SAML:2.0:protocol' not in sso_desc.get("protocolSupportEnumeration"):
raise MetadataParseError("This IdP does not support SAML 2.0")
# Now we just need to get the public_key and sso_url
public_key = sso_desc.findtext("./{}//{}".format(
etree.QName(SAML_XML_NS, "KeyDescriptor"), "{http://www.w3.org/2000/09/xmldsig#}X509Certificate"
))
if not public_key:
raise MetadataParseError("Public Key missing. Expected an <X509Certificate>")
public_key = public_key.replace(" ", "")
binding_elements = sso_desc.iterfind("./{}".format(etree.QName(SAML_XML_NS, "SingleSignOnService")))
sso_bindings = {element.get('Binding'): element.get('Location') for element in binding_elements}
try:
# The only binding supported by python-saml and python-social-auth is HTTP-Redirect:
sso_url = sso_bindings['urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect']
except KeyError:
raise MetadataParseError("Unable to find SSO URL with HTTP-Redirect binding.")
return public_key, sso_url, expires_at
def _update_data(entity_id, public_key, sso_url, expires_at):
"""
Update/Create the SAMLProviderData for the given entity ID.
Return value:
False if nothing has changed and existing data's "fetched at" timestamp is just updated.
True if a new record was created. (Either this is a new provider or something changed.)
"""
data_obj = SAMLProviderData.current(entity_id)
fetched_at = datetime.datetime.now()
if data_obj and (data_obj.public_key == public_key and data_obj.sso_url == sso_url):
data_obj.expires_at = expires_at
data_obj.fetched_at = fetched_at
data_obj.save()
return False
else:
SAMLProviderData.objects.create(
entity_id=entity_id,
fetched_at=fetched_at,
expires_at=expires_at,
sso_url=sso_url,
public_key=public_key,
)
return True
""" """
Third_party_auth integration tests using a mock version of the TestShib provider Third_party_auth integration tests using a mock version of the TestShib provider
""" """
from django.core.management import call_command
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
import httpretty import httpretty
from mock import patch from mock import patch
import StringIO
from student.tests.factories import UserFactory from student.tests.factories import UserFactory
from third_party_auth.tasks import fetch_saml_metadata
from third_party_auth.tests import testutil from third_party_auth.tests import testutil
import unittest import unittest
...@@ -209,15 +208,11 @@ class TestShibIntegrationTest(testutil.SAMLTestCase): ...@@ -209,15 +208,11 @@ class TestShibIntegrationTest(testutil.SAMLTestCase):
self.configure_saml_provider(**kwargs) self.configure_saml_provider(**kwargs)
if fetch_metadata: if fetch_metadata:
stdout = StringIO.StringIO()
stderr = StringIO.StringIO()
self.assertTrue(httpretty.is_enabled()) self.assertTrue(httpretty.is_enabled())
call_command('saml', 'pull', stdout=stdout, stderr=stderr) num_changed, num_failed, num_total = fetch_saml_metadata()
stdout = stdout.getvalue().decode('utf-8') self.assertEqual(num_failed, 0)
stderr = stderr.getvalue().decode('utf-8') self.assertEqual(num_changed, 1)
self.assertEqual(stderr, '') self.assertEqual(num_total, 1)
self.assertIn(u'Fetching {}'.format(TESTSHIB_METADATA_URL), stdout)
self.assertIn(u'Created new record for SAMLProviderData', stdout)
def _fake_testshib_login_and_return(self): def _fake_testshib_login_and_return(self):
""" Mocked: the user logs in to TestShib and then gets redirected back """ """ Mocked: the user logs in to TestShib and then gets redirected back """
......
...@@ -8,7 +8,7 @@ import unittest ...@@ -8,7 +8,7 @@ import unittest
from .testutil import AUTH_FEATURE_ENABLED, SAMLTestCase from .testutil import AUTH_FEATURE_ENABLED, SAMLTestCase
# Define some XML namespaces: # Define some XML namespaces:
SAML_XML_NS = 'urn:oasis:names:tc:SAML:2.0:metadata' from third_party_auth.tasks import SAML_XML_NS
XMLDSIG_XML_NS = 'http://www.w3.org/2000/09/xmldsig#' XMLDSIG_XML_NS = 'http://www.w3.org/2000/09/xmldsig#'
......
...@@ -16,6 +16,7 @@ Common traits: ...@@ -16,6 +16,7 @@ Common traits:
# and throws spurious errors. Therefore, we disable invalid-name checking. # and throws spurious errors. Therefore, we disable invalid-name checking.
# pylint: disable=invalid-name # pylint: disable=invalid-name
import datetime
import json import json
from .common import * from .common import *
...@@ -107,6 +108,7 @@ CELERY_QUEUES = { ...@@ -107,6 +108,7 @@ CELERY_QUEUES = {
if os.environ.get('QUEUE') == 'high_mem': if os.environ.get('QUEUE') == 'high_mem':
CELERYD_MAX_TASKS_PER_CHILD = 1 CELERYD_MAX_TASKS_PER_CHILD = 1
CELERYBEAT_SCHEDULE = {} # For scheduling tasks, entries can be added to this dict
########################## NON-SECURE ENV CONFIG ############################## ########################## NON-SECURE ENV CONFIG ##############################
# Things like server locations, ports, etc. # Things like server locations, ports, etc.
...@@ -552,6 +554,12 @@ if FEATURES.get('ENABLE_THIRD_PARTY_AUTH'): ...@@ -552,6 +554,12 @@ if FEATURES.get('ENABLE_THIRD_PARTY_AUTH'):
# third_party_auth config moved to ConfigurationModels. This is for data migration only: # third_party_auth config moved to ConfigurationModels. This is for data migration only:
THIRD_PARTY_AUTH_OLD_CONFIG = AUTH_TOKENS.get('THIRD_PARTY_AUTH', None) THIRD_PARTY_AUTH_OLD_CONFIG = AUTH_TOKENS.get('THIRD_PARTY_AUTH', None)
if ENV_TOKENS.get('THIRD_PARTY_AUTH_SAML_FETCH_PERIOD_HOURS', 24) is not None:
CELERYBEAT_SCHEDULE['refresh-saml-metadata'] = {
'task': 'third_party_auth.fetch_saml_metadata',
'schedule': datetime.timedelta(hours=ENV_TOKENS.get('THIRD_PARTY_AUTH_SAML_FETCH_PERIOD_HOURS', 24)),
}
##### OAUTH2 Provider ############## ##### OAUTH2 Provider ##############
if FEATURES.get('ENABLE_OAUTH2_PROVIDER'): if FEATURES.get('ENABLE_OAUTH2_PROVIDER'):
OAUTH_OIDC_ISSUER = ENV_TOKENS['OAUTH_OIDC_ISSUER'] OAUTH_OIDC_ISSUER = ENV_TOKENS['OAUTH_OIDC_ISSUER']
......
...@@ -109,7 +109,7 @@ def celery(options): ...@@ -109,7 +109,7 @@ def celery(options):
Runs Celery workers. Runs Celery workers.
""" """
settings = getattr(options, 'settings', 'dev_with_worker') settings = getattr(options, 'settings', 'dev_with_worker')
run_process(django_cmd('lms', settings, 'celery', 'worker', '--loglevel=INFO', '--pythonpath=.')) run_process(django_cmd('lms', settings, 'celery', 'worker', '--beat', '--loglevel=INFO', '--pythonpath=.'))
@task @task
...@@ -142,7 +142,7 @@ def run_all_servers(options): ...@@ -142,7 +142,7 @@ def run_all_servers(options):
run_multi_processes([ run_multi_processes([
django_cmd('lms', settings_lms, 'runserver', '--traceback', '--pythonpath=.', "0.0.0.0:{}".format(DEFAULT_PORT['lms'])), django_cmd('lms', settings_lms, 'runserver', '--traceback', '--pythonpath=.', "0.0.0.0:{}".format(DEFAULT_PORT['lms'])),
django_cmd('studio', settings_cms, 'runserver', '--traceback', '--pythonpath=.', "0.0.0.0:{}".format(DEFAULT_PORT['studio'])), django_cmd('studio', settings_cms, 'runserver', '--traceback', '--pythonpath=.', "0.0.0.0:{}".format(DEFAULT_PORT['studio'])),
django_cmd('lms', worker_settings, 'celery', 'worker', '--loglevel=INFO', '--pythonpath=.') django_cmd('lms', worker_settings, 'celery', 'worker', '--beat', '--loglevel=INFO', '--pythonpath=.')
]) ])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment