Commit ad1dca56 by Clinton Blackburn

Converted tests to py.test

parent 93eb5c0f
[settings] [settings]
line_length=120 line_length=120
multi_line_output=5
import logging import logging
import pytest import pytest
from django.contrib.sites.models import Site
from django.core.cache import cache from django.core.cache import cache
from django.test.client import Client
from haystack import connections as haystack_connections
from pytest_django.lazy_django import skip_if_no_django from pytest_django.lazy_django import skip_if_no_django
from course_discovery.apps.core.tests.factories import PartnerFactory, SiteFactory
from course_discovery.apps.core.utils import ElasticsearchUtils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TEST_DOMAIN = 'testserver.fake'
@pytest.fixture
def django_cache(request, settings): @pytest.fixture(scope='session', autouse=True)
def django_cache_add_xdist_key_prefix(request):
skip_if_no_django() skip_if_no_django()
from django.conf import settings
xdist_prefix = getattr(request.config, 'slaveinput', {}).get('slaveid') xdist_prefix = getattr(request.config, 'slaveinput', {}).get('slaveid')
if xdist_prefix: if xdist_prefix:
# Put a prefix like gw0_, gw1_ etc on xdist processes
for name, cache_settings in settings.CACHES.items(): for name, cache_settings in settings.CACHES.items():
# Put a prefix like _gw0, _gw1 etc on xdist processes
cache_settings['KEY_PREFIX'] = xdist_prefix + '_' + cache_settings.get('KEY_PREFIX', '') cache_settings['KEY_PREFIX'] = xdist_prefix + '_' + cache_settings.get('KEY_PREFIX', '')
logger.info('Set cache key prefix for [%s] cache to [%s]', name, cache_settings['KEY_PREFIX']) logger.info('Set cache key prefix for [%s] cache to [%s]', name, cache_settings['KEY_PREFIX'])
@pytest.fixture
def django_cache(django_cache_add_xdist_key_prefix): # pylint: disable=unused-argument
skip_if_no_django()
cache.clear()
yield cache yield cache
cache.clear() cache.clear()
@pytest.fixture(scope='session', autouse=True)
def haystack_add_xdist_suffix_to_index_name(request):
skip_if_no_django()
from django.conf import settings
xdist_suffix = getattr(request.config, 'slaveinput', {}).get('slaveid')
if xdist_suffix:
# Put a prefix like _gw0, _gw1 etc on xdist processes
for name, connection in settings.HAYSTACK_CONNECTIONS.items():
connection['INDEX_NAME'] = connection['INDEX_NAME'] + '_' + xdist_suffix
logger.info('Set index name for Haystack connection [%s] to [%s]', name, connection['INDEX_NAME'])
@pytest.fixture
def haystack_default_connection(haystack_add_xdist_suffix_to_index_name): # pylint: disable=unused-argument
skip_if_no_django()
backend = haystack_connections['default'].get_backend()
# Force Haystack to update the mapping for the index
backend.setup_complete = False
es = backend.conn
index_name = backend.index_name
ElasticsearchUtils.delete_index(es, index_name)
ElasticsearchUtils.create_alias_and_index(es, index_name)
ElasticsearchUtils.refresh_index(es, index_name)
yield backend
ElasticsearchUtils.delete_index(es, index_name)
@pytest.fixture
def site(db): # pylint: disable=unused-argument
skip_if_no_django()
from django.conf import settings
Site.objects.all().delete()
return SiteFactory(id=settings.SITE_ID, domain=TEST_DOMAIN)
@pytest.fixture
def partner(db, site): # pylint: disable=unused-argument
skip_if_no_django()
return PartnerFactory(site=site)
@pytest.fixture
def client():
skip_if_no_django()
return Client(SERVER_NAME=TEST_DOMAIN)
# pylint: disable=no-member # pylint: disable=no-member
import base64 import base64
import ddt import pytest
from django.core.files.base import ContentFile from django.core.files.base import ContentFile
from django.test import TestCase
from course_discovery.apps.api.fields import ImageField, StdImageSerializerField from course_discovery.apps.api.fields import ImageField, StdImageSerializerField
from course_discovery.apps.api.tests.test_serializers import make_request from course_discovery.apps.api.tests.test_serializers import make_request
...@@ -11,40 +10,27 @@ from course_discovery.apps.core.tests.helpers import make_image_file ...@@ -11,40 +10,27 @@ from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.tests.factories import ProgramFactory from course_discovery.apps.course_metadata.tests.factories import ProgramFactory
class ImageFieldTests(TestCase): @pytest.mark.django_db
def test_to_representation(self): def test_imagefield_to_representation():
value = 'https://example.com/image.jpg' value = 'https://example.com/image.jpg'
expected = { expected = {'src': value, 'description': None, 'height': None, 'width': None}
'src': value, assert ImageField().to_representation(value) == expected
'description': None,
'height': None,
'width': None
}
self.assertEqual(ImageField().to_representation(value), expected)
@ddt.ddt @pytest.mark.django_db
class StdImageSerializerFieldTests(TestCase): class TestStdImageSerializerField:
def test_to_representation(self): def test_to_representation(self):
request = make_request() request = make_request()
# TODO Create test-only model to avoid unnecessary dependency on Program model.
program = ProgramFactory(banner_image=make_image_file('test.jpg')) program = ProgramFactory(banner_image=make_image_file('test.jpg'))
field = StdImageSerializerField() field = StdImageSerializerField()
field._context = {'request': request} # pylint: disable=protected-access field._context = {'request': request} # pylint: disable=protected-access
expected = { expected = {
size_key: { size_key: {
'url': '{}{}'.format( 'url': '{}{}'.format('http://testserver', getattr(program.banner_image, size_key).url),
'http://testserver',
getattr(program.banner_image, size_key).url
),
'width': program.banner_image.field.variations[size_key]['width'], 'width': program.banner_image.field.variations[size_key]['width'],
'height': program.banner_image.field.variations[size_key]['height'] 'height': program.banner_image.field.variations[size_key]['height']
} } for size_key in program.banner_image.field.variations}
for size_key in program.banner_image.field.variations assert field.to_representation(program.banner_image) == expected
}
self.assertDictEqual(field.to_representation(program.banner_image), expected)
def test_to_internal_value(self): def test_to_internal_value(self):
base64_header = "data:image/jpeg;base64," base64_header = "data:image/jpeg;base64,"
...@@ -76,15 +62,9 @@ class StdImageSerializerFieldTests(TestCase): ...@@ -76,15 +62,9 @@ class StdImageSerializerFieldTests(TestCase):
"//Z" "//Z"
) )
base64_full = base64_header + base64_data base64_full = base64_header + base64_data
expected = ContentFile(base64.b64decode(base64_data), name='tmp.jpeg') expected = ContentFile(base64.b64decode(base64_data), name='tmp.jpeg')
assert list(StdImageSerializerField().to_internal_value(base64_full).chunks()) == list(expected.chunks())
# assert that the data chunks are equal @pytest.mark.parametrize('falsey_value', ("", False, None, []))
self.assertListEqual(
list(StdImageSerializerField().to_internal_value(base64_full).chunks()),
list(expected.chunks())
)
@ddt.data("", False, None, [])
def test_to_internal_value_falsey(self, falsey_value): def test_to_internal_value_falsey(self, falsey_value):
self.assertIsNone(StdImageSerializerField().to_internal_value(falsey_value)) assert StdImageSerializerField().to_internal_value(falsey_value) is None
import ddt
from django.test import TestCase
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView from rest_framework.views import APIView
from course_discovery.apps.api.filters import HaystackRequestFilterMixin from course_discovery.apps.api.filters import HaystackRequestFilterMixin
@ddt.ddt class TestHaystackRequestFilterMixin:
class HaystackRequestFilterMixinTests(TestCase):
def test_get_request_filters(self): def test_get_request_filters(self):
""" Verify the method removes query parameters with empty values """ """ Verify the method removes query parameters with empty values """
request = APIRequestFactory().get('/?q=') request = APIRequestFactory().get('/?q=')
request = APIView().initialize_request(request) request = APIView().initialize_request(request)
filters = HaystackRequestFilterMixin.get_request_filters(request) filters = HaystackRequestFilterMixin.get_request_filters(request)
self.assertDictEqual(filters, {}) assert filters == {}
def test_get_request_filters_with_list(self): def test_get_request_filters_with_list(self):
""" Verify the method does not affect list values. """ """ Verify the method does not affect list values. """
request = APIRequestFactory().get('/?q=&content_type=courserun&content_type=program') request = APIRequestFactory().get('/?q=&content_type=courserun&content_type=program')
request = APIView().initialize_request(request) request = APIView().initialize_request(request)
filters = HaystackRequestFilterMixin.get_request_filters(request) filters = HaystackRequestFilterMixin.get_request_filters(request)
self.assertNotIn('q', filters) assert 'q' not in filters
self.assertEqual(filters.getlist('content_type'), ['courserun', 'program']) assert filters.getlist('content_type') == ['courserun', 'program']
def test_get_request_filters_with_falsey_values(self): def test_get_request_filters_with_falsey_values(self):
""" Verify the method does not strip valid falsey values. """ """ Verify the method does not strip valid falsey values. """
request = APIRequestFactory().get('/?q=&test=0') request = APIRequestFactory().get('/?q=&test=0')
request = APIView().initialize_request(request) request = APIView().initialize_request(request)
filters = HaystackRequestFilterMixin.get_request_filters(request) filters = HaystackRequestFilterMixin.get_request_filters(request)
self.assertNotIn('q', filters) assert 'q' not in filters
self.assertEqual(filters.get('test'), '0') assert filters.get('test') == '0'
...@@ -4,6 +4,7 @@ import itertools ...@@ -4,6 +4,7 @@ import itertools
from urllib.parse import urlencode from urllib.parse import urlencode
import ddt import ddt
import pytest
import pytz import pytz
from django.test import TestCase from django.test import TestCase
from haystack.query import SearchQuerySet from haystack.query import SearchQuerySet
...@@ -1249,7 +1250,9 @@ class CourseRunSearchSerializerTests(ElasticsearchTestMixin, TestCase): ...@@ -1249,7 +1250,9 @@ class CourseRunSearchSerializerTests(ElasticsearchTestMixin, TestCase):
assert serializer.data['level_type'] is None assert serializer.data['level_type'] is None
class ProgramSearchSerializerTests(TestCase): @pytest.mark.django_db
@pytest.mark.usefixtures('haystack_default_connection')
class TestProgramSearchSerializer:
def _create_expected_data(self, program): def _create_expected_data(self, program):
return { return {
'uuid': str(program.uuid), 'uuid': str(program.uuid),
...@@ -1327,7 +1330,9 @@ class ProgramSearchSerializerTests(TestCase): ...@@ -1327,7 +1330,9 @@ class ProgramSearchSerializerTests(TestCase):
assert {'English', 'Chinese - Mandarin'} == {*expected['language']} assert {'English', 'Chinese - Mandarin'} == {*expected['language']}
class TypeaheadCourseRunSearchSerializerTests(TestCase): @pytest.mark.django_db
@pytest.mark.usefixtures('haystack_default_connection')
class TestTypeaheadCourseRunSearchSerializer:
def test_data(self): def test_data(self):
authoring_organization = OrganizationFactory() authoring_organization = OrganizationFactory()
course_run = CourseRunFactory(authoring_organizations=[authoring_organization]) course_run = CourseRunFactory(authoring_organizations=[authoring_organization])
...@@ -1348,7 +1353,9 @@ class TypeaheadCourseRunSearchSerializerTests(TestCase): ...@@ -1348,7 +1353,9 @@ class TypeaheadCourseRunSearchSerializerTests(TestCase):
return serializer return serializer
class TypeaheadProgramSearchSerializerTests(TestCase): @pytest.mark.django_db
@pytest.mark.usefixtures('haystack_default_connection')
class TestTypeaheadProgramSearchSerializer:
def _create_expected_data(self, program): def _create_expected_data(self, program):
return { return {
'uuid': str(program.uuid), 'uuid': str(program.uuid),
......
...@@ -15,7 +15,7 @@ from course_discovery.apps.api.serializers import ( ...@@ -15,7 +15,7 @@ from course_discovery.apps.api.serializers import (
from course_discovery.apps.api.tests.mixins import SiteMixin from course_discovery.apps.api.tests.mixins import SiteMixin
class SerializationMixin(object): class SerializationMixin:
def _get_request(self, format=None): def _get_request(self, format=None):
if getattr(self, 'request', None): if getattr(self, 'request', None):
return self.request return self.request
......
import urllib.parse import urllib.parse
import ddt import pytest
from django.core.cache import cache from django.test import RequestFactory
from django.urls import reverse from django.urls import reverse
from course_discovery.apps.api.serializers import MinimalProgramSerializer from course_discovery.apps.api.serializers import MinimalProgramSerializer
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.api.v1.views.programs import ProgramViewSet from course_discovery.apps.api.v1.views.programs import ProgramViewSet
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file from course_discovery.apps.core.tests.helpers import make_image_file
...@@ -17,18 +17,30 @@ from course_discovery.apps.course_metadata.tests.factories import ( ...@@ -17,18 +17,30 @@ from course_discovery.apps.course_metadata.tests.factories import (
) )
@ddt.ddt @pytest.mark.django_db
class ProgramViewSetTests(SerializationMixin, APITestCase): @pytest.mark.usefixtures('django_cache')
class TestProgramViewSet(SerializationMixin):
client = None
django_assert_num_queries = None
list_path = reverse('api:v1:program-list') list_path = reverse('api:v1:program-list')
partner = None
request = None
def setUp(self): @pytest.fixture(autouse=True)
super(ProgramViewSetTests, self).setUp() def setup(self, client, django_assert_num_queries, partner):
self.user = UserFactory(is_staff=True, is_superuser=True) user = UserFactory(is_staff=True, is_superuser=True)
self.request.user = self.user
self.client.login(username=self.user.username, password=USER_PASSWORD)
# Clear the cache between test cases, so they don't interfere with each other. client.login(username=user.username, password=USER_PASSWORD)
cache.clear()
site = partner.site
request = RequestFactory(SERVER_NAME=site.domain).get('')
request.site = site
request.user = user
self.client = client
self.django_assert_num_queries = django_assert_num_queries
self.partner = partner
self.request = request
def create_program(self): def create_program(self):
organizations = [OrganizationFactory(partner=self.partner)] organizations = [OrganizationFactory(partner=self.partner)]
...@@ -59,37 +71,37 @@ class ProgramViewSetTests(SerializationMixin, APITestCase): ...@@ -59,37 +71,37 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
url += '?' + urllib.parse.urlencode(querystring) url += '?' + urllib.parse.urlencode(querystring)
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) assert response.status_code == 200
return response return response
def test_authentication(self): def test_authentication(self):
""" Verify the endpoint requires the user to be authenticated. """ """ Verify the endpoint requires the user to be authenticated. """
response = self.client.get(self.list_path) response = self.client.get(self.list_path)
self.assertEqual(response.status_code, 200) assert response.status_code == 200
self.client.logout() self.client.logout()
response = self.client.get(self.list_path) response = self.client.get(self.list_path)
self.assertEqual(response.status_code, 403) assert response.status_code == 403
def test_retrieve(self): def test_retrieve(self, django_assert_num_queries):
""" Verify the endpoint returns the details for a single program. """ """ Verify the endpoint returns the details for a single program. """
program = self.create_program() program = self.create_program()
with self.assertNumQueries(39): with django_assert_num_queries(39):
response = self.assert_retrieve_success(program) response = self.assert_retrieve_success(program)
# property does not have the right values while being indexed # property does not have the right values while being indexed
del program._course_run_weeks_to_complete del program._course_run_weeks_to_complete
assert response.data == self.serialize_program(program) assert response.data == self.serialize_program(program)
# Verify that repeated retrieve requests use the cache. # Verify that repeated retrieve requests use the cache.
with self.assertNumQueries(4): with django_assert_num_queries(4):
self.assert_retrieve_success(program) self.assert_retrieve_success(program)
# Verify that requests including querystring parameters are cached separately. # Verify that requests including querystring parameters are cached separately.
response = self.assert_retrieve_success(program, querystring={'use_full_course_serializer': 1}) response = self.assert_retrieve_success(program, querystring={'use_full_course_serializer': 1})
assert response.data == self.serialize_program(program, extra_context={'use_full_course_serializer': 1}) assert response.data == self.serialize_program(program, extra_context={'use_full_course_serializer': 1})
@ddt.data(True, False) @pytest.mark.parametrize('order_courses_by_start_date', (True, False,))
def test_retrieve_with_sorting_flag(self, order_courses_by_start_date): def test_retrieve_with_sorting_flag(self, order_courses_by_start_date, django_assert_num_queries):
""" Verify the number of queries is the same with sorting flag set to true. """ """ Verify the number of queries is the same with sorting flag set to true. """
course_list = CourseFactory.create_batch(3, partner=self.partner) course_list = CourseFactory.create_batch(3, partner=self.partner)
for course in course_list: for course in course_list:
...@@ -100,16 +112,16 @@ class ProgramViewSetTests(SerializationMixin, APITestCase): ...@@ -100,16 +112,16 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
partner=self.partner) partner=self.partner)
# property does not have the right values while being indexed # property does not have the right values while being indexed
del program._course_run_weeks_to_complete del program._course_run_weeks_to_complete
with self.assertNumQueries(28): with django_assert_num_queries(28):
response = self.assert_retrieve_success(program) response = self.assert_retrieve_success(program)
assert response.data == self.serialize_program(program) assert response.data == self.serialize_program(program)
self.assertEqual(course_list, list(program.courses.all())) # pylint: disable=no-member assert course_list == list(program.courses.all()) # pylint: disable=no-member
def test_retrieve_without_course_runs(self): def test_retrieve_without_course_runs(self, django_assert_num_queries):
""" Verify the endpoint returns data for a program even if the program's courses have no course runs. """ """ Verify the endpoint returns data for a program even if the program's courses have no course runs. """
course = CourseFactory(partner=self.partner) course = CourseFactory(partner=self.partner)
program = ProgramFactory(courses=[course], partner=self.partner) program = ProgramFactory(courses=[course], partner=self.partner)
with self.assertNumQueries(22): with django_assert_num_queries(22):
response = self.assert_retrieve_success(program) response = self.assert_retrieve_success(program)
assert response.data == self.serialize_program(program) assert response.data == self.serialize_program(program)
...@@ -127,13 +139,10 @@ class ProgramViewSetTests(SerializationMixin, APITestCase): ...@@ -127,13 +139,10 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
Returns: Returns:
None None
""" """
with self.assertNumQueries(expected_query_count): with self.django_assert_num_queries(expected_query_count):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual( assert response.data['results'] == self.serialize_program(expected, many=True, extra_context=extra_context)
response.data['results'],
self.serialize_program(expected, many=True, extra_context=extra_context)
)
def test_list(self): def test_list(self):
""" Verify the endpoint returns a list of all programs. """ """ Verify the endpoint returns a list of all programs. """
...@@ -200,11 +209,13 @@ class ProgramViewSetTests(SerializationMixin, APITestCase): ...@@ -200,11 +209,13 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
self.assert_list_results(url, expected, 10) self.assert_list_results(url, expected, 10)
@ddt.data( @pytest.mark.parametrize(
(ProgramStatus.Unpublished, False, 5), 'status,is_marketable,expected_query_count',
(ProgramStatus.Active, True, 10), (
(ProgramStatus.Unpublished, False, 5),
(ProgramStatus.Active, True, 10),
)
) )
@ddt.unpack
def test_filter_by_marketable(self, status, is_marketable, expected_query_count): def test_filter_by_marketable(self, status, is_marketable, expected_query_count):
""" Verify the endpoint filters programs to those that are marketable. """ """ Verify the endpoint filters programs to those that are marketable. """
url = self.list_path + '?marketable=1' url = self.list_path + '?marketable=1'
...@@ -213,7 +224,7 @@ class ProgramViewSetTests(SerializationMixin, APITestCase): ...@@ -213,7 +224,7 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
programs.reverse() programs.reverse()
expected = programs if is_marketable else [] expected = programs if is_marketable else []
self.assertEqual(list(Program.objects.marketable()), expected) assert list(Program.objects.marketable()) == expected
self.assert_list_results(url, expected, expected_query_count) self.assert_list_results(url, expected, expected_query_count)
def test_filter_by_status(self): def test_filter_by_status(self):
...@@ -271,4 +282,4 @@ class ProgramViewSetTests(SerializationMixin, APITestCase): ...@@ -271,4 +282,4 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def test_minimal_serializer_use(self): def test_minimal_serializer_use(self):
""" Verify that the list view uses the minimal serializer. """ """ Verify that the list view uses the minimal serializer. """
self.assertEqual(ProgramViewSet(action='list').get_serializer_class(), MinimalProgramSerializer) assert ProgramViewSet(action='list').get_serializer_class() == MinimalProgramSerializer
...@@ -323,8 +323,9 @@ class AggregateSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT ...@@ -323,8 +323,9 @@ class AggregateSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT
def process_response(self, response): def process_response(self, response):
response = self.get_response(response).json() response = self.get_response(response).json()
self.assertTrue(response['objects']['count']) objects = response['objects']
return response['objects'] assert objects['count'] > 0
return objects
def test_results_only_include_published_objects(self): def test_results_only_include_published_objects(self):
""" Verify the search results only include items with status set to 'Published'. """ """ Verify the search results only include items with status set to 'Published'. """
......
import logging import logging
import pytest
from django.conf import settings from django.conf import settings
from haystack import connections as haystack_connections from haystack import connections as haystack_connections
...@@ -9,44 +10,13 @@ from course_discovery.apps.course_metadata.models import Course, CourseRun ...@@ -9,44 +10,13 @@ from course_discovery.apps.course_metadata.models import Course, CourseRun
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@pytest.mark.usefixtures('haystack_default_connection')
class ElasticsearchTestMixin(object): class ElasticsearchTestMixin(object):
@classmethod
def setUpClass(cls):
super(ElasticsearchTestMixin, cls).setUpClass()
cls.index = settings.HAYSTACK_CONNECTIONS['default']['INDEX_NAME']
# Make use of the changes in our custom ES backend
# This is required for typeahead autocomplete to work in the tests
connection = haystack_connections['default']
cls.backend = connection.get_backend()
# Without this line, haystack doesn't fully recreate the connection
# The first test using this backend succeeds, but the following tests
# do not set the Elasticsearch _mapping
def setUp(self): def setUp(self):
super(ElasticsearchTestMixin, self).setUp() super(ElasticsearchTestMixin, self).setUp()
self.backend.setup_complete = False self.index = settings.HAYSTACK_CONNECTIONS['default']['INDEX_NAME']
self.es = self.backend.conn connection = haystack_connections['default']
self.reset_index() self.es = connection.get_backend().conn
self.refresh_index()
def reset_index(self):
""" Deletes and re-creates the Elasticsearch index. """
self.delete_index(self.index)
ElasticsearchUtils.create_alias_and_index(self.es, self.index)
def delete_index(self, index):
"""
Deletes an index.
Args:
index (str): Name of index to delete
Returns:
None
"""
logger.info('Deleting index [%s]...', index)
self.es.indices.delete(index=index, ignore=404) # pylint: disable=unexpected-keyword-arg
logger.info('...index deleted.')
def refresh_index(self): def refresh_index(self):
""" """
...@@ -54,9 +24,7 @@ class ElasticsearchTestMixin(object): ...@@ -54,9 +24,7 @@ class ElasticsearchTestMixin(object):
https://www.elastic.co/guide/en/elasticsearch/reference/current/indices-refresh.html https://www.elastic.co/guide/en/elasticsearch/reference/current/indices-refresh.html
""" """
# pylint: disable=unexpected-keyword-arg ElasticsearchUtils.refresh_index(self.es, self.index)
self.es.indices.refresh(index=self.index)
self.es.cluster.health(index=self.index, wait_for_status='yellow', request_timeout=1)
def reindex_course_runs(self, course): def reindex_course_runs(self, course):
index = haystack_connections['default'].get_unified_index().get_index(CourseRun) index = haystack_connections['default'].get_unified_index().get_index(CourseRun)
......
...@@ -53,6 +53,24 @@ class ElasticsearchUtils(object): ...@@ -53,6 +53,24 @@ class ElasticsearchUtils(object):
logger.info('...index [%s] created.', index_name) logger.info('...index [%s] created.', index_name)
return index_name return index_name
@classmethod
def delete_index(cls, es_connection, index):
logger.info('Deleting index [%s]...', index)
es_connection.indices.delete(index=index, ignore=404) # pylint: disable=unexpected-keyword-arg
logger.info('...index deleted.')
@classmethod
def refresh_index(cls, es_connection, index):
"""
Refreshes the index.
https://www.elastic.co/guide/en/elasticsearch/reference/current/indices-refresh.html
"""
logger.info('Refreshing index [%s]...', index)
es_connection.indices.refresh(index=index)
es_connection.cluster.health(index=index, wait_for_status='yellow', request_timeout=1)
logger.info('...index refreshed.')
def get_all_related_field_names(model): def get_all_related_field_names(model):
""" """
......
import json import json
from urllib.parse import quote from urllib.parse import quote
import ddt import pytest
from django.test import TestCase from django.test import TestCase
from django.urls import reverse from django.urls import reverse
...@@ -16,128 +16,93 @@ from course_discovery.apps.publisher.tests import factories ...@@ -16,128 +16,93 @@ from course_discovery.apps.publisher.tests import factories
# pylint: disable=no-member # pylint: disable=no-member
@ddt.ddt @pytest.mark.django_db
class AutocompleteTests(SiteMixin, TestCase): class TestAutocomplete:
""" Tests for autocomplete lookups.""" def assert_valid_query_result(self, client, path, query, expected_result):
def setUp(self): """ Asserts a query made against the given endpoint returns the expected result. """
super(AutocompleteTests, self).setUp() response = client.get(path + '?q={q}'.format(q=query))
self.user = UserFactory(is_staff=True)
self.client.login(username=self.user.username, password=USER_PASSWORD)
self.courses = CourseFactory.create_batch(3, title='Some random course title')
for course in self.courses:
CourseRunFactory(course=course)
self.organizations = OrganizationFactory.create_batch(3)
first_instructor = PersonFactory(given_name="First Instructor")
second_instructor = PersonFactory(given_name="Second Instructor")
self.instructors = [first_instructor, second_instructor]
@ddt.data('dum', 'ing')
def test_course_autocomplete(self, search_key):
""" Verify course autocomplete returns the data. """
response = self.client.get(reverse('admin_metadata:course-autocomplete'))
data = json.loads(response.content.decode('utf-8')) data = json.loads(response.content.decode('utf-8'))
self.assertEqual(response.status_code, 200) assert len(data['results']) == 1
self.assertEqual(len(data['results']), 3) assert data['results'][0]['text'] == str(expected_result)
# update the first course title
self.courses[0].key = 'edx/dummy/key'
self.courses[0].title = 'this is some thing new'
self.courses[0].save()
response = self.client.get(
reverse('admin_metadata:course-autocomplete') + '?q={title}'.format(title=search_key)
)
data = json.loads(response.content.decode('utf-8'))
self.assertEqual(data['results'][0]['text'], str(self.courses[0]))
def test_course_autocomplete_un_authorize_user(self): def test_course_autocomplete(self, admin_client):
""" Verify course autocomplete returns empty list for un-authorized users. """ """ Verify course autocomplete returns the data. """
self._make_user_non_staff() courses = CourseFactory.create_batch(3)
response = self.client.get(reverse('admin_metadata:course-autocomplete')) path = reverse('admin_metadata:course-autocomplete')
data = json.loads(response.content.decode('utf-8')) response = admin_client.get(path)
self.assertEqual(data['results'], [])
@ddt.data('ing', 'dum')
def test_course_run_autocomplete(self, search_key):
""" Verify course run autocomplete returns the data. """
response = self.client.get(reverse('admin_metadata:course-run-autocomplete'))
data = json.loads(response.content.decode('utf-8')) data = json.loads(response.content.decode('utf-8'))
self.assertEqual(response.status_code, 200) assert response.status_code == 200
self.assertEqual(len(data['results']), 3) assert len(data['results']) == 3
# update the first course title
course = self.courses[0] # Search for substrings of course keys and titles
course.title = 'this is some thing new' course = courses[0]
course.save() self.assert_valid_query_result(admin_client, path, course.key[12:], course)
course_run = self.courses[0].course_runs.first() self.assert_valid_query_result(admin_client, path, course.title[12:], course)
course_run.key = 'edx/dummy/testrun'
course_run.save() def test_course_run_autocomplete(self, admin_client):
course_runs = CourseRunFactory.create_batch(3)
response = self.client.get( path = reverse('admin_metadata:course-run-autocomplete')
reverse('admin_metadata:course-run-autocomplete') + '?q={q}'.format(q=search_key) response = admin_client.get(path)
)
data = json.loads(response.content.decode('utf-8')) data = json.loads(response.content.decode('utf-8'))
self.assertEqual(data['results'][0]['text'], str(course_run)) assert response.status_code == 200
assert len(data['results']) == 3
def test_course_run_autocomplete_un_authorize_user(self): # Search for substrings of course run keys and titles
""" Verify course run autocomplete returns empty list for un-authorized users. """ course_run = course_runs[0]
self._make_user_non_staff() self.assert_valid_query_result(admin_client, path, course_run.key[14:], course_run)
response = self.client.get(reverse('admin_metadata:course-run-autocomplete')) self.assert_valid_query_result(admin_client, path, course_run.title[12:], course_run)
data = json.loads(response.content.decode('utf-8'))
self.assertEqual(data['results'], [])
@ddt.data('irc', 'ing') def test_organization_autocomplete(self, admin_client):
def test_organization_autocomplete(self, search_key):
""" Verify Organization autocomplete returns the data. """ """ Verify Organization autocomplete returns the data. """
response = self.client.get(reverse('admin_metadata:organisation-autocomplete')) organizations = OrganizationFactory.create_batch(3)
path = reverse('admin_metadata:organisation-autocomplete')
response = admin_client.get(path)
data = json.loads(response.content.decode('utf-8')) data = json.loads(response.content.decode('utf-8'))
self.assertEqual(response.status_code, 200) assert response.status_code == 200
self.assertEqual(len(data['results']), 3) assert len(data['results']) == 3
self.organizations[0].key = 'Mirco' # Search for substrings of organization keys and names
self.organizations[0].name = 'testing name' organization = organizations[0]
self.organizations[0].save() self.assert_valid_query_result(admin_client, path, organization.key[:3], organization)
self.assert_valid_query_result(admin_client, path, organization.name[:5], organization)
response = self.client.get( @pytest.mark.parametrize('view_prefix', ['organisation', 'course', 'course-run'])
reverse('admin_metadata:organisation-autocomplete') + '?q={key}'.format( def test_autocomplete_requires_staff_permission(self, view_prefix, client):
key=search_key """ Verify autocomplete returns empty list for non-staff users. """
)
)
data = json.loads(response.content.decode('utf-8'))
self.assertEqual(data['results'][0]['text'], str(self.organizations[0]))
self.assertEqual(len(data['results']), 1)
def test_organization_autocomplete_un_authorize_user(self): user = UserFactory(is_staff=False)
""" Verify Organization autocomplete returns empty list for un-authorized users. """ client.login(username=user.username, password=USER_PASSWORD)
self._make_user_non_staff() response = client.get(reverse('admin_metadata:{}-autocomplete'.format(view_prefix)))
response = self.client.get(reverse('admin_metadata:organisation-autocomplete'))
data = json.loads(response.content.decode('utf-8')) data = json.loads(response.content.decode('utf-8'))
self.assertEqual(data['results'], []) assert response.status_code == 200
assert data['results'] == []
def _make_user_non_staff(self):
self.client.logout()
self.user = UserFactory(is_staff=False)
self.user.save()
self.client.login(username=self.user.username, password=USER_PASSWORD)
@ddt.ddt
class AutoCompletePersonTests(SiteMixin, TestCase): class AutoCompletePersonTests(SiteMixin, TestCase):
""" """
Tests for person autocomplete lookups Tests for person autocomplete lookups
""" """
def setUp(self): def setUp(self):
super(AutoCompletePersonTests, self).setUp() super(AutoCompletePersonTests, self).setUp()
self.user = UserFactory(is_staff=True) self.user = UserFactory(is_staff=True)
self.client.login(username=self.user.username, password=USER_PASSWORD) self.client.login(username=self.user.username, password=USER_PASSWORD)
self.courses = factories.CourseFactory.create_batch(3, title='Some random course title') self.courses = factories.CourseFactory.create_batch(3, title='Some random course title')
for course in self.courses: for course in self.courses:
factories.CourseRunFactory(course=course) factories.CourseRunFactory(course=course)
self.organizations = OrganizationFactory.create_batch(3) self.organizations = OrganizationFactory.create_batch(3)
self.organization_extensions = [] self.organization_extensions = []
for organization in self.organizations: for organization in self.organizations:
self.organization_extensions.append(factories.OrganizationExtensionFactory(organization=organization)) self.organization_extensions.append(factories.OrganizationExtensionFactory(organization=organization))
self.user.groups.add(self.organization_extensions[0].group) self.user.groups.add(self.organization_extensions[0].group)
first_instructor = PersonFactory(given_name="First Instructor") first_instructor = PersonFactory(given_name="First Instructor")
second_instructor = PersonFactory(given_name="Second Instructor") second_instructor = PersonFactory(given_name="Second Instructor")
self.instructors = [first_instructor, second_instructor] self.instructors = [first_instructor, second_instructor]
for instructor in self.instructors: for instructor in self.instructors:
PositionFactory(organization=self.organizations[0], title="professor", person=instructor) PositionFactory(organization=self.organizations[0], title="professor", person=instructor)
...@@ -185,9 +150,9 @@ class AutoCompletePersonTests(SiteMixin, TestCase): ...@@ -185,9 +150,9 @@ class AutoCompletePersonTests(SiteMixin, TestCase):
def _assert_response(self, response, expected_length): def _assert_response(self, response, expected_length):
""" Assert autocomplete response. """ """ Assert autocomplete response. """
self.assertEqual(response.status_code, 200) assert response.status_code == 200
data = json.loads(response.content.decode('utf-8')) data = json.loads(response.content.decode('utf-8'))
self.assertEqual(len(data['results']), expected_length) assert len(data['results']) == expected_length
def test_instructor_autocomplete_with_uuid(self): def test_instructor_autocomplete_with_uuid(self):
""" Verify instructor autocomplete returns the data with valid uuid. """ """ Verify instructor autocomplete returns the data with valid uuid. """
...@@ -243,10 +208,10 @@ class AutoCompletePersonTests(SiteMixin, TestCase): ...@@ -243,10 +208,10 @@ class AutoCompletePersonTests(SiteMixin, TestCase):
reverse('admin_metadata:person-autocomplete') + '?q={q}'.format(q='ins'), reverse('admin_metadata:person-autocomplete') + '?q={q}'.format(q='ins'),
HTTP_REFERER=reverse('admin:publisher_courserun_add') HTTP_REFERER=reverse('admin:publisher_courserun_add')
) )
self.assertEqual(response.status_code, 200) assert response.status_code == 200
data = json.loads(response.content.decode('utf-8')) data = json.loads(response.content.decode('utf-8'))
expected_results = [{'id': instructor.id, 'text': str(instructor)} for instructor in self.instructors] expected_results = [{'id': instructor.id, 'text': str(instructor)} for instructor in self.instructors]
self.assertEqual(data.get('results'), expected_results) assert data.get('results') == expected_results
def _make_user_non_staff(self): def _make_user_non_staff(self):
self.client.logout() self.client.logout()
......
...@@ -4,23 +4,23 @@ from decimal import Decimal ...@@ -4,23 +4,23 @@ from decimal import Decimal
import ddt import ddt
import mock import mock
import pytest
from dateutil.parser import parse from dateutil.parser import parse
from django.conf import settings from django.conf import settings
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db import IntegrityError from django.db import IntegrityError
from django.db.models.functions import Lower
from django.test import TestCase from django.test import TestCase
from freezegun import freeze_time from freezegun import freeze_time
from course_discovery.apps.api.tests.mixins import SiteMixin from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.models import Currency from course_discovery.apps.core.models import Currency
from course_discovery.apps.core.tests.helpers import make_image_file from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.core.utils import SearchQuerySetWrapper from course_discovery.apps.core.utils import SearchQuerySetWrapper
from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus
from course_discovery.apps.course_metadata.models import ( from course_discovery.apps.course_metadata.models import (
FAQ, AbstractMediaModel, AbstractNamedModel, AbstractValueModel, CorporateEndorsement, Course, CourseRun, FAQ, AbstractMediaModel, AbstractNamedModel, AbstractValueModel, CorporateEndorsement, Course, CourseRun,
Endorsement, Seat, SeatType, Subject) Endorsement, Seat, SeatType, Subject
)
from course_discovery.apps.course_metadata.publishers import ( from course_discovery.apps.course_metadata.publishers import (
CourseRunMarketingSitePublisher, ProgramMarketingSitePublisher CourseRunMarketingSitePublisher, ProgramMarketingSitePublisher
) )
...@@ -31,34 +31,17 @@ from course_discovery.apps.ietf_language_tags.models import LanguageTag ...@@ -31,34 +31,17 @@ from course_discovery.apps.ietf_language_tags.models import LanguageTag
# pylint: disable=no-member # pylint: disable=no-member
@pytest.mark.django_db
class CourseTests(ElasticsearchTestMixin, TestCase): class TestCourse:
""" Tests for the `Course` model. """
def setUp(self):
super(CourseTests, self).setUp()
self.course = factories.CourseFactory()
def test_str(self): def test_str(self):
""" Verify casting an instance to a string returns a string containing the key and title. """ course = factories.CourseFactory()
self.assertEqual(str(self.course), '{key}: {title}'.format(key=self.course.key, title=self.course.title)) assert str(course), '{key}: {title}'.format(key=course.key, title=course.title)
def test_search(self): def test_search(self, haystack_default_connection): # pylint: disable=unused-argument
""" Verify the method returns a filtered queryset of courses. """
toggle_switch('log_course_search_queries', active=True)
title = 'Some random title' title = 'Some random title'
courses = factories.CourseFactory.create_batch(3, title=title) expected = set(factories.CourseFactory.create_batch(3, title=title))
# Sort lowercase keys to prevent different sort orders due to casing.
# For example, sorted(['a', 'Z']) gives ['Z', 'a'], but an ordered
# queryset containing the same values may give ['a', 'Z'] depending
# on the database backend in use.
courses = sorted(courses, key=lambda course: course.key.lower())
query = 'title:' + title query = 'title:' + title
# Use Lower() to force a case-insensitive sort. assert set(Course.search(query)) == expected
actual = list(Course.search(query).order_by(Lower('key')))
self.assertEqual(actual, courses)
def test_image_url(self): def test_image_url(self):
course = factories.CourseFactory() course = factories.CourseFactory()
......
...@@ -2,12 +2,10 @@ import datetime ...@@ -2,12 +2,10 @@ import datetime
import mock import mock
import pytest import pytest
from django.test import TestCase
from haystack.backends import SQ from haystack.backends import SQ
from haystack.backends.elasticsearch_backend import ElasticsearchSearchQuery from haystack.backends.elasticsearch_backend import ElasticsearchSearchQuery
from haystack.query import SearchQuerySet from haystack.query import SearchQuerySet
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.models import CourseRun from course_discovery.apps.course_metadata.models import CourseRun
from course_discovery.apps.course_metadata.tests.factories import CourseFactory, CourseRunFactory from course_discovery.apps.course_metadata.tests.factories import CourseFactory, CourseRunFactory
from course_discovery.apps.edx_haystack_extensions.distinct_counts.backends import ( from course_discovery.apps.edx_haystack_extensions.distinct_counts.backends import (
...@@ -16,7 +14,9 @@ from course_discovery.apps.edx_haystack_extensions.distinct_counts.backends impo ...@@ -16,7 +14,9 @@ from course_discovery.apps.edx_haystack_extensions.distinct_counts.backends impo
# pylint: disable=protected-access # pylint: disable=protected-access
class DistinctCountsSearchQueryTests(ElasticsearchTestMixin, TestCase): @pytest.mark.django_db
@pytest.mark.usefixtures('haystack_default_connection')
class TestDistinctCountsSearchQuery:
def test_clone(self): def test_clone(self):
""" Verify that clone copies all fields, including the aggregation_key and distinct_hit_count.""" """ Verify that clone copies all fields, including the aggregation_key and distinct_hit_count."""
query = DistinctCountsSearchQuery() query = DistinctCountsSearchQuery()
...@@ -234,7 +234,9 @@ class DistinctCountsSearchQueryTests(ElasticsearchTestMixin, TestCase): ...@@ -234,7 +234,9 @@ class DistinctCountsSearchQueryTests(ElasticsearchTestMixin, TestCase):
assert query.facets['pacing_type_exact']['size'] == 5 assert query.facets['pacing_type_exact']['size'] == 5
class DistinctCountsElasticsearchBackendWrapperTests(ElasticsearchTestMixin, TestCase): @pytest.mark.django_db
@pytest.mark.usefixtures('haystack_default_connection')
class TestDistinctCountsElasticsearchBackendWrapper:
def test_search_raises_when_called_with_date_facet(self): def test_search_raises_when_called_with_date_facet(self):
now = datetime.datetime.now() now = datetime.datetime.now()
one_day = datetime.timedelta(days=1) one_day = datetime.timedelta(days=1)
......
import datetime import datetime
import mock
import mock
import pytest import pytest
from django.test import TestCase
from haystack.query import SearchQuerySet from haystack.query import SearchQuerySet
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.models import CourseRun from course_discovery.apps.course_metadata.models import CourseRun
from course_discovery.apps.course_metadata.tests.factories import CourseFactory, CourseRunFactory from course_discovery.apps.course_metadata.tests.factories import CourseFactory, CourseRunFactory
from course_discovery.apps.edx_haystack_extensions.distinct_counts.backends import DistinctCountsSearchQuery from course_discovery.apps.edx_haystack_extensions.distinct_counts.backends import DistinctCountsSearchQuery
from course_discovery.apps.edx_haystack_extensions.distinct_counts.query import DistinctCountsSearchQuerySet from course_discovery.apps.edx_haystack_extensions.distinct_counts.query import DistinctCountsSearchQuerySet
class DistinctCountsSearchQuerySetTests(ElasticsearchTestMixin, TestCase): @pytest.mark.django_db
@pytest.mark.usefixtures('haystack_default_connection')
class TestDistinctCountsSearchQuerySet:
def test_from_queryset(self): def test_from_queryset(self):
""" Verify that a DistinctCountsSearchQuerySet can be built from an existing SearchQuerySet.""" """ Verify that a DistinctCountsSearchQuerySet can be built from an existing SearchQuerySet."""
course_1 = CourseFactory() course_1 = CourseFactory()
......
import datetime import datetime
import ddt import pytest
import pytz import pytz
from django.test import TestCase
from haystack.query import SearchQuerySet from haystack.query import SearchQuerySet
from mock import patch from mock import patch
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.models import CourseRun, Program, ProgramType from course_discovery.apps.course_metadata.models import CourseRun, Program, ProgramType
from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory, ProgramFactory from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory, ProgramFactory
@ddt.ddt @pytest.mark.django_db
class TestSearchBoosting(ElasticsearchTestMixin, TestCase): @pytest.mark.usefixtures('haystack_default_connection')
class TestSearchBoosting:
def build_normalized_course_run(self, **kwargs): def build_normalized_course_run(self, **kwargs):
"""Builds a CourseRun with fields set to normalize boosting behavior.""" """Builds a CourseRun with fields set to normalize boosting behavior."""
defaults = { defaults = {
...@@ -35,7 +34,7 @@ class TestSearchBoosting(ElasticsearchTestMixin, TestCase): ...@@ -35,7 +34,7 @@ class TestSearchBoosting(ElasticsearchTestMixin, TestCase):
assert search_results[0].score > search_results[1].score assert search_results[0].score > search_results[1].score
assert test_record.pacing_type == search_results[0].pacing_type assert test_record.pacing_type == search_results[0].pacing_type
@ddt.data('MicroMasters', 'Professional Certificate') @pytest.mark.parametrize('program_type', ['MicroMasters', 'Professional Certificate'])
def test_program_type_boosting(self, program_type): def test_program_type_boosting(self, program_type):
"""Verify MicroMasters and Professional Certificate are boosted over XSeries.""" """Verify MicroMasters and Professional Certificate are boosted over XSeries."""
ProgramFactory(type=ProgramType.objects.get(name='XSeries')) ProgramFactory(type=ProgramType.objects.get(name='XSeries'))
...@@ -55,20 +54,23 @@ class TestSearchBoosting(ElasticsearchTestMixin, TestCase): ...@@ -55,20 +54,23 @@ class TestSearchBoosting(ElasticsearchTestMixin, TestCase):
search_results = SearchQuerySet().models(CourseRun).all() search_results = SearchQuerySet().models(CourseRun).all()
assert len(search_results) == 2 assert len(search_results) == 2
assert search_results[0].score > search_results[1].score assert search_results[0].score > search_results[1].score
assert int(test_record.start.timestamp()) == int(search_results[0].start.timestamp()) # pylint: disable=no-member # pylint: disable=no-member
assert int(test_record.start.timestamp()) == int(search_results[0].start.timestamp())
@ddt.data(
# Should not get boost if has_enrollable_paid_seats is False, has_enrollable_paid_seats @pytest.mark.parametrize(
# is None, or paid_seat_enrollment_end is in the past. 'has_enrollable_paid_seats,paid_seat_enrollment_end,expects_boost',
(False, None, False), [
(None, None, False), # Should not get boost if has_enrollable_paid_seats is False, has_enrollable_paid_seats
(True, datetime.datetime.now(pytz.timezone('utc')) - datetime.timedelta(days=15), False), # is None, or paid_seat_enrollment_end is in the past.
# Should get boost if has_enrollable_paid_seats is True and paid_seat_enrollment_end (False, None, False),
# is None or in the future. (None, None, False),
(True, None, True), (True, datetime.datetime.now(pytz.timezone('utc')) - datetime.timedelta(days=15), False),
(True, datetime.datetime.now(pytz.timezone('utc')) + datetime.timedelta(days=15), True), # Should get boost if has_enrollable_paid_seats is True and paid_seat_enrollment_end
# is None or in the future.
(True, None, True),
(True, datetime.datetime.now(pytz.timezone('utc')) + datetime.timedelta(days=15), True),
]
) )
@ddt.unpack
def test_enrollable_paid_seat_boosting(self, has_enrollable_paid_seats, paid_seat_enrollment_end, expects_boost): def test_enrollable_paid_seat_boosting(self, has_enrollable_paid_seats, paid_seat_enrollment_end, expects_boost):
""" """
Verify that CourseRuns for which an unenrolled user may enroll and Verify that CourseRuns for which an unenrolled user may enroll and
...@@ -123,25 +125,27 @@ class TestSearchBoosting(ElasticsearchTestMixin, TestCase): ...@@ -123,25 +125,27 @@ class TestSearchBoosting(ElasticsearchTestMixin, TestCase):
# negative relevance score. # negative relevance score.
assert 0 > search_results[1].score assert 0 > search_results[1].score
@ddt.data( @pytest.mark.parametrize(
# Should get boost if enrollment_start and enrollment_end unspecified. 'enrollment_start,enrollment_end,expects_boost',
(None, None, True), [
# Should get boost if enrollment_start unspecified and enrollment_end in future. # Should get boost if enrollment_start and enrollment_end unspecified.
(None, datetime.datetime.now(pytz.timezone('utc')) + datetime.timedelta(days=15), True), (None, None, True),
# Should get boost if enrollment_start in past and enrollment_end unspecified. # Should get boost if enrollment_start unspecified and enrollment_end in future.
(datetime.datetime.now(pytz.timezone('utc')) - datetime.timedelta(days=15), None, True), (None, datetime.datetime.now(pytz.timezone('utc')) + datetime.timedelta(days=15), True),
# Should get boost if enrollment_start in past and enrollment_end in future. # Should get boost if enrollment_start in past and enrollment_end unspecified.
( (datetime.datetime.now(pytz.timezone('utc')) - datetime.timedelta(days=15), None, True),
datetime.datetime.now(pytz.timezone('utc')) - datetime.timedelta(days=15), # Should get boost if enrollment_start in past and enrollment_end in future.
datetime.datetime.now(pytz.timezone('utc')) + datetime.timedelta(days=15), (
True datetime.datetime.now(pytz.timezone('utc')) - datetime.timedelta(days=15),
), datetime.datetime.now(pytz.timezone('utc')) + datetime.timedelta(days=15),
# Should not get boost if enrollment_start in future. True
(datetime.datetime.now(pytz.timezone('utc')) + datetime.timedelta(days=15), None, False), ),
# Should not get boost if enrollment_end in past. # Should not get boost if enrollment_start in future.
(None, datetime.datetime.now(pytz.timezone('utc')) - datetime.timedelta(days=15), False), (datetime.datetime.now(pytz.timezone('utc')) + datetime.timedelta(days=15), None, False),
# Should not get boost if enrollment_end in past.
(None, datetime.datetime.now(pytz.timezone('utc')) - datetime.timedelta(days=15), False),
]
) )
@ddt.unpack
def test_enrollable_course_run_boosting(self, enrollment_start, enrollment_end, expects_boost): def test_enrollable_course_run_boosting(self, enrollment_start, enrollment_end, expects_boost):
"""Verify that enrollable CourseRuns are boosted.""" """Verify that enrollable CourseRuns are boosted."""
......
...@@ -2,40 +2,38 @@ import datetime ...@@ -2,40 +2,38 @@ import datetime
from django.conf import settings from django.conf import settings
from django.core.management import call_command from django.core.management import call_command
from django.test import TestCase
from freezegun import freeze_time from freezegun import freeze_time
from course_discovery.apps.core.utils import ElasticsearchUtils from course_discovery.apps.core.utils import ElasticsearchUtils
from course_discovery.apps.edx_haystack_extensions.tests.mixins import SearchIndexTestMixin
class RemoveUnusedIndexesTests(SearchIndexTestMixin, TestCase): class TestRemoveUnusedIndexes:
backend = None backend = None
def test_handle(self): def test_handle(self, haystack_default_connection):
""" Verify the command removes all but the newest indexes. """ """ Verify the command removes all but the newest indexes. """
backend = haystack_default_connection
# Create initial index with alias
ElasticsearchUtils.create_alias_and_index(es_connection=self.backend.conn, alias=self.backend.index_name)
# Use now as initial time, so indexes are created AFTER the current index so expected values are accurate # Use now as initial time, so indexes are created AFTER the current index so expected values are accurate
initial_time = datetime.datetime.now() initial_time = datetime.datetime.now()
# Create 2 more indexes than we expect to exist after removal # Create 2 more indexes than we expect to exist after removal
for number in range(1, settings.HAYSTACK_INDEX_RETENTION_LIMIT + 2): for number in range(1, settings.HAYSTACK_INDEX_RETENTION_LIMIT + 2):
current_time = initial_time + datetime.timedelta(seconds=number) current_time = initial_time + datetime.timedelta(seconds=number)
freezer = freeze_time(current_time) freezer = freeze_time(current_time)
freezer.start() freezer.start()
ElasticsearchUtils.create_index(es_connection=self.backend.conn, prefix=self.backend.index_name) ElasticsearchUtils.create_index(es_connection=backend.conn, prefix=backend.index_name)
freezer.stop() freezer.stop()
# Prune indexes and confirm the right indexes are removed # Prune indexes and confirm the right indexes are removed
call_command('remove_unused_indexes') call_command('remove_unused_indexes')
current_alias_name = self.backend.index_name current_alias_name = backend.index_name
indices_client = self.backend.conn.indices indices_client = backend.conn.indices
current_alias = indices_client.get_alias(name=current_alias_name) current_alias = indices_client.get_alias(name=current_alias_name)
indexes_to_keep = current_alias.keys() indexes_to_keep = current_alias.keys()
# check that we keep the current indexes, which we don't want removed # check that we keep the current indexes, which we don't want removed
all_indexes = self.get_current_index_names(indices_client=indices_client, index_prefix=self.backend.index_name) all_indexes = self.get_current_index_names(indices_client=indices_client, index_prefix=backend.index_name)
assert set(all_indexes).issuperset(set(indexes_to_keep)) assert set(all_indexes).issuperset(set(indexes_to_keep))
# check that other indexes are removed, excepting those that don't hit the retention limit # check that other indexes are removed, excepting those that don't hit the retention limit
...@@ -46,12 +44,15 @@ class RemoveUnusedIndexesTests(SearchIndexTestMixin, TestCase): ...@@ -46,12 +44,15 @@ class RemoveUnusedIndexesTests(SearchIndexTestMixin, TestCase):
call_command('remove_unused_indexes') call_command('remove_unused_indexes')
# check that we keep the current indexes, which we don't want removed # check that we keep the current indexes, which we don't want removed
all_indexes = self.get_current_index_names(indices_client=indices_client, index_prefix=self.backend.index_name) all_indexes = self.get_current_index_names(indices_client=indices_client, index_prefix=backend.index_name)
assert set(all_indexes).issuperset(set(indexes_to_keep)) assert set(all_indexes).issuperset(set(indexes_to_keep))
# check that index count remains the same as before # check that index count remains the same as before
assert len(all_indexes) == expected_count assert len(all_indexes) == expected_count
# Cleanup indexes created by this test
backend.conn.indices.delete(index=backend.index_name + '_*')
@staticmethod @staticmethod
def get_current_index_names(indices_client, index_prefix): def get_current_index_names(indices_client, index_prefix):
all_index_settings = indices_client.get_settings() all_index_settings = indices_client.get_settings()
......
...@@ -11,7 +11,8 @@ lxml==3.6.1 ...@@ -11,7 +11,8 @@ lxml==3.6.1
mock==2.0.0 mock==2.0.0
pep8==1.7.0 pep8==1.7.0
pytest==3.0.6 pytest==3.0.6
pytest-django==3.1.2 # See https://github.com/pytest-dev/pytest-django/issues/473
git+https://github.com/pytest-dev/pytest-django@de31fab4122cfc53e7116b499235b610366d941a#egg=pytest-django==3.1.3.dev47+gde31fab
pytest-django-ordering==1.0.1 pytest-django-ordering==1.0.1
pytest-responses==0.3.0 pytest-responses==0.3.0
responses==0.7.0 responses==0.7.0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment