Commit 97c43ada by Daniel Friedman Committed by Andy Armstrong

Track user settings changed events on User save

parent 3f20a6f3
...@@ -37,11 +37,11 @@ from config_models.models import ConfigurationModel ...@@ -37,11 +37,11 @@ from config_models.models import ConfigurationModel
from track import contexts from track import contexts
from eventtracking import tracker from eventtracking import tracker
from importlib import import_module from importlib import import_module
from model_utils import FieldTracker
from opaque_keys.edx.locations import SlashSeparatedCourseKey from opaque_keys.edx.locations import SlashSeparatedCourseKey
import lms.lib.comment_client as cc import lms.lib.comment_client as cc
from util.model_utils import emit_field_changed_events, get_changed_fields_dict
from util.query import use_read_replica_if_available from util.query import use_read_replica_if_available
from xmodule_django.models import CourseKeyField, NoneToEmptyManager from xmodule_django.models import CourseKeyField, NoneToEmptyManager
from xmodule.modulestore.exceptions import ItemNotFoundError from xmodule.modulestore.exceptions import ItemNotFoundError
...@@ -255,9 +255,6 @@ class UserProfile(models.Model): ...@@ -255,9 +255,6 @@ class UserProfile(models.Model):
bio = models.CharField(blank=True, null=True, max_length=3000, db_index=False) bio = models.CharField(blank=True, null=True, max_length=3000, db_index=False)
profile_image_uploaded_at = models.DateTimeField(null=True) profile_image_uploaded_at = models.DateTimeField(null=True)
# Use FieldTracker to track changes to model instances.
tracker = FieldTracker()
@property @property
def has_profile_image(self): def has_profile_image(self):
""" """
...@@ -336,6 +333,11 @@ def user_profile_pre_save_callback(sender, **kwargs): ...@@ -336,6 +333,11 @@ def user_profile_pre_save_callback(sender, **kwargs):
if user_profile.requires_parental_consent() and user_profile.has_profile_image: if user_profile.requires_parental_consent() and user_profile.has_profile_image:
user_profile.profile_image_uploaded_at = None user_profile.profile_image_uploaded_at = None
# Cache "old" field values on the model instance so that they can be
# retrieved in the post_save callback when we emit an event with new and
# old field values.
user_profile._changed_fields = get_changed_fields_dict(user_profile, sender)
@receiver(post_save, sender=UserProfile) @receiver(post_save, sender=UserProfile)
def user_profile_post_save_callback(sender, **kwargs): def user_profile_post_save_callback(sender, **kwargs):
...@@ -345,40 +347,39 @@ def user_profile_post_save_callback(sender, **kwargs): ...@@ -345,40 +347,39 @@ def user_profile_post_save_callback(sender, **kwargs):
user_profile = kwargs['instance'] user_profile = kwargs['instance']
# pylint: disable=protected-access # pylint: disable=protected-access
emit_field_changed_events( emit_field_changed_events(
user_profile, user_profile.user, USER_SETTINGS_CHANGED_EVENT_NAME, sender._meta.db_table, ['meta'] user_profile,
user_profile.user,
USER_SETTINGS_CHANGED_EVENT_NAME,
sender._meta.db_table,
excluded_fields=['meta']
) )
def emit_field_changed_events(instance, user, event_name, db_table, excluded_fields=None): @receiver(pre_save, sender=User)
""" For the given model instance, save the fields that have changed since the last save. def user_pre_save_callback(sender, **kwargs):
"""
Requires that the Model uses 'model_utils.FieldTracker' and has assigned it to 'tracker'. Capture old fields on the user instance before save and cache them as a
private field on the current model for use in the post_save callback.
"""
user = kwargs['instance']
user._changed_fields = get_changed_fields_dict(user, sender)
Args:
instance (Model instance): the model instance that is being saved
user (User): the user that this instance is associated with
event_name (str): the name of the event to be emitted
db_table (str): the name of the table that we're modifying
excluded_fields (list): a list of field names for which events should
not be emitted
Returns: @receiver(post_save, sender=User)
None def user_post_save_callback(sender, **kwargs):
""" """
excluded_fields = excluded_fields or [] Emit analytics events after saving the User.
changed_fields = instance.tracker.changed() """
for field in changed_fields: user = kwargs['instance']
if field not in excluded_fields: # pylint: disable=protected-access
tracker.emit( emit_field_changed_events(
event_name, user,
{ user,
"setting": field, USER_SETTINGS_CHANGED_EVENT_NAME,
'old': changed_fields[field], sender._meta.db_table,
'new': getattr(instance, field), excluded_fields=['last_login'],
"user_id": user.id, hidden_fields=['password']
"table": db_table )
}
)
class UserSignupSource(models.Model): class UserSignupSource(models.Model):
......
...@@ -4,16 +4,20 @@ Test that various events are fired for models in the student app. ...@@ -4,16 +4,20 @@ Test that various events are fired for models in the student app.
""" """
from django.test import TestCase from django.test import TestCase
from django_countries.fields import Country
from student.models import PasswordHistory, USER_SETTINGS_CHANGED_EVENT_NAME
from student.tests.factories import UserFactory from student.tests.factories import UserFactory
from student.tests.tests import UserProfileEventTestMixin from student.tests.tests import UserSettingsEventTestMixin
class TestUserProfileEvents(UserProfileEventTestMixin, TestCase): class TestUserProfileEvents(UserSettingsEventTestMixin, TestCase):
""" """
Test that we emit field change events when UserProfile models are changed. Test that we emit field change events when UserProfile models are changed.
""" """
def setUp(self): def setUp(self):
super(TestUserProfileEvents, self).setUp() super(TestUserProfileEvents, self).setUp()
self.table = 'auth_userprofile'
self.user = UserFactory.create() self.user = UserFactory.create()
self.profile = self.user.profile self.profile = self.user.profile
self.reset_tracker() self.reset_tracker()
...@@ -25,7 +29,12 @@ class TestUserProfileEvents(UserProfileEventTestMixin, TestCase): ...@@ -25,7 +29,12 @@ class TestUserProfileEvents(UserProfileEventTestMixin, TestCase):
""" """
self.profile.year_of_birth = 1900 self.profile.year_of_birth = 1900
self.profile.save() self.profile.save()
self.assert_profile_event_emitted(setting='year_of_birth', old=None, new=self.profile.year_of_birth) self.assert_user_setting_event_emitted(setting='year_of_birth', old=None, new=self.profile.year_of_birth)
# Verify that we remove the temporary `_changed_fields` property from
# the model after we're done emitting events.
with self.assertRaises(AttributeError):
getattr(self.profile, '_changed_fields')
def test_change_many_fields(self): def test_change_many_fields(self):
""" """
...@@ -35,8 +44,8 @@ class TestUserProfileEvents(UserProfileEventTestMixin, TestCase): ...@@ -35,8 +44,8 @@ class TestUserProfileEvents(UserProfileEventTestMixin, TestCase):
self.profile.gender = u'o' self.profile.gender = u'o'
self.profile.bio = 'test bio' self.profile.bio = 'test bio'
self.profile.save() self.profile.save()
self.assert_profile_event_emitted(setting='bio', old=None, new=self.profile.bio) self.assert_user_setting_event_emitted(setting='bio', old=None, new=self.profile.bio)
self.assert_profile_event_emitted(setting='gender', old=u'm', new=u'o') self.assert_user_setting_event_emitted(setting='gender', old=u'm', new=u'o')
def test_unicode(self): def test_unicode(self):
""" """
...@@ -45,4 +54,69 @@ class TestUserProfileEvents(UserProfileEventTestMixin, TestCase): ...@@ -45,4 +54,69 @@ class TestUserProfileEvents(UserProfileEventTestMixin, TestCase):
old_name = self.profile.name old_name = self.profile.name
self.profile.name = u'Dånîél' self.profile.name = u'Dånîél'
self.profile.save() self.profile.save()
self.assert_profile_event_emitted(setting='name', old=old_name, new=self.profile.name) self.assert_user_setting_event_emitted(setting='name', old=old_name, new=self.profile.name)
def test_country(self):
"""
Verify that we properly serialize the JSON-unfriendly Country field.
"""
self.profile.country = Country(u'AL', 'dummy_flag_url')
self.profile.save()
self.assert_user_setting_event_emitted(setting='country', old=None, new=self.profile.country)
def test_excluded_field(self):
"""
Verify that we don't emit events for ignored fields.
"""
self.profile.meta = {u'foo': u'bar'}
self.profile.save()
self.assert_no_events_were_emitted()
class TestUserEvents(UserSettingsEventTestMixin, TestCase):
"""
Test that we emit field change events when User models are changed.
"""
def setUp(self):
super(TestUserEvents, self).setUp()
self.user = UserFactory.create()
self.reset_tracker()
self.table = 'auth_user'
def test_change_one_field(self):
"""
Verify that we emit an event when a single field changes on the user.
"""
old_username = self.user.username
self.user.username = u'new username'
self.user.save()
self.assert_user_setting_event_emitted(setting='username', old=old_username, new=self.user.username)
def test_change_many_fields(self):
"""
Verify that we emit one event per field when many fields change on the
user in one transaction.
"""
old_email = self.user.email
old_is_staff = self.user.is_staff
self.user.email = u'foo@bar.com'
self.user.is_staff = True
self.user.save()
self.assert_user_setting_event_emitted(setting='email', old=old_email, new=self.user.email)
self.assert_user_setting_event_emitted(setting='is_staff', old=old_is_staff, new=self.user.is_staff)
def test_password(self):
"""
Verify that password values are not included in the event payload.
"""
self.user.password = u'new password'
self.user.save()
self.assert_user_setting_event_emitted(setting='password', old=None, new=None)
def test_related_fields_ignored(self):
"""
Verify that we don't emit events for related fields.
"""
self.user.passwordhistory_set.add(PasswordHistory(password='new_password'))
self.user.save()
self.assert_no_events_were_emitted()
...@@ -486,14 +486,14 @@ class DashboardTest(ModuleStoreTestCase): ...@@ -486,14 +486,14 @@ class DashboardTest(ModuleStoreTestCase):
self.assertContains(response, expected_url) self.assertContains(response, expected_url)
class UserProfileEventTestMixin(EventTestMixin): class UserSettingsEventTestMixin(EventTestMixin):
""" """
Mixin for verifying that UserProfile events were emitted during a test. Mixin for verifying that user setting events were emitted during a test.
""" """
def setUp(self): def setUp(self):
super(UserProfileEventTestMixin, self).setUp('student.models.tracker') super(UserSettingsEventTestMixin, self).setUp('util.model_utils.tracker')
def assert_profile_event_emitted(self, **kwargs): def assert_user_setting_event_emitted(self, **kwargs):
""" """
Helper method to assert that we emit the expected user settings events. Helper method to assert that we emit the expected user settings events.
...@@ -501,7 +501,7 @@ class UserProfileEventTestMixin(EventTestMixin): ...@@ -501,7 +501,7 @@ class UserProfileEventTestMixin(EventTestMixin):
""" """
self.assert_event_emitted( self.assert_event_emitted(
USER_SETTINGS_CHANGED_EVENT_NAME, USER_SETTINGS_CHANGED_EVENT_NAME,
table='auth_userprofile', table=self.table, # pylint: disable=no-member
user_id=self.user.id, user_id=self.user.id,
**kwargs **kwargs
) )
......
"""
Utilities for django models.
"""
from eventtracking import tracker
from django.core.exceptions import ObjectDoesNotExist
from django.db.models.fields.related import RelatedField
from django_countries.fields import Country
def get_changed_fields_dict(instance, model_class):
"""
Helper method for tracking field changes on a model.
Given a model instance and class, return a dict whose keys are that
instance's fields which differ from the last saved ones and whose values
are the old values of those fields. Related fields are not considered.
Args:
instance (Model instance): the model instance with changes that are
being tracked
model_class (Model class): the class of the model instance we are
tracking
Returns:
dict: a mapping of field names to current database values of those
fields, or an empty dicit if the model is new
"""
try:
old_model = model_class.objects.get(pk=instance.pk)
except model_class.DoesNotExist:
# Object is new, so fields haven't technically changed. We'll return
# an empty dict as a default value.
return {}
else:
field_names = [
field[0].name for field in model_class._meta.get_fields_with_model()
]
changed_fields = {
field_name: getattr(old_model, field_name) for field_name in field_names
if getattr(old_model, field_name) != getattr(instance, field_name)
}
return changed_fields
def emit_field_changed_events(instance, user, event_name, db_table, excluded_fields=None, hidden_fields=None):
"""
For the given model instance, emit a setting changed event the fields that
have changed since the last save.
Note that this function expects that a `_changed_fields` dict has been set
as an attribute on `instance` (see `get_changed_fields_dict`.
Args:
instance (Model instance): the model instance that is being saved
user (User): the user that this instance is associated with
event_name (str): the name of the event to be emitted
db_table (str): the name of the table that we're modifying
excluded_fields (list): a list of field names for which events should
not be emitted
hidden_fields (list): a list of field names specifying fields whose
values should not be included in the event (None will be used
instead)
Returns:
None
"""
def clean_field(field_name, value):
"""
Prepare a field to be emitted in a JSON serializable format. If
`field_name` is a hidden field, return None.
"""
if field_name in hidden_fields:
return None
# Country is not JSON serializable. Return the country code.
if isinstance(value, Country):
if value.code:
return value.code
else:
return None
return value
excluded_fields = excluded_fields or []
hidden_fields = hidden_fields or []
changed_fields = getattr(instance, '_changed_fields', {})
for field_name in changed_fields:
if field_name not in excluded_fields:
tracker.emit(
event_name,
{
"setting": field_name,
'old': clean_field(field_name, changed_fields[field_name]),
'new': clean_field(field_name, getattr(instance, field_name)),
"user_id": user.id,
"table": db_table
}
)
# Remove the now inaccurate _changed_fields attribute.
if getattr(instance, '_changed_fields', None):
del instance._changed_fields
...@@ -279,23 +279,25 @@ class EventsTestMixin(object): ...@@ -279,23 +279,25 @@ class EventsTestMixin(object):
self.event_collection = MongoClient()["test"]["events"] self.event_collection = MongoClient()["test"]["events"]
self.reset_event_tracking() self.reset_event_tracking()
def assert_event_emitted_num_times(self, event_name, event_time, event_user_id, num_times_emitted): def assert_event_emitted_num_times(self, event_name, event_time, event_user_id, num_times_emitted, **kwargs):
""" """
Tests the number of times a particular event was emitted. Tests the number of times a particular event was emitted.
Extra kwargs get passed to the mongo query in the form: "event.<key>: value".
:param event_name: Expected event name (e.g., "edx.course.enrollment.activated") :param event_name: Expected event name (e.g., "edx.course.enrollment.activated")
:param event_time: Latest expected time, after which the event would fire (e.g., the beginning of the test case) :param event_time: Latest expected time, after which the event would fire (e.g., the beginning of the test case)
:param event_user_id: user_id expected in the event :param event_user_id: user_id expected in the event
:param num_times_emitted: number of times the event is expected to appear since the event_time :param num_times_emitted: number of times the event is expected to appear since the event_time
""" """
self.assertEqual( find_kwargs = {
self.event_collection.find( "name": event_name,
{ "time": {"$gt": event_time},
"name": event_name, "event.user_id": int(event_user_id),
"time": {"$gt": event_time}, }
"event.user_id": int(event_user_id), find_kwargs.update({"event.{}".format(key): value for key, value in kwargs.items()})
} matching_events = self.event_collection.find(find_kwargs)
).count(), num_times_emitted self.assertEqual(matching_events.count(), num_times_emitted, '\n'.join(str(event) for event in matching_events))
)
def reset_event_tracking(self): def reset_event_tracking(self):
""" """
...@@ -315,7 +317,6 @@ class EventsTestMixin(object): ...@@ -315,7 +317,6 @@ class EventsTestMixin(object):
def verify_events_of_type(self, event_type, expected_events, expected_referers=None): def verify_events_of_type(self, event_type, expected_events, expected_referers=None):
"""Verify that the expected events of a given type were logged. """Verify that the expected events of a given type were logged.
Args: Args:
event_type (str): The type of event to be verified. event_type (str): The type of event to be verified.
expected_events (list): A list of dicts representing the events that should expected_events (list): A list of dicts representing the events that should
......
...@@ -19,10 +19,11 @@ class AccountSettingsTestMixin(EventsTestMixin, WebAppTest): ...@@ -19,10 +19,11 @@ class AccountSettingsTestMixin(EventsTestMixin, WebAppTest):
Mixin with helper methods to test the account settings page. Mixin with helper methods to test the account settings page.
""" """
USERNAME = "test" USERNAME = u"test"
PASSWORD = "testpass" PASSWORD = "testpass"
EMAIL = u"test@example.com" EMAIL = u"test@example.com"
CHANGE_INITIATED_EVENT_NAME = u"edx.user.settings.change_initiated" CHANGE_INITIATED_EVENT_NAME = u"edx.user.settings.change_initiated"
USER_SETTINGS_CHANGED_EVENT_NAME = 'edx.user.settings.changed'
ACCOUNT_SETTINGS_REFERER = u"/account/settings" ACCOUNT_SETTINGS_REFERER = u"/account/settings"
def setUp(self): def setUp(self):
...@@ -35,6 +36,26 @@ class AccountSettingsTestMixin(EventsTestMixin, WebAppTest): ...@@ -35,6 +36,26 @@ class AccountSettingsTestMixin(EventsTestMixin, WebAppTest):
self.browser, username=self.USERNAME, password=self.PASSWORD, email=self.EMAIL self.browser, username=self.USERNAME, password=self.PASSWORD, email=self.EMAIL
).visit().get_user_id() ).visit().get_user_id()
def assert_event_emitted_num_times(self, setting, num_times):
"""
Verify a particular user settings change event was emitted a certain
number of times.
"""
# pylint disable=no-member
super(AccountSettingsTestMixin, self).assert_event_emitted_num_times(
self.USER_SETTINGS_CHANGED_EVENT_NAME, self.start_time, self.user_id, num_times, setting=setting
)
def verify_settings_changed_events(self, events):
"""
Verify a particular set of account settings change events were fired.
"""
expected_referers = [self.ACCOUNT_SETTINGS_REFERER] * len(events)
for event in events:
event[u'user_id'] = long(self.user_id)
event[u'table'] = u"auth_userprofile"
self.verify_events_of_type(self.USER_SETTINGS_CHANGED_EVENT_NAME, events, expected_referers=expected_referers)
class DashboardMenuTest(AccountSettingsTestMixin, WebAppTest): class DashboardMenuTest(AccountSettingsTestMixin, WebAppTest):
""" """
...@@ -207,7 +228,18 @@ class AccountSettingsPageTest(AccountSettingsTestMixin, WebAppTest): ...@@ -207,7 +228,18 @@ class AccountSettingsPageTest(AccountSettingsTestMixin, WebAppTest):
[u'another name', self.USERNAME], [u'another name', self.USERNAME],
) )
self.assert_event_emitted_num_times('edx.user.settings.changed', self.start_time, self.user_id, 2) self.verify_settings_changed_events(
[{
u"setting": u"name",
u"old": self.USERNAME,
u"new": u"another name",
},
{
u"setting": u"name",
u"old": u'another name',
u"new": self.USERNAME,
}]
)
def test_email_field(self): def test_email_field(self):
""" """
...@@ -241,6 +273,9 @@ class AccountSettingsPageTest(AccountSettingsTestMixin, WebAppTest): ...@@ -241,6 +273,9 @@ class AccountSettingsPageTest(AccountSettingsTestMixin, WebAppTest):
], ],
[self.ACCOUNT_SETTINGS_REFERER, self.ACCOUNT_SETTINGS_REFERER] [self.ACCOUNT_SETTINGS_REFERER, self.ACCOUNT_SETTINGS_REFERER]
) )
# Email is not saved until user confirms, so no events should have been
# emitted.
self.assert_event_emitted_num_times('email', 0)
def test_password_field(self): def test_password_field(self):
""" """
...@@ -263,6 +298,9 @@ class AccountSettingsPageTest(AccountSettingsTestMixin, WebAppTest): ...@@ -263,6 +298,9 @@ class AccountSettingsPageTest(AccountSettingsTestMixin, WebAppTest):
}], }],
[self.ACCOUNT_SETTINGS_REFERER] [self.ACCOUNT_SETTINGS_REFERER]
) )
# Like email, since the user has not confirmed their password change,
# the field has not yet changed, so no events will have been emitted.
self.assert_event_emitted_num_times('password', 0)
@skip( @skip(
'On bokchoy test servers, language changes take a few reloads to fully realize ' 'On bokchoy test servers, language changes take a few reloads to fully realize '
...@@ -290,7 +328,18 @@ class AccountSettingsPageTest(AccountSettingsTestMixin, WebAppTest): ...@@ -290,7 +328,18 @@ class AccountSettingsPageTest(AccountSettingsTestMixin, WebAppTest):
u'', u'',
[u'Bachelor\'s degree', u''], [u'Bachelor\'s degree', u''],
) )
self.assert_event_emitted_num_times('edx.user.settings.changed', self.start_time, self.user_id, 2) self.verify_settings_changed_events(
[{
u"setting": u"level_of_education",
u"old": None,
u"new": u'b',
},
{
u"setting": u"level_of_education",
u"old": u'b',
u"new": None,
}]
)
def test_gender_field(self): def test_gender_field(self):
""" """
...@@ -302,7 +351,18 @@ class AccountSettingsPageTest(AccountSettingsTestMixin, WebAppTest): ...@@ -302,7 +351,18 @@ class AccountSettingsPageTest(AccountSettingsTestMixin, WebAppTest):
u'', u'',
[u'Female', u''], [u'Female', u''],
) )
self.assert_event_emitted_num_times('edx.user.settings.changed', self.start_time, self.user_id, 2) self.verify_settings_changed_events(
[{
u"setting": u"gender",
u"old": None,
u"new": u'f',
},
{
u"setting": u"gender",
u"old": u'f',
u"new": None,
}]
)
def test_year_of_birth_field(self): def test_year_of_birth_field(self):
""" """
...@@ -310,13 +370,25 @@ class AccountSettingsPageTest(AccountSettingsTestMixin, WebAppTest): ...@@ -310,13 +370,25 @@ class AccountSettingsPageTest(AccountSettingsTestMixin, WebAppTest):
""" """
# Note that when we clear the year_of_birth here we're firing an event. # Note that when we clear the year_of_birth here we're firing an event.
self.assertEqual(self.account_settings_page.value_for_dropdown_field('year_of_birth', ''), '') self.assertEqual(self.account_settings_page.value_for_dropdown_field('year_of_birth', ''), '')
self.reset_event_tracking()
self._test_dropdown_field( self._test_dropdown_field(
u'year_of_birth', u'year_of_birth',
u'Year of Birth', u'Year of Birth',
u'', u'',
[u'1980', u''], [u'1980', u''],
) )
self.assert_event_emitted_num_times('edx.user.settings.changed', self.start_time, self.user_id, 3) self.verify_settings_changed_events(
[{
u"setting": u"year_of_birth",
u"old": None,
u"new": 1980,
},
{
u"setting": u"year_of_birth",
u"old": 1980,
u"new": None,
}]
)
def test_country_field(self): def test_country_field(self):
""" """
...@@ -328,6 +400,18 @@ class AccountSettingsPageTest(AccountSettingsTestMixin, WebAppTest): ...@@ -328,6 +400,18 @@ class AccountSettingsPageTest(AccountSettingsTestMixin, WebAppTest):
u'', u'',
[u'Pakistan', u'Palau'], [u'Pakistan', u'Palau'],
) )
self.verify_settings_changed_events(
[{
u"setting": u"country",
u"old": None,
u"new": u'PK',
},
{
u"setting": u"country",
u"old": u'PK',
u"new": u'PW',
}]
)
def test_preferred_language_field(self): def test_preferred_language_field(self):
""" """
......
...@@ -14,7 +14,7 @@ from PIL import Image ...@@ -14,7 +14,7 @@ from PIL import Image
from rest_framework.test import APITestCase, APIClient from rest_framework.test import APITestCase, APIClient
from student.tests.factories import UserFactory from student.tests.factories import UserFactory
from student.tests.tests import UserProfileEventTestMixin from student.tests.tests import UserSettingsEventTestMixin
from ...user_api.accounts.image_helpers import ( from ...user_api.accounts.image_helpers import (
set_has_profile_image, set_has_profile_image,
...@@ -29,7 +29,7 @@ TEST_PASSWORD = "test" ...@@ -29,7 +29,7 @@ TEST_PASSWORD = "test"
TEST_UPLOAD_DT = datetime.datetime(2002, 1, 9, 15, 43, 01, tzinfo=UTC) TEST_UPLOAD_DT = datetime.datetime(2002, 1, 9, 15, 43, 01, tzinfo=UTC)
class ProfileImageEndpointTestCase(UserProfileEventTestMixin, APITestCase): class ProfileImageEndpointTestCase(UserSettingsEventTestMixin, APITestCase):
""" """
Base class / shared infrastructure for tests of profile_image "upload" and Base class / shared infrastructure for tests of profile_image "upload" and
"remove" endpoints. "remove" endpoints.
...@@ -47,6 +47,7 @@ class ProfileImageEndpointTestCase(UserProfileEventTestMixin, APITestCase): ...@@ -47,6 +47,7 @@ class ProfileImageEndpointTestCase(UserProfileEventTestMixin, APITestCase):
self.url = reverse(self._view_name, kwargs={'username': self.user.username}) self.url = reverse(self._view_name, kwargs={'username': self.user.username})
self.client.login(username=self.user.username, password=TEST_PASSWORD) self.client.login(username=self.user.username, password=TEST_PASSWORD)
self.storage = get_profile_image_storage() self.storage = get_profile_image_storage()
self.table = 'auth_userprofile'
# this assertion is made here as a sanity check because all tests # this assertion is made here as a sanity check because all tests
# assume user.profile.has_profile_image is False by default # assume user.profile.has_profile_image is False by default
self.assertFalse(self.user.profile.has_profile_image) self.assertFalse(self.user.profile.has_profile_image)
...@@ -114,7 +115,7 @@ class ProfileImageUploadTestCase(ProfileImageEndpointTestCase): ...@@ -114,7 +115,7 @@ class ProfileImageUploadTestCase(ProfileImageEndpointTestCase):
Make sure we emit a UserProfile event corresponding to the Make sure we emit a UserProfile event corresponding to the
profile_image_uploaded_at field changing. profile_image_uploaded_at field changing.
""" """
self.assert_profile_event_emitted( self.assert_user_setting_event_emitted(
setting='profile_image_uploaded_at', old=None, new=TEST_UPLOAD_DT setting='profile_image_uploaded_at', old=None, new=TEST_UPLOAD_DT
) )
...@@ -279,7 +280,7 @@ class ProfileImageRemoveTestCase(ProfileImageEndpointTestCase): ...@@ -279,7 +280,7 @@ class ProfileImageRemoveTestCase(ProfileImageEndpointTestCase):
Make sure we emit a UserProfile event corresponding to the Make sure we emit a UserProfile event corresponding to the
profile_image_uploaded_at field changing. profile_image_uploaded_at field changing.
""" """
self.assert_profile_event_emitted( self.assert_user_setting_event_emitted(
setting='profile_image_uploaded_at', old=TEST_UPLOAD_DT, new=None setting='profile_image_uploaded_at', old=TEST_UPLOAD_DT, new=None
) )
......
...@@ -6,6 +6,7 @@ import datetime ...@@ -6,6 +6,7 @@ import datetime
import logging import logging
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
from django.utils.timezone import utc
from rest_framework import permissions, status from rest_framework import permissions, status
from rest_framework.parsers import MultiPartParser, FormParser from rest_framework.parsers import MultiPartParser, FormParser
from rest_framework.response import Response from rest_framework.response import Response
...@@ -31,7 +32,7 @@ def _make_upload_dt(): ...@@ -31,7 +32,7 @@ def _make_upload_dt():
Generate a server-side timestamp for the upload. This is in a separate Generate a server-side timestamp for the upload. This is in a separate
function so its behavior can be overridden in tests. function so its behavior can be overridden in tests.
""" """
return datetime.datetime.utcnow() return datetime.datetime.utcnow().replace(tzinfo=utc)
class ProfileImageUploadView(APIView): class ProfileImageUploadView(APIView):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment