Unverified Commit 07bb3ea5 by Calen Pennington Committed by Gabe Mulley

Make the exclude orgs a bit more natural by returning a filtered query

parent fa7b9a13
......@@ -78,7 +78,6 @@ class BinnedSchedulesBaseResolver(PrefixedDebugLoggerMixin, RecipientResolver):
def __attrs_post_init__(self):
# TODO: in the next refactor of this task, pass in current_datetime instead of reproducing it here
self.current_datetime = self.target_datetime - datetime.timedelta(days=self.day_offset)
self.exclude_orgs, self.org_list = self.get_course_org_filter()
def send(self, msg_type):
for (user, language, context) in self.schedules_for_bin():
......@@ -133,11 +132,7 @@ class BinnedSchedulesBaseResolver(PrefixedDebugLoggerMixin, RecipientResolver):
**schedule_day_equals_target_day_filter
).order_by(order_by)
if self.org_list is not None:
if self.exclude_orgs:
schedules = schedules.exclude(enrollment__course__org__in=self.org_list)
else:
schedules = schedules.filter(enrollment__course__org__in=self.org_list)
schedules = self.filter_by_org(schedules)
if "read_replica" in settings.DATABASES:
schedules = schedules.using("read_replica")
......@@ -153,7 +148,7 @@ class BinnedSchedulesBaseResolver(PrefixedDebugLoggerMixin, RecipientResolver):
return schedules
def get_course_org_filter(self):
def filter_by_org(self, schedules):
"""
Given the configuration of sites, get the list of orgs that should be included or excluded from this send.
......@@ -165,7 +160,6 @@ class BinnedSchedulesBaseResolver(PrefixedDebugLoggerMixin, RecipientResolver):
try:
site_config = self.site.configuration
org_list = site_config.get_value('course_org_filter')
exclude_orgs = False
if not org_list:
not_orgs = set()
for other_site_config in SiteConfiguration.objects.all():
......@@ -175,15 +169,13 @@ class BinnedSchedulesBaseResolver(PrefixedDebugLoggerMixin, RecipientResolver):
not_orgs.add(other)
else:
not_orgs.update(other)
org_list = list(not_orgs)
exclude_orgs = True
return schedules.exclude(enrollment__course__org__in=not_orgs)
elif not isinstance(org_list, list):
org_list = [org_list]
return schedules.filter(enrollment__course__org=org_list)
except SiteConfiguration.DoesNotExist:
org_list = None
exclude_orgs = False
return schedules
return exclude_orgs, org_list
return schedules.filter(enrollment__course__org__in=org_list)
def schedules_for_bin(self):
schedules = self.get_schedules_with_target_date_by_bin_and_orgs()
......
......@@ -30,28 +30,40 @@ class TestBinnedSchedulesBaseResolver(CacheIsolationTestCase):
bin_num=2,
)
@ddt.data(
'course1'
)
def test_get_course_org_filter_equal(self, course_org_filter):
self.site_config.values['course_org_filter'] = course_org_filter
self.site_config.save()
mock_query = Mock()
result = self.resolver.filter_by_org(mock_query)
self.assertEqual(result, mock_query.filter.return_value)
mock_query.filter.assert_called_once_with(enrollment__course__org=course_org_filter)
@ddt.unpack
@ddt.data(
('course1', ['course1']),
(['course1', 'course2'], ['course1', 'course2'])
)
def test_get_course_org_filter_include(self, course_org_filter, expected_org_list):
def test_get_course_org_filter_include__in(self, course_org_filter, expected_org_list):
self.site_config.values['course_org_filter'] = course_org_filter
self.site_config.save()
exclude_orgs, org_list = self.resolver.get_course_org_filter()
assert not exclude_orgs
assert org_list == expected_org_list
mock_query = Mock()
result = self.resolver.filter_by_org(mock_query)
self.assertEqual(result, mock_query.filter.return_value)
mock_query.filter.assert_called_once_with(enrollment__course__org__in=expected_org_list)
@ddt.unpack
@ddt.data(
(None, []),
('course1', [u'course1']),
(['course1', 'course2'], [u'course1', u'course2'])
(None, set([])),
('course1', set([u'course1'])),
(['course1', 'course2'], set([u'course1', u'course2']))
)
def test_get_course_org_filter_exclude(self, course_org_filter, expected_org_list):
def test_get_course_org_filter_exclude__in(self, course_org_filter, expected_org_list):
SiteConfigurationFactory.create(
values={'course_org_filter': course_org_filter},
)
exclude_orgs, org_list = self.resolver.get_course_org_filter()
assert exclude_orgs
self.assertItemsEqual(org_list, expected_org_list)
mock_query = Mock()
result = self.resolver.filter_by_org(mock_query)
mock_query.exclude.assert_called_once_with(enrollment__course__org__in=expected_org_list)
self.assertEqual(result, mock_query.exclude.return_value)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment