Unverified Commit 07bb3ea5 by Calen Pennington Committed by Gabe Mulley

Make the exclude orgs a bit more natural by returning a filtered query

parent fa7b9a13
...@@ -78,7 +78,6 @@ class BinnedSchedulesBaseResolver(PrefixedDebugLoggerMixin, RecipientResolver): ...@@ -78,7 +78,6 @@ class BinnedSchedulesBaseResolver(PrefixedDebugLoggerMixin, RecipientResolver):
def __attrs_post_init__(self): def __attrs_post_init__(self):
# TODO: in the next refactor of this task, pass in current_datetime instead of reproducing it here # TODO: in the next refactor of this task, pass in current_datetime instead of reproducing it here
self.current_datetime = self.target_datetime - datetime.timedelta(days=self.day_offset) self.current_datetime = self.target_datetime - datetime.timedelta(days=self.day_offset)
self.exclude_orgs, self.org_list = self.get_course_org_filter()
def send(self, msg_type): def send(self, msg_type):
for (user, language, context) in self.schedules_for_bin(): for (user, language, context) in self.schedules_for_bin():
...@@ -133,11 +132,7 @@ class BinnedSchedulesBaseResolver(PrefixedDebugLoggerMixin, RecipientResolver): ...@@ -133,11 +132,7 @@ class BinnedSchedulesBaseResolver(PrefixedDebugLoggerMixin, RecipientResolver):
**schedule_day_equals_target_day_filter **schedule_day_equals_target_day_filter
).order_by(order_by) ).order_by(order_by)
if self.org_list is not None: schedules = self.filter_by_org(schedules)
if self.exclude_orgs:
schedules = schedules.exclude(enrollment__course__org__in=self.org_list)
else:
schedules = schedules.filter(enrollment__course__org__in=self.org_list)
if "read_replica" in settings.DATABASES: if "read_replica" in settings.DATABASES:
schedules = schedules.using("read_replica") schedules = schedules.using("read_replica")
...@@ -153,7 +148,7 @@ class BinnedSchedulesBaseResolver(PrefixedDebugLoggerMixin, RecipientResolver): ...@@ -153,7 +148,7 @@ class BinnedSchedulesBaseResolver(PrefixedDebugLoggerMixin, RecipientResolver):
return schedules return schedules
def get_course_org_filter(self): def filter_by_org(self, schedules):
""" """
Given the configuration of sites, get the list of orgs that should be included or excluded from this send. Given the configuration of sites, get the list of orgs that should be included or excluded from this send.
...@@ -165,7 +160,6 @@ class BinnedSchedulesBaseResolver(PrefixedDebugLoggerMixin, RecipientResolver): ...@@ -165,7 +160,6 @@ class BinnedSchedulesBaseResolver(PrefixedDebugLoggerMixin, RecipientResolver):
try: try:
site_config = self.site.configuration site_config = self.site.configuration
org_list = site_config.get_value('course_org_filter') org_list = site_config.get_value('course_org_filter')
exclude_orgs = False
if not org_list: if not org_list:
not_orgs = set() not_orgs = set()
for other_site_config in SiteConfiguration.objects.all(): for other_site_config in SiteConfiguration.objects.all():
...@@ -175,15 +169,13 @@ class BinnedSchedulesBaseResolver(PrefixedDebugLoggerMixin, RecipientResolver): ...@@ -175,15 +169,13 @@ class BinnedSchedulesBaseResolver(PrefixedDebugLoggerMixin, RecipientResolver):
not_orgs.add(other) not_orgs.add(other)
else: else:
not_orgs.update(other) not_orgs.update(other)
org_list = list(not_orgs) return schedules.exclude(enrollment__course__org__in=not_orgs)
exclude_orgs = True
elif not isinstance(org_list, list): elif not isinstance(org_list, list):
org_list = [org_list] return schedules.filter(enrollment__course__org=org_list)
except SiteConfiguration.DoesNotExist: except SiteConfiguration.DoesNotExist:
org_list = None return schedules
exclude_orgs = False
return exclude_orgs, org_list return schedules.filter(enrollment__course__org__in=org_list)
def schedules_for_bin(self): def schedules_for_bin(self):
schedules = self.get_schedules_with_target_date_by_bin_and_orgs() schedules = self.get_schedules_with_target_date_by_bin_and_orgs()
......
...@@ -30,28 +30,40 @@ class TestBinnedSchedulesBaseResolver(CacheIsolationTestCase): ...@@ -30,28 +30,40 @@ class TestBinnedSchedulesBaseResolver(CacheIsolationTestCase):
bin_num=2, bin_num=2,
) )
@ddt.data(
'course1'
)
def test_get_course_org_filter_equal(self, course_org_filter):
self.site_config.values['course_org_filter'] = course_org_filter
self.site_config.save()
mock_query = Mock()
result = self.resolver.filter_by_org(mock_query)
self.assertEqual(result, mock_query.filter.return_value)
mock_query.filter.assert_called_once_with(enrollment__course__org=course_org_filter)
@ddt.unpack @ddt.unpack
@ddt.data( @ddt.data(
('course1', ['course1']),
(['course1', 'course2'], ['course1', 'course2']) (['course1', 'course2'], ['course1', 'course2'])
) )
def test_get_course_org_filter_include(self, course_org_filter, expected_org_list): def test_get_course_org_filter_include__in(self, course_org_filter, expected_org_list):
self.site_config.values['course_org_filter'] = course_org_filter self.site_config.values['course_org_filter'] = course_org_filter
self.site_config.save() self.site_config.save()
exclude_orgs, org_list = self.resolver.get_course_org_filter() mock_query = Mock()
assert not exclude_orgs result = self.resolver.filter_by_org(mock_query)
assert org_list == expected_org_list self.assertEqual(result, mock_query.filter.return_value)
mock_query.filter.assert_called_once_with(enrollment__course__org__in=expected_org_list)
@ddt.unpack @ddt.unpack
@ddt.data( @ddt.data(
(None, []), (None, set([])),
('course1', [u'course1']), ('course1', set([u'course1'])),
(['course1', 'course2'], [u'course1', u'course2']) (['course1', 'course2'], set([u'course1', u'course2']))
) )
def test_get_course_org_filter_exclude(self, course_org_filter, expected_org_list): def test_get_course_org_filter_exclude__in(self, course_org_filter, expected_org_list):
SiteConfigurationFactory.create( SiteConfigurationFactory.create(
values={'course_org_filter': course_org_filter}, values={'course_org_filter': course_org_filter},
) )
exclude_orgs, org_list = self.resolver.get_course_org_filter() mock_query = Mock()
assert exclude_orgs result = self.resolver.filter_by_org(mock_query)
self.assertItemsEqual(org_list, expected_org_list) mock_query.exclude.assert_called_once_with(enrollment__course__org__in=expected_org_list)
self.assertEqual(result, mock_query.exclude.return_value)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment