Commit ba8a85ca by Renzo Lucioni Committed by GitHub

Further accelerate metadata refresh (#363)

A Waffle switch can now be used to toggle parallelized execution of the data loading pipeline. The pipeline is structured so that stages that can run independently of others are interleaved using separate processes.

A separate Waffle switch can be used to toggle threaded writes to the database. Instead of using threads to request all data, then writing it all serially, use this switch to spawn separate threads responsible for reading and writing each page of data.

ECOM-5871
parent 846188ff
......@@ -26,13 +26,15 @@ class AbstractDataLoader(metaclass=abc.ABCMeta):
SUPPORTED_TOKEN_TYPES = ('bearer', 'jwt',)
MARKDOWN_CLEANUP_REGEX = re.compile(r'^<p>(.*)</p>$')
def __init__(self, partner, api_url, access_token=None, token_type=None, max_workers=None):
def __init__(self, partner, api_url, access_token=None, token_type=None, max_workers=None, is_threadsafe=False):
"""
Arguments:
partner (Partner): Partner which owns the APIs and data being loaded
api_url (str): URL of the API from which data is loaded
access_token (str): OAuth2 access token
token_type (str): The type of access token passed in (e.g. Bearer, JWT)
max_workers (int): Number of worker threads to use when traversing paginated responses.
is_threadsafe (bool): True if multiple threads can be used to write data.
"""
if token_type:
token_type = token_type.lower()
......@@ -46,6 +48,7 @@ class AbstractDataLoader(metaclass=abc.ABCMeta):
self.api_url = api_url.strip('/')
self.max_workers = max_workers
self.is_threadsafe = is_threadsafe
@cached_property
def api_client(self):
......
......@@ -75,17 +75,26 @@ class CoursesApiDataLoader(AbstractDataLoader):
pages = response['pagination']['num_pages']
self._process_response(response)
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [executor.submit(self._make_request, page) for page in range(initial_page + 1, pages + 1)]
pagerange = range(initial_page + 1, pages + 1)
for future in futures: # pragma: no cover
response = future.result()
self._process_response(response)
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor: # pragma: no cover
if self.is_threadsafe:
for page in pagerange:
executor.submit(self._load_data, page)
else:
for future in [executor.submit(self._make_request, page) for page in pagerange]:
response = future.result()
self._process_response(response)
logger.info('Retrieved %d course runs from %s.', count, self.partner.courses_api_url)
self.delete_orphans()
def _load_data(self, page): # pragma: no cover
"""Make a request for the given page and process the response."""
response = self._make_request(page)
self._process_response(response)
def _make_request(self, page):
return self.api_client.courses().get(page=page, page_size=self.PAGE_SIZE)
......@@ -190,17 +199,26 @@ class EcommerceApiDataLoader(AbstractDataLoader):
pages = math.ceil(count / self.PAGE_SIZE)
self._process_response(response)
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [executor.submit(self._make_request, page) for page in range(initial_page + 1, pages + 1)]
pagerange = range(initial_page + 1, pages + 1)
for future in futures: # pragma: no cover
response = future.result()
self._process_response(response)
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor: # pragma: no cover
if self.is_threadsafe:
for page in pagerange:
executor.submit(self._load_data, page)
else:
for future in [executor.submit(self._make_request, page) for page in pagerange]:
response = future.result()
self._process_response(response)
logger.info('Retrieved %d course seats from %s.', count, self.partner.ecommerce_api_url)
self.delete_orphans()
def _load_data(self, page): # pragma: no cover
"""Make a request for the given page and process the response."""
response = self._make_request(page)
self._process_response(response)
def _make_request(self, page):
return self.api_client.courses().get(page=page, page_size=self.PAGE_SIZE, include_products=True)
......@@ -273,8 +291,10 @@ class ProgramsApiDataLoader(AbstractDataLoader):
image_height = 480
XSERIES = None
def __init__(self, partner, api_url, access_token=None, token_type=None, max_workers=None):
super(ProgramsApiDataLoader, self).__init__(partner, api_url, access_token, token_type, max_workers)
def __init__(self, partner, api_url, access_token=None, token_type=None, max_workers=None, is_threadsafe=False):
super(ProgramsApiDataLoader, self).__init__(
partner, api_url, access_token, token_type, max_workers, is_threadsafe
)
self.XSERIES = ProgramType.objects.get(name='XSeries')
def ingest(self):
......
......@@ -22,8 +22,10 @@ logger = logging.getLogger(__name__)
class AbstractMarketingSiteDataLoader(AbstractDataLoader):
def __init__(self, partner, api_url, access_token=None, token_type=None, max_workers=None):
super(AbstractMarketingSiteDataLoader, self).__init__(partner, api_url, access_token, token_type, max_workers)
def __init__(self, partner, api_url, access_token=None, token_type=None, max_workers=None, is_threadsafe=False):
super(AbstractMarketingSiteDataLoader, self).__init__(
partner, api_url, access_token, token_type, max_workers, is_threadsafe
)
if not (self.partner.marketing_site_api_username and self.partner.marketing_site_api_password):
msg = 'Marketing Site API credentials are not properly configured for Partner [{partner}]!'.format(
......@@ -69,13 +71,21 @@ class AbstractMarketingSiteDataLoader(AbstractDataLoader):
# Add one to avoid requesting the first page again and to make sure
# we get the last page when range() is used below.
pages = [self._extract_page(url) + 1 for url in (data['first'], data['last'])]
pagerange = range(*pages)
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [executor.submit(self._request, page) for page in range(*pages)]
for future in futures:
response = future.result()
self._process_response(response)
if self.is_threadsafe: # pragma: no cover
for page in pagerange:
executor.submit(self._load_data, page)
else:
for future in [executor.submit(self._request, page) for page in pagerange]:
response = future.result()
self._process_response(response)
def _load_data(self, page): # pragma: no cover
"""Make a request for the given page and process the response."""
response = self._request(page)
self._process_response(response)
def _request(self, page):
"""Make a request to the marketing site."""
......
import concurrent.futures
import itertools
import logging
from django.core.management import BaseCommand, CommandError
from edx_rest_api_client.client import EdxRestApiClient
import waffle
from course_discovery.apps.core.models import Partner
from course_discovery.apps.course_metadata.data_loaders.api import (
......@@ -11,10 +14,19 @@ from course_discovery.apps.course_metadata.data_loaders.marketing_site import (
XSeriesMarketingSiteDataLoader, SubjectMarketingSiteDataLoader, SchoolMarketingSiteDataLoader,
SponsorMarketingSiteDataLoader, PersonMarketingSiteDataLoader, CourseMarketingSiteDataLoader,
)
from course_discovery.apps.course_metadata.models import Course
logger = logging.getLogger(__name__)
def execute_loader(loader_class, *loader_args):
try:
loader_class(*loader_args).ingest()
except Exception: # pylint: disable=broad-except
logger.exception('%s failed!', loader_class.__name__)
class Command(BaseCommand):
help = 'Refresh course metadata from external sources.'
......@@ -88,30 +100,67 @@ class Command(BaseCommand):
logger.exception('No access token provided or acquired through client_credential flow.')
raise
data_loaders = (
(partner.marketing_site_url_root, SubjectMarketingSiteDataLoader, None),
(partner.marketing_site_url_root, SchoolMarketingSiteDataLoader, None),
(partner.marketing_site_url_root, SponsorMarketingSiteDataLoader, None),
(partner.marketing_site_url_root, PersonMarketingSiteDataLoader, None),
(partner.marketing_site_url_root, CourseMarketingSiteDataLoader, None),
(partner.organizations_api_url, OrganizationsApiDataLoader, None),
(partner.courses_api_url, CoursesApiDataLoader, None),
(partner.ecommerce_api_url, EcommerceApiDataLoader, 1),
(partner.programs_api_url, ProgramsApiDataLoader, None),
(partner.marketing_site_url_root, XSeriesMarketingSiteDataLoader, None),
# If no courses exist for this partner, this command is likely being run on a
# new catalog installation. In that case, we don't want multiple threads racing
# to create courses. If courses do exist, this command is likely being run
# as an update, significantly lowering the probability of race conditions.
courses_exist = Course.objects.filter(partner=partner).exists()
is_threadsafe = True if courses_exist and waffle.switch_is_active('threaded_metadata_write') else False
logger.info(
'Command is{negation} using threads to write data.'.format(negation='' if is_threadsafe else ' not')
)
pipeline = (
(
(SubjectMarketingSiteDataLoader, partner.marketing_site_url_root, None),
(SchoolMarketingSiteDataLoader, partner.marketing_site_url_root, None),
(SponsorMarketingSiteDataLoader, partner.marketing_site_url_root, None),
(PersonMarketingSiteDataLoader, partner.marketing_site_url_root, None),
),
(
(CourseMarketingSiteDataLoader, partner.marketing_site_url_root, None),
(OrganizationsApiDataLoader, partner.organizations_api_url, None),
),
(
(CoursesApiDataLoader, partner.courses_api_url, None),
),
(
(EcommerceApiDataLoader, partner.ecommerce_api_url, 1),
(ProgramsApiDataLoader, partner.programs_api_url, None),
),
(
(XSeriesMarketingSiteDataLoader, partner.marketing_site_url_root, None),
),
)
for api_url, loader_class, max_workers_override in data_loaders:
if api_url:
try:
loader_class(
if waffle.switch_is_active('parallel_refresh_pipeline'):
for stage in pipeline:
with concurrent.futures.ProcessPoolExecutor() as executor:
for loader_class, api_url, max_workers_override in stage:
if api_url:
executor.submit(
execute_loader,
loader_class,
partner,
api_url,
access_token,
token_type,
(max_workers_override or max_workers),
is_threadsafe,
)
else:
# Flatten pipeline and run serially.
for loader_class, api_url, max_workers_override in itertools.chain(*(stage for stage in pipeline)):
if api_url:
execute_loader(
loader_class,
partner,
api_url,
access_token,
token_type,
(max_workers_override or max_workers)
).ingest()
except Exception: # pylint: disable=broad-except
logger.exception('%s failed!', loader_class.__name__)
(max_workers_override or max_workers),
is_threadsafe,
)
# TODO Cleanup CourseRun overrides equivalent to the Course values.
import json
import ddt
import mock
import responses
from django.core.management import call_command, CommandError
......@@ -15,11 +16,14 @@ from course_discovery.apps.course_metadata.data_loaders.marketing_site import (
SponsorMarketingSiteDataLoader, PersonMarketingSiteDataLoader, CourseMarketingSiteDataLoader
)
from course_discovery.apps.course_metadata.data_loaders.tests import mock_data
from course_discovery.apps.course_metadata.tests import toggle_switch
from course_discovery.apps.course_metadata.tests.factories import CourseFactory
ACCESS_TOKEN = 'secret'
JSON = 'application/json'
@ddt.ddt
class RefreshCourseMetadataCommandTests(TestCase):
def setUp(self):
super(RefreshCourseMetadataCommandTests, self).setUp()
......@@ -107,6 +111,23 @@ class RefreshCourseMetadataCommandTests(TestCase):
)
return bodies
@ddt.data(True, False)
def test_refresh_course_metadata(self, is_parallel):
if is_parallel:
for name in ['threaded_metadata_write', 'parallel_refresh_pipeline']:
toggle_switch(name)
with responses.RequestsMock() as rsps:
self.mock_access_token_api(rsps)
self.mock_apis()
# Courses must exist for the command to use multiple threads. If there are no
# courses, the command won't risk race conditions between threads trying to
# create the same course.
CourseFactory(partner=self.partner)
call_command('refresh_course_metadata')
def test_refresh_course_metadata_with_invalid_partner_code(self):
""" Verify an error is raised if an invalid partner code is passed on the command line. """
with self.assertRaises(CommandError):
......
from waffle.models import Switch
def toggle_switch(name, active):
def toggle_switch(name, active=True):
"""
Activate or deactivate a feature switch.
The switch is created if it does not exist.
Activate or deactivate a feature switch. The switch is created if it does not exist.
Arguments:
name (str): name of the switch to be toggled
active (bool): boolean indicating if the switch should be activated or deactivated
name (str): name of the switch to be toggled.
Keyword Arguments:
active (bool): Whether the switch should be on or off.
Returns:
Switch: Waffle Switch
"""
switch, __ = Switch.objects.get_or_create(name=name,
defaults={'active': active})
switch, __ = Switch.objects.get_or_create(name=name, defaults={'active': active})
switch.active = active
switch.save()
return switch
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment