Add type checking for module arguments, converting as much as possible

Converts to list from comma-separated strings, and to dicts from
comma-separated, key=value strings.

Fixes #2126.
parent 5b31feb7
...@@ -179,6 +179,7 @@ class AnsibleModule(object): ...@@ -179,6 +179,7 @@ class AnsibleModule(object):
if not bypass_checks: if not bypass_checks:
self._check_required_arguments() self._check_required_arguments()
self._check_argument_values()
self._check_argument_types() self._check_argument_types()
self._check_mutually_exclusive(mutually_exclusive) self._check_mutually_exclusive(mutually_exclusive)
self._check_required_together(required_together) self._check_required_together(required_together)
...@@ -535,7 +536,7 @@ class AnsibleModule(object): ...@@ -535,7 +536,7 @@ class AnsibleModule(object):
if len(missing) > 0: if len(missing) > 0:
self.fail_json(msg="missing required arguments: %s" % ",".join(missing)) self.fail_json(msg="missing required arguments: %s" % ",".join(missing))
def _check_argument_types(self): def _check_argument_values(self):
''' ensure all arguments have the requested values, and there are no stray arguments ''' ''' ensure all arguments have the requested values, and there are no stray arguments '''
for (k,v) in self.argument_spec.iteritems(): for (k,v) in self.argument_spec.iteritems():
choices = v.get('choices',None) choices = v.get('choices',None)
...@@ -550,6 +551,45 @@ class AnsibleModule(object): ...@@ -550,6 +551,45 @@ class AnsibleModule(object):
else: else:
self.fail_json(msg="internal error: do not know how to interpret argument_spec") self.fail_json(msg="internal error: do not know how to interpret argument_spec")
def _check_argument_types(self):
''' ensure all arguments have the requested type '''
for (k, v) in self.argument_spec.iteritems():
wanted = v.get('type', None)
if wanted is None:
continue
if k not in self.params:
continue
value = self.params[k]
is_invalid = False
if wanted == 'str':
if not isinstance(value, basestring):
self.params[k] = str(value)
elif wanted == 'list':
if not isinstance(value, list):
if isinstance(value, basestring):
self.params[k] = value.split(",")
else:
is_invalid = True
elif wanted == 'dict':
if not isinstance(value, dict):
if isinstance(value, basestring):
self.params[k] = dict([x.split("=", 1) for x in value.split(",")])
else:
is_invalid = True
elif wanted == 'bool':
if not isinstance(value, bool):
if isinstance(value, basestring):
self.params[k] = self.boolean(value)
else:
is_invalid = True
else:
self.fail_json(msg="implementation error: unknown type %s requested for %s" % (wanted, k))
if is_invalid:
self.fail_json(msg="argument %s is of invalid type: %s, required: %s" % (k, type(value), wanted))
def _set_defaults(self, pre=True): def _set_defaults(self, pre=True):
for (k,v) in self.argument_spec.iteritems(): for (k,v) in self.argument_spec.iteritems():
default = v.get('default', None) default = v.get('default', None)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment