Commit daf533c8 by James Cammarata

V2 fixes

* PluginLoader class will now be more selective about loading some
  plugin classes, if a required base class is specified (used to avoid
  loading v1 plugins that have changed significantly in their apis)
* Added ability for the connection info class to read values from a
  given hosts variables, to support "magic" variables
* Added some more magic variables to the VariableManager output
* Fixed a bug in the ActionBase class, where the module configuration
  code was not correctly handling unicode
parent f141ec96
...@@ -29,6 +29,20 @@ from ansible.errors import AnsibleError ...@@ -29,6 +29,20 @@ from ansible.errors import AnsibleError
__all__ = ['ConnectionInformation'] __all__ = ['ConnectionInformation']
# the magic variable mapping dictionary below is used to translate
# host/inventory variables to fields in the ConnectionInformation
# object. The dictionary values are tuples, to account for aliases
# in variable names.
MAGIC_VARIABLE_MAPPING = dict(
connection = ('ansible_connection',),
remote_addr = ('ansible_ssh_host', 'ansible_host'),
remote_user = ('ansible_ssh_user', 'ansible_user'),
port = ('ansible_ssh_port', 'ansible_port'),
password = ('ansible_ssh_pass', 'ansible_password'),
private_key_file = ('ansible_ssh_private_key_file', 'ansible_private_key_file'),
shell = ('ansible_shell_type',),
)
class ConnectionInformation: class ConnectionInformation:
...@@ -51,6 +65,7 @@ class ConnectionInformation: ...@@ -51,6 +65,7 @@ class ConnectionInformation:
self.port = None self.port = None
self.private_key_file = C.DEFAULT_PRIVATE_KEY_FILE self.private_key_file = C.DEFAULT_PRIVATE_KEY_FILE
self.timeout = C.DEFAULT_TIMEOUT self.timeout = C.DEFAULT_TIMEOUT
self.shell = None
# privilege escalation # privilege escalation
self.become = None self.become = None
...@@ -170,7 +185,7 @@ class ConnectionInformation: ...@@ -170,7 +185,7 @@ class ConnectionInformation:
else: else:
setattr(self, field, value) setattr(self, field, value)
def set_task_override(self, task): def set_task_and_host_override(self, task, host):
''' '''
Sets attributes from the task if they are set, which will override Sets attributes from the task if they are set, which will override
those from the play. those from the play.
...@@ -179,12 +194,22 @@ class ConnectionInformation: ...@@ -179,12 +194,22 @@ class ConnectionInformation:
new_info = ConnectionInformation() new_info = ConnectionInformation()
new_info.copy(self) new_info.copy(self)
# loop through a subset of attributes on the task object and set
# connection fields based on their values
for attr in ('connection', 'remote_user', 'become', 'become_user', 'become_pass', 'become_method', 'environment', 'no_log'): for attr in ('connection', 'remote_user', 'become', 'become_user', 'become_pass', 'become_method', 'environment', 'no_log'):
if hasattr(task, attr): if hasattr(task, attr):
attr_val = getattr(task, attr) attr_val = getattr(task, attr)
if attr_val: if attr_val:
setattr(new_info, attr, attr_val) setattr(new_info, attr, attr_val)
# finally, use the MAGIC_VARIABLE_MAPPING dictionary to update this
# connection info object with 'magic' variables from inventory
variables = host.get_vars()
for (attr, variable_names) in MAGIC_VARIABLE_MAPPING.iteritems():
for variable_name in variable_names:
if variable_name in variables:
setattr(new_info, attr, variables[variable_name])
return new_info return new_info
def make_become_cmd(self, cmd, executable, become_settings=None): def make_become_cmd(self, cmd, executable, become_settings=None):
......
...@@ -111,7 +111,7 @@ class WorkerProcess(multiprocessing.Process): ...@@ -111,7 +111,7 @@ class WorkerProcess(multiprocessing.Process):
# apply the given task's information to the connection info, # apply the given task's information to the connection info,
# which may override some fields already set by the play or # which may override some fields already set by the play or
# the options specified on the command line # the options specified on the command line
new_connection_info = connection_info.set_task_override(task) new_connection_info = connection_info.set_task_and_host_override(task=task, host=host)
# execute the task and build a TaskResult from the result # execute the task and build a TaskResult from the result
debug("running TaskExecutor() for %s/%s" % (host, task)) debug("running TaskExecutor() for %s/%s" % (host, task))
......
...@@ -55,9 +55,10 @@ class PluginLoader: ...@@ -55,9 +55,10 @@ class PluginLoader:
The first match is used. The first match is used.
''' '''
def __init__(self, class_name, package, config, subdir, aliases={}): def __init__(self, class_name, package, config, subdir, aliases={}, required_base_class=None):
self.class_name = class_name self.class_name = class_name
self.base_class = required_base_class
self.package = package self.package = package
self.config = config self.config = config
self.subdir = subdir self.subdir = subdir
...@@ -87,11 +88,12 @@ class PluginLoader: ...@@ -87,11 +88,12 @@ class PluginLoader:
config = data.get('config') config = data.get('config')
subdir = data.get('subdir') subdir = data.get('subdir')
aliases = data.get('aliases') aliases = data.get('aliases')
base_class = data.get('base_class')
PATH_CACHE[class_name] = data.get('PATH_CACHE') PATH_CACHE[class_name] = data.get('PATH_CACHE')
PLUGIN_PATH_CACHE[class_name] = data.get('PLUGIN_PATH_CACHE') PLUGIN_PATH_CACHE[class_name] = data.get('PLUGIN_PATH_CACHE')
self.__init__(class_name, package, config, subdir, aliases) self.__init__(class_name, package, config, subdir, aliases, base_class)
self._extra_dirs = data.get('_extra_dirs', []) self._extra_dirs = data.get('_extra_dirs', [])
self._searched_paths = data.get('_searched_paths', set()) self._searched_paths = data.get('_searched_paths', set())
...@@ -102,6 +104,7 @@ class PluginLoader: ...@@ -102,6 +104,7 @@ class PluginLoader:
return dict( return dict(
class_name = self.class_name, class_name = self.class_name,
base_class = self.base_class,
package = self.package, package = self.package,
config = self.config, config = self.config,
subdir = self.subdir, subdir = self.subdir,
...@@ -268,9 +271,13 @@ class PluginLoader: ...@@ -268,9 +271,13 @@ class PluginLoader:
self._module_cache[path] = imp.load_source('.'.join([self.package, name]), path) self._module_cache[path] = imp.load_source('.'.join([self.package, name]), path)
if kwargs.get('class_only', False): if kwargs.get('class_only', False):
return getattr(self._module_cache[path], self.class_name) obj = getattr(self._module_cache[path], self.class_name)
else: else:
return getattr(self._module_cache[path], self.class_name)(*args, **kwargs) obj = getattr(self._module_cache[path], self.class_name)(*args, **kwargs)
if self.base_class and self.base_class not in [base.__name__ for base in obj.__class__.__bases__]:
return None
return obj
def all(self, *args, **kwargs): def all(self, *args, **kwargs):
''' instantiates all plugins with the same arguments ''' ''' instantiates all plugins with the same arguments '''
...@@ -291,6 +298,9 @@ class PluginLoader: ...@@ -291,6 +298,9 @@ class PluginLoader:
else: else:
obj = getattr(self._module_cache[path], self.class_name)(*args, **kwargs) obj = getattr(self._module_cache[path], self.class_name)(*args, **kwargs)
if self.base_class and self.base_class not in [base.__name__ for base in obj.__class__.__bases__]:
continue
# set extra info on the module, in case we want it later # set extra info on the module, in case we want it later
setattr(obj, '_original_path', path) setattr(obj, '_original_path', path)
yield obj yield obj
...@@ -299,21 +309,22 @@ action_loader = PluginLoader( ...@@ -299,21 +309,22 @@ action_loader = PluginLoader(
'ActionModule', 'ActionModule',
'ansible.plugins.action', 'ansible.plugins.action',
C.DEFAULT_ACTION_PLUGIN_PATH, C.DEFAULT_ACTION_PLUGIN_PATH,
'action_plugins' 'action_plugins',
required_base_class='ActionBase',
) )
cache_loader = PluginLoader( cache_loader = PluginLoader(
'CacheModule', 'CacheModule',
'ansible.plugins.cache', 'ansible.plugins.cache',
C.DEFAULT_CACHE_PLUGIN_PATH, C.DEFAULT_CACHE_PLUGIN_PATH,
'cache_plugins' 'cache_plugins',
) )
callback_loader = PluginLoader( callback_loader = PluginLoader(
'CallbackModule', 'CallbackModule',
'ansible.plugins.callback', 'ansible.plugins.callback',
C.DEFAULT_CALLBACK_PLUGIN_PATH, C.DEFAULT_CALLBACK_PLUGIN_PATH,
'callback_plugins' 'callback_plugins',
) )
connection_loader = PluginLoader( connection_loader = PluginLoader(
...@@ -321,7 +332,8 @@ connection_loader = PluginLoader( ...@@ -321,7 +332,8 @@ connection_loader = PluginLoader(
'ansible.plugins.connections', 'ansible.plugins.connections',
C.DEFAULT_CONNECTION_PLUGIN_PATH, C.DEFAULT_CONNECTION_PLUGIN_PATH,
'connection_plugins', 'connection_plugins',
aliases={'paramiko': 'paramiko_ssh'} aliases={'paramiko': 'paramiko_ssh'},
required_base_class='ConnectionBase',
) )
shell_loader = PluginLoader( shell_loader = PluginLoader(
...@@ -335,28 +347,29 @@ module_loader = PluginLoader( ...@@ -335,28 +347,29 @@ module_loader = PluginLoader(
'', '',
'ansible.modules', 'ansible.modules',
C.DEFAULT_MODULE_PATH, C.DEFAULT_MODULE_PATH,
'library' 'library',
) )
lookup_loader = PluginLoader( lookup_loader = PluginLoader(
'LookupModule', 'LookupModule',
'ansible.plugins.lookup', 'ansible.plugins.lookup',
C.DEFAULT_LOOKUP_PLUGIN_PATH, C.DEFAULT_LOOKUP_PLUGIN_PATH,
'lookup_plugins' 'lookup_plugins',
required_base_class='LookupBase',
) )
vars_loader = PluginLoader( vars_loader = PluginLoader(
'VarsModule', 'VarsModule',
'ansible.plugins.vars', 'ansible.plugins.vars',
C.DEFAULT_VARS_PLUGIN_PATH, C.DEFAULT_VARS_PLUGIN_PATH,
'vars_plugins' 'vars_plugins',
) )
filter_loader = PluginLoader( filter_loader = PluginLoader(
'FilterModule', 'FilterModule',
'ansible.plugins.filter', 'ansible.plugins.filter',
C.DEFAULT_FILTER_PLUGIN_PATH, C.DEFAULT_FILTER_PLUGIN_PATH,
'filter_plugins' 'filter_plugins',
) )
fragment_loader = PluginLoader( fragment_loader = PluginLoader(
...@@ -371,4 +384,5 @@ strategy_loader = PluginLoader( ...@@ -371,4 +384,5 @@ strategy_loader = PluginLoader(
'ansible.plugins.strategies', 'ansible.plugins.strategies',
None, None,
'strategy_plugins', 'strategy_plugins',
required_base_class='StrategyBase',
) )
...@@ -34,6 +34,7 @@ from ansible.parsing.utils.jsonify import jsonify ...@@ -34,6 +34,7 @@ from ansible.parsing.utils.jsonify import jsonify
from ansible.plugins import shell_loader from ansible.plugins import shell_loader
from ansible.utils.debug import debug from ansible.utils.debug import debug
from ansible.utils.unicode import to_bytes
class ActionBase: class ActionBase:
...@@ -51,20 +52,20 @@ class ActionBase: ...@@ -51,20 +52,20 @@ class ActionBase:
self._loader = loader self._loader = loader
self._templar = templar self._templar = templar
self._shared_loader_obj = shared_loader_obj self._shared_loader_obj = shared_loader_obj
self._shell = self.get_shell()
self._supports_check_mode = True # load the shell plugin for this action/connection
if self._connection_info.shell:
def get_shell(self): shell_type = self._connection_info.shell
elif hasattr(connection, '_shell'):
if hasattr(self._connection, '_shell'): shell_type = getattr(connection, '_shell')
shell_plugin = getattr(self._connection, '_shell', '')
else: else:
shell_plugin = shell_loader.get(os.path.basename(C.DEFAULT_EXECUTABLE)) shell_type = os.path.basename(C.DEFAULT_EXECUTABLE)
if shell_plugin is None:
shell_plugin = shell_loader.get('sh')
return shell_plugin self._shell = shell_loader.get(shell_type)
if not self._shell:
raise AnsibleError("Invalid shell type specified (%s), or the plugin for that shell type is missing." % shell_type)
self._supports_check_mode = True
def _configure_module(self, module_name, module_args): def _configure_module(self, module_name, module_args):
''' '''
...@@ -201,18 +202,13 @@ class ActionBase: ...@@ -201,18 +202,13 @@ class ActionBase:
Copies the module data out to the temporary module path. Copies the module data out to the temporary module path.
''' '''
if type(data) == dict: if isinstance(data, dict):
data = jsonify(data) data = jsonify(data)
afd, afile = tempfile.mkstemp() afd, afile = tempfile.mkstemp()
afo = os.fdopen(afd, 'w') afo = os.fdopen(afd, 'w')
try: try:
# FIXME: is this still necessary? data = to_bytes(data, errors='strict')
#if not isinstance(data, unicode):
# #ensure the data is valid UTF-8
# data = data.decode('utf-8')
#else:
# data = data.encode('utf-8')
afo.write(data) afo.write(data)
except Exception as e: except Exception as e:
#raise AnsibleError("failure encoding into utf-8: %s" % str(e)) #raise AnsibleError("failure encoding into utf-8: %s" % str(e))
......
...@@ -212,9 +212,13 @@ class VariableManager: ...@@ -212,9 +212,13 @@ class VariableManager:
# FIXME: make sure all special vars are here # FIXME: make sure all special vars are here
# Finally, we create special vars # Finally, we create special vars
if host and self._inventory is not None:
hostvars = HostVars(vars_manager=self, inventory=self._inventory, loader=loader) if host:
all_vars['hostvars'] = hostvars all_vars['groups'] = [group.name for group in host.get_groups()]
if self._inventory is not None:
hostvars = HostVars(vars_manager=self, inventory=self._inventory, loader=loader)
all_vars['hostvars'] = hostvars
if self._inventory is not None: if self._inventory is not None:
all_vars['inventory_dir'] = self._inventory.basedir() all_vars['inventory_dir'] = self._inventory.basedir()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment