Commit 54b820fc by Max Rothman Committed by Kevin Falcone

Add tests and address comments and bugfixes

parent cd42cdf4
...@@ -5,7 +5,7 @@ DOCUMENTATION = """ ...@@ -5,7 +5,7 @@ DOCUMENTATION = """
module: mongodb_replica_set module: mongodb_replica_set
short_description: Modify replica set config. short_description: Modify replica set config.
description: description:
- Modify replica set config, including modifing/adding/removing members from a replica set - Modify replica set config, including modifying/adding/removing members from a replica set
changing replica set options, and initiating the replica set if necessary. changing replica set options, and initiating the replica set if necessary.
Uses replSetReconfig and replSetInitiate. Uses replSetReconfig and replSetInitiate.
version_added: "1.9" version_added: "1.9"
...@@ -95,7 +95,7 @@ import json, copy ...@@ -95,7 +95,7 @@ import json, copy
from urllib import quote_plus from urllib import quote_plus
########### Mongo API calls ########### ########### Mongo API calls ###########
def get_replset(client): def get_replset():
# Not using `replSetGetConfig` because it's not supported in MongoDB 2.x. # Not using `replSetGetConfig` because it's not supported in MongoDB 2.x.
try: try:
rs_config = client.local.system.replset.find_one() rs_config = client.local.system.replset.find_one()
...@@ -123,9 +123,11 @@ def reconfig_replset(rs_config): ...@@ -123,9 +123,11 @@ def reconfig_replset(rs_config):
def get_rs_config_id(): def get_rs_config_id():
try: try:
replset_name = client.admin.command('getCmdLineOpts')['parsed']['replication']['replSetName'] return client.admin.command('getCmdLineOpts')['parsed']['replication']['replSetName']
except OperationFailure, KeyError as e: except (OperationFailure, KeyError) as e:
module.fail_json(msg="Unable to get replSet name. Was mongod started with --replSet? Error: " + e.message) module.fail_json(msg=("Unable to get replSet name. "
"Was mongod started with --replSet, "
"or was replication.replSetName set in the config file? Error: ") + e.message)
########### Helper functions ########### ########### Helper functions ###########
...@@ -134,11 +136,12 @@ def set_member_ids(members, old_members=None): ...@@ -134,11 +136,12 @@ def set_member_ids(members, old_members=None):
Set the _id property of members who don't already have one. Set the _id property of members who don't already have one.
Prefer the _id of the "matching" member from `old_members`. Prefer the _id of the "matching" member from `old_members`.
''' '''
#Add a little padding to ensure we don't run out of IDs
available_ids = set(range(len(members)*2)) available_ids = set(range(len(members)*2))
available_ids -= {m['_id'] for m in members} available_ids -= {m['_id'] for m in members if '_id' in m}
if old_members is not None: if old_members is not None:
available_ids -= {m['_id'] for m in old_members} available_ids -= {m['_id'] for m in old_members}
available_ids = list(available_ids).sort(reverse=True) available_ids = list(sorted(available_ids, reverse=True))
for member in members: for member in members:
if '_id' not in member: if '_id' not in member:
...@@ -147,7 +150,6 @@ def set_member_ids(members, old_members=None): ...@@ -147,7 +150,6 @@ def set_member_ids(members, old_members=None):
member['_id'] = match['_id'] if match is not None else available_ids.pop() member['_id'] = match['_id'] if match is not None else available_ids.pop()
else: else:
member['_id'] = available_ids.pop() member['_id'] = available_ids.pop()
return members
def get_matching_member(member, members): def get_matching_member(member, members):
'''Return the rs_member from `members` that "matches" `member` (currently on host)''' '''Return the rs_member from `members` that "matches" `member` (currently on host)'''
...@@ -155,62 +157,77 @@ def get_matching_member(member, members): ...@@ -155,62 +157,77 @@ def get_matching_member(member, members):
return match[0] if len(match) > 0 else None return match[0] if len(match) > 0 else None
def members_match(new, old): def members_match(new, old):
"Compare 2 lists of members, discounting their `_id`s and matching on a custom criterion" "Compare 2 lists of members, discounting their `_id`s and matching on hostname"
if len(new) != len(old): if len(new) != len(old):
return False return False
for old_member in old: for old_member in old:
new_member = get_matching_member(old_member, new) new_member = get_matching_member(old_member, new).copy()
if old_member.copy().pop('_id') != new_member.copy().pop('_id'): #Don't compare on _id
new_member.pop('_id', None)
old_member = old_member.copy()
old_member.pop('_id', None)
if old_member != new_member:
return False return False
return True return True
def fix_host_port(rs_config): def fix_host_port(rs_config):
"Fix host, port to host:port" '''Fix host, port to host:port'''
if 'members' in rs_config: if 'members' in rs_config:
if not isinstance(rs_config['members'], list) if not isinstance(rs_config['members'], list):
module.fail_json(msg='rs_config.members must be a list') module.fail_json(msg='rs_config.members must be a list')
for member in rs_config['members']: for member in rs_config['members']:
if ':' not in member['host']: if ':' not in member['host']:
member['host'] = '{}:{}'.format(member['host'], member.get('port', 27017)) member['host'] = '{}:{}'.format(member['host'], member.get('port', 27017))
if 'port' in member:
del member['port'] del member['port']
def update_replset(rs_config): def update_replset(rs_config):
changed = False changed = False
old_rs_config = get_replset(client) old_rs_config = get_replset()
fix_host_port(module, rs_config) #fix host, port to host:port fix_host_port(rs_config) #fix host, port to host:port
#Decide whether we need to initialize #Decide whether we need to initialize
if old_rs_config is None: if old_rs_config is None:
changed = True changed = True
if '_id' not in rs_config: if '_id' not in rs_config:
rs_config['_id'] = get_rs_config_id(client) #Errors if no replSet specified to mongod rs_config['_id'] = get_rs_config_id() #Errors if no replSet specified to mongod
rs_config['members'] = set_member_ids(rs_config['members']) #Noop if all _ids are set set_member_ids(rs_config['members']) #Noop if all _ids are set
#Don't set the version, it'll auto-set #Don't set the version, it'll auto-set
initialize_replset(client, rs_config) initialize_replset(rs_config)
else: else:
rs_config_non_collections = {k:v for k,v in rs_config.items() old_rs_config_scalars = {k:v for k,v in old_rs_config.items() if not isinstance(v, (list, dict))}
if not isinstance(v, list) if not isinstance(v, dict)}
old_rs_config_non_collections = {k:v for k,v in old_rs_config.items() rs_config_scalars = {k:v for k,v in rs_config.items() if not isinstance(v, (list, dict))}
if not isinstance(v, list) if not isinstance(v, dict)} if '_id' not in rs_config_scalars and '_id' in old_rs_config_scalars:
# _id is going to be managed, don't compare on it
# Is the provided doc "different" from the one currently on the cluster? del old_rs_config_scalars['_id']
if rs_config_non_collections != old_rs_config_non_collections \ if 'version' not in rs_config and 'version' in old_rs_config_scalars:
or rs_config['settings'] != old_rs_config['settings'] \ # version is going to be managed, don't compare on it
del old_rs_config_scalars['version']
# Special comparison to test whether 2 rs_configs are "equivalent"
# We can't simply use == because of special logic in `members_match()`
# 1. Compare the scalars (i.e. non-collections)
# 2. Compare the "settings" dict
# 3. Compare the members dicts using `members_match()`
# Since the only nested structures in the rs_config spec are "members" and "settings",
# if all of the above 3 match, the structures are equivalent.
if rs_config_scalars != old_rs_config_scalars \
or rs_config.get('settings') != old_rs_config.get('settings') \
or not members_match(rs_config['members'], old_rs_config['members']): or not members_match(rs_config['members'], old_rs_config['members']):
changed=True changed=True
if '_id' not in rs_config: if '_id' not in rs_config:
rs_config['_id'] = get_rs_config_id(client) #Errors if no replSet specified to mongod rs_config['_id'] = old_rs_config['_id']
if 'version' not in rs_config: if 'version' not in rs_config:
#Using manual increment to prevent race condition #Using manual increment to prevent race condition
rs_config['version'] = old_rs_config['version'] + 1 rs_config['version'] = old_rs_config['version'] + 1
#Noop if all _ids are set set_member_ids(rs_config['members'], old_rs_config['members']) #Noop if all _ids are set
rs_config['members'] = set_member_ids(rs_config['members'], old_rs_config['members'])
reconfig_replset(module, client, rs_config) reconfig_replset(rs_config)
#Validate it worked #Validate it worked
if changed: if changed:
...@@ -235,11 +252,11 @@ def get_mongo_uri(host, port, username, password, auth_database): ...@@ -235,11 +252,11 @@ def get_mongo_uri(host, port, username, password, auth_database):
return mongo_uri return mongo_uri
def primary_client(some_host, some_port, username, password, auth_database): def primary_client(some_host, some_port, username, password, auth_database):
""" '''
Given a member of a replica set, find out who the primary is Given a member of a replica set, find out who the primary is
and provide a client that is connected to the primary for running and provide a client that is connected to the primary for running
commands. commands.
""" '''
mongo_uri = get_mongo_uri(some_host, some_port, username, password, auth_database) mongo_uri = get_mongo_uri(some_host, some_port, username, password, auth_database)
client = MongoClient(mongo_uri) client = MongoClient(mongo_uri)
try: try:
...@@ -268,19 +285,22 @@ def validate_args(): ...@@ -268,19 +285,22 @@ def validate_args():
auth_database = dict(required=False, type='str'), auth_database = dict(required=False, type='str'),
rs_host = dict(required=False, type='str', default="localhost"), rs_host = dict(required=False, type='str', default="localhost"),
rs_port = dict(required=False, type='int', default=27017), rs_port = dict(required=False, type='int', default=27017),
rs_config = dict(required=True, type='dict') rs_config = dict(required=True, type='dict'),
force = dict(required=False, type='bool', default=False), force = dict(required=False, type='bool', default=False),
) )
module = AnsibleModule(argument_spec=arg_spec, supports_check_mode=False) module = AnsibleModule(argument_spec=arg_spec, supports_check_mode=False)
username = module.params.get('username')
password = module.params.get('password')
if (username and not password) or (password and not username): if (username and not password) or (password and not username):
module.fail_json(msg="Must provide both username and password or neither.") module.fail_json(msg="Must provide both username and password or neither.")
return module return module
if __name__ == '__main__': if __name__ == '__main__':
module = validate_args(): module = validate_args()
if not pymongo_found: if not pymongo_found:
module.fail_json(msg="The python pymongo module is not installed.") module.fail_json(msg="The python pymongo module is not installed.")
...@@ -292,4 +312,4 @@ if __name__ == '__main__': ...@@ -292,4 +312,4 @@ if __name__ == '__main__':
rs_port = module.params['rs_port'] rs_port = module.params['rs_port']
client = primary_client(module, rs_host, rs_port, username, password, auth_database) client = primary_client(module, rs_host, rs_port, username, password, auth_database)
update_replset(module, client, module['rs_config']) update_replset(module['rs_config'])
# Tests for mongodb_replica_set ansible module
#
# How to run these tests:
# 1. move this file to playbooks/library
# 2. rename mongodb_replica_set to mongodb_replica_set.py
# 3. python test_mongodb_replica_set.py
import mongodb_replica_set as mrs
import unittest, mock
from urllib import quote_plus
from copy import deepcopy
class TestNoPatchingMongodbReplicaSet(unittest.TestCase):
def test_host_port_transformation(self):
unfixed = {
'members': [
{'host': 'foo.bar'},
{'host': 'bar.baz', 'port': 1234},
{'host': 'baz.bing:54321'}
]}
fixed = {
'members': [
{'host': 'foo.bar:27017'},
{'host': 'bar.baz:1234'},
{'host': 'baz.bing:54321'}
]}
mrs.fix_host_port(unfixed)
self.assertEqual(fixed, unfixed)
fixed_2 = deepcopy(fixed)
mrs.fix_host_port(fixed_2)
self.assertEqual(fixed, fixed_2)
def test_member_id_managed(self):
new = [
{'host': 'foo.bar', '_id': 1},
{'host': 'bar.baz'},
{'host': 'baz.bing'}
]
old = [
{'host': 'baz.bing', '_id': 0}
]
fixed = deepcopy(new)
mrs.set_member_ids(fixed, old)
#test that each id is unique
unique_ids = {m['_id'] for m in fixed}
self.assertEqual(len(unique_ids), len(new))
#test that it "prefers" the "matching" one in old_members
self.assertEqual(fixed[0]['_id'], new[0]['_id'])
self.assertEqual(fixed[2]['_id'], old[0]['_id'])
self.assertIn('_id', fixed[1])
def test_mongo_uri_escaped(self):
host = username = password = auth_database = ':!@#$%/'
port = 1234
uri = mrs.get_mongo_uri(host=host, port=port, username=username, password=password, auth_database=auth_database)
self.assertEqual(uri, "mongodb://{un}:{pw}@{host}:{port}/{db}".format(
un=quote_plus(username), pw=quote_plus(password),
host=quote_plus(host), port=port, db=quote_plus(auth_database),
))
rs_id = 'a replset id'
members = [
{'host': 'foo.bar:1234'},
{'host': 'bar.baz:4321'},
]
old_rs_config = {
'version': 1,
'_id': rs_id,
'members': [
{'_id': 0, 'host': 'foo.bar:1234',},
{'_id': 1, 'host': 'bar.baz:4321',},
]
}
new_rs_config = {
'version': 2,
'_id': rs_id,
'members': [
{'_id': 0, 'host': 'foo.bar:1234',},
{'_id': 1, 'host': 'bar.baz:4321',},
{'_id': 2, 'host': 'baz.bing:27017',},
]
}
rs_config = {
'members': [
{'host': 'foo.bar', 'port': 1234,},
{'host': 'bar.baz', 'port': 4321,},
{'host': 'baz.bing', 'port': 27017,},
]
}
def init_replset_mock(f):
get_replset_initialize_mock = mock.patch.object(mrs, 'get_replset',
side_effect=(None, deepcopy(new_rs_config)))
initialize_replset_mock = mock.patch.object(mrs, 'initialize_replset')
return get_replset_initialize_mock(initialize_replset_mock(f))
def update_replset_mock(f):
get_replset_update_mock = mock.patch.object(mrs, 'get_replset',
side_effect=(deepcopy(old_rs_config), deepcopy(new_rs_config)))
reconfig_replset_mock = mock.patch.object(mrs, 'reconfig_replset')
return get_replset_update_mock(reconfig_replset_mock(f))
@mock.patch.object(mrs, 'get_rs_config_id', return_value=rs_id)
@mock.patch.object(mrs, 'client', create=True)
@mock.patch.object(mrs, 'module', create=True)
class TestPatchingMongodbReplicaSet(unittest.TestCase):
@update_replset_mock
def test_version_managed(self, _1, _2, module, *args):
# Version set automatically on initialize
mrs.update_replset(deepcopy(rs_config))
new_version = module.exit_json.call_args[1]['config']['version']
self.assertEqual(old_rs_config['version'], new_version - 1)
@init_replset_mock
def test_doc_id_managed_on_initialize(self, _1, _2, module, *args):
#old_rs_config provided by init_replset_mock via mrs.get_replset().
#That returns None on the first call, so it falls through to get_rs_config_id(),
#which is also mocked.
mrs.update_replset(deepcopy(rs_config))
new_id = module.exit_json.call_args[1]['config']['_id']
self.assertEqual(rs_id, new_id)
@update_replset_mock
def test_doc_id_managed_on_update(self, _1, _2, module, *args):
#old_rs_config provided by update_replset_mock via mrs.get_replset()
mrs.update_replset(deepcopy(rs_config))
new_id = module.exit_json.call_args[1]['config']['_id']
self.assertEqual(rs_id, new_id)
@init_replset_mock
def test_initialize_if_necessary(self, initialize_replset, _2, module, *args):
mrs.update_replset(deepcopy(rs_config))
self.assertTrue(initialize_replset.called)
#self.assertFalse(reconfig_replset.called)
@update_replset_mock
def test_reconfig_if_necessary(self, reconfig_replset, _2, module, *args):
mrs.update_replset(deepcopy(rs_config))
self.assertTrue(reconfig_replset.called)
#self.assertFalse(initialize_replset.called)
@update_replset_mock
def test_not_changed_when_docs_match(self, _1, _2, module, *args):
rs_config = {'members': members} #This way the docs "match", but aren't identical
mrs.update_replset(deepcopy(rs_config))
changed = module.exit_json.call_args[1]['changed']
self.assertFalse(changed)
@update_replset_mock
def test_ignores_magic_given_full_doc(self, _1, _2, module, _3, get_rs_config_id, *args):
mrs.update_replset(deepcopy(new_rs_config))
new_doc = module.exit_json.call_args[1]['config']
self.assertEqual(new_doc, new_rs_config)
self.assertFalse(get_rs_config_id.called)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment