Commit c5379890 by Dennis Jen Committed by Daniel Friedman

Added AWS request signing.

parent 6a67dc83
from collections import defaultdict from importlib import import_module
from django.db.models import Q from django.db.models import Q
from rest_framework.authtoken.models import Token from rest_framework.authtoken.models import Token
from analytics_data_api.v0.models import ProblemResponseAnswerDistribution
def delete_user_auth_token(username): def delete_user_auth_token(username):
""" """
...@@ -47,49 +45,6 @@ def matching_tuple(answer): ...@@ -47,49 +45,6 @@ def matching_tuple(answer):
) )
def consolidate_answers(problem):
""" Attempt to consolidate erroneously randomized answers. """
answer_sets = defaultdict(list)
match_tuple_sets = defaultdict(set)
for answer in problem:
answer.consolidated_variant = False
answer_sets[answer.value_id].append(answer)
match_tuple_sets[answer.value_id].add(matching_tuple(answer))
# If a part has more than one unique tuple of matching fields, do not consolidate.
for _, match_tuple_set in match_tuple_sets.iteritems():
if len(match_tuple_set) > 1:
return problem
consolidated_answers = []
for _, answers in answer_sets.iteritems():
consolidated_answer = None
if len(answers) == 1:
consolidated_answers.append(answers[0])
continue
for answer in answers:
if consolidated_answer:
if isinstance(consolidated_answer, ProblemResponseAnswerDistribution):
consolidated_answer.count += answer.count
else:
consolidated_answer.first_response_count += answer.first_response_count
consolidated_answer.last_response_count += answer.last_response_count
else:
consolidated_answer = answer
consolidated_answer.variant = None
consolidated_answer.consolidated_variant = True
consolidated_answers.append(consolidated_answer)
return consolidated_answers
def dictfetchall(cursor): def dictfetchall(cursor):
"""Returns all rows from a cursor as a dict""" """Returns all rows from a cursor as a dict"""
...@@ -98,3 +53,10 @@ def dictfetchall(cursor): ...@@ -98,3 +53,10 @@ def dictfetchall(cursor):
dict(zip([col[0] for col in desc], row)) dict(zip([col[0] for col in desc], row))
for row in cursor.fetchall() for row in cursor.fetchall()
] ]
def load_fully_qualified_definition(definition):
""" Returns the class given the full definition. """
module_name, class_name = definition.rsplit('.', 1)
module = import_module(module_name)
return getattr(module, class_name)
...@@ -2,6 +2,8 @@ from django.apps import AppConfig ...@@ -2,6 +2,8 @@ from django.apps import AppConfig
from django.conf import settings from django.conf import settings
from elasticsearch_dsl import connections from elasticsearch_dsl import connections
from analytics_data_api.utils import load_fully_qualified_definition
class ApiAppConfig(AppConfig): class ApiAppConfig(AppConfig):
...@@ -10,4 +12,17 @@ class ApiAppConfig(AppConfig): ...@@ -10,4 +12,17 @@ class ApiAppConfig(AppConfig):
def ready(self): def ready(self):
super(ApiAppConfig, self).ready() super(ApiAppConfig, self).ready()
if settings.ELASTICSEARCH_LEARNERS_HOST: if settings.ELASTICSEARCH_LEARNERS_HOST:
connections.connections.create_connection(hosts=[settings.ELASTICSEARCH_LEARNERS_HOST]) connection_params = {'hosts': [settings.ELASTICSEARCH_LEARNERS_HOST]}
if settings.ELASTICSEARCH_CONNECTION_CLASS:
connection_params['connection_class'] = \
load_fully_qualified_definition(settings.ELASTICSEARCH_CONNECTION_CLASS)
# aws settings
connection_params['aws_access_key_id'] = settings.ELASTICSEARCH_AWS_ACCESS_KEY_ID
connection_params['aws_secret_access_key'] = settings.ELASTICSEARCH_AWS_SECRET_ACCESS_KEY
connection_params['region'] = settings.ELASTICSEARCH_CONNECTION_DEFAULT_REGION
# Remove 'None' values so that we don't overwrite defaults
connection_params = {key: val for key, val in connection_params.items() if val is not None}
connections.connections.create_connection(**connection_params)
import json
import time
from boto.connection import AWSAuthConnection
from elasticsearch import Connection
class BotoHttpConnection(Connection):
"""
Uses AWS configured connection to sign requests before they're sent to
elasticsearch nodes.
"""
connection = None
def __init__(self, host='localhost', port=443, aws_access_key_id=None, aws_secret_access_key=None,
region=None, **kwargs):
super(BotoHttpConnection, self).__init__(host=host, port=port, **kwargs)
connection_params = {'host': host, 'port': port}
# If not provided, boto will attempt to use default environment variables to fill
# the access credentials.
connection_params['aws_access_key_id'] = aws_access_key_id
connection_params['aws_secret_access_key'] = aws_secret_access_key
connection_params['region'] = region
# Remove 'None' values so that we don't overwrite defaults
connection_params = {key: val for key, val in connection_params.items() if val is not None}
self.connection = ESConnection(**connection_params)
# pylint: disable=unused-argument
def perform_request(self, method, url, params=None, body=None, timeout=None, ignore=()):
"""
Called when making requests elasticsearch. Requests are signed and
http status, headers, and response is returned.
Note: the "timeout" kwarg is ignored in this case. Boto manages the timeout
and the default is 70 seconds.
See: https://github.com/boto/boto/blob/develop/boto/connection.py#L533
"""
if not isinstance(body, basestring):
body = json.dumps(body)
start = time.time()
response = self.connection.make_request(method, url, params=params, data=body)
duration = time.time() - start
raw_data = response.read()
# raise errors based on http status codes and let the client handle them
if not (200 <= response.status < 300) and response.status not in ignore:
self.log_request_fail(method, url, body, duration, response.status)
self._raise_error(response.status, raw_data)
self.log_request_success(method, url, url, body, response.status, raw_data, duration)
return response.status, dict(response.getheaders()), raw_data
class ESConnection(AWSAuthConnection):
"""
Use to sign requests for an AWS hosted elasticsearch cluster.
"""
def __init__(self, *args, **kwargs):
region = kwargs.pop('region', None)
kwargs.setdefault('is_secure', True)
super(ESConnection, self).__init__(*args, **kwargs)
self.auth_region_name = region
self.auth_service_name = 'es'
def _required_auth_capability(self):
"""
Supplies the capabilities of the auth handler and signs the responses to
AWS using HMAC-4.
"""
return ['hmac-v4']
import socket
from django.test import TestCase
from elasticsearch.exceptions import ElasticsearchException
from mock import patch
from analytics_data_api.v0.connections import BotoHttpConnection, ESConnection
class ESConnectionTests(TestCase):
def test_constructor_params(self):
connection = ESConnection('mockservice.cc-zone-1.amazonaws.com',
aws_access_key_id='access_key',
aws_secret_access_key='secret',
region='region_123')
self.assertEqual(connection.auth_region_name, 'region_123')
self.assertEqual(connection.aws_access_key_id, 'access_key')
self.assertEqual(connection.aws_secret_access_key, 'secret')
def test_signing(self):
connection = ESConnection('mockservice.cc-zone-1.amazonaws.com',
aws_access_key_id='my_access_key',
aws_secret_access_key='secret',
region='region_123')
# create a request and sign it
request = connection.build_base_http_request('GET', '/', None)
request.authorize(connection)
# confirm the header contains signing method and key id
auth_header = request.headers['Authorization']
self.assertTrue('AWS4-HMAC-SHA256' in auth_header)
self.assertTrue('my_access_key' in auth_header)
def test_timeout(self):
def fake_connection(_address):
raise socket.timeout('fake error')
socket.create_connection = fake_connection
connection = ESConnection('mockservice.cc-zone-1.amazonaws.com',
aws_access_key_id='access_key',
aws_secret_access_key='secret',
region='region_123')
connection.num_retries = 0
with self.assertRaises(socket.error):
connection.make_request('GET', 'https://example.com')
class BotoHttpConnectionTests(TestCase):
@patch('analytics_data_api.v0.connections.ESConnection.make_request')
def test_perform_request_success(self, mock_response):
mock_response.return_value.status = 200
connection = BotoHttpConnection(aws_access_key_id='access_key', aws_secret_access_key='secret')
with patch('elasticsearch.connection.base.logger.info') as mock_logger:
status, _header, _data = connection.perform_request('get', 'http://example.com')
self.assertEqual(status, 200)
self.assertGreater(mock_logger.call_count, 0)
@patch('analytics_data_api.v0.connections.ESConnection.make_request')
def test_perform_request_error(self, mock_response):
mock_response.return_value.status = 500
connection = BotoHttpConnection(aws_access_key_id='access_key', aws_secret_access_key='secret')
with self.assertRaises(ElasticsearchException):
with patch('elasticsearch.connection.base.logger.debug') as mock_logger:
connection.perform_request('get', 'http://example.com')
self.assertGreater(mock_logger.call_count, 0)
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
API methods for module level data. API methods for module level data.
""" """
from collections import defaultdict
from itertools import groupby from itertools import groupby
from django.db import OperationalError from django.db import OperationalError
...@@ -19,7 +20,7 @@ from analytics_data_api.v0.serializers import ( ...@@ -19,7 +20,7 @@ from analytics_data_api.v0.serializers import (
GradeDistributionSerializer, GradeDistributionSerializer,
SequentialOpenDistributionSerializer, SequentialOpenDistributionSerializer,
) )
from analytics_data_api.utils import consolidate_answers from analytics_data_api.utils import matching_tuple
class ProblemResponseAnswerDistributionView(generics.ListAPIView): class ProblemResponseAnswerDistributionView(generics.ListAPIView):
...@@ -55,6 +56,48 @@ class ProblemResponseAnswerDistributionView(generics.ListAPIView): ...@@ -55,6 +56,48 @@ class ProblemResponseAnswerDistributionView(generics.ListAPIView):
serializer_class = ConsolidatedAnswerDistributionSerializer serializer_class = ConsolidatedAnswerDistributionSerializer
allow_empty = False allow_empty = False
@classmethod
def consolidate_answers(cls, problem):
""" Attempt to consolidate erroneously randomized answers. """
answer_sets = defaultdict(list)
match_tuple_sets = defaultdict(set)
for answer in problem:
answer.consolidated_variant = False
answer_sets[answer.value_id].append(answer)
match_tuple_sets[answer.value_id].add(matching_tuple(answer))
# If a part has more than one unique tuple of matching fields, do not consolidate.
for _, match_tuple_set in match_tuple_sets.iteritems():
if len(match_tuple_set) > 1:
return problem
consolidated_answers = []
for _, answers in answer_sets.iteritems():
consolidated_answer = None
if len(answers) == 1:
consolidated_answers.append(answers[0])
continue
for answer in answers:
if consolidated_answer:
if isinstance(consolidated_answer, ProblemResponseAnswerDistribution):
consolidated_answer.count += answer.count
else:
consolidated_answer.first_response_count += answer.first_response_count
consolidated_answer.last_response_count += answer.last_response_count
else:
consolidated_answer = answer
consolidated_answer.variant = None
consolidated_answer.consolidated_variant = True
consolidated_answers.append(consolidated_answer)
return consolidated_answers
def get_queryset(self): def get_queryset(self):
"""Select all the answer distribution response having to do with this usage of the problem.""" """Select all the answer distribution response having to do with this usage of the problem."""
problem_id = self.kwargs.get('problem_id') problem_id = self.kwargs.get('problem_id')
...@@ -69,7 +112,7 @@ class ProblemResponseAnswerDistributionView(generics.ListAPIView): ...@@ -69,7 +112,7 @@ class ProblemResponseAnswerDistributionView(generics.ListAPIView):
consolidated_rows = [] consolidated_rows = []
for _, part in groupby(queryset, lambda x: x.part_id): for _, part in groupby(queryset, lambda x: x.part_id):
consolidated_rows += consolidate_answers(list(part)) consolidated_rows += self.consolidate_answers(list(part))
return consolidated_rows return consolidated_rows
......
...@@ -55,6 +55,16 @@ DATABASES = { ...@@ -55,6 +55,16 @@ DATABASES = {
ELASTICSEARCH_LEARNERS_HOST = environ.get('ELASTICSEARCH_LEARNERS_HOST', None) ELASTICSEARCH_LEARNERS_HOST = environ.get('ELASTICSEARCH_LEARNERS_HOST', None)
ELASTICSEARCH_LEARNERS_INDEX = environ.get('ELASTICSEARCH_LEARNERS_INDEX', None) ELASTICSEARCH_LEARNERS_INDEX = environ.get('ELASTICSEARCH_LEARNERS_INDEX', None)
ELASTICSEARCH_LEARNERS_UPDATE_INDEX = environ.get('ELASTICSEARCH_LEARNERS_UPDATE_INDEX', None) ELASTICSEARCH_LEARNERS_UPDATE_INDEX = environ.get('ELASTICSEARCH_LEARNERS_UPDATE_INDEX', None)
# access credentials for signing requests to AWS.
# For more information see http://docs.aws.amazon.com/general/latest/gr/signing_aws_api_requests.html
ELASTICSEARCH_AWS_ACCESS_KEY_ID = None
ELASTICSEARCH_AWS_SECRET_ACCESS_KEY = None
# override the default elasticsearch connection class and useful for signing certificates
# e.g. 'analytics_data_api.v0.connections.BotoHttpConnection'
ELASTICSEARCH_CONNECTION_CLASS = None
# only needed with BotoHttpConnection, e.g. 'us-east-1'
ELASTICSEARCH_CONNECTION_DEFAULT_REGION = None
########## END ELASTICSEARCH CONFIGURATION ########## END ELASTICSEARCH CONFIGURATION
########## GENERAL CONFIGURATION ########## GENERAL CONFIGURATION
......
boto==2.22.1 # MIT
Django==1.7.5 # BSD License Django==1.7.5 # BSD License
Markdown==2.6 # BSD
django-model-utils==2.2 # BSD django-model-utils==2.2 # BSD
djangorestframework==2.4.4 # BSD djangorestframework==2.4.4 # BSD
ipython==2.4.1 # BSD
django-rest-swagger==0.2.8 # BSD django-rest-swagger==0.2.8 # BSD
djangorestframework-csv==1.3.3 # BSD djangorestframework-csv==1.3.3 # BSD
django-countries==3.2 # MIT django-countries==3.2 # MIT
elasticsearch-dsl==0.0.9 # Apache 2.0 elasticsearch-dsl==0.0.9 # Apache 2.0
-e git+https://github.com/edx/opaque-keys.git@d45d0bd8d64c69531be69178b9505b5d38806ce0#egg=opaque-keys -e git+https://github.com/edx/opaque-keys.git@d45d0bd8d64c69531be69178b9505b5d38806ce0#egg=opaque-keys
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment