Commit c0c5b3a1 by Eugeny Kolpakov

Merge pull request #398 from open-craft/eugeny/export-discussion-participation-tests

Tests for export discussion participation command
parents b250411c 262675c9
...@@ -17,7 +17,7 @@ from lms.lib.comment_client.user import User ...@@ -17,7 +17,7 @@ from lms.lib.comment_client.user import User
import django_comment_client.utils as utils import django_comment_client.utils as utils
class _Fields: class DiscussionExportFields(object):
""" Container class for field names """ """ Container class for field names """
USER_ID = u"id" USER_ID = u"id"
USERNAME = u"username" USERNAME = u"username"
...@@ -32,18 +32,6 @@ class _Fields: ...@@ -32,18 +32,6 @@ class _Fields:
COMMENTS_GENERATED = u"num_comments_generated" COMMENTS_GENERATED = u"num_comments_generated"
def _make_social_stats(threads=0, comments=0, replies=0, upvotes=0, followers=0, comments_generated=0):
""" Builds social stats with values specified """
return {
_Fields.THREADS: threads,
_Fields.COMMENTS: comments,
_Fields.REPLIES: replies,
_Fields.UPVOTES: upvotes,
_Fields.FOLOWERS: followers,
_Fields.COMMENTS_GENERATED: comments_generated,
}
class Command(BaseCommand): class Command(BaseCommand):
""" """
Exports discussion participation per course Exports discussion participation per course
...@@ -99,48 +87,6 @@ class Command(BaseCommand): ...@@ -99,48 +87,6 @@ class Command(BaseCommand):
), ),
) )
row_order = [
_Fields.USERNAME, _Fields.EMAIL, _Fields.FIRST_NAME, _Fields.LAST_NAME, _Fields.USER_ID,
_Fields.THREADS, _Fields.COMMENTS, _Fields.REPLIES,
_Fields.UPVOTES, _Fields.FOLOWERS, _Fields.COMMENTS_GENERATED
]
def _get_users(self, course_key):
""" Returns users enrolled to a course as dictionary user_id => user """
users = CourseEnrollment.users_enrolled_in(course_key)
return {user.id: user for user in users}
def _get_social_stats(self, course_key, end_date=None, thread_type=None):
""" Gets social stats for course with specified filter parameters """
date = dateutil.parser.parse(end_date) if end_date else None
return {
int(user_id): data for user_id, data
in User.all_social_stats(str(course_key), end_date=date, thread_type=thread_type).iteritems()
}
def _merge_user_data_and_social_stats(self, userdata, social_stats):
""" Merges user data (email, username, etc.) and discussion stats """
result = []
for user_id, user in userdata.iteritems():
user_record = {
_Fields.USER_ID: user.id,
_Fields.USERNAME: user.username,
_Fields.EMAIL: user.email,
_Fields.FIRST_NAME: user.first_name,
_Fields.LAST_NAME: user.last_name,
}
stats = social_stats.get(user_id, _make_social_stats())
result.append(utils.merge_dict(user_record, stats))
return result
def _output(self, data, output_stream):
""" Exports data in csv format to specified output stream """
csv_writer = csv.DictWriter(output_stream, self.row_order)
csv_writer.writeheader()
for row in sorted(data, key=lambda item: item['username']):
to_write = {key: value for key, value in row.items() if key in self.row_order}
csv_writer.writerow(to_write)
def _get_filter_string_representation(self, options): def _get_filter_string_representation(self, options):
""" Builds human-readable filter parameters representation """ """ Builds human-readable filter parameters representation """
filter_strs = [] filter_strs = []
...@@ -178,20 +124,92 @@ class Command(BaseCommand): ...@@ -178,20 +124,92 @@ class Command(BaseCommand):
if not course: if not course:
raise CommandError("Invalid course id: {}".format(course_key)) raise CommandError("Invalid course id: {}".format(course_key))
users = self._get_users(course_key) raw_end_date = options.get(self.END_DATE_PARAMETER, None)
social_stats = self._get_social_stats( end_date = dateutil.parser.parse(raw_end_date) if raw_end_date else None
data = Extractor().extract(
course_key, course_key,
end_date=options.get(self.END_DATE_PARAMETER, None), end_date=end_date,
thread_type=options.get(self.THREAD_TYPE_PARAMETER, None) thread_type=(options.get(self.THREAD_TYPE_PARAMETER, None))
) )
merged_data = self._merge_user_data_and_social_stats(users, social_stats)
filter_str = self._get_filter_string_representation(options) filter_str = self._get_filter_string_representation(options)
self.stdout.write("Writing social stats ({filters}) to {file}\n".format( self.stdout.write("Writing social stats ({}) to {}\n".format(filter_str, output_file_location))
filters=filter_str, file=output_file_location
))
with open(output_file_location, 'wb') as output_stream: with open(output_file_location, 'wb') as output_stream:
self._output(merged_data, output_stream) Exporter(output_stream).export(data)
self.stdout.write("Success!\n") self.stdout.write("Success!\n")
class Extractor(object):
""" Extracts discussion participation data from db and cs_comments_service """
@classmethod
def _make_social_stats(cls, threads=0, comments=0, replies=0, upvotes=0, followers=0, comments_generated=0):
""" Builds social stats with values specified """
return {
DiscussionExportFields.THREADS: threads,
DiscussionExportFields.COMMENTS: comments,
DiscussionExportFields.REPLIES: replies,
DiscussionExportFields.UPVOTES: upvotes,
DiscussionExportFields.FOLOWERS: followers,
DiscussionExportFields.COMMENTS_GENERATED: comments_generated,
}
def _get_users(self, course_key):
""" Returns users enrolled to a course as dictionary user_id => user """
users = CourseEnrollment.users_enrolled_in(course_key)
return {user.id: user for user in users}
def _get_social_stats(self, course_key, end_date=None, thread_type=None):
""" Gets social stats for course with specified filter parameters """
return {
int(user_id): data for user_id, data
in User.all_social_stats(str(course_key), end_date=end_date, thread_type=thread_type).iteritems()
}
def _merge_user_data_and_social_stats(self, userdata, social_stats):
""" Merges user data (email, username, etc.) and discussion stats """
result = []
for user_id, user in userdata.iteritems():
user_record = {
DiscussionExportFields.USER_ID: user.id,
DiscussionExportFields.USERNAME: user.username,
DiscussionExportFields.EMAIL: user.email,
DiscussionExportFields.FIRST_NAME: user.first_name,
DiscussionExportFields.LAST_NAME: user.last_name,
}
stats = social_stats.get(user_id, self._make_social_stats())
result.append(utils.merge_dict(user_record, stats))
return result
def extract(self, course_key, end_date=None, thread_type=None):
""" Extracts and merges data according to course key and filter parameters """
users = self._get_users(course_key)
social_stats = self._get_social_stats(
course_key,
end_date=end_date,
thread_type=thread_type
)
return self._merge_user_data_and_social_stats(users, social_stats)
class Exporter(object):
""" Exports data to csv """
def __init__(self, output_stream):
self.stream = output_stream
row_order = [
DiscussionExportFields.USERNAME, DiscussionExportFields.EMAIL, DiscussionExportFields.FIRST_NAME,
DiscussionExportFields.LAST_NAME, DiscussionExportFields.USER_ID,
DiscussionExportFields.THREADS, DiscussionExportFields.COMMENTS, DiscussionExportFields.REPLIES,
DiscussionExportFields.UPVOTES, DiscussionExportFields.FOLOWERS, DiscussionExportFields.COMMENTS_GENERATED
]
def export(self, data):
""" Exports data in csv format to specified output stream """
csv_writer = csv.DictWriter(self.stream, self.row_order)
csv_writer.writeheader()
for row in sorted(data, key=lambda item: item['username']):
to_write = {key: value for key, value in row.items() if key in self.row_order}
csv_writer.writerow(to_write)
...@@ -656,7 +656,7 @@ class FormatFilenameTests(TestCase): ...@@ -656,7 +656,7 @@ class FormatFilenameTests(TestCase):
("normal_with_alnum.csv", "normal_with_alnum.csv"), ("normal_with_alnum.csv", "normal_with_alnum.csv"),
("normal_with_multiple_extensions.dot.csv", "normal_with_multiple_extensions.dot.csv"), ("normal_with_multiple_extensions.dot.csv", "normal_with_multiple_extensions.dot.csv"),
("contains/slashes.html", "containsslashes.html"), ("contains/slashes.html", "containsslashes.html"),
("contains_symbols!@#$%^&*+=\|,.html", "contains_symbols.html"), (r"contains_symbols!@#$%^&*+=\|,.html", "contains_symbols.html"),
("contains spaces.org", "contains_spaces.org"), ("contains spaces.org", "contains_spaces.org"),
) )
def test_format_filename(self, raw_filename, expected_output): def test_format_filename(self, raw_filename, expected_output):
......
""" Utility functions for django_comment_client """
import string
import pytz import pytz
import string # http://www.logilab.org/ticket/2481 pylint: disable=deprecated-module
from collections import defaultdict from collections import defaultdict
import logging import logging
from datetime import datetime from datetime import datetime
...@@ -16,6 +17,7 @@ from django_comment_client.permissions import check_permissions_by_view, cached_ ...@@ -16,6 +17,7 @@ from django_comment_client.permissions import check_permissions_by_view, cached_
from edxmako import lookup_template from edxmako import lookup_template
import pystache_custom as pystache import pystache_custom as pystache
from xmodule.modulestore.django import modulestore from xmodule.modulestore.django import modulestore
from django.utils.timezone import UTC from django.utils.timezone import UTC
from opaque_keys.edx.locations import i4xEncoder from opaque_keys.edx.locations import i4xEncoder
...@@ -440,7 +442,7 @@ def safe_content(content, course_id, is_staff=False): ...@@ -440,7 +442,7 @@ def safe_content(content, course_id, is_staff=False):
return content return content
def format_filename(s): def format_filename(filename):
"""Take a string and return a valid filename constructed from the string. """Take a string and return a valid filename constructed from the string.
Uses a whitelist approach: any characters not present in valid_chars are Uses a whitelist approach: any characters not present in valid_chars are
removed. Also spaces are replaced with underscores. removed. Also spaces are replaced with underscores.
...@@ -453,6 +455,6 @@ def format_filename(s): ...@@ -453,6 +455,6 @@ def format_filename(s):
https://gist.github.com/seanh/93666 https://gist.github.com/seanh/93666
""" """
valid_chars = "-_.() %s%s" % (string.ascii_letters, string.digits) valid_chars = "-_.() %s%s" % (string.ascii_letters, string.digits)
filename = ''.join(c for c in s if c in valid_chars) filename = ''.join(c for c in filename if c in valid_chars)
filename = filename.replace(' ', '_') filename = filename.replace(' ', '_')
return filename return filename
""" Unit tests for comment_client package"""
import ddt
import mock
from datetime import datetime
from django.test import TestCase
from opaque_keys.edx.locator import CourseLocator
from lms.lib.comment_client import User, CommentClientRequestError
from lms.lib.comment_client.user import get_user_social_stats
TEST_ORG = 'test_org'
TEST_COURSE_ID = 'test_id'
TEST_RUN = 'test_run'
@ddt.ddt
class UserTests(TestCase):
""" Tests for User model """
@ddt.unpack
@ddt.data(
(CourseLocator(TEST_ORG, TEST_COURSE_ID, TEST_RUN), None, None, {}),
(CourseLocator(TEST_ORG, TEST_COURSE_ID, TEST_RUN), datetime(2015, 01, 01), None, {'1': 1}),
(CourseLocator("edX", "DemoX", "now"), datetime(2014, 12, 03, 18, 15, 44), None, {'1': {'num_threads': 10}}),
(CourseLocator("edX", "DemoX", "now"), datetime(2016, 03, 17, 22, 54, 03), 'discussion', {}),
(CourseLocator("Umbrella", "ZMB101", "T1"), datetime(2016, 03, 17, 22, 54, 03), 'question', {'num_threads': 5}),
)
def test_all_social_stats_sends_correct_request(self, course_key, end_date, thread_type, expected_result):
"""
Tests that all_social_stats classmethod invokes get_user_social_stats with correct parameters
when optional parameters are explicitly specified
"""
with mock.patch("lms.lib.comment_client.user.get_user_social_stats") as patched_stats:
patched_stats.return_value = expected_result
result = User.all_social_stats(course_key, end_date, thread_type)
self.assertEqual(result, expected_result)
patched_stats.assert_called_once_with('*', course_key, end_date=end_date, thread_type=thread_type)
def test_all_social_stats_defaults(self):
"""
Tests that all_social_stats classmethod invokes get_user_social_stats with correct parameters
when optional parameters are omitted
"""
with mock.patch("lms.lib.comment_client.user.get_user_social_stats") as patched_stats:
patched_stats.return_value = {}
course_key = CourseLocator("edX", "demoX", "now")
User.all_social_stats(course_key)
patched_stats.assert_called_once_with('*', course_key, end_date=None, thread_type=None)
@ddt.ddt
class UtilityTests(TestCase):
""" Tests for utility functions found in user module """
def test_get_user_social_stats_given_none_course_id_raises(self):
with self.assertRaises(CommentClientRequestError):
get_user_social_stats('irrelevant', None)
@ddt.unpack
@ddt.data(
(1, CourseLocator("edX", "DemoX", "now"), None, None, "api/v1/users/1/social_stats", {}, {}),
(
2, CourseLocator("edX", "DemoX", "now"), datetime(2015, 01, 01), None,
"api/v1/users/2/social_stats", {'end_date': "2015-01-01T00:00:00"}, {'2': {'num_threads': 2}}
),
(
17, CourseLocator("otherX", "CourseX", "later"), datetime(2016, 07, 15), 'discussion',
"api/v1/users/44/social_stats", {'end_date': "2016-07-15T00:00:00", 'thread_type': 'discussion'},
{'2': {'num_threads': 42, 'num_comments': 7}}
),
(
42, CourseLocator("otherX", "CourseX", "later"), datetime(2011, 01, 9, 17, 24, 22), 'question',
"some/unrelated/url", {'end_date': "2011-01-09T17:24:22", 'thread_type': 'question'},
{'28': {'num_threads': 15, 'num_comments': 96}}
),
)
def test_get_user_social_stats(self, user_id, course_id, end_date, thread_type,
expected_url, expected_data, expected_result):
""" Tests get_user_social_stats utility function """
expected_data['course_id'] = course_id
with mock.patch("lms.lib.comment_client.user._url_for_user_social_stats") as patched_url_for_social_stats, \
mock.patch("lms.lib.comment_client.user.perform_request") as patched_perform_request:
patched_perform_request.return_value = expected_result
patched_url_for_social_stats.return_value = expected_url
result = get_user_social_stats(user_id, course_id, end_date=end_date, thread_type=thread_type)
patched_url_for_social_stats.assert_called_with(user_id)
patched_perform_request.assert_called_with('get', expected_url, expected_data)
self.assertEqual(result, expected_result)
...@@ -119,6 +119,7 @@ class User(models.Model): ...@@ -119,6 +119,7 @@ class User(models.Model):
@classmethod @classmethod
def all_social_stats(cls, course_id, end_date=None, thread_type=None): def all_social_stats(cls, course_id, end_date=None, thread_type=None):
""" Get social stats for all users participating in a course """
return get_user_social_stats('*', course_id, end_date=end_date, thread_type=thread_type) return get_user_social_stats('*', course_id, end_date=end_date, thread_type=thread_type)
def _retrieve(self, *args, **kwargs): def _retrieve(self, *args, **kwargs):
...@@ -153,6 +154,7 @@ class User(models.Model): ...@@ -153,6 +154,7 @@ class User(models.Model):
def get_user_social_stats(user_id, course_id, end_date=None, thread_type=None): def get_user_social_stats(user_id, course_id, end_date=None, thread_type=None):
""" Queries cs_comments_service for social_stats """
if not course_id: if not course_id:
raise CommentClientRequestError("Must provide course_id when retrieving social stats for the user") raise CommentClientRequestError("Must provide course_id when retrieving social stats for the user")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment