Commit 5cd58ec7 by Matthew Piatetsky

Add config model to control max threads for data loaders

ECOM-7188
parent ae50ef24
...@@ -222,5 +222,5 @@ for model in (LevelType, Prerequisite,): ...@@ -222,5 +222,5 @@ for model in (LevelType, Prerequisite,):
# Register remaining models using basic ModelAdmin classes # Register remaining models using basic ModelAdmin classes
for model in (Image, Video, ExpectedLearningItem, SyllabusItem, PersonSocialNetwork, CourseRunSocialNetwork, for model in (Image, Video, ExpectedLearningItem, SyllabusItem, PersonSocialNetwork, CourseRunSocialNetwork,
JobOutlookItem,): JobOutlookItem, DataLoaderConfig):
admin.site.register(model) admin.site.register(model)
...@@ -16,7 +16,7 @@ from course_discovery.apps.course_metadata.data_loaders.marketing_site import ( ...@@ -16,7 +16,7 @@ from course_discovery.apps.course_metadata.data_loaders.marketing_site import (
CourseMarketingSiteDataLoader, PersonMarketingSiteDataLoader, SchoolMarketingSiteDataLoader, CourseMarketingSiteDataLoader, PersonMarketingSiteDataLoader, SchoolMarketingSiteDataLoader,
SponsorMarketingSiteDataLoader, SubjectMarketingSiteDataLoader, XSeriesMarketingSiteDataLoader SponsorMarketingSiteDataLoader, SubjectMarketingSiteDataLoader, XSeriesMarketingSiteDataLoader
) )
from course_discovery.apps.course_metadata.models import Course from course_discovery.apps.course_metadata.models import Course, DataLoaderConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -59,18 +59,7 @@ class Command(BaseCommand): ...@@ -59,18 +59,7 @@ class Command(BaseCommand):
help='The short code for a specific partner to refresh.' help='The short code for a specific partner to refresh.'
) )
parser.add_argument(
'-w', '--max_workers',
type=int,
action='store',
dest='max_workers',
default=7,
help='Number of worker threads to use when traversing paginated responses.'
)
def handle(self, *args, **options): def handle(self, *args, **options):
max_workers = options.get('max_workers')
# For each partner defined... # For each partner defined...
partners = Partner.objects.all() partners = Partner.objects.all()
...@@ -129,6 +118,7 @@ class Command(BaseCommand): ...@@ -129,6 +118,7 @@ class Command(BaseCommand):
# as an update, significantly lowering the probability of race conditions. # as an update, significantly lowering the probability of race conditions.
courses_exist = Course.objects.filter(partner=partner).exists() courses_exist = Course.objects.filter(partner=partner).exists()
is_threadsafe = courses_exist and waffle.switch_is_active('threaded_metadata_write') is_threadsafe = courses_exist and waffle.switch_is_active('threaded_metadata_write')
max_workers = DataLoaderConfig.get_solo().max_workers
logger.info( logger.info(
'Command is{negation} using threads to write data.'.format(negation='' if is_threadsafe else ' not') 'Command is{negation} using threads to write data.'.format(negation='' if is_threadsafe else ' not')
...@@ -136,31 +126,31 @@ class Command(BaseCommand): ...@@ -136,31 +126,31 @@ class Command(BaseCommand):
pipeline = ( pipeline = (
( (
(SubjectMarketingSiteDataLoader, partner.marketing_site_url_root, None), (SubjectMarketingSiteDataLoader, partner.marketing_site_url_root, max_workers),
(SchoolMarketingSiteDataLoader, partner.marketing_site_url_root, None), (SchoolMarketingSiteDataLoader, partner.marketing_site_url_root, max_workers),
(SponsorMarketingSiteDataLoader, partner.marketing_site_url_root, None), (SponsorMarketingSiteDataLoader, partner.marketing_site_url_root, max_workers),
(PersonMarketingSiteDataLoader, partner.marketing_site_url_root, None), (PersonMarketingSiteDataLoader, partner.marketing_site_url_root, max_workers),
), ),
( (
(CourseMarketingSiteDataLoader, partner.marketing_site_url_root, None), (CourseMarketingSiteDataLoader, partner.marketing_site_url_root, max_workers),
(OrganizationsApiDataLoader, partner.organizations_api_url, None), (OrganizationsApiDataLoader, partner.organizations_api_url, max_workers),
), ),
( (
(CoursesApiDataLoader, partner.courses_api_url, None), (CoursesApiDataLoader, partner.courses_api_url, max_workers),
), ),
( (
(EcommerceApiDataLoader, partner.ecommerce_api_url, 1), (EcommerceApiDataLoader, partner.ecommerce_api_url, 1),
(ProgramsApiDataLoader, partner.programs_api_url, None), (ProgramsApiDataLoader, partner.programs_api_url, max_workers),
), ),
( (
(XSeriesMarketingSiteDataLoader, partner.marketing_site_url_root, None), (XSeriesMarketingSiteDataLoader, partner.marketing_site_url_root, max_workers),
), ),
) )
if waffle.switch_is_active('parallel_refresh_pipeline'): if waffle.switch_is_active('parallel_refresh_pipeline'):
for stage in pipeline: for stage in pipeline:
with concurrent.futures.ProcessPoolExecutor() as executor: with concurrent.futures.ProcessPoolExecutor() as executor:
for loader_class, api_url, max_workers_override in stage: for loader_class, api_url, max_workers in stage:
if api_url: if api_url:
executor.submit( executor.submit(
execute_parallel_loader, execute_parallel_loader,
...@@ -169,13 +159,13 @@ class Command(BaseCommand): ...@@ -169,13 +159,13 @@ class Command(BaseCommand):
api_url, api_url,
access_token, access_token,
token_type, token_type,
(max_workers_override or max_workers), max_workers,
is_threadsafe, is_threadsafe,
**kwargs, **kwargs,
) )
else: else:
# Flatten pipeline and run serially. # Flatten pipeline and run serially.
for loader_class, api_url, max_workers_override in itertools.chain(*(stage for stage in pipeline)): for loader_class, api_url, max_workers in itertools.chain(*(stage for stage in pipeline)):
if api_url: if api_url:
execute_loader( execute_loader(
loader_class, loader_class,
...@@ -183,7 +173,7 @@ class Command(BaseCommand): ...@@ -183,7 +173,7 @@ class Command(BaseCommand):
api_url, api_url,
access_token, access_token,
token_type, token_type,
(max_workers_override or max_workers), max_workers,
is_threadsafe, is_threadsafe,
**kwargs, **kwargs,
) )
......
...@@ -135,8 +135,8 @@ class RefreshCourseMetadataCommandTests(TransactionTestCase): ...@@ -135,8 +135,8 @@ class RefreshCourseMetadataCommandTests(TransactionTestCase):
# Set up expected calls # Set up expected calls
expected_calls = [mock.call(loader_class, self.partner, api_url, expected_calls = [mock.call(loader_class, self.partner, api_url,
ACCESS_TOKEN, 'JWT', max_workers_override or 7, False, **self.kwargs) ACCESS_TOKEN, 'JWT', max_workers or 7, False, **self.kwargs)
for loader_class, api_url, max_workers_override in self.pipeline] for loader_class, api_url, max_workers in self.pipeline]
mock_executor.assert_has_calls(expected_calls) mock_executor.assert_has_calls(expected_calls)
def test_refresh_course_metadata_parallel(self): def test_refresh_course_metadata_parallel(self):
...@@ -157,8 +157,8 @@ class RefreshCourseMetadataCommandTests(TransactionTestCase): ...@@ -157,8 +157,8 @@ class RefreshCourseMetadataCommandTests(TransactionTestCase):
# Set up expected calls # Set up expected calls
expected_calls = [mock.call(execute_parallel_loader, loader_class, expected_calls = [mock.call(execute_parallel_loader, loader_class,
self.partner, api_url, ACCESS_TOKEN, self.partner, api_url, ACCESS_TOKEN,
'JWT', max_workers_override or 7, True, **self.kwargs) 'JWT', max_workers or 7, True, **self.kwargs)
for loader_class, api_url, max_workers_override in self.pipeline] for loader_class, api_url, max_workers in self.pipeline]
mock_executor.assert_has_calls(expected_calls, any_order=True) mock_executor.assert_has_calls(expected_calls, any_order=True)
def test_refresh_course_metadata_with_invalid_partner_code(self): def test_refresh_course_metadata_with_invalid_partner_code(self):
......
# -*- coding: utf-8 -*-
# Generated by Django 1.9.11 on 2017-02-21 20:11
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('course_metadata', '0047_personwork'),
]
operations = [
migrations.CreateModel(
name='DataLoaderConfig',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('max_workers', models.PositiveSmallIntegerField(default=7)),
],
options={
'abstract': False,
},
),
]
...@@ -15,6 +15,7 @@ from django_extensions.db.fields import AutoSlugField ...@@ -15,6 +15,7 @@ from django_extensions.db.fields import AutoSlugField
from django_extensions.db.models import TimeStampedModel from django_extensions.db.models import TimeStampedModel
from haystack import connections from haystack import connections
from haystack.query import SearchQuerySet from haystack.query import SearchQuerySet
from solo.models import SingletonModel
from sortedm2m.fields import SortedManyToManyField from sortedm2m.fields import SortedManyToManyField
from stdimage.models import StdImageField from stdimage.models import StdImageField
from stdimage.utils import UploadToAutoSlug from stdimage.utils import UploadToAutoSlug
...@@ -902,3 +903,10 @@ class CourseRunSocialNetwork(AbstractSocialNetworkModel): ...@@ -902,3 +903,10 @@ class CourseRunSocialNetwork(AbstractSocialNetworkModel):
class PersonWork(AbstractValueModel): class PersonWork(AbstractValueModel):
""" Person Works model. """ """ Person Works model. """
person = models.ForeignKey(Person, related_name='person_works') person = models.ForeignKey(Person, related_name='person_works')
class DataLoaderConfig(SingletonModel):
"""
Configuration for data loaders used in the refresh_course_metadata command.
"""
max_workers = models.PositiveSmallIntegerField(default=7)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment