Commit 8dd066ce by Matthew Piatetsky

Fix marketing site client login process

LEARNER-498
parent e85104f8
...@@ -77,19 +77,17 @@ class AbstractMarketingSiteDataLoaderTestMixin(DataLoaderTestMixin): ...@@ -77,19 +77,17 @@ class AbstractMarketingSiteDataLoaderTestMixin(DataLoaderTestMixin):
def mock_login_response(self, failure=False): def mock_login_response(self, failure=False):
url = self.api_url + 'user' url = self.api_url + 'user'
landing_url = '{base}users/{username}'.format(base=self.api_url, landing_url = '{base}admin'.format(base=self.api_url)
username=self.partner.marketing_site_api_username)
status = 500 if failure else 302 status = 500 if failure else 302
adding_headers = {} adding_headers = {}
if not failure: if not failure:
adding_headers['Location'] = landing_url adding_headers['Location'] = landing_url
responses.add(responses.POST, url, status=status, adding_headers=adding_headers) responses.add(responses.POST, url, status=status, adding_headers=adding_headers)
responses.add(responses.GET, landing_url)
responses.add( responses.add(
responses.GET, responses.GET,
'{root}admin'.format(root=self.api_url), landing_url,
status=(500 if failure else 200) status=(500 if failure else 200)
) )
......
...@@ -22,9 +22,8 @@ class MarketingSiteAPIClientTestMixin(TestCase): ...@@ -22,9 +22,8 @@ class MarketingSiteAPIClientTestMixin(TestCase):
def mock_login_response(self, status): def mock_login_response(self, status):
""" Mock the response of the marketing site login """ """ Mock the response of the marketing site login """
response_url = '{root}/users/{username}'.format( response_url = '{root}/admin'.format(
root=self.api_root, root=self.api_root
username=self.username
) )
def request_callback(request): # pylint: disable=unused-argument def request_callback(request): # pylint: disable=unused-argument
...@@ -48,15 +47,6 @@ class MarketingSiteAPIClientTestMixin(TestCase): ...@@ -48,15 +47,6 @@ class MarketingSiteAPIClientTestMixin(TestCase):
status=status status=status
) )
def mock_admin_response(self, status):
""" Test that we can access the admin """
response_url = '{root}/admin'.format(root=self.api_root)
responses.add(
responses.GET,
response_url,
status=status
)
def mock_csrf_token_response(self, status): def mock_csrf_token_response(self, status):
responses.add( responses.add(
responses.GET, responses.GET,
...@@ -95,7 +85,6 @@ class MarketingSitePublisherTestMixin(MarketingSiteAPIClientTestMixin): ...@@ -95,7 +85,6 @@ class MarketingSitePublisherTestMixin(MarketingSiteAPIClientTestMixin):
def mock_api_client(self, status): def mock_api_client(self, status):
self.mock_login_response(status) self.mock_login_response(status)
self.mock_admin_response(status)
self.mock_csrf_token_response(status) self.mock_csrf_token_response(status)
self.mock_user_id_response(status) self.mock_user_id_response(status)
......
...@@ -614,7 +614,7 @@ class ProgramTests(MarketingSitePublisherTestMixin): ...@@ -614,7 +614,7 @@ class ProgramTests(MarketingSitePublisherTestMixin):
with mock.patch.object(MarketingSitePublisher, '_get_form_build_id_and_form_token', return_value={}): with mock.patch.object(MarketingSitePublisher, '_get_form_build_id_and_form_token', return_value={}):
with mock.patch.object(MarketingSitePublisher, '_get_delete_alias_url', return_value='/foo'): with mock.patch.object(MarketingSitePublisher, '_get_delete_alias_url', return_value='/foo'):
self.program.save() self.program.save()
self.assert_responses_call_count(9) self.assert_responses_call_count(8)
@responses.activate @responses.activate
def test_xseries_program_save(self): def test_xseries_program_save(self):
...@@ -649,7 +649,7 @@ class ProgramTests(MarketingSitePublisherTestMixin): ...@@ -649,7 +649,7 @@ class ProgramTests(MarketingSitePublisherTestMixin):
self.mock_node_delete(204) self.mock_node_delete(204)
toggle_switch('publish_program_to_marketing_site', True) toggle_switch('publish_program_to_marketing_site', True)
self.program.delete() self.program.delete()
self.assert_responses_call_count(6) self.assert_responses_call_count(5)
@responses.activate @responses.activate
def test_delete_and_no_marketing_site(self): def test_delete_and_no_marketing_site(self):
......
...@@ -57,7 +57,7 @@ class MarketingSitePublisherTests(MarketingSitePublisherTestMixin): ...@@ -57,7 +57,7 @@ class MarketingSitePublisherTests(MarketingSitePublisherTestMixin):
self.mock_node_retrieval(self.program.uuid) self.mock_node_retrieval(self.program.uuid)
publisher = MarketingSitePublisher() publisher = MarketingSitePublisher()
node_id = publisher._get_node_id(self.api_client, self.program.uuid) # pylint: disable=protected-access node_id = publisher._get_node_id(self.api_client, self.program.uuid) # pylint: disable=protected-access
self.assert_responses_call_count(5) self.assert_responses_call_count(4)
self.assertEqual(node_id, self.node_id) self.assertEqual(node_id, self.node_id)
@responses.activate @responses.activate
...@@ -75,7 +75,7 @@ class MarketingSitePublisherTests(MarketingSitePublisherTestMixin): ...@@ -75,7 +75,7 @@ class MarketingSitePublisherTests(MarketingSitePublisherTestMixin):
publisher = MarketingSitePublisher() publisher = MarketingSitePublisher()
publish_data = publisher._get_node_data(self.program, self.user_id) # pylint: disable=protected-access publish_data = publisher._get_node_data(self.program, self.user_id) # pylint: disable=protected-access
publisher._edit_node(self.api_client, self.node_id, publish_data) # pylint: disable=protected-access publisher._edit_node(self.api_client, self.node_id, publish_data) # pylint: disable=protected-access
self.assert_responses_call_count(5) self.assert_responses_call_count(4)
@responses.activate @responses.activate
def test_edit_node_failed(self): def test_edit_node_failed(self):
...@@ -114,7 +114,7 @@ class MarketingSitePublisherTests(MarketingSitePublisherTestMixin): ...@@ -114,7 +114,7 @@ class MarketingSitePublisherTests(MarketingSitePublisherTestMixin):
with mock.patch.object(MarketingSitePublisher, '_get_headers', return_value={}): with mock.patch.object(MarketingSitePublisher, '_get_headers', return_value={}):
with mock.patch.object(MarketingSitePublisher, '_get_form_build_id_and_form_token', return_value={}): with mock.patch.object(MarketingSitePublisher, '_get_form_build_id_and_form_token', return_value={}):
publisher.publish_program(self.program) publisher.publish_program(self.program)
self.assert_responses_call_count(8) self.assert_responses_call_count(7)
@responses.activate @responses.activate
def test_publish_program_edit(self): def test_publish_program_edit(self):
...@@ -126,7 +126,7 @@ class MarketingSitePublisherTests(MarketingSitePublisherTestMixin): ...@@ -126,7 +126,7 @@ class MarketingSitePublisherTests(MarketingSitePublisherTestMixin):
with mock.patch.object(MarketingSitePublisher, '_get_headers', return_value={}): with mock.patch.object(MarketingSitePublisher, '_get_headers', return_value={}):
with mock.patch.object(MarketingSitePublisher, '_get_form_build_id_and_form_token', return_value={}): with mock.patch.object(MarketingSitePublisher, '_get_form_build_id_and_form_token', return_value={}):
publisher.publish_program(self.program) publisher.publish_program(self.program)
self.assert_responses_call_count(8) self.assert_responses_call_count(7)
@responses.activate @responses.activate
def test_publish_modified_program(self): def test_publish_modified_program(self):
...@@ -141,7 +141,7 @@ class MarketingSitePublisherTests(MarketingSitePublisherTestMixin): ...@@ -141,7 +141,7 @@ class MarketingSitePublisherTests(MarketingSitePublisherTestMixin):
with mock.patch.object(MarketingSitePublisher, '_get_form_build_id_and_form_token', return_value={}): with mock.patch.object(MarketingSitePublisher, '_get_form_build_id_and_form_token', return_value={}):
with mock.patch.object(MarketingSitePublisher, '_get_delete_alias_url', return_value='/foo'): with mock.patch.object(MarketingSitePublisher, '_get_delete_alias_url', return_value='/foo'):
publisher.publish_program(self.program) publisher.publish_program(self.program)
self.assert_responses_call_count(9) self.assert_responses_call_count(8)
@responses.activate @responses.activate
def test_get_alias_form(self): def test_get_alias_form(self):
...@@ -153,7 +153,7 @@ class MarketingSitePublisherTests(MarketingSitePublisherTestMixin): ...@@ -153,7 +153,7 @@ class MarketingSitePublisherTests(MarketingSitePublisherTestMixin):
self.mock_get_alias_form() self.mock_get_alias_form()
with mock.patch.object(MarketingSitePublisher, '_get_headers', return_value={}): with mock.patch.object(MarketingSitePublisher, '_get_headers', return_value={}):
publisher.publish_program(self.program) publisher.publish_program(self.program)
self.assert_responses_call_count(9) self.assert_responses_call_count(8)
@responses.activate @responses.activate
def test_get_delete_form(self): def test_get_delete_form(self):
...@@ -167,7 +167,7 @@ class MarketingSitePublisherTests(MarketingSitePublisherTestMixin): ...@@ -167,7 +167,7 @@ class MarketingSitePublisherTests(MarketingSitePublisherTestMixin):
with mock.patch.object(MarketingSitePublisher, '_get_headers', return_value={}): with mock.patch.object(MarketingSitePublisher, '_get_headers', return_value={}):
with mock.patch.object(MarketingSitePublisher, '_get_form_build_id_and_form_token', return_value={}): with mock.patch.object(MarketingSitePublisher, '_get_form_build_id_and_form_token', return_value={}):
publisher.publish_program(self.program) publisher.publish_program(self.program)
self.assert_responses_call_count(10) self.assert_responses_call_count(9)
@responses.activate @responses.activate
def test_get_alias_form_failed(self): def test_get_alias_form_failed(self):
...@@ -236,7 +236,7 @@ class MarketingSitePublisherTests(MarketingSitePublisherTestMixin): ...@@ -236,7 +236,7 @@ class MarketingSitePublisherTests(MarketingSitePublisherTestMixin):
self.mock_node_delete(204) self.mock_node_delete(204)
publisher = MarketingSitePublisher() publisher = MarketingSitePublisher()
publisher.delete_program(self.program) publisher.delete_program(self.program)
self.assert_responses_call_count(6) self.assert_responses_call_count(5)
@responses.activate @responses.activate
def test_publish_delete_non_existent_program(self): def test_publish_delete_non_existent_program(self):
...@@ -244,7 +244,7 @@ class MarketingSitePublisherTests(MarketingSitePublisherTestMixin): ...@@ -244,7 +244,7 @@ class MarketingSitePublisherTests(MarketingSitePublisherTestMixin):
self.mock_node_retrieval(self.program.uuid, exists=False) self.mock_node_retrieval(self.program.uuid, exists=False)
publisher = MarketingSitePublisher() publisher = MarketingSitePublisher()
publisher.delete_program(self.program) publisher.delete_program(self.program)
self.assert_responses_call_count(5) self.assert_responses_call_count(4)
@responses.activate @responses.activate
def test_publish_delete_xseries(self): def test_publish_delete_xseries(self):
......
...@@ -49,31 +49,27 @@ class MarketingSiteAPIClientTests(MarketingSiteAPIClientTestMixin): ...@@ -49,31 +49,27 @@ class MarketingSiteAPIClientTests(MarketingSiteAPIClientTestMixin):
@responses.activate @responses.activate
def test_init_session(self): def test_init_session(self):
self.mock_login_response(200) self.mock_login_response(200)
self.mock_admin_response(200)
session = self.api_client.init_session session = self.api_client.init_session
self.assert_responses_call_count(3) self.assert_responses_call_count(2)
self.assertIsNotNone(session) self.assertIsNotNone(session)
@responses.activate @responses.activate
def test_init_session_failed(self): def test_init_session_failed(self):
self.mock_login_response(500) self.mock_login_response(500)
self.mock_admin_response(500)
with self.assertRaises(MarketingSiteAPIClientException): with self.assertRaises(MarketingSiteAPIClientException):
self.api_client.init_session # pylint: disable=pointless-statement self.api_client.init_session # pylint: disable=pointless-statement
@responses.activate @responses.activate
def test_csrf_token(self): def test_csrf_token(self):
self.mock_login_response(200) self.mock_login_response(200)
self.mock_admin_response(200)
self.mock_csrf_token_response(200) self.mock_csrf_token_response(200)
csrf_token = self.api_client.csrf_token csrf_token = self.api_client.csrf_token
self.assert_responses_call_count(4) self.assert_responses_call_count(3)
self.assertEqual(self.csrf_token, csrf_token) self.assertEqual(self.csrf_token, csrf_token)
@responses.activate @responses.activate
def test_csrf_token_failed(self): def test_csrf_token_failed(self):
self.mock_login_response(200) self.mock_login_response(200)
self.mock_admin_response(200)
self.mock_csrf_token_response(500) self.mock_csrf_token_response(500)
with self.assertRaises(MarketingSiteAPIClientException): with self.assertRaises(MarketingSiteAPIClientException):
self.api_client.csrf_token # pylint: disable=pointless-statement self.api_client.csrf_token # pylint: disable=pointless-statement
...@@ -81,16 +77,14 @@ class MarketingSiteAPIClientTests(MarketingSiteAPIClientTestMixin): ...@@ -81,16 +77,14 @@ class MarketingSiteAPIClientTests(MarketingSiteAPIClientTestMixin):
@responses.activate @responses.activate
def test_user_id(self): def test_user_id(self):
self.mock_login_response(200) self.mock_login_response(200)
self.mock_admin_response(200)
self.mock_user_id_response(200) self.mock_user_id_response(200)
user_id = self.api_client.user_id user_id = self.api_client.user_id
self.assert_responses_call_count(4) self.assert_responses_call_count(3)
self.assertEqual(self.user_id, user_id) self.assertEqual(self.user_id, user_id)
@responses.activate @responses.activate
def test_user_id_failed(self): def test_user_id_failed(self):
self.mock_login_response(200) self.mock_login_response(200)
self.mock_admin_response(200)
self.mock_user_id_response(500) self.mock_user_id_response(500)
with self.assertRaises(MarketingSiteAPIClientException): with self.assertRaises(MarketingSiteAPIClientException):
self.api_client.user_id # pylint: disable=pointless-statement self.api_client.user_id # pylint: disable=pointless-statement
...@@ -98,10 +92,9 @@ class MarketingSiteAPIClientTests(MarketingSiteAPIClientTestMixin): ...@@ -98,10 +92,9 @@ class MarketingSiteAPIClientTests(MarketingSiteAPIClientTestMixin):
@responses.activate @responses.activate
def test_api_session(self): def test_api_session(self):
self.mock_login_response(200) self.mock_login_response(200)
self.mock_admin_response(200)
self.mock_csrf_token_response(200) self.mock_csrf_token_response(200)
api_session = self.api_client.api_session api_session = self.api_client.api_session
self.assert_responses_call_count(4) self.assert_responses_call_count(3)
self.assertIsNotNone(api_session) self.assertIsNotNone(api_session)
self.assertEqual(api_session.headers.get('Content-Type'), 'application/json') self.assertEqual(api_session.headers.get('Content-Type'), 'application/json')
self.assertEqual(api_session.headers.get('X-CSRF-Token'), self.csrf_token) self.assertEqual(api_session.headers.get('X-CSRF-Token'), self.csrf_token)
...@@ -109,7 +102,6 @@ class MarketingSiteAPIClientTests(MarketingSiteAPIClientTestMixin): ...@@ -109,7 +102,6 @@ class MarketingSiteAPIClientTests(MarketingSiteAPIClientTestMixin):
@responses.activate @responses.activate
def test_api_session_failed(self): def test_api_session_failed(self):
self.mock_login_response(500) self.mock_login_response(500)
self.mock_admin_response(500)
self.mock_csrf_token_response(500) self.mock_csrf_token_response(500)
with self.assertRaises(MarketingSiteAPIClientException): with self.assertRaises(MarketingSiteAPIClientException):
self.api_client.api_session # pylint: disable=pointless-statement self.api_client.api_session # pylint: disable=pointless-statement
...@@ -96,12 +96,10 @@ class MarketingSiteAPIClient(object): ...@@ -96,12 +96,10 @@ class MarketingSiteAPIClient(object):
'op': 'Log in', 'op': 'Log in',
} }
response = session.post(login_url, data=login_data) response = session.post(login_url, data=login_data)
expected_url = '{root}/users/{username}'.format(root=self.api_url, username=self.username)
admin_url = '{root}/admin'.format(root=self.api_url) admin_url = '{root}/admin'.format(root=self.api_url)
# Temporary way of checking whether the user has been logged into marketing site until # This is not a RESTful API so checking the status code is not enough
# the marketing site login flow is fixed # We also check that we were redirected to the admin page
can_access_admin = session.get(admin_url) if not (response.status_code == 200 and response.url == admin_url):
if not (can_access_admin.status_code == 200 and response.url == expected_url):
raise MarketingSiteAPIClientException('Marketing Site Login failed!') raise MarketingSiteAPIClientException('Marketing Site Login failed!')
return session return session
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment