Commit 48f53dd8 by Renzo Lucioni

Use threads to overlap I/O during data loading.

ECOM-5520.
parent c9d9d2fc
import json
import math
from urllib.parse import parse_qs, urlparse
from factory.fuzzy import (
......@@ -68,6 +69,7 @@ def mock_api_callback(url, data, results_key=True, pagination=False):
body = {
'count': count,
'next': next_url,
'num_pages': math.ceil(count / page_size),
'previous': previous_url,
}
......
......@@ -23,7 +23,7 @@ class AbstractDataLoader(metaclass=abc.ABCMeta):
PAGE_SIZE = 50
SUPPORTED_TOKEN_TYPES = ('bearer', 'jwt',)
def __init__(self, partner, api_url, access_token=None, token_type=None):
def __init__(self, partner, api_url, access_token=None, token_type=None, max_workers=None):
"""
Arguments:
partner (Partner): Partner which owns the APIs and data being loaded
......@@ -42,6 +42,8 @@ class AbstractDataLoader(metaclass=abc.ABCMeta):
self.partner = partner
self.api_url = api_url.strip('/')
self.max_workers = max_workers
@cached_property
def api_client(self):
"""
......
import logging
import concurrent.futures
from decimal import Decimal
from io import BytesIO
import logging
import math
import requests
from django.core.files import File
from opaque_keys.edx.keys import CourseKey
import requests
from course_discovery.apps.core.models import Currency
from course_discovery.apps.course_metadata.data_loaders import AbstractDataLoader
......@@ -60,40 +62,47 @@ class CoursesApiDataLoader(AbstractDataLoader):
""" Loads course runs from the Courses API. """
def ingest(self):
api_url = self.partner.courses_api_url
count = None
page = 1
logger.info('Refreshing Courses and CourseRuns from %s...', self.partner.courses_api_url)
logger.info('Refreshing Courses and CourseRuns from %s...', api_url)
initial_page = 1
response = self._request(initial_page)
count = response['pagination']['count']
pages = response['pagination']['num_pages']
self._process_response(response)
while page:
response = self.api_client.courses().get(page=page, page_size=self.PAGE_SIZE)
count = response['pagination']['count']
results = response['results']
logger.info('Retrieved %d course runs...', len(results))
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [executor.submit(self._request, page) for page in range(initial_page + 1, pages + 1)]
if response['pagination']['next']:
page += 1
else:
page = None
for future in futures: # pragma: no cover
response = future.result()
self._process_response(response)
for body in results:
course_run_id = body['id']
logger.info('Retrieved %d course runs from %s.', count, self.partner.courses_api_url)
try:
body = self.clean_strings(body)
course = self.update_course(body)
self.update_course_run(course, body)
except: # pylint: disable=bare-except
msg = 'An error occurred while updating {course_run} from {api_url}'.format(
course_run=course_run_id,
api_url=api_url
)
logger.exception(msg)
self.delete_orphans()
logger.info('Retrieved %d course runs from %s.', count, api_url)
def _request(self, page):
"""Make a request."""
return self.api_client.courses().get(page=page, page_size=self.PAGE_SIZE)
self.delete_orphans()
def _process_response(self, response):
"""Process a response."""
results = response['results']
logger.info('Retrieved %d course runs...', len(results))
for body in results:
course_run_id = body['id']
try:
body = self.clean_strings(body)
course = self.update_course(body)
self.update_course_run(course, body)
except: # pylint: disable=bare-except
msg = 'An error occurred while updating {course_run} from {api_url}'.format(
course_run=course_run_id,
api_url=self.partner.courses_api_url
)
logger.exception(msg)
def update_course(self, body):
course_run_key = CourseKey.from_string(body['id'])
......@@ -170,31 +179,37 @@ class EcommerceApiDataLoader(AbstractDataLoader):
""" Loads course seats from the E-Commerce API. """
def ingest(self):
api_url = self.partner.ecommerce_api_url
count = None
page = 1
logger.info('Refreshing course seats from %s...', api_url)
logger.info('Refreshing course seats from %s...', self.partner.ecommerce_api_url)
while page:
response = self.api_client.courses().get(page=page, page_size=self.PAGE_SIZE, include_products=True)
count = response['count']
results = response['results']
logger.info('Retrieved %d course seats...', len(results))
response = self._request(1)
count = response['count']
pages = math.ceil(count / self.PAGE_SIZE)
self._process_response(response)
if response['next']:
page += 1
else:
page = None
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [executor.submit(self._request, page) for page in range(2, pages + 1)]
for body in results:
body = self.clean_strings(body)
self.update_seats(body)
for future in futures: # pragma: no cover
response = future.result()
self._process_response(response)
logger.info('Retrieved %d course seats from %s.', count, api_url)
logger.info('Retrieved %d course seats from %s.', count, self.partner.ecommerce_api_url)
self.delete_orphans()
def _request(self, page):
"""Make a request."""
return self.api_client.courses().get(page=page, page_size=self.PAGE_SIZE, include_products=True)
def _process_response(self, response):
"""Process a response."""
results = response['results']
logger.info('Retrieved %d course seats...', len(results))
for body in results:
body = self.clean_strings(body)
self.update_seats(body)
def update_seats(self, body):
course_run_key = body['id']
try:
......
import abc
import concurrent.futures
import datetime
import logging
from urllib.parse import urlencode
from urllib.parse import parse_qs, urlencode, urlparse
from uuid import UUID
import pytz
......@@ -20,8 +21,8 @@ logger = logging.getLogger(__name__)
class AbstractMarketingSiteDataLoader(AbstractDataLoader):
def __init__(self, partner, api_url, access_token=None, token_type=None):
super(AbstractMarketingSiteDataLoader, self).__init__(partner, api_url, access_token, token_type)
def __init__(self, partner, api_url, access_token=None, token_type=None, max_workers=None):
super(AbstractMarketingSiteDataLoader, self).__init__(partner, api_url, access_token, token_type, max_workers)
if not (self.partner.marketing_site_api_username and self.partner.marketing_site_api_password):
msg = 'Marketing Site API credentials are not properly configured for Partner [{partner}]!'.format(
......@@ -58,40 +59,61 @@ class AbstractMarketingSiteDataLoader(AbstractDataLoader):
def ingest(self):
""" Load data for all supported objects (e.g. courses, runs). """
page = 0
query_kwargs = self.get_query_kwargs()
while page is not None and page >= 0: # pragma: no cover
kwargs = {
'page': page,
}
kwargs.update(query_kwargs)
qs = urlencode(kwargs)
url = '{root}/node.json?{qs}'.format(root=self.api_url, qs=qs)
response = self.api_client.get(url)
status_code = response.status_code
if status_code is not 200:
msg = 'Failed to retrieve data from {url}\nStatus Code: {status}\nBody: {body}'.format(
url=url, status=status_code, body=response.content)
logger.error(msg)
raise Exception(msg)
data = response.json()
for datum in data['list']:
try:
url = datum['url']
datum = self.clean_strings(datum)
self.process_node(datum)
except: # pylint: disable=bare-except
logger.exception('Failed to load %s.', url)
if 'next' in data:
page += 1
else:
break
initial_page = 0
response = self._request(initial_page)
self._check_status_code(response)
self._process_response(response)
data = response.json()
if 'next' in data:
# Add one to avoid requesting the first page again and to make sure
# we get the last page when range() is used below.
pages = [self._extract_page(url) + 1 for url in (data['first'], data['last'])]
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [executor.submit(self._request, page) for page in range(*pages)]
for future in futures:
response = future.result()
self._process_response(response)
def _request(self, page):
"""Make a request to the marketing site."""
kwargs = {'page': page}
kwargs.update(self.get_query_kwargs())
qs = urlencode(kwargs)
url = '{root}/node.json?{qs}'.format(root=self.api_url, qs=qs)
return self.api_client.get(url)
def _check_status_code(self, response):
"""Check the status code on a response from the marketing site."""
status_code = response.status_code
if status_code != 200:
msg = 'Failed to retrieve data from {url}\nStatus Code: {status}\nBody: {body}'.format(
url=response.url, status=status_code, body=response.content)
logger.error(msg)
raise Exception(msg)
def _extract_page(self, url):
"""Extract page number from a marketing site URL."""
qs = parse_qs(urlparse(url).query)
return int(qs['page'][0])
def _process_response(self, response):
"""Process a response from the marketing site."""
self._check_status_code(response)
data = response.json()
for node in data['list']:
try:
url = node['url']
node = self.clean_strings(node)
self.process_node(node)
except: # pylint: disable=bare-except
logger.exception('Failed to load %s.', url)
def _get_nested_url(self, field):
""" Helper method that retrieves the nested `url` field in the specified field, if it exists.
......
......@@ -930,7 +930,6 @@ MARKETING_SITE_API_SPONSOR_BODIES = [
'title': 'Turkcell Akademi',
'url': 'https://www.edx.org/sponsorer/turkcell-akademi',
'uuid': 'fcb48e7e-8f1b-4d4b-8bb0-77617aaad9ba'
},
{
'body': [],
......
import datetime
import json
import math
from urllib.parse import parse_qs, urlparse
from uuid import UUID
......@@ -45,7 +46,9 @@ class AbstractMarketingSiteDataLoaderTestMixin(DataLoaderTestMixin):
page_size = 1
body = {
'list': [data[page]]
'list': [data[page]],
'first': '{}?page={}'.format(url, 0),
'last': '{}?page={}'.format(url, math.ceil(count / page_size) - 1),
}
if (page * page_size) < count - 1:
......
......@@ -43,7 +43,18 @@ class Command(BaseCommand):
help='The short code for a specific partner to refresh.'
)
parser.add_argument(
'-w', '--max_workers',
type=int,
action='store',
dest='max_workers',
default=7,
help='Number of worker threads to use when traversing paginated responses.'
)
def handle(self, *args, **options):
max_workers = options.get('max_workers')
# For each partner defined...
partners = Partner.objects.all()
......@@ -56,7 +67,6 @@ class Command(BaseCommand):
raise CommandError('No partners available!')
for partner in partners:
access_token = options.get('access_token')
token_type = options.get('token_type')
......@@ -94,7 +104,7 @@ class Command(BaseCommand):
for api_url, loader_class in data_loaders:
if api_url:
try:
loader_class(partner, api_url, access_token, token_type).ingest()
loader_class(partner, api_url, access_token, token_type, max_workers).ingest()
except Exception: # pylint: disable=broad-except
logger.exception('%s failed!', loader_class.__name__)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment