Commit ad1dca56 by Clinton Blackburn

Converted tests to py.test

parent 93eb5c0f
[settings]
line_length=120
multi_line_output=5
import logging
import pytest
from django.contrib.sites.models import Site
from django.core.cache import cache
from django.test.client import Client
from haystack import connections as haystack_connections
from pytest_django.lazy_django import skip_if_no_django
from course_discovery.apps.core.tests.factories import PartnerFactory, SiteFactory
from course_discovery.apps.core.utils import ElasticsearchUtils
logger = logging.getLogger(__name__)
TEST_DOMAIN = 'testserver.fake'
@pytest.fixture
def django_cache(request, settings):
@pytest.fixture(scope='session', autouse=True)
def django_cache_add_xdist_key_prefix(request):
skip_if_no_django()
from django.conf import settings
xdist_prefix = getattr(request.config, 'slaveinput', {}).get('slaveid')
if xdist_prefix:
# Put a prefix like gw0_, gw1_ etc on xdist processes
for name, cache_settings in settings.CACHES.items():
# Put a prefix like _gw0, _gw1 etc on xdist processes
cache_settings['KEY_PREFIX'] = xdist_prefix + '_' + cache_settings.get('KEY_PREFIX', '')
logger.info('Set cache key prefix for [%s] cache to [%s]', name, cache_settings['KEY_PREFIX'])
@pytest.fixture
def django_cache(django_cache_add_xdist_key_prefix): # pylint: disable=unused-argument
skip_if_no_django()
cache.clear()
yield cache
cache.clear()
@pytest.fixture(scope='session', autouse=True)
def haystack_add_xdist_suffix_to_index_name(request):
skip_if_no_django()
from django.conf import settings
xdist_suffix = getattr(request.config, 'slaveinput', {}).get('slaveid')
if xdist_suffix:
# Put a prefix like _gw0, _gw1 etc on xdist processes
for name, connection in settings.HAYSTACK_CONNECTIONS.items():
connection['INDEX_NAME'] = connection['INDEX_NAME'] + '_' + xdist_suffix
logger.info('Set index name for Haystack connection [%s] to [%s]', name, connection['INDEX_NAME'])
@pytest.fixture
def haystack_default_connection(haystack_add_xdist_suffix_to_index_name): # pylint: disable=unused-argument
skip_if_no_django()
backend = haystack_connections['default'].get_backend()
# Force Haystack to update the mapping for the index
backend.setup_complete = False
es = backend.conn
index_name = backend.index_name
ElasticsearchUtils.delete_index(es, index_name)
ElasticsearchUtils.create_alias_and_index(es, index_name)
ElasticsearchUtils.refresh_index(es, index_name)
yield backend
ElasticsearchUtils.delete_index(es, index_name)
@pytest.fixture
def site(db): # pylint: disable=unused-argument
skip_if_no_django()
from django.conf import settings
Site.objects.all().delete()
return SiteFactory(id=settings.SITE_ID, domain=TEST_DOMAIN)
@pytest.fixture
def partner(db, site): # pylint: disable=unused-argument
skip_if_no_django()
return PartnerFactory(site=site)
@pytest.fixture
def client():
skip_if_no_django()
return Client(SERVER_NAME=TEST_DOMAIN)
# pylint: disable=no-member
import base64
import ddt
import pytest
from django.core.files.base import ContentFile
from django.test import TestCase
from course_discovery.apps.api.fields import ImageField, StdImageSerializerField
from course_discovery.apps.api.tests.test_serializers import make_request
......@@ -11,40 +10,27 @@ from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.tests.factories import ProgramFactory
class ImageFieldTests(TestCase):
def test_to_representation(self):
value = 'https://example.com/image.jpg'
expected = {
'src': value,
'description': None,
'height': None,
'width': None
}
self.assertEqual(ImageField().to_representation(value), expected)
@pytest.mark.django_db
def test_imagefield_to_representation():
value = 'https://example.com/image.jpg'
expected = {'src': value, 'description': None, 'height': None, 'width': None}
assert ImageField().to_representation(value) == expected
@ddt.ddt
class StdImageSerializerFieldTests(TestCase):
@pytest.mark.django_db
class TestStdImageSerializerField:
def test_to_representation(self):
request = make_request()
# TODO Create test-only model to avoid unnecessary dependency on Program model.
program = ProgramFactory(banner_image=make_image_file('test.jpg'))
field = StdImageSerializerField()
field._context = {'request': request} # pylint: disable=protected-access
expected = {
size_key: {
'url': '{}{}'.format(
'http://testserver',
getattr(program.banner_image, size_key).url
),
'url': '{}{}'.format('http://testserver', getattr(program.banner_image, size_key).url),
'width': program.banner_image.field.variations[size_key]['width'],
'height': program.banner_image.field.variations[size_key]['height']
}
for size_key in program.banner_image.field.variations
}
self.assertDictEqual(field.to_representation(program.banner_image), expected)
} for size_key in program.banner_image.field.variations}
assert field.to_representation(program.banner_image) == expected
def test_to_internal_value(self):
base64_header = "data:image/jpeg;base64,"
......@@ -76,15 +62,9 @@ class StdImageSerializerFieldTests(TestCase):
"//Z"
)
base64_full = base64_header + base64_data
expected = ContentFile(base64.b64decode(base64_data), name='tmp.jpeg')
assert list(StdImageSerializerField().to_internal_value(base64_full).chunks()) == list(expected.chunks())
# assert that the data chunks are equal
self.assertListEqual(
list(StdImageSerializerField().to_internal_value(base64_full).chunks()),
list(expected.chunks())
)
@ddt.data("", False, None, [])
@pytest.mark.parametrize('falsey_value', ("", False, None, []))
def test_to_internal_value_falsey(self, falsey_value):
self.assertIsNone(StdImageSerializerField().to_internal_value(falsey_value))
assert StdImageSerializerField().to_internal_value(falsey_value) is None
import ddt
from django.test import TestCase
from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView
from course_discovery.apps.api.filters import HaystackRequestFilterMixin
@ddt.ddt
class HaystackRequestFilterMixinTests(TestCase):
class TestHaystackRequestFilterMixin:
def test_get_request_filters(self):
""" Verify the method removes query parameters with empty values """
request = APIRequestFactory().get('/?q=')
request = APIView().initialize_request(request)
filters = HaystackRequestFilterMixin.get_request_filters(request)
self.assertDictEqual(filters, {})
assert filters == {}
def test_get_request_filters_with_list(self):
""" Verify the method does not affect list values. """
request = APIRequestFactory().get('/?q=&content_type=courserun&content_type=program')
request = APIView().initialize_request(request)
filters = HaystackRequestFilterMixin.get_request_filters(request)
self.assertNotIn('q', filters)
self.assertEqual(filters.getlist('content_type'), ['courserun', 'program'])
assert 'q' not in filters
assert filters.getlist('content_type') == ['courserun', 'program']
def test_get_request_filters_with_falsey_values(self):
""" Verify the method does not strip valid falsey values. """
request = APIRequestFactory().get('/?q=&test=0')
request = APIView().initialize_request(request)
filters = HaystackRequestFilterMixin.get_request_filters(request)
self.assertNotIn('q', filters)
self.assertEqual(filters.get('test'), '0')
assert 'q' not in filters
assert filters.get('test') == '0'
......@@ -4,6 +4,7 @@ import itertools
from urllib.parse import urlencode
import ddt
import pytest
import pytz
from django.test import TestCase
from haystack.query import SearchQuerySet
......@@ -1249,7 +1250,9 @@ class CourseRunSearchSerializerTests(ElasticsearchTestMixin, TestCase):
assert serializer.data['level_type'] is None
class ProgramSearchSerializerTests(TestCase):
@pytest.mark.django_db
@pytest.mark.usefixtures('haystack_default_connection')
class TestProgramSearchSerializer:
def _create_expected_data(self, program):
return {
'uuid': str(program.uuid),
......@@ -1327,7 +1330,9 @@ class ProgramSearchSerializerTests(TestCase):
assert {'English', 'Chinese - Mandarin'} == {*expected['language']}
class TypeaheadCourseRunSearchSerializerTests(TestCase):
@pytest.mark.django_db
@pytest.mark.usefixtures('haystack_default_connection')
class TestTypeaheadCourseRunSearchSerializer:
def test_data(self):
authoring_organization = OrganizationFactory()
course_run = CourseRunFactory(authoring_organizations=[authoring_organization])
......@@ -1348,7 +1353,9 @@ class TypeaheadCourseRunSearchSerializerTests(TestCase):
return serializer
class TypeaheadProgramSearchSerializerTests(TestCase):
@pytest.mark.django_db
@pytest.mark.usefixtures('haystack_default_connection')
class TestTypeaheadProgramSearchSerializer:
def _create_expected_data(self, program):
return {
'uuid': str(program.uuid),
......
......@@ -15,7 +15,7 @@ from course_discovery.apps.api.serializers import (
from course_discovery.apps.api.tests.mixins import SiteMixin
class SerializationMixin(object):
class SerializationMixin:
def _get_request(self, format=None):
if getattr(self, 'request', None):
return self.request
......
import urllib.parse
import ddt
from django.core.cache import cache
import pytest
from django.test import RequestFactory
from django.urls import reverse
from course_discovery.apps.api.serializers import MinimalProgramSerializer
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.api.v1.views.programs import ProgramViewSet
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file
......@@ -17,18 +17,30 @@ from course_discovery.apps.course_metadata.tests.factories import (
)
@ddt.ddt
class ProgramViewSetTests(SerializationMixin, APITestCase):
@pytest.mark.django_db
@pytest.mark.usefixtures('django_cache')
class TestProgramViewSet(SerializationMixin):
client = None
django_assert_num_queries = None
list_path = reverse('api:v1:program-list')
partner = None
request = None
def setUp(self):
super(ProgramViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True)
self.request.user = self.user
self.client.login(username=self.user.username, password=USER_PASSWORD)
@pytest.fixture(autouse=True)
def setup(self, client, django_assert_num_queries, partner):
user = UserFactory(is_staff=True, is_superuser=True)
# Clear the cache between test cases, so they don't interfere with each other.
cache.clear()
client.login(username=user.username, password=USER_PASSWORD)
site = partner.site
request = RequestFactory(SERVER_NAME=site.domain).get('')
request.site = site
request.user = user
self.client = client
self.django_assert_num_queries = django_assert_num_queries
self.partner = partner
self.request = request
def create_program(self):
organizations = [OrganizationFactory(partner=self.partner)]
......@@ -59,37 +71,37 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
url += '?' + urllib.parse.urlencode(querystring)
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
assert response.status_code == 200
return response
def test_authentication(self):
""" Verify the endpoint requires the user to be authenticated. """
response = self.client.get(self.list_path)
self.assertEqual(response.status_code, 200)
assert response.status_code == 200
self.client.logout()
response = self.client.get(self.list_path)
self.assertEqual(response.status_code, 403)
assert response.status_code == 403
def test_retrieve(self):
def test_retrieve(self, django_assert_num_queries):
""" Verify the endpoint returns the details for a single program. """
program = self.create_program()
with self.assertNumQueries(39):
with django_assert_num_queries(39):
response = self.assert_retrieve_success(program)
# property does not have the right values while being indexed
del program._course_run_weeks_to_complete
assert response.data == self.serialize_program(program)
# Verify that repeated retrieve requests use the cache.
with self.assertNumQueries(4):
with django_assert_num_queries(4):
self.assert_retrieve_success(program)
# Verify that requests including querystring parameters are cached separately.
response = self.assert_retrieve_success(program, querystring={'use_full_course_serializer': 1})
assert response.data == self.serialize_program(program, extra_context={'use_full_course_serializer': 1})
@ddt.data(True, False)
def test_retrieve_with_sorting_flag(self, order_courses_by_start_date):
@pytest.mark.parametrize('order_courses_by_start_date', (True, False,))
def test_retrieve_with_sorting_flag(self, order_courses_by_start_date, django_assert_num_queries):
""" Verify the number of queries is the same with sorting flag set to true. """
course_list = CourseFactory.create_batch(3, partner=self.partner)
for course in course_list:
......@@ -100,16 +112,16 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
partner=self.partner)
# property does not have the right values while being indexed
del program._course_run_weeks_to_complete
with self.assertNumQueries(28):
with django_assert_num_queries(28):
response = self.assert_retrieve_success(program)
assert response.data == self.serialize_program(program)
self.assertEqual(course_list, list(program.courses.all())) # pylint: disable=no-member
assert course_list == list(program.courses.all()) # pylint: disable=no-member
def test_retrieve_without_course_runs(self):
def test_retrieve_without_course_runs(self, django_assert_num_queries):
""" Verify the endpoint returns data for a program even if the program's courses have no course runs. """
course = CourseFactory(partner=self.partner)
program = ProgramFactory(courses=[course], partner=self.partner)
with self.assertNumQueries(22):
with django_assert_num_queries(22):
response = self.assert_retrieve_success(program)
assert response.data == self.serialize_program(program)
......@@ -127,13 +139,10 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
Returns:
None
"""
with self.assertNumQueries(expected_query_count):
with self.django_assert_num_queries(expected_query_count):
response = self.client.get(url)
self.assertEqual(
response.data['results'],
self.serialize_program(expected, many=True, extra_context=extra_context)
)
assert response.data['results'] == self.serialize_program(expected, many=True, extra_context=extra_context)
def test_list(self):
""" Verify the endpoint returns a list of all programs. """
......@@ -200,11 +209,13 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
self.assert_list_results(url, expected, 10)
@ddt.data(
(ProgramStatus.Unpublished, False, 5),
(ProgramStatus.Active, True, 10),
@pytest.mark.parametrize(
'status,is_marketable,expected_query_count',
(
(ProgramStatus.Unpublished, False, 5),
(ProgramStatus.Active, True, 10),
)
)
@ddt.unpack
def test_filter_by_marketable(self, status, is_marketable, expected_query_count):
""" Verify the endpoint filters programs to those that are marketable. """
url = self.list_path + '?marketable=1'
......@@ -213,7 +224,7 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
programs.reverse()
expected = programs if is_marketable else []
self.assertEqual(list(Program.objects.marketable()), expected)
assert list(Program.objects.marketable()) == expected
self.assert_list_results(url, expected, expected_query_count)
def test_filter_by_status(self):
......@@ -271,4 +282,4 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def test_minimal_serializer_use(self):
""" Verify that the list view uses the minimal serializer. """
self.assertEqual(ProgramViewSet(action='list').get_serializer_class(), MinimalProgramSerializer)
assert ProgramViewSet(action='list').get_serializer_class() == MinimalProgramSerializer
......@@ -323,8 +323,9 @@ class AggregateSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT
def process_response(self, response):
response = self.get_response(response).json()
self.assertTrue(response['objects']['count'])
return response['objects']
objects = response['objects']
assert objects['count'] > 0
return objects
def test_results_only_include_published_objects(self):
""" Verify the search results only include items with status set to 'Published'. """
......
import logging
import pytest
from django.conf import settings
from haystack import connections as haystack_connections
......@@ -9,44 +10,13 @@ from course_discovery.apps.course_metadata.models import Course, CourseRun
logger = logging.getLogger(__name__)
@pytest.mark.usefixtures('haystack_default_connection')
class ElasticsearchTestMixin(object):
@classmethod
def setUpClass(cls):
super(ElasticsearchTestMixin, cls).setUpClass()
cls.index = settings.HAYSTACK_CONNECTIONS['default']['INDEX_NAME']
# Make use of the changes in our custom ES backend
# This is required for typeahead autocomplete to work in the tests
connection = haystack_connections['default']
cls.backend = connection.get_backend()
# Without this line, haystack doesn't fully recreate the connection
# The first test using this backend succeeds, but the following tests
# do not set the Elasticsearch _mapping
def setUp(self):
super(ElasticsearchTestMixin, self).setUp()
self.backend.setup_complete = False
self.es = self.backend.conn
self.reset_index()
self.refresh_index()
def reset_index(self):
""" Deletes and re-creates the Elasticsearch index. """
self.delete_index(self.index)
ElasticsearchUtils.create_alias_and_index(self.es, self.index)
def delete_index(self, index):
"""
Deletes an index.
Args:
index (str): Name of index to delete
Returns:
None
"""
logger.info('Deleting index [%s]...', index)
self.es.indices.delete(index=index, ignore=404) # pylint: disable=unexpected-keyword-arg
logger.info('...index deleted.')
self.index = settings.HAYSTACK_CONNECTIONS['default']['INDEX_NAME']
connection = haystack_connections['default']
self.es = connection.get_backend().conn
def refresh_index(self):
"""
......@@ -54,9 +24,7 @@ class ElasticsearchTestMixin(object):
https://www.elastic.co/guide/en/elasticsearch/reference/current/indices-refresh.html
"""
# pylint: disable=unexpected-keyword-arg
self.es.indices.refresh(index=self.index)
self.es.cluster.health(index=self.index, wait_for_status='yellow', request_timeout=1)
ElasticsearchUtils.refresh_index(self.es, self.index)
def reindex_course_runs(self, course):
index = haystack_connections['default'].get_unified_index().get_index(CourseRun)
......
......@@ -53,6 +53,24 @@ class ElasticsearchUtils(object):
logger.info('...index [%s] created.', index_name)
return index_name
@classmethod
def delete_index(cls, es_connection, index):
logger.info('Deleting index [%s]...', index)
es_connection.indices.delete(index=index, ignore=404) # pylint: disable=unexpected-keyword-arg
logger.info('...index deleted.')
@classmethod
def refresh_index(cls, es_connection, index):
"""
Refreshes the index.
https://www.elastic.co/guide/en/elasticsearch/reference/current/indices-refresh.html
"""
logger.info('Refreshing index [%s]...', index)
es_connection.indices.refresh(index=index)
es_connection.cluster.health(index=index, wait_for_status='yellow', request_timeout=1)
logger.info('...index refreshed.')
def get_all_related_field_names(model):
"""
......
import json
from urllib.parse import quote
import ddt
import pytest
from django.test import TestCase
from django.urls import reverse
......@@ -16,128 +16,93 @@ from course_discovery.apps.publisher.tests import factories
# pylint: disable=no-member
@ddt.ddt
class AutocompleteTests(SiteMixin, TestCase):
""" Tests for autocomplete lookups."""
def setUp(self):
super(AutocompleteTests, self).setUp()
self.user = UserFactory(is_staff=True)
self.client.login(username=self.user.username, password=USER_PASSWORD)
self.courses = CourseFactory.create_batch(3, title='Some random course title')
for course in self.courses:
CourseRunFactory(course=course)
self.organizations = OrganizationFactory.create_batch(3)
first_instructor = PersonFactory(given_name="First Instructor")
second_instructor = PersonFactory(given_name="Second Instructor")
self.instructors = [first_instructor, second_instructor]
@ddt.data('dum', 'ing')
def test_course_autocomplete(self, search_key):
""" Verify course autocomplete returns the data. """
response = self.client.get(reverse('admin_metadata:course-autocomplete'))
@pytest.mark.django_db
class TestAutocomplete:
def assert_valid_query_result(self, client, path, query, expected_result):
""" Asserts a query made against the given endpoint returns the expected result. """
response = client.get(path + '?q={q}'.format(q=query))
data = json.loads(response.content.decode('utf-8'))
self.assertEqual(response.status_code, 200)
self.assertEqual(len(data['results']), 3)
# update the first course title
self.courses[0].key = 'edx/dummy/key'
self.courses[0].title = 'this is some thing new'
self.courses[0].save()
response = self.client.get(
reverse('admin_metadata:course-autocomplete') + '?q={title}'.format(title=search_key)
)
data = json.loads(response.content.decode('utf-8'))
self.assertEqual(data['results'][0]['text'], str(self.courses[0]))
assert len(data['results']) == 1
assert data['results'][0]['text'] == str(expected_result)
def test_course_autocomplete_un_authorize_user(self):
""" Verify course autocomplete returns empty list for un-authorized users. """
self._make_user_non_staff()
response = self.client.get(reverse('admin_metadata:course-autocomplete'))
data = json.loads(response.content.decode('utf-8'))
self.assertEqual(data['results'], [])
@ddt.data('ing', 'dum')
def test_course_run_autocomplete(self, search_key):
""" Verify course run autocomplete returns the data. """
response = self.client.get(reverse('admin_metadata:course-run-autocomplete'))
def test_course_autocomplete(self, admin_client):
""" Verify course autocomplete returns the data. """
courses = CourseFactory.create_batch(3)
path = reverse('admin_metadata:course-autocomplete')
response = admin_client.get(path)
data = json.loads(response.content.decode('utf-8'))
self.assertEqual(response.status_code, 200)
self.assertEqual(len(data['results']), 3)
# update the first course title
course = self.courses[0]
course.title = 'this is some thing new'
course.save()
course_run = self.courses[0].course_runs.first()
course_run.key = 'edx/dummy/testrun'
course_run.save()
response = self.client.get(
reverse('admin_metadata:course-run-autocomplete') + '?q={q}'.format(q=search_key)
)
assert response.status_code == 200
assert len(data['results']) == 3
# Search for substrings of course keys and titles
course = courses[0]
self.assert_valid_query_result(admin_client, path, course.key[12:], course)
self.assert_valid_query_result(admin_client, path, course.title[12:], course)
def test_course_run_autocomplete(self, admin_client):
course_runs = CourseRunFactory.create_batch(3)
path = reverse('admin_metadata:course-run-autocomplete')
response = admin_client.get(path)
data = json.loads(response.content.decode('utf-8'))
self.assertEqual(data['results'][0]['text'], str(course_run))
assert response.status_code == 200
assert len(data['results']) == 3
def test_course_run_autocomplete_un_authorize_user(self):
""" Verify course run autocomplete returns empty list for un-authorized users. """
self._make_user_non_staff()
response = self.client.get(reverse('admin_metadata:course-run-autocomplete'))
data = json.loads(response.content.decode('utf-8'))
self.assertEqual(data['results'], [])
# Search for substrings of course run keys and titles
course_run = course_runs[0]
self.assert_valid_query_result(admin_client, path, course_run.key[14:], course_run)
self.assert_valid_query_result(admin_client, path, course_run.title[12:], course_run)
@ddt.data('irc', 'ing')
def test_organization_autocomplete(self, search_key):
def test_organization_autocomplete(self, admin_client):
""" Verify Organization autocomplete returns the data. """
response = self.client.get(reverse('admin_metadata:organisation-autocomplete'))
organizations = OrganizationFactory.create_batch(3)
path = reverse('admin_metadata:organisation-autocomplete')
response = admin_client.get(path)
data = json.loads(response.content.decode('utf-8'))
self.assertEqual(response.status_code, 200)
self.assertEqual(len(data['results']), 3)
assert response.status_code == 200
assert len(data['results']) == 3
self.organizations[0].key = 'Mirco'
self.organizations[0].name = 'testing name'
self.organizations[0].save()
# Search for substrings of organization keys and names
organization = organizations[0]
self.assert_valid_query_result(admin_client, path, organization.key[:3], organization)
self.assert_valid_query_result(admin_client, path, organization.name[:5], organization)
response = self.client.get(
reverse('admin_metadata:organisation-autocomplete') + '?q={key}'.format(
key=search_key
)
)
data = json.loads(response.content.decode('utf-8'))
self.assertEqual(data['results'][0]['text'], str(self.organizations[0]))
self.assertEqual(len(data['results']), 1)
@pytest.mark.parametrize('view_prefix', ['organisation', 'course', 'course-run'])
def test_autocomplete_requires_staff_permission(self, view_prefix, client):
""" Verify autocomplete returns empty list for non-staff users. """
def test_organization_autocomplete_un_authorize_user(self):
""" Verify Organization autocomplete returns empty list for un-authorized users. """
self._make_user_non_staff()
response = self.client.get(reverse('admin_metadata:organisation-autocomplete'))
user = UserFactory(is_staff=False)
client.login(username=user.username, password=USER_PASSWORD)
response = client.get(reverse('admin_metadata:{}-autocomplete'.format(view_prefix)))
data = json.loads(response.content.decode('utf-8'))
self.assertEqual(data['results'], [])
def _make_user_non_staff(self):
self.client.logout()
self.user = UserFactory(is_staff=False)
self.user.save()
self.client.login(username=self.user.username, password=USER_PASSWORD)
assert response.status_code == 200
assert data['results'] == []
@ddt.ddt
class AutoCompletePersonTests(SiteMixin, TestCase):
"""
Tests for person autocomplete lookups
"""
def setUp(self):
super(AutoCompletePersonTests, self).setUp()
self.user = UserFactory(is_staff=True)
self.client.login(username=self.user.username, password=USER_PASSWORD)
self.courses = factories.CourseFactory.create_batch(3, title='Some random course title')
for course in self.courses:
factories.CourseRunFactory(course=course)
self.organizations = OrganizationFactory.create_batch(3)
self.organization_extensions = []
for organization in self.organizations:
self.organization_extensions.append(factories.OrganizationExtensionFactory(organization=organization))
self.user.groups.add(self.organization_extensions[0].group)
first_instructor = PersonFactory(given_name="First Instructor")
second_instructor = PersonFactory(given_name="Second Instructor")
self.instructors = [first_instructor, second_instructor]
for instructor in self.instructors:
PositionFactory(organization=self.organizations[0], title="professor", person=instructor)
......@@ -185,9 +150,9 @@ class AutoCompletePersonTests(SiteMixin, TestCase):
def _assert_response(self, response, expected_length):
""" Assert autocomplete response. """
self.assertEqual(response.status_code, 200)
assert response.status_code == 200
data = json.loads(response.content.decode('utf-8'))
self.assertEqual(len(data['results']), expected_length)
assert len(data['results']) == expected_length
def test_instructor_autocomplete_with_uuid(self):
""" Verify instructor autocomplete returns the data with valid uuid. """
......@@ -243,10 +208,10 @@ class AutoCompletePersonTests(SiteMixin, TestCase):
reverse('admin_metadata:person-autocomplete') + '?q={q}'.format(q='ins'),
HTTP_REFERER=reverse('admin:publisher_courserun_add')
)
self.assertEqual(response.status_code, 200)
assert response.status_code == 200
data = json.loads(response.content.decode('utf-8'))
expected_results = [{'id': instructor.id, 'text': str(instructor)} for instructor in self.instructors]
self.assertEqual(data.get('results'), expected_results)
assert data.get('results') == expected_results
def _make_user_non_staff(self):
self.client.logout()
......
......@@ -4,23 +4,23 @@ from decimal import Decimal
import ddt
import mock
import pytest
from dateutil.parser import parse
from django.conf import settings
from django.core.exceptions import ValidationError
from django.db import IntegrityError
from django.db.models.functions import Lower
from django.test import TestCase
from freezegun import freeze_time
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.models import Currency
from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.core.utils import SearchQuerySetWrapper
from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus
from course_discovery.apps.course_metadata.models import (
FAQ, AbstractMediaModel, AbstractNamedModel, AbstractValueModel, CorporateEndorsement, Course, CourseRun,
Endorsement, Seat, SeatType, Subject)
Endorsement, Seat, SeatType, Subject
)
from course_discovery.apps.course_metadata.publishers import (
CourseRunMarketingSitePublisher, ProgramMarketingSitePublisher
)
......@@ -31,34 +31,17 @@ from course_discovery.apps.ietf_language_tags.models import LanguageTag
# pylint: disable=no-member
class CourseTests(ElasticsearchTestMixin, TestCase):
""" Tests for the `Course` model. """
def setUp(self):
super(CourseTests, self).setUp()
self.course = factories.CourseFactory()
@pytest.mark.django_db
class TestCourse:
def test_str(self):
""" Verify casting an instance to a string returns a string containing the key and title. """
self.assertEqual(str(self.course), '{key}: {title}'.format(key=self.course.key, title=self.course.title))
course = factories.CourseFactory()
assert str(course), '{key}: {title}'.format(key=course.key, title=course.title)
def test_search(self):
""" Verify the method returns a filtered queryset of courses. """
toggle_switch('log_course_search_queries', active=True)
def test_search(self, haystack_default_connection): # pylint: disable=unused-argument
title = 'Some random title'
courses = factories.CourseFactory.create_batch(3, title=title)
# Sort lowercase keys to prevent different sort orders due to casing.
# For example, sorted(['a', 'Z']) gives ['Z', 'a'], but an ordered
# queryset containing the same values may give ['a', 'Z'] depending
# on the database backend in use.
courses = sorted(courses, key=lambda course: course.key.lower())
expected = set(factories.CourseFactory.create_batch(3, title=title))
query = 'title:' + title
# Use Lower() to force a case-insensitive sort.
actual = list(Course.search(query).order_by(Lower('key')))
self.assertEqual(actual, courses)
assert set(Course.search(query)) == expected
def test_image_url(self):
course = factories.CourseFactory()
......
......@@ -2,12 +2,10 @@ import datetime
import mock
import pytest
from django.test import TestCase
from haystack.backends import SQ
from haystack.backends.elasticsearch_backend import ElasticsearchSearchQuery
from haystack.query import SearchQuerySet
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.models import CourseRun
from course_discovery.apps.course_metadata.tests.factories import CourseFactory, CourseRunFactory
from course_discovery.apps.edx_haystack_extensions.distinct_counts.backends import (
......@@ -16,7 +14,9 @@ from course_discovery.apps.edx_haystack_extensions.distinct_counts.backends impo
# pylint: disable=protected-access
class DistinctCountsSearchQueryTests(ElasticsearchTestMixin, TestCase):
@pytest.mark.django_db
@pytest.mark.usefixtures('haystack_default_connection')
class TestDistinctCountsSearchQuery:
def test_clone(self):
""" Verify that clone copies all fields, including the aggregation_key and distinct_hit_count."""
query = DistinctCountsSearchQuery()
......@@ -234,7 +234,9 @@ class DistinctCountsSearchQueryTests(ElasticsearchTestMixin, TestCase):
assert query.facets['pacing_type_exact']['size'] == 5
class DistinctCountsElasticsearchBackendWrapperTests(ElasticsearchTestMixin, TestCase):
@pytest.mark.django_db
@pytest.mark.usefixtures('haystack_default_connection')
class TestDistinctCountsElasticsearchBackendWrapper:
def test_search_raises_when_called_with_date_facet(self):
now = datetime.datetime.now()
one_day = datetime.timedelta(days=1)
......
import datetime
import mock
import mock
import pytest
from django.test import TestCase
from haystack.query import SearchQuerySet
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.models import CourseRun
from course_discovery.apps.course_metadata.tests.factories import CourseFactory, CourseRunFactory
from course_discovery.apps.edx_haystack_extensions.distinct_counts.backends import DistinctCountsSearchQuery
from course_discovery.apps.edx_haystack_extensions.distinct_counts.query import DistinctCountsSearchQuerySet
class DistinctCountsSearchQuerySetTests(ElasticsearchTestMixin, TestCase):
@pytest.mark.django_db
@pytest.mark.usefixtures('haystack_default_connection')
class TestDistinctCountsSearchQuerySet:
def test_from_queryset(self):
""" Verify that a DistinctCountsSearchQuerySet can be built from an existing SearchQuerySet."""
course_1 = CourseFactory()
......
import datetime
import ddt
import pytest
import pytz
from django.test import TestCase
from haystack.query import SearchQuerySet
from mock import patch
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.models import CourseRun, Program, ProgramType
from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory, ProgramFactory
@ddt.ddt
class TestSearchBoosting(ElasticsearchTestMixin, TestCase):
@pytest.mark.django_db
@pytest.mark.usefixtures('haystack_default_connection')
class TestSearchBoosting:
def build_normalized_course_run(self, **kwargs):
"""Builds a CourseRun with fields set to normalize boosting behavior."""
defaults = {
......@@ -35,7 +34,7 @@ class TestSearchBoosting(ElasticsearchTestMixin, TestCase):
assert search_results[0].score > search_results[1].score
assert test_record.pacing_type == search_results[0].pacing_type
@ddt.data('MicroMasters', 'Professional Certificate')
@pytest.mark.parametrize('program_type', ['MicroMasters', 'Professional Certificate'])
def test_program_type_boosting(self, program_type):
"""Verify MicroMasters and Professional Certificate are boosted over XSeries."""
ProgramFactory(type=ProgramType.objects.get(name='XSeries'))
......@@ -55,20 +54,23 @@ class TestSearchBoosting(ElasticsearchTestMixin, TestCase):
search_results = SearchQuerySet().models(CourseRun).all()
assert len(search_results) == 2
assert search_results[0].score > search_results[1].score
assert int(test_record.start.timestamp()) == int(search_results[0].start.timestamp()) # pylint: disable=no-member
@ddt.data(
# Should not get boost if has_enrollable_paid_seats is False, has_enrollable_paid_seats
# is None, or paid_seat_enrollment_end is in the past.
(False, None, False),
(None, None, False),
(True, datetime.datetime.now(pytz.timezone('utc')) - datetime.timedelta(days=15), False),
# Should get boost if has_enrollable_paid_seats is True and paid_seat_enrollment_end
# is None or in the future.
(True, None, True),
(True, datetime.datetime.now(pytz.timezone('utc')) + datetime.timedelta(days=15), True),
# pylint: disable=no-member
assert int(test_record.start.timestamp()) == int(search_results[0].start.timestamp())
@pytest.mark.parametrize(
'has_enrollable_paid_seats,paid_seat_enrollment_end,expects_boost',
[
# Should not get boost if has_enrollable_paid_seats is False, has_enrollable_paid_seats
# is None, or paid_seat_enrollment_end is in the past.
(False, None, False),
(None, None, False),
(True, datetime.datetime.now(pytz.timezone('utc')) - datetime.timedelta(days=15), False),
# Should get boost if has_enrollable_paid_seats is True and paid_seat_enrollment_end
# is None or in the future.
(True, None, True),
(True, datetime.datetime.now(pytz.timezone('utc')) + datetime.timedelta(days=15), True),
]
)
@ddt.unpack
def test_enrollable_paid_seat_boosting(self, has_enrollable_paid_seats, paid_seat_enrollment_end, expects_boost):
"""
Verify that CourseRuns for which an unenrolled user may enroll and
......@@ -123,25 +125,27 @@ class TestSearchBoosting(ElasticsearchTestMixin, TestCase):
# negative relevance score.
assert 0 > search_results[1].score
@ddt.data(
# Should get boost if enrollment_start and enrollment_end unspecified.
(None, None, True),
# Should get boost if enrollment_start unspecified and enrollment_end in future.
(None, datetime.datetime.now(pytz.timezone('utc')) + datetime.timedelta(days=15), True),
# Should get boost if enrollment_start in past and enrollment_end unspecified.
(datetime.datetime.now(pytz.timezone('utc')) - datetime.timedelta(days=15), None, True),
# Should get boost if enrollment_start in past and enrollment_end in future.
(
datetime.datetime.now(pytz.timezone('utc')) - datetime.timedelta(days=15),
datetime.datetime.now(pytz.timezone('utc')) + datetime.timedelta(days=15),
True
),
# Should not get boost if enrollment_start in future.
(datetime.datetime.now(pytz.timezone('utc')) + datetime.timedelta(days=15), None, False),
# Should not get boost if enrollment_end in past.
(None, datetime.datetime.now(pytz.timezone('utc')) - datetime.timedelta(days=15), False),
@pytest.mark.parametrize(
'enrollment_start,enrollment_end,expects_boost',
[
# Should get boost if enrollment_start and enrollment_end unspecified.
(None, None, True),
# Should get boost if enrollment_start unspecified and enrollment_end in future.
(None, datetime.datetime.now(pytz.timezone('utc')) + datetime.timedelta(days=15), True),
# Should get boost if enrollment_start in past and enrollment_end unspecified.
(datetime.datetime.now(pytz.timezone('utc')) - datetime.timedelta(days=15), None, True),
# Should get boost if enrollment_start in past and enrollment_end in future.
(
datetime.datetime.now(pytz.timezone('utc')) - datetime.timedelta(days=15),
datetime.datetime.now(pytz.timezone('utc')) + datetime.timedelta(days=15),
True
),
# Should not get boost if enrollment_start in future.
(datetime.datetime.now(pytz.timezone('utc')) + datetime.timedelta(days=15), None, False),
# Should not get boost if enrollment_end in past.
(None, datetime.datetime.now(pytz.timezone('utc')) - datetime.timedelta(days=15), False),
]
)
@ddt.unpack
def test_enrollable_course_run_boosting(self, enrollment_start, enrollment_end, expects_boost):
"""Verify that enrollable CourseRuns are boosted."""
......
......@@ -2,40 +2,38 @@ import datetime
from django.conf import settings
from django.core.management import call_command
from django.test import TestCase
from freezegun import freeze_time
from course_discovery.apps.core.utils import ElasticsearchUtils
from course_discovery.apps.edx_haystack_extensions.tests.mixins import SearchIndexTestMixin
class RemoveUnusedIndexesTests(SearchIndexTestMixin, TestCase):
class TestRemoveUnusedIndexes:
backend = None
def test_handle(self):
def test_handle(self, haystack_default_connection):
""" Verify the command removes all but the newest indexes. """
backend = haystack_default_connection
# Create initial index with alias
ElasticsearchUtils.create_alias_and_index(es_connection=self.backend.conn, alias=self.backend.index_name)
# Use now as initial time, so indexes are created AFTER the current index so expected values are accurate
initial_time = datetime.datetime.now()
# Create 2 more indexes than we expect to exist after removal
for number in range(1, settings.HAYSTACK_INDEX_RETENTION_LIMIT + 2):
current_time = initial_time + datetime.timedelta(seconds=number)
freezer = freeze_time(current_time)
freezer.start()
ElasticsearchUtils.create_index(es_connection=self.backend.conn, prefix=self.backend.index_name)
ElasticsearchUtils.create_index(es_connection=backend.conn, prefix=backend.index_name)
freezer.stop()
# Prune indexes and confirm the right indexes are removed
call_command('remove_unused_indexes')
current_alias_name = self.backend.index_name
indices_client = self.backend.conn.indices
current_alias_name = backend.index_name
indices_client = backend.conn.indices
current_alias = indices_client.get_alias(name=current_alias_name)
indexes_to_keep = current_alias.keys()
# check that we keep the current indexes, which we don't want removed
all_indexes = self.get_current_index_names(indices_client=indices_client, index_prefix=self.backend.index_name)
all_indexes = self.get_current_index_names(indices_client=indices_client, index_prefix=backend.index_name)
assert set(all_indexes).issuperset(set(indexes_to_keep))
# check that other indexes are removed, excepting those that don't hit the retention limit
......@@ -46,12 +44,15 @@ class RemoveUnusedIndexesTests(SearchIndexTestMixin, TestCase):
call_command('remove_unused_indexes')
# check that we keep the current indexes, which we don't want removed
all_indexes = self.get_current_index_names(indices_client=indices_client, index_prefix=self.backend.index_name)
all_indexes = self.get_current_index_names(indices_client=indices_client, index_prefix=backend.index_name)
assert set(all_indexes).issuperset(set(indexes_to_keep))
# check that index count remains the same as before
assert len(all_indexes) == expected_count
# Cleanup indexes created by this test
backend.conn.indices.delete(index=backend.index_name + '_*')
@staticmethod
def get_current_index_names(indices_client, index_prefix):
all_index_settings = indices_client.get_settings()
......
......@@ -11,7 +11,8 @@ lxml==3.6.1
mock==2.0.0
pep8==1.7.0
pytest==3.0.6
pytest-django==3.1.2
# See https://github.com/pytest-dev/pytest-django/issues/473
git+https://github.com/pytest-dev/pytest-django@de31fab4122cfc53e7116b499235b610366d941a#egg=pytest-django==3.1.3.dev47+gde31fab
pytest-django-ordering==1.0.1
pytest-responses==0.3.0
responses==0.7.0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment