test_middleware.py 7.74 KB
Newer Older
1 2
# -*- coding: utf-8 -*-
"""Tests for tracking middleware."""
3
import ddt
4
from mock import patch
5
from mock import sentinel
6

7
from django.contrib.auth.models import User
8
from django.contrib.sessions.middleware import SessionMiddleware
9 10 11 12
from django.test import TestCase
from django.test.client import RequestFactory
from django.test.utils import override_settings

13
from eventtracking import tracker
14 15 16
from track.middleware import TrackMiddleware


17
@ddt.ddt
18
class TrackMiddlewareTestCase(TestCase):
19
    """  Class for checking tracking requests """
20 21

    def setUp(self):
22
        super(TrackMiddlewareTestCase, self).setUp()
23 24 25
        self.track_middleware = TrackMiddleware()
        self.request_factory = RequestFactory()

26 27 28 29 30
        patcher = patch('track.views.server_track')
        self.mock_server_track = patcher.start()
        self.addCleanup(patcher.stop)

    def test_normal_request(self):
31 32
        request = self.request_factory.get('/somewhere')
        self.track_middleware.process_request(request)
33
        self.assertTrue(self.mock_server_track.called)
34

35 36 37 38 39 40 41 42
    @ddt.unpack
    @ddt.data(
        ('HTTP_USER_AGENT', 'agent'),
        ('PATH_INFO', 'path'),
        ('HTTP_REFERER', 'referer'),
        ('HTTP_ACCEPT_LANGUAGE', 'accept_language'),
    )
    def test_request_with_latin1_characters(self, meta_key, context_key):
43
        """
44
        When HTTP headers contains latin1 characters.
45 46
        """
        request = self.request_factory.get('/somewhere')
47 48
        # pylint: disable=no-member
        request.META[meta_key] = 'test latin1 \xd3 \xe9 \xf1'  # pylint: disable=no-member
49 50 51 52

        context = self.get_context_for_request(request)
        # The bytes in the string on the right are utf8 encoded in the source file, so we decode them to construct
        # a valid unicode string.
53
        self.assertEqual(context[context_key], 'test latin1 Ó é ñ'.decode('utf8'))
54

55
    def test_default_filters_do_not_render_view(self):
56 57 58
        for url in ['/event', '/event/1', '/login', '/heartbeat']:
            request = self.request_factory.get(url)
            self.track_middleware.process_request(request)
59 60
            self.assertFalse(self.mock_server_track.called)
            self.mock_server_track.reset_mock()
61 62

    @override_settings(TRACKING_IGNORE_URL_PATTERNS=[])
63
    def test_reading_filtered_urls_from_settings(self):
64 65
        request = self.request_factory.get('/event')
        self.track_middleware.process_request(request)
66
        self.assertTrue(self.mock_server_track.called)
67 68

    @override_settings(TRACKING_IGNORE_URL_PATTERNS=[r'^/some/excluded.*'])
69
    def test_anchoring_of_patterns_at_beginning(self):
70 71
        request = self.request_factory.get('/excluded')
        self.track_middleware.process_request(request)
72 73
        self.assertTrue(self.mock_server_track.called)
        self.mock_server_track.reset_mock()
74 75 76

        request = self.request_factory.get('/some/excluded/url')
        self.track_middleware.process_request(request)
77 78
        self.assertFalse(self.mock_server_track.called)

79 80 81
    def test_default_request_context(self):
        context = self.get_context_for_path('/courses/')
        self.assertEquals(context, {
82 83
            'accept_language': '',
            'referer': '',
84 85 86 87 88 89 90 91 92
            'user_id': '',
            'session': '',
            'username': '',
            'ip': '127.0.0.1',
            'host': 'testserver',
            'agent': '',
            'path': '/courses/',
            'org_id': '',
            'course_id': '',
93
            'client_id': None,
94 95
        })

96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
    def test_no_forward_for_header_ip_context(self):
        request = self.request_factory.get('/courses/')
        remote_addr = '127.0.0.1'

        request.META['REMOTE_ADDR'] = remote_addr
        context = self.get_context_for_request(request)

        self.assertEquals(context['ip'], remote_addr)

    def test_single_forward_for_header_ip_context(self):
        request = self.request_factory.get('/courses/')
        remote_addr = '127.0.0.1'
        forwarded_ip = '11.22.33.44'

        request.META['REMOTE_ADDR'] = remote_addr
        request.META['HTTP_X_FORWARDED_FOR'] = forwarded_ip
        context = self.get_context_for_request(request)

        self.assertEquals(context['ip'], forwarded_ip)

    def test_multiple_forward_for_header_ip_context(self):
        request = self.request_factory.get('/courses/')
        remote_addr = '127.0.0.1'
        forwarded_ip = '11.22.33.44, 10.0.0.1, 127.0.0.1'

        request.META['REMOTE_ADDR'] = remote_addr
        request.META['HTTP_X_FORWARDED_FOR'] = forwarded_ip
        context = self.get_context_for_request(request)

        self.assertEquals(context['ip'], '11.22.33.44')

127 128 129 130 131 132 133
    def get_context_for_path(self, path):
        """Extract the generated event tracking context for a given request for the given path."""
        request = self.request_factory.get(path)
        return self.get_context_for_request(request)

    def get_context_for_request(self, request):
        """Extract the generated event tracking context for the given request."""
134
        self.track_middleware.process_request(request)
135 136 137 138
        try:
            captured_context = tracker.get_tracker().resolve_context()
        finally:
            self.track_middleware.process_response(request, None)
139

140 141 142 143
        self.assertEquals(
            tracker.get_tracker().resolve_context(),
            {}
        )
144

145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
        return captured_context

    def test_request_in_course_context(self):
        captured_context = self.get_context_for_path('/courses/test_org/test_course/test_run/foo')
        expected_context_subset = {
            'course_id': 'test_org/test_course/test_run',
            'org_id': 'test_org',
        }
        self.assert_dict_subset(captured_context, expected_context_subset)

    def assert_dict_subset(self, superset, subset):
        """Assert that the superset dict contains all of the key-value pairs found in the subset dict."""
        for key, expected_value in subset.iteritems():
            self.assertEquals(superset[key], expected_value)

160
    def test_request_with_user(self):
161 162 163
        user_id = 1
        username = sentinel.username

164
        request = self.request_factory.get('/courses/')
165 166 167 168 169 170 171 172 173 174 175 176 177
        request.user = User(pk=user_id, username=username)

        context = self.get_context_for_request(request)
        self.assert_dict_subset(context, {
            'user_id': user_id,
            'username': username,
        })

    def test_request_with_session(self):
        request = self.request_factory.get('/courses/')
        SessionMiddleware().process_request(request)
        request.session.save()
        session_key = request.session.session_key
178
        expected_session_key = self.track_middleware.encrypt_session_key(session_key)
179
        self.assertEquals(len(session_key), len(expected_session_key))
180 181
        context = self.get_context_for_request(request)
        self.assert_dict_subset(context, {
182
            'session': expected_session_key,
183 184
        })

185 186 187 188 189 190 191
    @override_settings(SECRET_KEY='85920908f28904ed733fe576320db18cabd7b6cd')
    def test_session_key_encryption(self):
        session_key = '665924b49a93e22b46ee9365abf28c2a'
        expected_session_key = '3b81f559d14130180065d635a4f35dd2'
        encrypted_session_key = self.track_middleware.encrypt_session_key(session_key)
        self.assertEquals(encrypted_session_key, expected_session_key)

192 193 194
    def test_request_headers(self):
        ip_address = '10.0.0.0'
        user_agent = 'UnitTest/1.0'
195
        client_id_header = '123.123'
196

197 198 199
        factory = RequestFactory(
            REMOTE_ADDR=ip_address, HTTP_USER_AGENT=user_agent, HTTP_X_EDX_GA_CLIENT_ID=client_id_header
        )
200 201 202 203 204 205
        request = factory.get('/some-path')
        context = self.get_context_for_request(request)

        self.assert_dict_subset(context, {
            'ip': ip_address,
            'agent': user_agent,
206
            'client_id': client_id_header
207
        })