Commit eb987ed6 by Diana Huang Committed by GitHub

Merge pull request #15594 from edx/diana/access-checks

Move prereq checking into course load checks.
parents d5fca9ee c338b751
...@@ -734,7 +734,6 @@ def dashboard(request): ...@@ -734,7 +734,6 @@ def dashboard(request):
show_courseware_links_for = frozenset( show_courseware_links_for = frozenset(
enrollment.course_id for enrollment in course_enrollments enrollment.course_id for enrollment in course_enrollments
if has_access(request.user, 'load', enrollment.course_overview) if has_access(request.user, 'load', enrollment.course_overview)
and has_access(request.user, 'view_courseware_with_prerequisites', enrollment.course_overview)
) )
# Find programs associated with course runs being displayed. This information # Find programs associated with course runs being displayed. This information
......
...@@ -21,7 +21,11 @@ from django.utils.timezone import UTC ...@@ -21,7 +21,11 @@ from django.utils.timezone import UTC
from opaque_keys.edx.keys import CourseKey, UsageKey from opaque_keys.edx.keys import CourseKey, UsageKey
from xblock.core import XBlock from xblock.core import XBlock
from courseware.access_response import MilestoneError, MobileAvailabilityError, VisibilityError from courseware.access_response import (
MilestoneAccessError,
MobileAvailabilityError,
VisibilityError,
)
from courseware.access_utils import ( from courseware.access_utils import (
ACCESS_DENIED, ACCESS_DENIED,
ACCESS_GRANTED, ACCESS_GRANTED,
...@@ -309,7 +313,8 @@ def _has_access_course(user, action, courselike): ...@@ -309,7 +313,8 @@ def _has_access_course(user, action, courselike):
""" """
response = ( response = (
_visible_to_nonstaff_users(courselike) and _visible_to_nonstaff_users(courselike) and
check_course_open_for_learner(user, courselike) check_course_open_for_learner(user, courselike) and
_can_view_courseware_with_prerequisites(user, courselike)
) )
return ( return (
...@@ -355,8 +360,6 @@ def _has_access_course(user, action, courselike): ...@@ -355,8 +360,6 @@ def _has_access_course(user, action, courselike):
checkers = { checkers = {
'load': can_load, 'load': can_load,
'view_courseware_with_prerequisites':
lambda: _can_view_courseware_with_prerequisites(user, courselike),
'load_mobile': lambda: can_load() and _can_load_course_on_mobile(user, courselike), 'load_mobile': lambda: can_load() and _can_load_course_on_mobile(user, courselike),
'enroll': can_enroll, 'enroll': can_enroll,
'see_exists': see_exists, 'see_exists': see_exists,
...@@ -770,7 +773,7 @@ def _has_fulfilled_all_milestones(user, course_id): ...@@ -770,7 +773,7 @@ def _has_fulfilled_all_milestones(user, course_id):
course_id: ID of the course to check course_id: ID of the course to check
user_id: ID of the user to check user_id: ID of the user to check
""" """
return MilestoneError() if any_unfulfilled_milestones(course_id, user.id) else ACCESS_GRANTED return MilestoneAccessError() if any_unfulfilled_milestones(course_id, user.id) else ACCESS_GRANTED
def _has_fulfilled_prerequisites(user, course_id): def _has_fulfilled_prerequisites(user, course_id):
...@@ -782,7 +785,7 @@ def _has_fulfilled_prerequisites(user, course_id): ...@@ -782,7 +785,7 @@ def _has_fulfilled_prerequisites(user, course_id):
user: user to check user: user to check
course_id: ID of the course to check course_id: ID of the course to check
""" """
return MilestoneError() if get_pre_requisite_courses_not_completed(user, course_id) else ACCESS_GRANTED return MilestoneAccessError() if get_pre_requisite_courses_not_completed(user, course_id) else ACCESS_GRANTED
def _has_catalog_visibility(course, visibility_type): def _has_catalog_visibility(course, visibility_type):
......
...@@ -105,7 +105,7 @@ class StartDateError(AccessError): ...@@ -105,7 +105,7 @@ class StartDateError(AccessError):
super(StartDateError, self).__init__(error_code, developer_message, user_message) super(StartDateError, self).__init__(error_code, developer_message, user_message)
class MilestoneError(AccessError): class MilestoneAccessError(AccessError):
""" """
Access denied because the user has unfulfilled milestones Access denied because the user has unfulfilled milestones
""" """
...@@ -113,7 +113,7 @@ class MilestoneError(AccessError): ...@@ -113,7 +113,7 @@ class MilestoneError(AccessError):
error_code = "unfulfilled_milestones" error_code = "unfulfilled_milestones"
developer_message = "User has unfulfilled milestones" developer_message = "User has unfulfilled milestones"
user_message = _("You have unfulfilled milestones") user_message = _("You have unfulfilled milestones")
super(MilestoneError, self).__init__(error_code, developer_message, user_message) super(MilestoneAccessError, self).__init__(error_code, developer_message, user_message)
class VisibilityError(AccessError): class VisibilityError(AccessError):
......
...@@ -9,7 +9,7 @@ from datetime import datetime ...@@ -9,7 +9,7 @@ from datetime import datetime
import branding import branding
import pytz import pytz
from courseware.access import has_access from courseware.access import has_access
from courseware.access_response import StartDateError from courseware.access_response import StartDateError, MilestoneAccessError
from courseware.date_summary import ( from courseware.date_summary import (
CourseEndDate, CourseEndDate,
CourseStartDate, CourseStartDate,
...@@ -32,6 +32,7 @@ from openedx.core.djangoapps.site_configuration import helpers as configuration_ ...@@ -32,6 +32,7 @@ from openedx.core.djangoapps.site_configuration import helpers as configuration_
from path import Path as path from path import Path as path
from static_replace import replace_static_urls from static_replace import replace_static_urls
from student.models import CourseEnrollment from student.models import CourseEnrollment
from survey.utils import is_survey_required_and_unanswered
from util.date_utils import strftime_localized from util.date_utils import strftime_localized
from xmodule.modulestore.django import modulestore from xmodule.modulestore.django import modulestore
from xmodule.modulestore.exceptions import ItemNotFoundError from xmodule.modulestore.exceptions import ItemNotFoundError
...@@ -72,7 +73,7 @@ def get_course_by_id(course_key, depth=0): ...@@ -72,7 +73,7 @@ def get_course_by_id(course_key, depth=0):
raise Http404("Course not found: {}.".format(unicode(course_key))) raise Http404("Course not found: {}.".format(unicode(course_key)))
def get_course_with_access(user, action, course_key, depth=0, check_if_enrolled=False): def get_course_with_access(user, action, course_key, depth=0, check_if_enrolled=False, check_survey_complete=True):
""" """
Given a course_key, look up the corresponding course descriptor, Given a course_key, look up the corresponding course descriptor,
check that the user has the access to perform the specified action check that the user has the access to perform the specified action
...@@ -84,9 +85,14 @@ def get_course_with_access(user, action, course_key, depth=0, check_if_enrolled= ...@@ -84,9 +85,14 @@ def get_course_with_access(user, action, course_key, depth=0, check_if_enrolled=
check_if_enrolled: If true, additionally verifies that the user is either enrolled in the course check_if_enrolled: If true, additionally verifies that the user is either enrolled in the course
or has staff access. or has staff access.
check_survey_complete: If true, additionally verifies that the user has either completed the course survey
or has staff access.
Note: We do not want to continually add these optional booleans. Ideally,
these special cases could not only be handled inside has_access, but could
be plugged in as additional callback checks for different actions.
""" """
course = get_course_by_id(course_key, depth) course = get_course_by_id(course_key, depth)
check_course_access(course, user, action, check_if_enrolled) check_course_access(course, user, action, check_if_enrolled, check_survey_complete)
return course return course
...@@ -109,12 +115,13 @@ def get_course_overview_with_access(user, action, course_key, check_if_enrolled= ...@@ -109,12 +115,13 @@ def get_course_overview_with_access(user, action, course_key, check_if_enrolled=
return course_overview return course_overview
def check_course_access(course, user, action, check_if_enrolled=False): def check_course_access(course, user, action, check_if_enrolled=False, check_survey_complete=True):
""" """
Check that the user has the access to perform the specified action Check that the user has the access to perform the specified action
on the course (CourseDescriptor|CourseOverview). on the course (CourseDescriptor|CourseOverview).
check_if_enrolled: If true, additionally verifies that the user is enrolled. check_if_enrolled: If true, additionally verifies that the user is enrolled.
check_survey_complete: If true, additionally verifies that the user has completed the survey.
""" """
# Allow staff full access to the course even if not enrolled # Allow staff full access to the course even if not enrolled
if has_access(user, 'staff', course.id): if has_access(user, 'staff', course.id):
...@@ -130,7 +137,13 @@ def check_course_access(course, user, action, check_if_enrolled=False): ...@@ -130,7 +137,13 @@ def check_course_access(course, user, action, check_if_enrolled=False):
raise CourseAccessRedirect('{dashboard_url}?{params}'.format( raise CourseAccessRedirect('{dashboard_url}?{params}'.format(
dashboard_url=reverse('dashboard'), dashboard_url=reverse('dashboard'),
params=params.urlencode() params=params.urlencode()
)) ), access_response)
# Redirect if the user must answer a survey before entering the course.
if isinstance(access_response, MilestoneAccessError):
raise CourseAccessRedirect('{dashboard_url}'.format(
dashboard_url=reverse('dashboard'),
), access_response)
# Deliberately return a non-specific error message to avoid # Deliberately return a non-specific error message to avoid
# leaking info about access control settings # leaking info about access control settings
...@@ -141,6 +154,11 @@ def check_course_access(course, user, action, check_if_enrolled=False): ...@@ -141,6 +154,11 @@ def check_course_access(course, user, action, check_if_enrolled=False):
if not CourseEnrollment.is_enrolled(user, course.id): if not CourseEnrollment.is_enrolled(user, course.id):
raise CourseAccessRedirect(reverse('about_course', args=[unicode(course.id)])) raise CourseAccessRedirect(reverse('about_course', args=[unicode(course.id)]))
# Redirect if the user must answer a survey before entering the course.
if check_survey_complete and action == 'load':
if is_survey_required_and_unanswered(user, course):
raise CourseAccessRedirect(reverse('course_survey', args=[unicode(course.id)]))
def can_self_enroll_in_course(course_key): def can_self_enroll_in_course(course_key):
""" """
......
...@@ -15,5 +15,13 @@ class Redirect(Exception): ...@@ -15,5 +15,13 @@ class Redirect(Exception):
class CourseAccessRedirect(Redirect): class CourseAccessRedirect(Redirect):
""" """
Redirect raised when user does not have access to a course. Redirect raised when user does not have access to a course.
Arguments:
url (string): The redirect url.
access_error (AccessErro): The AccessError that caused the redirect.
The AccessError contains messages for developers and users explaining why
the user was denied access. These strings can then be exposed to the user.
""" """
pass def __init__(self, url, access_error=None):
super(CourseAccessRedirect, self).__init__(url)
self.access_error = access_error
...@@ -595,16 +595,16 @@ class AccessTestCase(LoginEnrollmentTestCase, ModuleStoreTestCase, MilestonesTes ...@@ -595,16 +595,16 @@ class AccessTestCase(LoginEnrollmentTestCase, ModuleStoreTestCase, MilestonesTes
# user should not be able to load course even if enrolled # user should not be able to load course even if enrolled
CourseEnrollmentFactory(user=user, course_id=course.id) CourseEnrollmentFactory(user=user, course_id=course.id)
response = access._has_access_course(user, 'view_courseware_with_prerequisites', course) response = access._has_access_course(user, 'load', course)
self.assertFalse(response) self.assertFalse(response)
self.assertIsInstance(response, access_response.MilestoneError) self.assertIsInstance(response, access_response.MilestoneAccessError)
# Staff can always access course # Staff can always access course
staff = StaffFactory.create(course_key=course.id) staff = StaffFactory.create(course_key=course.id)
self.assertTrue(access._has_access_course(staff, 'view_courseware_with_prerequisites', course)) self.assertTrue(access._has_access_course(staff, 'load', course))
# User should be able access after completing required course # User should be able access after completing required course
fulfill_course_milestone(pre_requisite_course.id, user) fulfill_course_milestone(pre_requisite_course.id, user)
self.assertTrue(access._has_access_course(user, 'view_courseware_with_prerequisites', course)) self.assertTrue(access._has_access_course(user, 'load', course))
@ddt.data( @ddt.data(
(True, True, True), (True, True, True),
...@@ -615,8 +615,7 @@ class AccessTestCase(LoginEnrollmentTestCase, ModuleStoreTestCase, MilestonesTes ...@@ -615,8 +615,7 @@ class AccessTestCase(LoginEnrollmentTestCase, ModuleStoreTestCase, MilestonesTes
""" """
Test course access on mobile for staff and students. Test course access on mobile for staff and students.
""" """
descriptor = Mock(id=self.course.id, user_partitions=[]) descriptor = CourseFactory()
descriptor._class_tags = {}
descriptor.visible_to_staff_only = False descriptor.visible_to_staff_only = False
descriptor.mobile_available = mobile_available descriptor.mobile_available = mobile_available
...@@ -773,7 +772,7 @@ class CourseOverviewAccessTestCase(ModuleStoreTestCase): ...@@ -773,7 +772,7 @@ class CourseOverviewAccessTestCase(ModuleStoreTestCase):
PREREQUISITES_TEST_DATA = list(itertools.product( PREREQUISITES_TEST_DATA = list(itertools.product(
['user_normal', 'user_completed_pre_requisite', 'user_staff', 'user_anonymous'], ['user_normal', 'user_completed_pre_requisite', 'user_staff', 'user_anonymous'],
['view_courseware_with_prerequisites'], ['load'],
['course_default', 'course_with_pre_requisite', 'course_with_pre_requisites'], ['course_default', 'course_with_pre_requisite', 'course_with_pre_requisites'],
)) ))
......
...@@ -52,7 +52,6 @@ from ..model_data import FieldDataCache ...@@ -52,7 +52,6 @@ from ..model_data import FieldDataCache
from ..module_render import get_module_for_descriptor, toc_for_course from ..module_render import get_module_for_descriptor, toc_for_course
from .views import ( from .views import (
CourseTabView, CourseTabView,
check_access_to_course,
check_and_get_upgrade_link, check_and_get_upgrade_link,
get_cosmetic_verified_display_price get_cosmetic_verified_display_price
) )
...@@ -136,7 +135,6 @@ class CoursewareIndex(View): ...@@ -136,7 +135,6 @@ class CoursewareIndex(View):
""" """
Render the index page. Render the index page.
""" """
check_access_to_course(request, self.course)
self._redirect_if_needed_to_pay_for_course() self._redirect_if_needed_to_pay_for_course()
self._prefetch_and_bind_course(request) self._prefetch_and_bind_course(request)
......
...@@ -9,7 +9,6 @@ from datetime import datetime, timedelta ...@@ -9,7 +9,6 @@ from datetime import datetime, timedelta
import analytics import analytics
import shoppingcart import shoppingcart
import survey.utils
import survey.views import survey.views
import waffle import waffle
from certificates import api as certs_api from certificates import api as certs_api
...@@ -91,7 +90,6 @@ from openedx.features.enterprise_support.api import data_sharing_consent_require ...@@ -91,7 +90,6 @@ from openedx.features.enterprise_support.api import data_sharing_consent_require
from rest_framework import status from rest_framework import status
from shoppingcart.utils import is_shopping_cart_enabled from shoppingcart.utils import is_shopping_cart_enabled
from student.models import CourseEnrollment, UserTestGroup from student.models import CourseEnrollment, UserTestGroup
from survey.utils import must_answer_survey
from util.cache import cache, cache_if_anonymous from util.cache import cache, cache_if_anonymous
from util.db import outer_atomic from util.db import outer_atomic
from util.milestones_helpers import get_prerequisite_courses_display from util.milestones_helpers import get_prerequisite_courses_display
...@@ -278,10 +276,6 @@ def course_info(request, course_id): ...@@ -278,10 +276,6 @@ def course_info(request, course_id):
if not user_is_enrolled and not can_self_enroll_in_course(course_key): if not user_is_enrolled and not can_self_enroll_in_course(course_key):
return redirect(reverse('dashboard')) return redirect(reverse('dashboard'))
# TODO: LEARNER-1865: Handle prereqs and course survey in new Course Home.
# Redirect the user if they are not yet allowed to view this course
check_access_to_course(request, course)
# LEARNER-170: Entrance exam is handled by new Course Outline. (DONE) # LEARNER-170: Entrance exam is handled by new Course Outline. (DONE)
# If the user needs to take an entrance exam to access this course, then we'll need # If the user needs to take an entrance exam to access this course, then we'll need
# to send them to that specific course module before allowing them into other areas # to send them to that specific course module before allowing them into other areas
...@@ -424,9 +418,6 @@ class CourseTabView(EdxFragmentView): ...@@ -424,9 +418,6 @@ class CourseTabView(EdxFragmentView):
with modulestore().bulk_operations(course_key): with modulestore().bulk_operations(course_key):
course = get_course_with_access(request.user, 'load', course_key) course = get_course_with_access(request.user, 'load', course_key)
try: try:
# Verify that the user has access to the course
check_access_to_course(request, course)
# Show warnings if the user has limited access # Show warnings if the user has limited access
self.register_user_access_warning_messages(request, course_key) self.register_user_access_warning_messages(request, course_key)
...@@ -739,8 +730,7 @@ def course_about(request, course_id): ...@@ -739,8 +730,7 @@ def course_about(request, course_id):
show_courseware_link = bool( show_courseware_link = bool(
( (
has_access(request.user, 'load', course) and has_access(request.user, 'load', course)
has_access(request.user, 'view_courseware_with_prerequisites', course)
) or settings.FEATURES.get('ENABLE_LMS_MIGRATION') ) or settings.FEATURES.get('ENABLE_LMS_MIGRATION')
) )
...@@ -921,9 +911,6 @@ def _progress(request, course_key, student_id): ...@@ -921,9 +911,6 @@ def _progress(request, course_key, student_id):
# NOTE: To make sure impersonation by instructor works, use # NOTE: To make sure impersonation by instructor works, use
# student instead of request.user in the rest of the function. # student instead of request.user in the rest of the function.
# Redirect the user if they are not yet allowed to view this course
check_access_to_course(request, course)
# The pre-fetching of groups is done to make auth checks not require an # The pre-fetching of groups is done to make auth checks not require an
# additional DB lookup (this kills the Progress page in particular). # additional DB lookup (this kills the Progress page in particular).
student = User.objects.prefetch_related("groups").get(id=student.id) student = User.objects.prefetch_related("groups").get(id=student.id)
...@@ -1311,7 +1298,7 @@ def course_survey(request, course_id): ...@@ -1311,7 +1298,7 @@ def course_survey(request, course_id):
""" """
course_key = CourseKey.from_string(course_id) course_key = CourseKey.from_string(course_id)
course = get_course_with_access(request.user, 'load', course_key) course = get_course_with_access(request.user, 'load', course_key, check_survey_complete=False)
redirect_url = reverse(course_home_url_name(course.id), args=[course_id]) redirect_url = reverse(course_home_url_name(course.id), args=[course_id])
...@@ -1721,22 +1708,3 @@ def get_financial_aid_courses(user): ...@@ -1721,22 +1708,3 @@ def get_financial_aid_courses(user):
) )
return financial_aid_courses return financial_aid_courses
def check_access_to_course(request, course):
"""
Raises Redirect exceptions if the user does not have course access.
"""
# TODO: LEARNER-1865: Handle prereqs in new Course Home.
# Redirect to the dashboard if not all prerequisites have been met
if not has_access(request.user, 'view_courseware_with_prerequisites', course):
log.info(
u'User %d tried to view course %s '
u'without fulfilling prerequisites',
request.user.id, unicode(course.id))
raise CourseAccessRedirect(reverse('dashboard'))
# TODO: LEARNER-1865: Handle course surveys in new Course Home.
# Redirect if the user must answer a survey before entering the course.
if must_answer_survey(course, request.user):
raise CourseAccessRedirect(reverse('course_survey', args=[unicode(course.id)]))
...@@ -42,6 +42,10 @@ def mobile_course_access(depth=0): ...@@ -42,6 +42,10 @@ def mobile_course_access(depth=0):
except CoursewareAccessException as error: except CoursewareAccessException as error:
return Response(data=error.to_json(), status=status.HTTP_404_NOT_FOUND) return Response(data=error.to_json(), status=status.HTTP_404_NOT_FOUND)
except CourseAccessRedirect as error: except CourseAccessRedirect as error:
# If the redirect contains information about the triggering AccessError,
# return the information contained in the AccessError.
if error.access_error is not None:
return Response(data=error.access_error.to_json(), status=status.HTTP_404_NOT_FOUND)
# Raise a 404 if the user does not have course access # Raise a 404 if the user does not have course access
raise Http404 raise Http404
return func(self, request, course=course, *args, **kwargs) return func(self, request, course=course, *args, **kwargs)
......
...@@ -4,7 +4,7 @@ Milestone related tests for the mobile_api ...@@ -4,7 +4,7 @@ Milestone related tests for the mobile_api
from django.conf import settings from django.conf import settings
from mock import patch from mock import patch
from courseware.access_response import MilestoneError from courseware.access_response import MilestoneAccessError
from courseware.tests.test_entrance_exam import add_entrance_exam_milestone, answer_entrance_exam_problem from courseware.tests.test_entrance_exam import add_entrance_exam_milestone, answer_entrance_exam_problem
from openedx.core.djangolib.testing.utils import get_mock_request from openedx.core.djangolib.testing.utils import get_mock_request
from util.milestones_helpers import add_prerequisite_course, fulfill_course_milestone from util.milestones_helpers import add_prerequisite_course, fulfill_course_milestone
...@@ -136,4 +136,4 @@ class MobileAPIMilestonesMixin(object): ...@@ -136,4 +136,4 @@ class MobileAPIMilestonesMixin(object):
self.api_response() self.api_response()
else: else:
response = self.api_response(expected_response_code=404) response = self.api_response(expected_response_code=404)
self.assertEqual(response.data, MilestoneError().to_json()) self.assertEqual(response.data, MilestoneAccessError().to_json())
...@@ -18,7 +18,7 @@ from certificates.api import generate_user_certificates ...@@ -18,7 +18,7 @@ from certificates.api import generate_user_certificates
from certificates.models import CertificateStatuses from certificates.models import CertificateStatuses
from certificates.tests.factories import GeneratedCertificateFactory from certificates.tests.factories import GeneratedCertificateFactory
from course_modes.models import CourseMode from course_modes.models import CourseMode
from courseware.access_response import MilestoneError, StartDateError, VisibilityError from courseware.access_response import MilestoneAccessError, StartDateError, VisibilityError
from lms.djangoapps.grades.tests.utils import mock_passing_grade from lms.djangoapps.grades.tests.utils import mock_passing_grade
from mobile_api.testutils import ( from mobile_api.testutils import (
MobileAPITestCase, MobileAPITestCase,
...@@ -155,7 +155,7 @@ class TestUserEnrollmentApi(UrlResetMixin, MobileAPITestCase, MobileAuthUserTest ...@@ -155,7 +155,7 @@ class TestUserEnrollmentApi(UrlResetMixin, MobileAPITestCase, MobileAuthUserTest
] ]
expected_error_codes = [ expected_error_codes = [
MilestoneError().error_code, # 'unfulfilled_milestones' MilestoneAccessError().error_code, # 'unfulfilled_milestones'
StartDateError(self.NEXT_WEEK).error_code, # 'course_not_started' StartDateError(self.NEXT_WEEK).error_code, # 'course_not_started'
VisibilityError().error_code, # 'not_visible_to_user' VisibilityError().error_code, # 'not_visible_to_user'
None, None,
......
...@@ -8,7 +8,7 @@ from django.contrib.auth.models import User ...@@ -8,7 +8,7 @@ from django.contrib.auth.models import User
from django.test.client import Client from django.test.client import Client
from survey.models import SurveyForm from survey.models import SurveyForm
from survey.utils import is_survey_required_for_course, must_answer_survey from survey.utils import is_survey_required_for_course, is_survey_required_and_unanswered
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
from xmodule.modulestore.tests.factories import CourseFactory from xmodule.modulestore.tests.factories import CourseFactory
...@@ -89,28 +89,28 @@ class SurveyModelsTests(ModuleStoreTestCase): ...@@ -89,28 +89,28 @@ class SurveyModelsTests(ModuleStoreTestCase):
""" """
Assert that a new course which has a required survey but user has not answered it yet Assert that a new course which has a required survey but user has not answered it yet
""" """
self.assertTrue(must_answer_survey(self.course, self.student)) self.assertTrue(is_survey_required_and_unanswered(self.student, self.course))
temp_course = CourseFactory.create( temp_course = CourseFactory.create(
course_survey_required=False course_survey_required=False
) )
self.assertFalse(must_answer_survey(temp_course, self.student)) self.assertFalse(is_survey_required_and_unanswered(self.student, temp_course))
temp_course = CourseFactory.create( temp_course = CourseFactory.create(
course_survey_required=True, course_survey_required=True,
course_survey_name="NonExisting" course_survey_name="NonExisting"
) )
self.assertFalse(must_answer_survey(temp_course, self.student)) self.assertFalse(is_survey_required_and_unanswered(self.student, temp_course))
def test_user_has_answered_required_survey(self): def test_user_has_answered_required_survey(self):
""" """
Assert that a new course which has a required survey and user has answers for it Assert that a new course which has a required survey and user has answers for it
""" """
self.survey.save_user_answers(self.student, self.student_answers, None) self.survey.save_user_answers(self.student, self.student_answers, None)
self.assertFalse(must_answer_survey(self.course, self.student)) self.assertFalse(is_survey_required_and_unanswered(self.student, self.course))
def test_staff_must_answer_survey(self): def test_staff_must_answer_survey(self):
""" """
Assert that someone with staff level permissions does not have to answer the survey Assert that someone with staff level permissions does not have to answer the survey
""" """
self.assertFalse(must_answer_survey(self.course, self.staff)) self.assertFalse(is_survey_required_and_unanswered(self.staff, self.course))
""" """
Helper methods for Surveys Utilities for determining whether or not a survey needs to be completed.
""" """
from courseware.access import has_access from courseware.access import has_access
from survey.models import SurveyAnswer, SurveyForm from survey.models import SurveyForm, SurveyAnswer
def is_survey_required_for_course(course_descriptor): def is_survey_required_for_course(course_descriptor):
...@@ -11,17 +10,19 @@ def is_survey_required_for_course(course_descriptor): ...@@ -11,17 +10,19 @@ def is_survey_required_for_course(course_descriptor):
Returns whether a Survey is required for this course Returns whether a Survey is required for this course
""" """
# check to see that the Survey name has been defined in the CourseDescriptor # Check to see that the survey is required in the CourseDescriptor.
# and that the specified Survey exists if not getattr(course_descriptor, 'course_survey_required', False):
return False
return course_descriptor.course_survey_required and \ # Check that the specified Survey for the course exists.
SurveyForm.get(course_descriptor.course_survey_name, throw_if_not_found=False) return SurveyForm.get(course_descriptor.course_survey_name, throw_if_not_found=False)
def must_answer_survey(course_descriptor, user): def is_survey_required_and_unanswered(user, course_descriptor):
""" """
Returns whether a user needs to answer a required survey Returns whether a user is required to answer the survey and has yet to do so.
""" """
if not is_survey_required_for_course(course_descriptor): if not is_survey_required_for_course(course_descriptor):
return False return False
...@@ -29,13 +30,13 @@ def must_answer_survey(course_descriptor, user): ...@@ -29,13 +30,13 @@ def must_answer_survey(course_descriptor, user):
if user.is_anonymous(): if user.is_anonymous():
return False return False
# this will throw exception if not found, but a non existing survey name will # course staff do not need to answer survey
# be trapped in the above is_survey_required_for_course() method
survey = SurveyForm.get(course_descriptor.course_survey_name)
has_staff_access = has_access(user, 'staff', course_descriptor) has_staff_access = has_access(user, 'staff', course_descriptor)
if has_staff_access:
return False
# survey is required and it exists, let's see if user has answered the survey # survey is required and it exists, let's see if user has answered the survey
# course staff do not need to answer survey survey = SurveyForm.get(course_descriptor.course_survey_name)
answered_survey = SurveyAnswer.do_survey_answers_exist(survey, user) answered_survey = SurveyAnswer.do_survey_answers_exist(survey, user)
return not answered_survey and not has_staff_access if not answered_survey:
return True
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment