Commit 62d79568 by James Cammarata

Creating playbook executor and dependent classes

parent b6c3670f
......@@ -18,3 +18,5 @@
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
__version__ = '1.v2'
......@@ -104,6 +104,7 @@ YAML_FILENAME_EXTENSIONS = [ "", ".yml", ".yaml", ".json" ]
DEFAULTS='defaults'
# configurable things
DEFAULT_DEBUG = get_config(p, DEFAULTS, 'debug', 'ANSIBLE_DEBUG', False, boolean=True)
DEFAULT_HOST_LIST = shell_expand_path(get_config(p, DEFAULTS, 'hostfile', 'ANSIBLE_HOSTS', '/etc/ansible/hosts'))
DEFAULT_MODULE_PATH = get_config(p, DEFAULTS, 'library', 'ANSIBLE_LIBRARY', None)
DEFAULT_ROLES_PATH = shell_expand_path(get_config(p, DEFAULTS, 'roles_path', 'ANSIBLE_ROLES_PATH', '/etc/ansible/roles'))
......
......@@ -21,7 +21,7 @@ __metaclass__ = type
import os
from ansible.parsing.yaml.strings import *
from ansible.errors.yaml_strings import *
class AnsibleError(Exception):
'''
......@@ -45,12 +45,12 @@ class AnsibleError(Exception):
self._obj = obj
self._show_content = show_content
if isinstance(self._obj, AnsibleBaseYAMLObject):
if obj and isinstance(obj, AnsibleBaseYAMLObject):
extended_error = self._get_extended_error()
if extended_error:
self.message = '%s\n\n%s' % (message, extended_error)
self.message = 'ERROR! %s\n\n%s' % (message, extended_error)
else:
self.message = message
self.message = 'ERROR! %s' % message
def __str__(self):
return self.message
......@@ -98,8 +98,9 @@ class AnsibleError(Exception):
(target_line, prev_line) = self._get_error_lines_from_file(src_file, line_number - 1)
if target_line:
stripped_line = target_line.replace(" ","")
arrow_line = (" " * (col_number-1)) + "^"
error_message += "%s\n%s\n%s\n" % (prev_line.rstrip(), target_line.rstrip(), arrow_line)
arrow_line = (" " * (col_number-1)) + "^ here"
#header_line = ("=" * 73)
error_message += "\nThe offending line appears to be:\n\n%s\n%s\n%s\n" % (prev_line.rstrip(), target_line.rstrip(), arrow_line)
# common error/remediation checking here:
# check for unquoted vars starting lines
......@@ -158,3 +159,11 @@ class AnsibleModuleError(AnsibleRuntimeError):
class AnsibleConnectionFailure(AnsibleRuntimeError):
''' the transport / connection_plugin had a fatal error '''
pass
class AnsibleFilterError(AnsibleRuntimeError):
''' a templating failure '''
pass
class AnsibleUndefinedVariable(AnsibleRuntimeError):
''' a templating failure '''
pass
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import pipes
import random
from ansible import constants as C
__all__ = ['ConnectionInformation']
class ConnectionInformation:
'''
This class is used to consolidate the connection information for
hosts in a play and child tasks, where the task may override some
connection/authentication information.
'''
def __init__(self, play=None, options=None):
# FIXME: implement the new methodology here for supporting
# various different auth escalation methods (becomes, etc.)
self.connection = C.DEFAULT_TRANSPORT
self.remote_user = 'root'
self.password = ''
self.port = 22
self.su = False
self.su_user = ''
self.su_pass = ''
self.sudo = False
self.sudo_user = ''
self.sudo_pass = ''
self.verbosity = 0
self.only_tags = set()
self.skip_tags = set()
if play:
self.set_play(play)
if options:
self.set_options(options)
def set_play(self, play):
'''
Configures this connection information instance with data from
the play class.
'''
if play.connection:
self.connection = play.connection
self.remote_user = play.remote_user
self.password = ''
self.port = int(play.port) if play.port else 22
self.su = play.su
self.su_user = play.su_user
self.su_pass = play.su_pass
self.sudo = play.sudo
self.sudo_user = play.sudo_user
self.sudo_pass = play.sudo_pass
def set_options(self, options):
'''
Configures this connection information instance with data from
options specified by the user on the command line. These have a
higher precedence than those set on the play or host.
'''
# FIXME: set other values from options here?
self.verbosity = options.verbosity
if options.connection:
self.connection = options.connection
# get the tag info from options, converting a comma-separated list
# of values into a proper list if need be
if isinstance(options.tags, list):
self.only_tags.update(options.tags)
elif isinstance(options.tags, basestring):
self.only_tags.update(options.tags.split(','))
if isinstance(options.skip_tags, list):
self.skip_tags.update(options.skip_tags)
elif isinstance(options.skip_tags, basestring):
self.skip_tags.update(options.skip_tags.split(','))
def copy(self, ci):
'''
Copies the connection info from another connection info object, used
when merging in data from task overrides.
'''
self.connection = ci.connection
self.remote_user = ci.remote_user
self.password = ci.password
self.port = ci.port
self.su = ci.su
self.su_user = ci.su_user
self.su_pass = ci.su_pass
self.sudo = ci.sudo
self.sudo_user = ci.sudo_user
self.sudo_pass = ci.sudo_pass
self.verbosity = ci.verbosity
self.only_tags = ci.only_tags.copy()
self.skip_tags = ci.skip_tags.copy()
def set_task_override(self, task):
'''
Sets attributes from the task if they are set, which will override
those from the play.
'''
new_info = ConnectionInformation()
new_info.copy(self)
for attr in ('connection', 'remote_user', 'su', 'su_user', 'su_pass', 'sudo', 'sudo_user', 'sudo_pass'):
if hasattr(task, attr):
attr_val = getattr(task, attr)
if attr_val:
setattr(new_info, attr, attr_val)
return new_info
def make_sudo_cmd(self, sudo_exe, executable, cmd):
"""
Helper function for wrapping commands with sudo.
Rather than detect if sudo wants a password this time, -k makes
sudo always ask for a password if one is required. Passing a quoted
compound command to sudo (or sudo -s) directly doesn't work, so we
shellquote it with pipes.quote() and pass the quoted string to the
user's shell. We loop reading output until we see the randomly-
generated sudo prompt set with the -p option.
"""
randbits = ''.join(chr(random.randint(ord('a'), ord('z'))) for x in xrange(32))
prompt = '[sudo via ansible, key=%s] password: ' % randbits
success_key = 'SUDO-SUCCESS-%s' % randbits
sudocmd = '%s -k && %s %s -S -p "%s" -u %s %s -c %s' % (
sudo_exe, sudo_exe, C.DEFAULT_SUDO_FLAGS, prompt,
self.sudo_user, executable or '$SHELL',
pipes.quote('echo %s; %s' % (success_key, cmd))
)
#return ('/bin/sh -c ' + pipes.quote(sudocmd), prompt, success_key)
return (sudocmd, prompt, success_key)
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from multiprocessing.managers import SyncManager, BaseProxy
from ansible.playbook.handler import Handler
from ansible.playbook.task import Task
from ansible.playbook.play import Play
from ansible.errors import AnsibleError
__all__ = ['AnsibleManager']
class VariableManagerWrapper:
'''
This class simply acts as a wrapper around the VariableManager class,
since manager proxies expect a new object to be returned rather than
any existing one. Using this wrapper, a shared proxy can be created
and an existing VariableManager class assigned to it, which can then
be accessed through the exposed proxy methods.
'''
def __init__(self):
self._vm = None
def get_vars(self, loader, play=None, host=None, task=None):
return self._vm.get_vars(loader=loader, play=play, host=host, task=task)
def set_variable_manager(self, vm):
self._vm = vm
def set_host_variable(self, host, varname, value):
self._vm.set_host_variable(host, varname, value)
def set_host_facts(self, host, facts):
self._vm.set_host_facts(host, facts)
class AnsibleManager(SyncManager):
'''
This is our custom manager class, which exists only so we may register
the new proxy below
'''
pass
AnsibleManager.register(
typeid='VariableManagerWrapper',
callable=VariableManagerWrapper,
)
# (c) 2013-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# from python and deps
from cStringIO import StringIO
import inspect
import json
import os
import shlex
# from Ansible
from ansible import __version__
from ansible import constants as C
from ansible.errors import AnsibleError
from ansible.parsing.utils.jsonify import jsonify
REPLACER = "#<<INCLUDE_ANSIBLE_MODULE_COMMON>>"
REPLACER_ARGS = "\"<<INCLUDE_ANSIBLE_MODULE_ARGS>>\""
REPLACER_COMPLEX = "\"<<INCLUDE_ANSIBLE_MODULE_COMPLEX_ARGS>>\""
REPLACER_WINDOWS = "# POWERSHELL_COMMON"
REPLACER_VERSION = "\"<<ANSIBLE_VERSION>>\""
class ModuleReplacer(object):
"""
The Replacer is used to insert chunks of code into modules before
transfer. Rather than doing classical python imports, this allows for more
efficient transfer in a no-bootstrapping scenario by not moving extra files
over the wire, and also takes care of embedding arguments in the transferred
modules.
This version is done in such a way that local imports can still be
used in the module code, so IDEs don't have to be aware of what is going on.
Example:
from ansible.module_utils.basic import *
... will result in the insertion basic.py into the module
from the module_utils/ directory in the source tree.
All modules are required to import at least basic, though there will also
be other snippets.
# POWERSHELL_COMMON
Also results in the inclusion of the common code in powershell.ps1
"""
# ******************************************************************************
def __init__(self, strip_comments=False):
# FIXME: these members need to be prefixed with '_' and the rest of the file fixed
this_file = inspect.getfile(inspect.currentframe())
# we've moved the module_common relative to the snippets, so fix the path
self.snippet_path = os.path.join(os.path.dirname(this_file), '..', 'module_utils')
self.strip_comments = strip_comments
# ******************************************************************************
def slurp(self, path):
if not os.path.exists(path):
raise AnsibleError("imported module support code does not exist at %s" % path)
fd = open(path)
data = fd.read()
fd.close()
return data
def _find_snippet_imports(self, module_data, module_path):
"""
Given the source of the module, convert it to a Jinja2 template to insert
module code and return whether it's a new or old style module.
"""
module_style = 'old'
if REPLACER in module_data:
module_style = 'new'
elif 'from ansible.module_utils.' in module_data:
module_style = 'new'
elif 'WANT_JSON' in module_data:
module_style = 'non_native_want_json'
output = StringIO()
lines = module_data.split('\n')
snippet_names = []
for line in lines:
if REPLACER in line:
output.write(self.slurp(os.path.join(self.snippet_path, "basic.py")))
snippet_names.append('basic')
if REPLACER_WINDOWS in line:
ps_data = self.slurp(os.path.join(self.snippet_path, "powershell.ps1"))
output.write(ps_data)
snippet_names.append('powershell')
elif line.startswith('from ansible.module_utils.'):
tokens=line.split(".")
import_error = False
if len(tokens) != 3:
import_error = True
if " import *" not in line:
import_error = True
if import_error:
raise AnsibleError("error importing module in %s, expecting format like 'from ansible.module_utils.basic import *'" % module_path)
snippet_name = tokens[2].split()[0]
snippet_names.append(snippet_name)
output.write(self.slurp(os.path.join(self.snippet_path, snippet_name + ".py")))
else:
if self.strip_comments and line.startswith("#") or line == '':
pass
output.write(line)
output.write("\n")
if not module_path.endswith(".ps1"):
# Unixy modules
if len(snippet_names) > 0 and not 'basic' in snippet_names:
raise AnsibleError("missing required import in %s: from ansible.module_utils.basic import *" % module_path)
else:
# Windows modules
if len(snippet_names) > 0 and not 'powershell' in snippet_names:
raise AnsibleError("missing required import in %s: # POWERSHELL_COMMON" % module_path)
return (output.getvalue(), module_style)
# ******************************************************************************
def modify_module(self, module_path, module_args):
with open(module_path) as f:
# read in the module source
module_data = f.read()
(module_data, module_style) = self._find_snippet_imports(module_data, module_path)
#module_args_json = jsonify(module_args)
module_args_json = json.dumps(module_args)
encoded_args = repr(module_args_json.encode('utf-8'))
# these strings should be part of the 'basic' snippet which is required to be included
module_data = module_data.replace(REPLACER_VERSION, repr(__version__))
module_data = module_data.replace(REPLACER_ARGS, "''")
module_data = module_data.replace(REPLACER_COMPLEX, encoded_args)
# FIXME: we're not passing around an inject dictionary anymore, so
# this needs to be fixed with whatever method we use for vars
# like this moving forward
#if module_style == 'new':
# facility = C.DEFAULT_SYSLOG_FACILITY
# if 'ansible_syslog_facility' in inject:
# facility = inject['ansible_syslog_facility']
# module_data = module_data.replace('syslog.LOG_USER', "syslog.%s" % facility)
lines = module_data.split("\n")
shebang = None
if lines[0].startswith("#!"):
shebang = lines[0].strip()
args = shlex.split(str(shebang[2:]))
interpreter = args[0]
interpreter_config = 'ansible_%s_interpreter' % os.path.basename(interpreter)
# FIXME: more inject stuff here...
#if interpreter_config in inject:
# lines[0] = shebang = "#!%s %s" % (inject[interpreter_config], " ".join(args[1:]))
# module_data = "\n".join(lines)
return (module_data, module_style, shebang)
......@@ -19,17 +19,110 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import signal
from ansible import constants as C
from ansible.errors import *
from ansible.executor.task_queue_manager import TaskQueueManager
from ansible.playbook import Playbook
from ansible.utils.debug import debug
class PlaybookExecutor:
def __init__(self, list_of_plays=[]):
# self.tqm = TaskQueueManager(forks)
assert False
'''
This is the primary class for executing playbooks, and thus the
basis for bin/ansible-playbook operation.
'''
def __init__(self, playbooks, inventory, variable_manager, loader, options):
self._playbooks = playbooks
self._inventory = inventory
self._variable_manager = variable_manager
self._loader = loader
self._options = options
self._tqm = TaskQueueManager(inventory=inventory, callback='default', variable_manager=variable_manager, loader=loader, options=options)
def run(self):
'''
Run the given playbook, based on the settings in the play which
may limit the runs to serialized groups, etc.
'''
signal.signal(signal.SIGINT, self._cleanup)
try:
for playbook_path in self._playbooks:
pb = Playbook.load(playbook_path, variable_manager=self._variable_manager, loader=self._loader)
# FIXME: playbook entries are just plays, so we should rename them
for play in pb.get_entries():
self._inventory.remove_restriction()
# Create a temporary copy of the play here, so we can run post_validate
# on it without the templating changes affecting the original object.
all_vars = self._variable_manager.get_vars(loader=self._loader, play=play)
new_play = play.copy()
new_play.post_validate(all_vars, ignore_undefined=True)
result = True
for batch in self._get_serialized_batches(new_play):
if len(batch) == 0:
raise AnsibleError("No hosts matched the list specified in the play", obj=play._ds)
# restrict the inventory to the hosts in the serialized batch
self._inventory.restrict_to_hosts(batch)
# and run it...
result = self._tqm.run(play=play)
if not result:
break
if not result:
# FIXME: do something here, to signify the playbook execution failed
self._cleanup()
return 1
except:
self._cleanup()
raise
self._cleanup()
return 0
def _cleanup(self, signum=None, framenum=None):
self._tqm.cleanup()
def _get_serialized_batches(self, play):
'''
Returns a list of hosts, subdivided into batches based on
the serial size specified in the play.
'''
# make sure we have a unique list of hosts
all_hosts = self._inventory.get_hosts(play.hosts)
# check to see if the serial number was specified as a percentage,
# and convert it to an integer value based on the number of hosts
if isinstance(play.serial, basestring) and play.serial.endswith('%'):
serial_pct = int(play.serial.replace("%",""))
serial = int((serial_pct/100.0) * len(all_hosts))
else:
serial = int(play.serial)
# if the serial count was not specified or is invalid, default to
# a list of all hosts, otherwise split the list of hosts into chunks
# which are based on the serial size
if serial <= 0:
return [all_hosts]
else:
serialized_batches = []
while len(all_hosts) > 0:
play_hosts = []
for x in range(serial):
if len(all_hosts) > 0:
play_hosts.append(all_hosts.pop(0))
def run(self):
# for play in list_of_plays:
# for block in play.blocks:
# # block must know it’s playbook class and context
# tqm.enqueue(block)
# tqm.go()...
assert False
serialized_batches.append(play_hosts)
return serialized_batches
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
class PlaybookState:
'''
A helper class, which keeps track of the task iteration
state for a given playbook. This is used in the PlaybookIterator
class on a per-host basis.
'''
def __init__(self, parent_iterator):
self._parent_iterator = parent_iterator
self._cur_play = 0
self._task_list = None
self._cur_task_pos = 0
self._done = False
def next(self, peek=False):
'''
Determines and returns the next available task from the playbook,
advancing through the list of plays as it goes.
'''
task = None
# we save these locally so that we can peek at the next task
# without updating the internal state of the iterator
cur_play = self._cur_play
task_list = self._task_list
cur_task_pos = self._cur_task_pos
while True:
# when we hit the end of the playbook entries list, we set a flag
# and return None to indicate we're there
# FIXME: accessing the entries and parent iterator playbook members
# should be done through accessor functions
if self._done or cur_play > len(self._parent_iterator._playbook._entries) - 1:
self._done = True
return None
# initialize the task list by calling the .compile() method
# on the play, which will call compile() for all child objects
if task_list is None:
task_list = self._parent_iterator._playbook._entries[cur_play].compile()
# if we've hit the end of this plays task list, move on to the next
# and reset the position values for the next iteration
if cur_task_pos > len(task_list) - 1:
cur_play += 1
task_list = None
cur_task_pos = 0
continue
else:
# FIXME: do tag/conditional evaluation here and advance
# the task position if it should be skipped without
# returning a task
task = task_list[cur_task_pos]
cur_task_pos += 1
# Skip the task if it is the member of a role which has already
# been run, unless the role allows multiple executions
if task._role:
# FIXME: this should all be done via member functions
# instead of direct access to internal variables
if task._role.has_run() and not task._role._metadata._allow_duplicates:
continue
# Break out of the while loop now that we have our task
break
# If we're not just peeking at the next task, save the internal state
if not peek:
self._cur_play = cur_play
self._task_list = task_list
self._cur_task_pos = cur_task_pos
return task
class PlaybookIterator:
'''
The main iterator class, which keeps the state of the playbook
on a per-host basis using the above PlaybookState class.
'''
def __init__(self, inventory, log_manager, playbook):
self._playbook = playbook
self._log_manager = log_manager
self._host_entries = dict()
self._first_host = None
# build the per-host dictionary of playbook states
for host in inventory.get_hosts():
if self._first_host is None:
self._first_host = host
self._host_entries[host.get_name()] = PlaybookState(parent_iterator=self)
def get_next_task(self, peek=False):
''' returns the next task for host[0] '''
return self._host_entries[self._first_host.get_name()].next(peek=peek)
def get_next_task_for_host(self, host, peek=False):
''' fetch the next task for the given host '''
if host.get_name() not in self._host_entries:
raise AnsibleError("invalid host specified for playbook iteration")
return self._host_entries[host.get_name()].next(peek=peek)
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import Queue
import multiprocessing
import os
import signal
import sys
import time
import traceback
HAS_ATFORK=True
try:
from Crypto.Random import atfork
except ImportError:
HAS_ATFORK=False
from ansible.executor.task_result import TaskResult
from ansible.playbook.handler import Handler
from ansible.playbook.task import Task
from ansible.utils.debug import debug
__all__ = ['ResultProcess']
class ResultProcess(multiprocessing.Process):
'''
The result worker thread, which reads results from the results
queue and fires off callbacks/etc. as necessary.
'''
def __init__(self, final_q, workers):
# takes a task queue manager as the sole param:
self._final_q = final_q
self._workers = workers
self._cur_worker = 0
self._terminated = False
super(ResultProcess, self).__init__()
def _send_result(self, result):
debug("sending result: %s" % (result,))
self._final_q.put(result, block=False)
debug("done sending result")
def _read_worker_result(self):
result = None
starting_point = self._cur_worker
while True:
(worker_prc, main_q, rslt_q) = self._workers[self._cur_worker]
self._cur_worker += 1
if self._cur_worker >= len(self._workers):
self._cur_worker = 0
try:
if not rslt_q.empty():
debug("worker %d has data to read" % self._cur_worker)
result = rslt_q.get(block=False)
debug("got a result from worker %d: %s" % (self._cur_worker, result))
break
except Queue.Empty:
pass
if self._cur_worker == starting_point:
break
return result
def terminate(self):
self._terminated = True
super(ResultProcess, self).terminate()
def run(self):
'''
The main thread execution, which reads from the results queue
indefinitely and sends callbacks/etc. when results are received.
'''
if HAS_ATFORK:
atfork()
while True:
try:
result = self._read_worker_result()
if result is None:
time.sleep(0.1)
continue
host_name = result._host.get_name()
# send callbacks, execute other options based on the result status
if result.is_failed():
#self._callback.runner_on_failed(result._task, result)
self._send_result(('host_task_failed', result))
elif result.is_unreachable():
#self._callback.runner_on_unreachable(result._task, result)
self._send_result(('host_unreachable', result))
elif result.is_skipped():
#self._callback.runner_on_skipped(result._task, result)
self._send_result(('host_task_skipped', result))
else:
#self._callback.runner_on_ok(result._task, result)
self._send_result(('host_task_ok', result))
# if this task is notifying a handler, do it now
if result._task.notify:
# The shared dictionary for notified handlers is a proxy, which
# does not detect when sub-objects within the proxy are modified.
# So, per the docs, we reassign the list so the proxy picks up and
# notifies all other threads
for notify in result._task.notify:
self._send_result(('notify_handler', notify, result._host))
# if this task is registering facts, do that now
if 'ansible_facts' in result._result:
if result._task.action in ('set_fact', 'include_vars'):
for (key, value) in result._result['ansible_facts'].iteritems():
self._send_result(('set_host_var', result._host, key, value))
else:
self._send_result(('set_host_facts', result._host, result._result['ansible_facts']))
# if this task is registering a result, do it now
if result._task.register:
self._send_result(('set_host_var', result._host, result._task.register, result._result))
except Queue.Empty:
pass
except (KeyboardInterrupt, IOError, EOFError):
break
except:
# FIXME: we should probably send a proper callback here instead of
# simply dumping a stack trace on the screen
traceback.print_exc()
break
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import Queue
import multiprocessing
import os
import signal
import sys
import time
import traceback
HAS_ATFORK=True
try:
from Crypto.Random import atfork
except ImportError:
HAS_ATFORK=False
from ansible.errors import AnsibleError, AnsibleConnectionFailure
from ansible.executor.task_executor import TaskExecutor
from ansible.executor.task_result import TaskResult
from ansible.playbook.handler import Handler
from ansible.playbook.task import Task
from ansible.utils.debug import debug
__all__ = ['ExecutorProcess']
class WorkerProcess(multiprocessing.Process):
'''
The worker thread class, which uses TaskExecutor to run tasks
read from a job queue and pushes results into a results queue
for reading later.
'''
def __init__(self, tqm, main_q, rslt_q, loader, new_stdin):
# takes a task queue manager as the sole param:
self._main_q = main_q
self._rslt_q = rslt_q
self._loader = loader
# dupe stdin, if we have one
try:
fileno = sys.stdin.fileno()
except ValueError:
fileno = None
self._new_stdin = new_stdin
if not new_stdin and fileno is not None:
try:
self._new_stdin = os.fdopen(os.dup(fileno))
except OSError, e:
# couldn't dupe stdin, most likely because it's
# not a valid file descriptor, so we just rely on
# using the one that was passed in
pass
super(WorkerProcess, self).__init__()
def run(self):
'''
Called when the process is started, and loops indefinitely
until an error is encountered (typically an IOerror from the
queue pipe being disconnected). During the loop, we attempt
to pull tasks off the job queue and run them, pushing the result
onto the results queue. We also remove the host from the blocked
hosts list, to signify that they are ready for their next task.
'''
if HAS_ATFORK:
atfork()
while True:
task = None
try:
if not self._main_q.empty():
debug("there's work to be done!")
(host, task, job_vars, connection_info) = self._main_q.get(block=False)
debug("got a task/handler to work on: %s" % task)
new_connection_info = connection_info.set_task_override(task)
# execute the task and build a TaskResult from the result
debug("running TaskExecutor() for %s/%s" % (host, task))
executor_result = TaskExecutor(host, task, job_vars, new_connection_info, self._loader).run()
debug("done running TaskExecutor() for %s/%s" % (host, task))
task_result = TaskResult(host, task, executor_result)
# put the result on the result queue
debug("sending task result")
self._rslt_q.put(task_result, block=False)
debug("done sending task result")
else:
time.sleep(0.1)
except Queue.Empty:
pass
except (IOError, EOFError, KeyboardInterrupt):
break
except AnsibleConnectionFailure:
try:
if task:
task_result = TaskResult(host, task, dict(unreachable=True))
self._rslt_q.put(task_result, block=False)
except:
# FIXME: most likely an abort, catch those kinds of errors specifically
break
except Exception, e:
debug("WORKER EXCEPTION: %s" % e)
debug("WORKER EXCEPTION: %s" % traceback.format_exc())
try:
if task:
task_result = TaskResult(host, task, dict(failed=True, exception=True, stdout=traceback.format_exc()))
self._rslt_q.put(task_result, block=False)
except:
# FIXME: most likely an abort, catch those kinds of errors specifically
break
debug("WORKER PROCESS EXITING")
......@@ -19,14 +19,196 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from ansible import constants as C
from ansible.errors import AnsibleError
from ansible.executor.connection_info import ConnectionInformation
from ansible.plugins import lookup_loader, connection_loader, action_loader
from ansible.utils.debug import debug
__all__ = ['TaskExecutor']
import json
import time
class TaskExecutor:
def __init__(self, task, host):
pass
'''
This is the main worker class for the executor pipeline, which
handles loading an action plugin to actually dispatch the task to
a given host. This class roughly corresponds to the old Runner()
class.
'''
def __init__(self, host, task, job_vars, connection_info, loader):
self._host = host
self._task = task
self._job_vars = job_vars
self._connection_info = connection_info
self._loader = loader
def run(self):
'''
The main executor entrypoint, where we determine if the specified
task requires looping and either runs the task with
'''
debug("in run()")
items = self._get_loop_items()
if items:
if len(items) > 0:
item_results = self._run_loop(items)
res = dict(results=item_results)
else:
res = dict(changed=False, skipped=True, skipped_reason='No items in the list', results=[])
else:
debug("calling self._execute()")
res = self._execute()
debug("_execute() done")
debug("dumping result to json")
result = json.dumps(res)
debug("done dumping result, returning")
return result
def _get_loop_items(self):
'''
Loads a lookup plugin to handle the with_* portion of a task (if specified),
and returns the items result.
'''
items = None
if self._task.loop and self._task.loop in lookup_loader:
items = lookup_loader.get(self._task.loop).run(self._task.loop_args)
return items
def _run_loop(self, items):
'''
Runs the task with the loop items specified and collates the result
into an array named 'results' which is inserted into the final result
along with the item for which the loop ran.
'''
results = []
# FIXME: squash items into a flat list here for those modules
# which support it (yum, apt, etc.) but make it smarter
# than it is today?
for item in items:
res = self._execute()
res['item'] = item
results.append(res)
return results
def _execute(self):
'''
The primary workhorse of the executor system, this runs the task
on the specified host (which may be the delegated_to host) and handles
the retry/until and block rescue/always execution
'''
connection = self._get_connection()
handler = self._get_action_handler(connection=connection)
# check to see if this task should be skipped, due to it being a member of a
# role which has already run (and whether that role allows duplicate execution)
if self._task._role and self._task._role.has_run():
# If there is no metadata, the default behavior is to not allow duplicates,
# if there is metadata, check to see if the allow_duplicates flag was set to true
if self._task._role._metadata is None or self._task._role._metadata and not self._task._role._metadata.allow_duplicates:
debug("task belongs to a role which has already run, but does not allow duplicate execution")
return dict(skipped=True, skip_reason='This role has already been run, but does not allow duplicates')
if not self._task.evaluate_conditional(self._job_vars):
debug("when evaulation failed, skipping this task")
return dict(skipped=True, skip_reason='Conditional check failed')
if not self._task.evaluate_tags(self._connection_info.only_tags, self._connection_info.skip_tags):
debug("Tags don't match, skipping this task")
return dict(skipped=True, skip_reason='Skipped due to specified tags')
retries = self._task.retries
if retries <= 0:
retries = 1
delay = self._task.delay
if delay < 0:
delay = 0
debug("starting attempt loop")
result = None
for attempt in range(retries):
if attempt > 0:
# FIXME: this should use the callback mechanism
print("FAILED - RETRYING: %s (%d retries left)" % (self._task, retries-attempt))
result['attempts'] = attempt + 1
debug("running the handler")
result = handler.run(task_vars=self._job_vars)
debug("handler run complete")
if self._task.until:
# TODO: implement until logic (pseudo logic follows...)
# if VariableManager.check_conditional(cond, extra_vars=(dict(result=result))):
# break
pass
elif 'failed' not in result and result.get('rc', 0) == 0:
# if the result is not failed, stop trying
break
if attempt < retries - 1:
time.sleep(delay)
debug("attempt loop complete, returning result")
return result
def _get_connection(self):
'''
Reads the connection property for the host, and returns the
correct connection object from the list of connection plugins
'''
# FIXME: delegate_to calculation should be done here
# FIXME: calculation of connection params/auth stuff should be done here
# FIXME: add all port/connection type munging here (accelerated mode,
# fixing up options for ssh, etc.)? and 'smart' conversion
conn_type = self._connection_info.connection
if conn_type == 'smart':
conn_type = 'ssh'
connection = connection_loader.get(conn_type, self._host, self._connection_info)
if not connection:
raise AnsibleError("the connection plugin '%s' was not found" % conn_type)
connection.connect()
return connection
def _get_action_handler(self, connection):
'''
Returns the correct action plugin to handle the requestion task action
'''
def run(self):
# returns TaskResult
pass
if self._task.action in action_loader:
if self._task.async != 0:
raise AnsibleError("async mode is not supported with the %s module" % module_name)
handler_name = self._task.action
elif self._task.async == 0:
handler_name = 'normal'
else:
handler_name = 'async'
handler = action_loader.get(
handler_name,
task=self._task,
connection=connection,
connection_info=self._connection_info,
loader=self._loader
)
if not handler:
raise AnsibleError("the handler '%s' was not found" % handler_name)
return handler
......@@ -19,18 +19,191 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
class TaskQueueManagerHostPlaybookIterator:
import multiprocessing
import os
import socket
import sys
def __init__(self, host, playbook):
pass
from ansible.errors import AnsibleError
from ansible.executor.connection_info import ConnectionInformation
#from ansible.executor.manager import AnsibleManager
from ansible.executor.play_iterator import PlayIterator
from ansible.executor.process.worker import WorkerProcess
from ansible.executor.process.result import ResultProcess
from ansible.plugins import callback_loader, strategy_loader
def get_next_task(self):
assert False
from ansible.utils.debug import debug
def is_blocked(self):
# depending on strategy, either
# ‘linear’ -- all prev tasks must be completed for all hosts
# ‘free’ -- this host doesn’t have any more work to do
assert False
__all__ = ['TaskQueueManager']
class TaskQueueManager:
'''
This class handles the multiprocessing requirements of Ansible by
creating a pool of worker forks, a result handler fork, and a
manager object with shared datastructures/queues for coordinating
work between all processes.
The queue manager is responsible for loading the play strategy plugin,
which dispatches the Play's tasks to hosts.
'''
def __init__(self, inventory, callback, variable_manager, loader, options):
self._inventory = inventory
self._variable_manager = variable_manager
self._loader = loader
self._options = options
# a special flag to help us exit cleanly
self._terminated = False
# create and start the multiprocessing manager
#self._manager = AnsibleManager()
#self._manager.start()
# this dictionary is used to keep track of notified handlers
self._notified_handlers = dict()
# dictionaries to keep track of failed/unreachable hosts
self._failed_hosts = dict()
self._unreachable_hosts = dict()
self._final_q = multiprocessing.Queue()
# FIXME: hard-coded the default callback plugin here, which
# should be configurable.
self._callback = callback_loader.get(callback)
# create the pool of worker threads, based on the number of forks specified
try:
fileno = sys.stdin.fileno()
except ValueError:
fileno = None
self._workers = []
for i in range(self._options.forks):
# duplicate stdin, if possible
new_stdin = None
if fileno is not None:
try:
new_stdin = os.fdopen(os.dup(fileno))
except OSError, e:
# couldn't dupe stdin, most likely because it's
# not a valid file descriptor, so we just rely on
# using the one that was passed in
pass
main_q = multiprocessing.Queue()
rslt_q = multiprocessing.Queue()
prc = WorkerProcess(self, main_q, rslt_q, loader, new_stdin)
prc.start()
self._workers.append((prc, main_q, rslt_q))
self._result_prc = ResultProcess(self._final_q, self._workers)
self._result_prc.start()
def _initialize_notified_handlers(self, handlers):
'''
Clears and initializes the shared notified handlers dict with entries
for each handler in the play, which is an empty array that will contain
inventory hostnames for those hosts triggering the handler.
'''
# Zero the dictionary first by removing any entries there.
# Proxied dicts don't support iteritems, so we have to use keys()
for key in self._notified_handlers.keys():
del self._notified_handlers[key]
# FIXME: there is a block compile helper for this...
handler_list = []
for handler_block in handlers:
handler_list.extend(handler_block.compile())
# then initalize it with the handler names from the handler list
for handler in handler_list:
self._notified_handlers[handler.get_name()] = []
def run(self, play):
'''
Iterates over the roles/tasks in a play, using the given (or default)
strategy for queueing tasks. The default is the linear strategy, which
operates like classic Ansible by keeping all hosts in lock-step with
a given task (meaning no hosts move on to the next task until all hosts
are done with the current task).
'''
connection_info = ConnectionInformation(play, self._options)
self._callback.set_connection_info(connection_info)
# run final validation on the play now, to make sure fields are templated
# FIXME: is this even required? Everything is validated and merged at the
# task level, so else in the play needs to be templated
#all_vars = self._vmw.get_vars(loader=self._dlw, play=play)
#all_vars = self._vmw.get_vars(loader=self._loader, play=play)
#play.post_validate(all_vars=all_vars)
self._callback.playbook_on_play_start(play.name)
# initialize the shared dictionary containing the notified handlers
self._initialize_notified_handlers(play.handlers)
# load the specified strategy (or the default linear one)
strategy = strategy_loader.get(play.strategy, self)
if strategy is None:
raise AnsibleError("Invalid play strategy specified: %s" % play.strategy, obj=play._ds)
# build the iterator
iterator = PlayIterator(inventory=self._inventory, play=play)
# and run the play using the strategy
return strategy.run(iterator, connection_info)
def cleanup(self):
debug("RUNNING CLEANUP")
self.terminate()
self._final_q.close()
self._result_prc.terminate()
for (worker_prc, main_q, rslt_q) in self._workers:
rslt_q.close()
main_q.close()
worker_prc.terminate()
def get_inventory(self):
return self._inventory
def get_callback(self):
return self._callback
def get_variable_manager(self):
return self._variable_manager
def get_loader(self):
return self._loader
def get_server_pipe(self):
return self._server_pipe
def get_client_pipe(self):
return self._client_pipe
def get_pending_results(self):
return self._pending_results
def get_allow_processing(self):
return self._allow_processing
def get_notified_handlers(self):
return self._notified_handlers
def get_workers(self):
return self._workers[:]
def terminate(self):
self._terminated = True
......@@ -19,3 +19,39 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from ansible.parsing import DataLoader
class TaskResult:
'''
This class is responsible for interpretting the resulting data
from an executed task, and provides helper methods for determining
the result of a given task.
'''
def __init__(self, host, task, return_data):
self._host = host
self._task = task
if isinstance(return_data, dict):
self._result = return_data.copy()
else:
self._result = DataLoader().load(return_data)
def is_changed(self):
return self._check_key('changed')
def is_skipped(self):
return self._check_key('skipped')
def is_failed(self):
return self._check_key('failed') or self._result.get('rc', 0) != 0
def is_unreachable(self):
return self._check_key('unreachable')
def _check_key(self, key):
if 'results' in self._result:
flag = False
for res in self._result.get('results', []):
flag |= res.get(key, False)
else:
return self._result.get(key, False)
# (c) 2012, Zettar Inc.
# Written by Chin Fang <fangchin@zettar.com>
#
# This file is part of Ansible
#
# This module is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This software is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this software. If not, see <http://www.gnu.org/licenses/>.
#
'''
This module is for enhancing ansible's inventory parsing capability such
that it can deal with hostnames specified using a simple pattern in the
form of [beg:end], example: [1:5], [a:c], [D:G]. If beg is not specified,
it defaults to 0.
If beg is given and is left-zero-padded, e.g. '001', it is taken as a
formatting hint when the range is expanded. e.g. [001:010] is to be
expanded into 001, 002 ...009, 010.
Note that when beg is specified with left zero padding, then the length of
end must be the same as that of beg, else an exception is raised.
'''
import string
from ansible import errors
def detect_range(line = None):
'''
A helper function that checks a given host line to see if it contains
a range pattern described in the docstring above.
Returnes True if the given line contains a pattern, else False.
'''
if 0 <= line.find("[") < line.find(":") < line.find("]"):
return True
else:
return False
def expand_hostname_range(line = None):
'''
A helper function that expands a given line that contains a pattern
specified in top docstring, and returns a list that consists of the
expanded version.
The '[' and ']' characters are used to maintain the pseudo-code
appearance. They are replaced in this function with '|' to ease
string splitting.
References: http://ansible.github.com/patterns.html#hosts-and-groups
'''
all_hosts = []
if line:
# A hostname such as db[1:6]-node is considered to consists
# three parts:
# head: 'db'
# nrange: [1:6]; range() is a built-in. Can't use the name
# tail: '-node'
# Add support for multiple ranges in a host so:
# db[01:10:3]node-[01:10]
# - to do this we split off at the first [...] set, getting the list
# of hosts and then repeat until none left.
# - also add an optional third parameter which contains the step. (Default: 1)
# so range can be [01:10:2] -> 01 03 05 07 09
# FIXME: make this work for alphabetic sequences too.
(head, nrange, tail) = line.replace('[','|',1).replace(']','|',1).split('|')
bounds = nrange.split(":")
if len(bounds) != 2 and len(bounds) != 3:
raise errors.AnsibleError("host range incorrectly specified")
beg = bounds[0]
end = bounds[1]
if len(bounds) == 2:
step = 1
else:
step = bounds[2]
if not beg:
beg = "0"
if not end:
raise errors.AnsibleError("host range end value missing")
if beg[0] == '0' and len(beg) > 1:
rlen = len(beg) # range length formatting hint
if rlen != len(end):
raise errors.AnsibleError("host range format incorrectly specified!")
fill = lambda _: str(_).zfill(rlen) # range sequence
else:
fill = str
try:
i_beg = string.ascii_letters.index(beg)
i_end = string.ascii_letters.index(end)
if i_beg > i_end:
raise errors.AnsibleError("host range format incorrectly specified!")
seq = string.ascii_letters[i_beg:i_end+1]
except ValueError: # not an alpha range
seq = range(int(beg), int(end)+1, int(step))
for rseq in seq:
hname = ''.join((head, fill(rseq), tail))
if detect_range(hname):
all_hosts.extend( expand_hostname_range( hname ) )
else:
all_hosts.append(hname)
return all_hosts
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
from ansible.utils.debug import debug
class Group:
''' a group of ansible hosts '''
#__slots__ = [ 'name', 'hosts', 'vars', 'child_groups', 'parent_groups', 'depth', '_hosts_cache' ]
def __init__(self, name=None):
self.depth = 0
self.name = name
self.hosts = []
self.vars = {}
self.child_groups = []
self.parent_groups = []
self._hosts_cache = None
#self.clear_hosts_cache()
#if self.name is None:
# raise Exception("group name is required")
def __repr__(self):
return self.get_name()
def __getstate__(self):
return self.serialize()
def __setstate__(self, data):
return self.deserialize(data)
def serialize(self):
parent_groups = []
for parent in self.parent_groups:
parent_groups.append(parent.serialize())
result = dict(
name=self.name,
vars=self.vars.copy(),
parent_groups=parent_groups,
depth=self.depth,
)
debug("serializing group, result is: %s" % result)
return result
def deserialize(self, data):
debug("deserializing group, data is: %s" % data)
self.__init__()
self.name = data.get('name')
self.vars = data.get('vars', dict())
parent_groups = data.get('parent_groups', [])
for parent_data in parent_groups:
g = Group()
g.deserialize(parent_data)
self.parent_groups.append(g)
def get_name(self):
return self.name
def add_child_group(self, group):
if self == group:
raise Exception("can't add group to itself")
# don't add if it's already there
if not group in self.child_groups:
self.child_groups.append(group)
# update the depth of the child
group.depth = max([self.depth+1, group.depth])
# update the depth of the grandchildren
group._check_children_depth()
# now add self to child's parent_groups list, but only if there
# isn't already a group with the same name
if not self.name in [g.name for g in group.parent_groups]:
group.parent_groups.append(self)
self.clear_hosts_cache()
def _check_children_depth(self):
for group in self.child_groups:
group.depth = max([self.depth+1, group.depth])
group._check_children_depth()
def add_host(self, host):
self.hosts.append(host)
host.add_group(self)
self.clear_hosts_cache()
def set_variable(self, key, value):
self.vars[key] = value
def clear_hosts_cache(self):
self._hosts_cache = None
for g in self.parent_groups:
g.clear_hosts_cache()
def get_hosts(self):
if self._hosts_cache is None:
self._hosts_cache = self._get_hosts()
return self._hosts_cache
def _get_hosts(self):
hosts = []
seen = {}
for kid in self.child_groups:
kid_hosts = kid.get_hosts()
for kk in kid_hosts:
if kk not in seen:
seen[kk] = 1
hosts.append(kk)
for mine in self.hosts:
if mine not in seen:
seen[mine] = 1
hosts.append(mine)
return hosts
def get_vars(self):
return self.vars.copy()
def _get_ancestors(self):
results = {}
for g in self.parent_groups:
results[g.name] = g
results.update(g._get_ancestors())
return results
def get_ancestors(self):
return self._get_ancestors().values()
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from ansible import constants as C
from ansible.inventory.group import Group
from ansible.utils.vars import combine_vars
__all__ = ['Host']
class Host:
''' a single ansible host '''
#__slots__ = [ 'name', 'vars', 'groups' ]
def __getstate__(self):
return self.serialize()
def __setstate__(self, data):
return self.deserialize(data)
def serialize(self):
groups = []
for group in self.groups:
groups.append(group.serialize())
return dict(
name=self.name,
vars=self.vars.copy(),
ipv4_address=self.ipv4_address,
ipv6_address=self.ipv6_address,
port=self.port,
gathered_facts=self._gathered_facts,
groups=groups,
)
def deserialize(self, data):
self.__init__()
self.name = data.get('name')
self.vars = data.get('vars', dict())
self.ipv4_address = data.get('ipv4_address', '')
self.ipv6_address = data.get('ipv6_address', '')
self.port = data.get('port')
groups = data.get('groups', [])
for group_data in groups:
g = Group()
g.deserialize(group_data)
self.groups.append(g)
def __init__(self, name=None, port=None):
self.name = name
self.vars = {}
self.groups = []
self.ipv4_address = name
self.ipv6_address = name
if port and port != C.DEFAULT_REMOTE_PORT:
self.port = int(port)
else:
self.port = C.DEFAULT_REMOTE_PORT
self._gathered_facts = False
def __repr__(self):
return self.get_name()
def get_name(self):
return self.name
@property
def gathered_facts(self):
return self._gathered_facts
def set_gathered_facts(self, gathered):
self._gathered_facts = gathered
def add_group(self, group):
self.groups.append(group)
def set_variable(self, key, value):
self.vars[key]=value
def get_groups(self):
groups = {}
for g in self.groups:
groups[g.name] = g
ancestors = g.get_ancestors()
for a in ancestors:
groups[a.name] = a
return groups.values()
def get_vars(self):
results = {}
groups = self.get_groups()
for group in sorted(groups, key=lambda g: g.depth):
results = combine_vars(results, group.get_vars())
results = combine_vars(results, self.vars)
results['inventory_hostname'] = self.name
results['inventory_hostname_short'] = self.name.split('.')[0]
results['group_names'] = sorted([ g.name for g in groups if g.name != 'all'])
return results
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
#############################################
import ast
import shlex
import re
from ansible import constants as C
from ansible.errors import *
from ansible.inventory.host import Host
from ansible.inventory.group import Group
from ansible.inventory.expand_hosts import detect_range
from ansible.inventory.expand_hosts import expand_hostname_range
class InventoryParser(object):
"""
Host inventory for ansible.
"""
def __init__(self, filename=C.DEFAULT_HOST_LIST):
with open(filename) as fh:
self.lines = fh.readlines()
self.groups = {}
self.hosts = {}
self._parse()
def _parse(self):
self._parse_base_groups()
self._parse_group_children()
self._add_allgroup_children()
self._parse_group_variables()
return self.groups
@staticmethod
def _parse_value(v):
if "#" not in v:
try:
return ast.literal_eval(v)
# Using explicit exceptions.
# Likely a string that literal_eval does not like. We wil then just set it.
except ValueError:
# For some reason this was thought to be malformed.
pass
except SyntaxError:
# Is this a hash with an equals at the end?
pass
return v
# [webservers]
# alpha
# beta:2345
# gamma sudo=True user=root
# delta asdf=jkl favcolor=red
def _add_allgroup_children(self):
for group in self.groups.values():
if group.depth == 0 and group.name != 'all':
self.groups['all'].add_child_group(group)
def _parse_base_groups(self):
# FIXME: refactor
ungrouped = Group(name='ungrouped')
all = Group(name='all')
all.add_child_group(ungrouped)
self.groups = dict(all=all, ungrouped=ungrouped)
active_group_name = 'ungrouped'
for line in self.lines:
line = self._before_comment(line).strip()
if line.startswith("[") and line.endswith("]"):
active_group_name = line.replace("[","").replace("]","")
if ":vars" in line or ":children" in line:
active_group_name = active_group_name.rsplit(":", 1)[0]
if active_group_name not in self.groups:
new_group = self.groups[active_group_name] = Group(name=active_group_name)
active_group_name = None
elif active_group_name not in self.groups:
new_group = self.groups[active_group_name] = Group(name=active_group_name)
elif line.startswith(";") or line == '':
pass
elif active_group_name:
tokens = shlex.split(line)
if len(tokens) == 0:
continue
hostname = tokens[0]
port = C.DEFAULT_REMOTE_PORT
# Three cases to check:
# 0. A hostname that contains a range pesudo-code and a port
# 1. A hostname that contains just a port
if hostname.count(":") > 1:
# Possible an IPv6 address, or maybe a host line with multiple ranges
# IPv6 with Port XXX:XXX::XXX.port
# FQDN foo.example.com
if hostname.count(".") == 1:
(hostname, port) = hostname.rsplit(".", 1)
elif ("[" in hostname and
"]" in hostname and
":" in hostname and
(hostname.rindex("]") < hostname.rindex(":")) or
("]" not in hostname and ":" in hostname)):
(hostname, port) = hostname.rsplit(":", 1)
hostnames = []
if detect_range(hostname):
hostnames = expand_hostname_range(hostname)
else:
hostnames = [hostname]
for hn in hostnames:
host = None
if hn in self.hosts:
host = self.hosts[hn]
else:
host = Host(name=hn, port=port)
self.hosts[hn] = host
if len(tokens) > 1:
for t in tokens[1:]:
if t.startswith('#'):
break
try:
(k,v) = t.split("=", 1)
except ValueError, e:
raise AnsibleError("Invalid ini entry: %s - %s" % (t, str(e)))
if k == 'ansible_ssh_host':
host.ipv4_address = self._parse_value(v)
else:
host.set_variable(k, self._parse_value(v))
self.groups[active_group_name].add_host(host)
# [southeast:children]
# atlanta
# raleigh
def _parse_group_children(self):
group = None
for line in self.lines:
line = line.strip()
if line is None or line == '':
continue
if line.startswith("[") and ":children]" in line:
line = line.replace("[","").replace(":children]","")
group = self.groups.get(line, None)
if group is None:
group = self.groups[line] = Group(name=line)
elif line.startswith("#") or line.startswith(";"):
pass
elif line.startswith("["):
group = None
elif group:
kid_group = self.groups.get(line, None)
if kid_group is None:
raise AnsibleError("child group is not defined: (%s)" % line)
else:
group.add_child_group(kid_group)
# [webservers:vars]
# http_port=1234
# maxRequestsPerChild=200
def _parse_group_variables(self):
group = None
for line in self.lines:
line = line.strip()
if line.startswith("[") and ":vars]" in line:
line = line.replace("[","").replace(":vars]","")
group = self.groups.get(line, None)
if group is None:
raise AnsibleError("can't add vars to undefined group: %s" % line)
elif line.startswith("#") or line.startswith(";"):
pass
elif line.startswith("["):
group = None
elif line == '':
pass
elif group:
if "=" not in line:
raise AnsibleError("variables assigned to group must be in key=value form")
else:
(k, v) = [e.strip() for e in line.split("=", 1)]
group.set_variable(k, self._parse_value(v))
def get_host_variables(self, host):
return {}
def _before_comment(self, msg):
''' what's the part of a string before a comment? '''
msg = msg.replace("\#","**NOT_A_COMMENT**")
msg = msg.split("#")[0]
msg = msg.replace("**NOT_A_COMMENT**","#")
return msg
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
#############################################
import os
import subprocess
import ansible.constants as C
from ansible.inventory.host import Host
from ansible.inventory.group import Group
from ansible.module_utils.basic import json_dict_unicode_to_bytes
from ansible import utils
from ansible import errors
import sys
class InventoryScript(object):
''' Host inventory parser for ansible using external inventory scripts. '''
def __init__(self, filename=C.DEFAULT_HOST_LIST):
# Support inventory scripts that are not prefixed with some
# path information but happen to be in the current working
# directory when '.' is not in PATH.
self.filename = os.path.abspath(filename)
cmd = [ self.filename, "--list" ]
try:
sp = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except OSError, e:
raise errors.AnsibleError("problem running %s (%s)" % (' '.join(cmd), e))
(stdout, stderr) = sp.communicate()
self.data = stdout
# see comment about _meta below
self.host_vars_from_top = None
self.groups = self._parse(stderr)
def _parse(self, err):
all_hosts = {}
# not passing from_remote because data from CMDB is trusted
self.raw = utils.parse_json(self.data)
self.raw = json_dict_unicode_to_bytes(self.raw)
all = Group('all')
groups = dict(all=all)
group = None
if 'failed' in self.raw:
sys.stderr.write(err + "\n")
raise errors.AnsibleError("failed to parse executable inventory script results: %s" % self.raw)
for (group_name, data) in self.raw.items():
# in Ansible 1.3 and later, a "_meta" subelement may contain
# a variable "hostvars" which contains a hash for each host
# if this "hostvars" exists at all then do not call --host for each
# host. This is for efficiency and scripts should still return data
# if called with --host for backwards compat with 1.2 and earlier.
if group_name == '_meta':
if 'hostvars' in data:
self.host_vars_from_top = data['hostvars']
continue
if group_name != all.name:
group = groups[group_name] = Group(group_name)
else:
group = all
host = None
if not isinstance(data, dict):
data = {'hosts': data}
# is not those subkeys, then simplified syntax, host with vars
elif not any(k in data for k in ('hosts','vars')):
data = {'hosts': [group_name], 'vars': data}
if 'hosts' in data:
if not isinstance(data['hosts'], list):
raise errors.AnsibleError("You defined a group \"%s\" with bad "
"data for the host list:\n %s" % (group_name, data))
for hostname in data['hosts']:
if not hostname in all_hosts:
all_hosts[hostname] = Host(hostname)
host = all_hosts[hostname]
group.add_host(host)
if 'vars' in data:
if not isinstance(data['vars'], dict):
raise errors.AnsibleError("You defined a group \"%s\" with bad "
"data for variables:\n %s" % (group_name, data))
for k, v in data['vars'].iteritems():
if group.name == all.name:
all.set_variable(k, v)
else:
group.set_variable(k, v)
# Separate loop to ensure all groups are defined
for (group_name, data) in self.raw.items():
if group_name == '_meta':
continue
if isinstance(data, dict) and 'children' in data:
for child_name in data['children']:
if child_name in groups:
groups[group_name].add_child_group(groups[child_name])
for group in groups.values():
if group.depth == 0 and group.name != 'all':
all.add_child_group(group)
return groups
def get_host_variables(self, host):
""" Runs <script> --host <hostname> to determine additional host variables """
if self.host_vars_from_top is not None:
got = self.host_vars_from_top.get(host.name, {})
return got
cmd = [self.filename, "--host", host.name]
try:
sp = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except OSError, e:
raise errors.AnsibleError("problem running %s (%s)" % (' '.join(cmd), e))
(out, err) = sp.communicate()
if out.strip() == '':
return dict()
try:
return json_dict_unicode_to_bytes(utils.parse_json(out))
except ValueError:
raise errors.AnsibleError("could not parse post variable response: %s, %s" % (cmd, out))
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
# (c) 2014, Serge van Ginderachter <serge@vanginderachter.be>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
class VarsModule(object):
"""
Loads variables for groups and/or hosts
"""
def __init__(self, inventory):
""" constructor """
self.inventory = inventory
self.inventory_basedir = inventory.basedir()
def run(self, host, vault_password=None):
""" For backwards compatibility, when only vars per host were retrieved
This method should return both host specific vars as well as vars
calculated from groups it is a member of """
return {}
def get_host_vars(self, host, vault_password=None):
""" Get host specific variables. """
return {}
def get_group_vars(self, group, vault_password=None):
""" Get group specific variables. """
return {}
# (c) 2012, Michael DeHaan <michael.dehaan@gmail.com>
# 2013, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
......@@ -15,19 +15,3 @@
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
from ansible.utils import template
import ansible.utils as utils
class LookupModule(object):
def __init__(self, basedir=None, **kwargs):
self.basedir = basedir
def run(self, terms, inject=None, **kwargs):
terms = utils.listify_lookup_plugin_terms(terms, self.basedir, inject)
ret = []
for term in terms:
ret.append(template.template_from_file(self.basedir, term, inject))
return ret
# This code is part of Ansible, but is an independent component.
# This particular file snippet, and this file snippet only, is BSD licensed.
# Modules you write using this snippet, which is embedded dynamically by Ansible
# still belong to the author of the module, and may assign their own license
# to the complete work.
#
# Copyright (c), Michael DeHaan <michael.dehaan@gmail.com>, 2012-2013
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
AXAPI_PORT_PROTOCOLS = {
'tcp': 2,
'udp': 3,
}
AXAPI_VPORT_PROTOCOLS = {
'tcp': 2,
'udp': 3,
'fast-http': 9,
'http': 11,
'https': 12,
}
def a10_argument_spec():
return dict(
host=dict(type='str', required=True),
username=dict(type='str', aliases=['user', 'admin'], required=True),
password=dict(type='str', aliases=['pass', 'pwd'], required=True, no_log=True),
write_config=dict(type='bool', default=False)
)
def axapi_failure(result):
if 'response' in result and result['response'].get('status') == 'fail':
return True
return False
def axapi_call(module, url, post=None):
'''
Returns a datastructure based on the result of the API call
'''
rsp, info = fetch_url(module, url, data=post)
if not rsp or info['status'] >= 400:
module.fail_json(msg="failed to connect (status code %s), error was %s" % (info['status'], info.get('msg', 'no error given')))
try:
raw_data = rsp.read()
data = json.loads(raw_data)
except ValueError:
# at least one API call (system.action.write_config) returns
# XML even when JSON is requested, so do some minimal handling
# here to prevent failing even when the call succeeded
if 'status="ok"' in raw_data.lower():
data = {"response": {"status": "OK"}}
else:
data = {"response": {"status": "fail", "err": {"msg": raw_data}}}
except:
module.fail_json(msg="could not read the result from the host")
finally:
rsp.close()
return data
def axapi_authenticate(module, base_url, username, password):
url = '%s&method=authenticate&username=%s&password=%s' % (base_url, username, password)
result = axapi_call(module, url)
if axapi_failure(result):
return module.fail_json(msg=result['response']['err']['msg'])
sessid = result['session_id']
return base_url + '&session_id=' + sessid
def axapi_enabled_disabled(flag):
'''
The axapi uses 0/1 integer values for flags, rather than strings
or booleans, so convert the given flag to a 0 or 1. For now, params
are specified as strings only so thats what we check.
'''
if flag == 'enabled':
return 1
else:
return 0
def axapi_get_port_protocol(protocol):
return AXAPI_PORT_PROTOCOLS.get(protocol.lower(), None)
def axapi_get_vport_protocol(protocol):
return AXAPI_VPORT_PROTOCOLS.get(protocol.lower(), None)
# This code is part of Ansible, but is an independent component.
# This particular file snippet, and this file snippet only, is BSD licensed.
# Modules you write using this snippet, which is embedded dynamically by Ansible
# still belong to the author of the module, and may assign their own license
# to the complete work.
#
# Copyright (c), Michael DeHaan <michael.dehaan@gmail.com>, 2012-2013
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
try:
from distutils.version import LooseVersion
HAS_LOOSE_VERSION = True
except:
HAS_LOOSE_VERSION = False
AWS_REGIONS = [
'ap-northeast-1',
'ap-southeast-1',
'ap-southeast-2',
'cn-north-1',
'eu-central-1',
'eu-west-1',
'sa-east-1',
'us-east-1',
'us-west-1',
'us-west-2',
'us-gov-west-1',
]
def aws_common_argument_spec():
return dict(
ec2_url=dict(),
aws_secret_key=dict(aliases=['ec2_secret_key', 'secret_key'], no_log=True),
aws_access_key=dict(aliases=['ec2_access_key', 'access_key']),
validate_certs=dict(default=True, type='bool'),
security_token=dict(no_log=True),
profile=dict(),
)
def ec2_argument_spec():
spec = aws_common_argument_spec()
spec.update(
dict(
region=dict(aliases=['aws_region', 'ec2_region'], choices=AWS_REGIONS),
)
)
return spec
def boto_supports_profile_name():
return hasattr(boto.ec2.EC2Connection, 'profile_name')
def get_aws_connection_info(module):
# Check module args for credentials, then check environment vars
# access_key
ec2_url = module.params.get('ec2_url')
access_key = module.params.get('aws_access_key')
secret_key = module.params.get('aws_secret_key')
security_token = module.params.get('security_token')
region = module.params.get('region')
profile_name = module.params.get('profile')
validate_certs = module.params.get('validate_certs')
if not ec2_url:
if 'EC2_URL' in os.environ:
ec2_url = os.environ['EC2_URL']
elif 'AWS_URL' in os.environ:
ec2_url = os.environ['AWS_URL']
if not access_key:
if 'EC2_ACCESS_KEY' in os.environ:
access_key = os.environ['EC2_ACCESS_KEY']
elif 'AWS_ACCESS_KEY_ID' in os.environ:
access_key = os.environ['AWS_ACCESS_KEY_ID']
elif 'AWS_ACCESS_KEY' in os.environ:
access_key = os.environ['AWS_ACCESS_KEY']
else:
# in case access_key came in as empty string
access_key = None
if not secret_key:
if 'EC2_SECRET_KEY' in os.environ:
secret_key = os.environ['EC2_SECRET_KEY']
elif 'AWS_SECRET_ACCESS_KEY' in os.environ:
secret_key = os.environ['AWS_SECRET_ACCESS_KEY']
elif 'AWS_SECRET_KEY' in os.environ:
secret_key = os.environ['AWS_SECRET_KEY']
else:
# in case secret_key came in as empty string
secret_key = None
if not region:
if 'EC2_REGION' in os.environ:
region = os.environ['EC2_REGION']
elif 'AWS_REGION' in os.environ:
region = os.environ['AWS_REGION']
else:
# boto.config.get returns None if config not found
region = boto.config.get('Boto', 'aws_region')
if not region:
region = boto.config.get('Boto', 'ec2_region')
if not security_token:
if 'AWS_SECURITY_TOKEN' in os.environ:
security_token = os.environ['AWS_SECURITY_TOKEN']
else:
# in case security_token came in as empty string
security_token = None
boto_params = dict(aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
security_token=security_token)
# profile_name only works as a key in boto >= 2.24
# so only set profile_name if passed as an argument
if profile_name:
if not boto_supports_profile_name():
module.fail_json("boto does not support profile_name before 2.24")
boto_params['profile_name'] = profile_name
if validate_certs and HAS_LOOSE_VERSION and LooseVersion(boto.Version) >= LooseVersion("2.6.0"):
boto_params['validate_certs'] = validate_certs
return region, ec2_url, boto_params
def get_ec2_creds(module):
''' for compatibility mode with old modules that don't/can't yet
use ec2_connect method '''
region, ec2_url, boto_params = get_aws_connection_info(module)
return ec2_url, boto_params['aws_access_key_id'], boto_params['aws_secret_access_key'], region
def boto_fix_security_token_in_profile(conn, profile_name):
''' monkey patch for boto issue boto/boto#2100 '''
profile = 'profile ' + profile_name
if boto.config.has_option(profile, 'aws_security_token'):
conn.provider.set_security_token(boto.config.get(profile, 'aws_security_token'))
return conn
def connect_to_aws(aws_module, region, **params):
conn = aws_module.connect_to_region(region, **params)
if params.get('profile_name'):
conn = boto_fix_security_token_in_profile(conn, params['profile_name'])
return conn
def ec2_connect(module):
""" Return an ec2 connection"""
region, ec2_url, boto_params = get_aws_connection_info(module)
# If we have a region specified, connect to its endpoint.
if region:
try:
ec2 = connect_to_aws(boto.ec2, region, **boto_params)
except boto.exception.NoAuthHandlerFound, e:
module.fail_json(msg=str(e))
# Otherwise, no region so we fallback to the old connection method
elif ec2_url:
try:
ec2 = boto.connect_ec2_endpoint(ec2_url, **boto_params)
except boto.exception.NoAuthHandlerFound, e:
module.fail_json(msg=str(e))
else:
module.fail_json(msg="Either region or ec2_url must be specified")
return ec2
This source diff could not be displayed because it is too large. You can view the blob instead.
# This code is part of Ansible, but is an independent component.
# This particular file snippet, and this file snippet only, is BSD licensed.
# Modules you write using this snippet, which is embedded dynamically by Ansible
# still belong to the author of the module, and may assign their own license
# to the complete work.
#
# Copyright (c), Franck Cuny <franck.cuny@gmail.com>, 2014
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import pprint
USER_AGENT_PRODUCT="Ansible-gce"
USER_AGENT_VERSION="v1"
def gce_connect(module):
"""Return a Google Cloud Engine connection."""
service_account_email = module.params.get('service_account_email', None)
pem_file = module.params.get('pem_file', None)
project_id = module.params.get('project_id', None)
# If any of the values are not given as parameters, check the appropriate
# environment variables.
if not service_account_email:
service_account_email = os.environ.get('GCE_EMAIL', None)
if not project_id:
project_id = os.environ.get('GCE_PROJECT', None)
if not pem_file:
pem_file = os.environ.get('GCE_PEM_FILE_PATH', None)
# If we still don't have one or more of our credentials, attempt to
# get the remaining values from the libcloud secrets file.
if service_account_email is None or pem_file is None:
try:
import secrets
except ImportError:
secrets = None
if hasattr(secrets, 'GCE_PARAMS'):
if not service_account_email:
service_account_email = secrets.GCE_PARAMS[0]
if not pem_file:
pem_file = secrets.GCE_PARAMS[1]
keyword_params = getattr(secrets, 'GCE_KEYWORD_PARAMS', {})
if not project_id:
project_id = keyword_params.get('project', None)
# If we *still* don't have the credentials we need, then it's time to
# just fail out.
if service_account_email is None or pem_file is None or project_id is None:
module.fail_json(msg='Missing GCE connection parameters in libcloud '
'secrets file.')
return None
try:
gce = get_driver(Provider.GCE)(service_account_email, pem_file, datacenter=module.params.get('zone'), project=project_id)
gce.connection.user_agent_append("%s/%s" % (
USER_AGENT_PRODUCT, USER_AGENT_VERSION))
except (RuntimeError, ValueError), e:
module.fail_json(msg=str(e), changed=False)
except Exception, e:
module.fail_json(msg=unexpected_error_msg(e), changed=False)
return gce
def unexpected_error_msg(error):
"""Create an error string based on passed in error."""
return 'Unexpected response: ' + pprint.pformat(vars(error))
# This code is part of Ansible, but is an independent component.
# This particular file snippet, and this file snippet only, is BSD licensed.
# Modules you write using this snippet, which is embedded dynamically by Ansible
# still belong to the author of the module, and may assign their own license
# to the complete work.
#
# Copyright (c), Michael DeHaan <michael.dehaan@gmail.com>, 2012-2013
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import hmac
import urlparse
try:
from hashlib import sha1
except ImportError:
import sha as sha1
HASHED_KEY_MAGIC = "|1|"
def add_git_host_key(module, url, accept_hostkey=True, create_dir=True):
""" idempotently add a git url hostkey """
fqdn = get_fqdn(module.params['repo'])
if fqdn:
known_host = check_hostkey(module, fqdn)
if not known_host:
if accept_hostkey:
rc, out, err = add_host_key(module, fqdn, create_dir=create_dir)
if rc != 0:
module.fail_json(msg="failed to add %s hostkey: %s" % (fqdn, out + err))
else:
module.fail_json(msg="%s has an unknown hostkey. Set accept_hostkey to True or manually add the hostkey prior to running the git module" % fqdn)
def get_fqdn(repo_url):
""" chop the hostname out of a giturl """
result = None
if "@" in repo_url and "://" not in repo_url:
# most likely a git@ or ssh+git@ type URL
repo_url = repo_url.split("@", 1)[1]
if ":" in repo_url:
repo_url = repo_url.split(":")[0]
result = repo_url
elif "/" in repo_url:
repo_url = repo_url.split("/")[0]
result = repo_url
elif "://" in repo_url:
# this should be something we can parse with urlparse
parts = urlparse.urlparse(repo_url)
if 'ssh' not in parts[0] and 'git' not in parts[0]:
# don't try and scan a hostname that's not ssh
return None
# parts[1] will be empty on python2.4 on ssh:// or git:// urls, so
# ensure we actually have a parts[1] before continuing.
if parts[1] != '':
result = parts[1]
if ":" in result:
result = result.split(":")[0]
if "@" in result:
result = result.split("@", 1)[1]
return result
def check_hostkey(module, fqdn):
return not not_in_host_file(module, fqdn)
# this is a variant of code found in connection_plugins/paramiko.py and we should modify
# the paramiko code to import and use this.
def not_in_host_file(self, host):
if 'USER' in os.environ:
user_host_file = os.path.expandvars("~${USER}/.ssh/known_hosts")
else:
user_host_file = "~/.ssh/known_hosts"
user_host_file = os.path.expanduser(user_host_file)
host_file_list = []
host_file_list.append(user_host_file)
host_file_list.append("/etc/ssh/ssh_known_hosts")
host_file_list.append("/etc/ssh/ssh_known_hosts2")
hfiles_not_found = 0
for hf in host_file_list:
if not os.path.exists(hf):
hfiles_not_found += 1
continue
try:
host_fh = open(hf)
except IOError, e:
hfiles_not_found += 1
continue
else:
data = host_fh.read()
host_fh.close()
for line in data.split("\n"):
if line is None or " " not in line:
continue
tokens = line.split()
if tokens[0].find(HASHED_KEY_MAGIC) == 0:
# this is a hashed known host entry
try:
(kn_salt,kn_host) = tokens[0][len(HASHED_KEY_MAGIC):].split("|",2)
hash = hmac.new(kn_salt.decode('base64'), digestmod=sha1)
hash.update(host)
if hash.digest() == kn_host.decode('base64'):
return False
except:
# invalid hashed host key, skip it
continue
else:
# standard host file entry
if host in tokens[0]:
return False
return True
def add_host_key(module, fqdn, key_type="rsa", create_dir=False):
""" use ssh-keyscan to add the hostkey """
result = False
keyscan_cmd = module.get_bin_path('ssh-keyscan', True)
if 'USER' in os.environ:
user_ssh_dir = os.path.expandvars("~${USER}/.ssh/")
user_host_file = os.path.expandvars("~${USER}/.ssh/known_hosts")
else:
user_ssh_dir = "~/.ssh/"
user_host_file = "~/.ssh/known_hosts"
user_ssh_dir = os.path.expanduser(user_ssh_dir)
if not os.path.exists(user_ssh_dir):
if create_dir:
try:
os.makedirs(user_ssh_dir, 0700)
except:
module.fail_json(msg="failed to create host key directory: %s" % user_ssh_dir)
else:
module.fail_json(msg="%s does not exist" % user_ssh_dir)
elif not os.path.isdir(user_ssh_dir):
module.fail_json(msg="%s is not a directory" % user_ssh_dir)
this_cmd = "%s -t %s %s" % (keyscan_cmd, key_type, fqdn)
rc, out, err = module.run_command(this_cmd)
module.append_to_file(user_host_file, out)
return rc, out, err
# This code is part of Ansible, but is an independent component.
# This particular file snippet, and this file snippet only, is BSD licensed.
# Modules you write using this snippet, which is embedded dynamically by Ansible
# still belong to the author of the module, and may assign their own license
# to the complete work.
#
# Copyright (c) 2014 Hewlett-Packard Development Company, L.P.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
def openstack_argument_spec():
# Consume standard OpenStack environment variables.
# This is mainly only useful for ad-hoc command line operation as
# in playbooks one would assume variables would be used appropriately
OS_AUTH_URL=os.environ.get('OS_AUTH_URL', 'http://127.0.0.1:35357/v2.0/')
OS_PASSWORD=os.environ.get('OS_PASSWORD', None)
OS_REGION_NAME=os.environ.get('OS_REGION_NAME', None)
OS_USERNAME=os.environ.get('OS_USERNAME', 'admin')
OS_TENANT_NAME=os.environ.get('OS_TENANT_NAME', OS_USERNAME)
spec = dict(
login_username = dict(default=OS_USERNAME),
auth_url = dict(default=OS_AUTH_URL),
region_name = dict(default=OS_REGION_NAME),
availability_zone = dict(default=None),
)
if OS_PASSWORD:
spec['login_password'] = dict(default=OS_PASSWORD)
else:
spec['login_password'] = dict(required=True)
if OS_TENANT_NAME:
spec['login_tenant_name'] = dict(default=OS_TENANT_NAME)
else:
spec['login_tenant_name'] = dict(required=True)
return spec
def openstack_find_nova_addresses(addresses, ext_tag, key_name=None):
ret = []
for (k, v) in addresses.iteritems():
if key_name and k == key_name:
ret.extend([addrs['addr'] for addrs in v])
else:
for interface_spec in v:
if 'OS-EXT-IPS:type' in interface_spec and interface_spec['OS-EXT-IPS:type'] == ext_tag:
ret.append(interface_spec['addr'])
return ret
# This particular file snippet, and this file snippet only, is BSD licensed.
# Modules you write using this snippet, which is embedded dynamically by Ansible
# still belong to the author of the module, and may assign their own license
# to the complete work.
#
# Copyright (c), Michael DeHaan <michael.dehaan@gmail.com>, 2014, and others
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Helper function to parse Ansible JSON arguments from a file passed as
# the single argument to the module
# Example: $params = Parse-Args $args
Function Parse-Args($arguments)
{
$parameters = New-Object psobject;
If ($arguments.Length -gt 0)
{
$parameters = Get-Content $arguments[0] | ConvertFrom-Json;
}
$parameters;
}
# Helper function to set an "attribute" on a psobject instance in powershell.
# This is a convenience to make adding Members to the object easier and
# slightly more pythonic
# Example: Set-Attr $result "changed" $true
Function Set-Attr($obj, $name, $value)
{
# If the provided $obj is undefined, define one to be nice
If (-not $obj.GetType)
{
$obj = New-Object psobject
}
$obj | Add-Member -Force -MemberType NoteProperty -Name $name -Value $value
}
# Helper function to convert a powershell object to JSON to echo it, exiting
# the script
# Example: Exit-Json $result
Function Exit-Json($obj)
{
# If the provided $obj is undefined, define one to be nice
If (-not $obj.GetType)
{
$obj = New-Object psobject
}
echo $obj | ConvertTo-Json -Depth 99
Exit
}
# Helper function to add the "msg" property and "failed" property, convert the
# powershell object to JSON and echo it, exiting the script
# Example: Fail-Json $result "This is the failure message"
Function Fail-Json($obj, $message = $null)
{
# If we weren't given 2 args, and the only arg was a string, create a new
# psobject and use the arg as the failure message
If ($message -eq $null -and $obj.GetType().Name -eq "String")
{
$message = $obj
$obj = New-Object psobject
}
# If the first args is undefined or not an object, make it an object
ElseIf (-not $obj.GetType -or $obj.GetType().Name -ne "PSCustomObject")
{
$obj = New-Object psobject
}
Set-Attr $obj "msg" $message
Set-Attr $obj "failed" $true
echo $obj | ConvertTo-Json -Depth 99
Exit 1
}
# Helper function to get an "attribute" from a psobject instance in powershell.
# This is a convenience to make getting Members from an object easier and
# slightly more pythonic
# Example: $attr = Get-Attr $response "code" -default "1"
#Note that if you use the failifempty option, you do need to specify resultobject as well.
Function Get-Attr($obj, $name, $default = $null,$resultobj, $failifempty=$false, $emptyattributefailmessage)
{
# Check if the provided Member $name exists in $obj and return it or the
# default
If ($obj.$name.GetType)
{
$obj.$name
}
Elseif($failifempty -eq $false)
{
$default
}
else
{
if (!$emptyattributefailmessage) {$emptyattributefailmessage = "Missing required argument: $name"}
Fail-Json -obj $resultobj -message $emptyattributefailmessage
}
return
}
# Helper filter/pipeline function to convert a value to boolean following current
# Ansible practices
# Example: $is_true = "true" | ConvertTo-Bool
Function ConvertTo-Bool
{
param(
[parameter(valuefrompipeline=$true)]
$obj
)
$boolean_strings = "yes", "on", "1", "true", 1
$obj_string = [string]$obj
if (($obj.GetType().Name -eq "Boolean" -and $obj) -or $boolean_strings -contains $obj_string.ToLower())
{
$true
}
Else
{
$false
}
return
}
# (c) 2014 James Cammarata, <jcammarata@ansible.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
def _get_quote_state(token, quote_char):
'''
the goal of this block is to determine if the quoted string
is unterminated in which case it needs to be put back together
'''
# the char before the current one, used to see if
# the current character is escaped
prev_char = None
for idx, cur_char in enumerate(token):
if idx > 0:
prev_char = token[idx-1]
if cur_char in '"\'' and prev_char != '\\':
if quote_char:
if cur_char == quote_char:
quote_char = None
else:
quote_char = cur_char
return quote_char
def _count_jinja2_blocks(token, cur_depth, open_token, close_token):
'''
this function counts the number of opening/closing blocks for a
given opening/closing type and adjusts the current depth for that
block based on the difference
'''
num_open = token.count(open_token)
num_close = token.count(close_token)
if num_open != num_close:
cur_depth += (num_open - num_close)
if cur_depth < 0:
cur_depth = 0
return cur_depth
def split_args(args):
'''
Splits args on whitespace, but intelligently reassembles
those that may have been split over a jinja2 block or quotes.
When used in a remote module, we won't ever have to be concerned about
jinja2 blocks, however this function is/will be used in the
core portions as well before the args are templated.
example input: a=b c="foo bar"
example output: ['a=b', 'c="foo bar"']
Basically this is a variation shlex that has some more intelligence for
how Ansible needs to use it.
'''
# the list of params parsed out of the arg string
# this is going to be the result value when we are donei
params = []
# here we encode the args, so we have a uniform charset to
# work with, and split on white space
args = args.strip()
try:
args = args.encode('utf-8')
do_decode = True
except UnicodeDecodeError:
do_decode = False
items = args.split('\n')
# iterate over the tokens, and reassemble any that may have been
# split on a space inside a jinja2 block.
# ex if tokens are "{{", "foo", "}}" these go together
# These variables are used
# to keep track of the state of the parsing, since blocks and quotes
# may be nested within each other.
quote_char = None
inside_quotes = False
print_depth = 0 # used to count nested jinja2 {{ }} blocks
block_depth = 0 # used to count nested jinja2 {% %} blocks
comment_depth = 0 # used to count nested jinja2 {# #} blocks
# now we loop over each split chunk, coalescing tokens if the white space
# split occurred within quotes or a jinja2 block of some kind
for itemidx,item in enumerate(items):
# we split on spaces and newlines separately, so that we
# can tell which character we split on for reassembly
# inside quotation characters
tokens = item.strip().split(' ')
line_continuation = False
for idx,token in enumerate(tokens):
# if we hit a line continuation character, but
# we're not inside quotes, ignore it and continue
# on to the next token while setting a flag
if token == '\\' and not inside_quotes:
line_continuation = True
continue
# store the previous quoting state for checking later
was_inside_quotes = inside_quotes
quote_char = _get_quote_state(token, quote_char)
inside_quotes = quote_char is not None
# multiple conditions may append a token to the list of params,
# so we keep track with this flag to make sure it only happens once
# append means add to the end of the list, don't append means concatenate
# it to the end of the last token
appended = False
# if we're inside quotes now, but weren't before, append the token
# to the end of the list, since we'll tack on more to it later
# otherwise, if we're inside any jinja2 block, inside quotes, or we were
# inside quotes (but aren't now) concat this token to the last param
if inside_quotes and not was_inside_quotes:
params.append(token)
appended = True
elif print_depth or block_depth or comment_depth or inside_quotes or was_inside_quotes:
if idx == 0 and not inside_quotes and was_inside_quotes:
params[-1] = "%s%s" % (params[-1], token)
elif len(tokens) > 1:
spacer = ''
if idx > 0:
spacer = ' '
params[-1] = "%s%s%s" % (params[-1], spacer, token)
else:
spacer = ''
if not params[-1].endswith('\n') and idx == 0:
spacer = '\n'
params[-1] = "%s%s%s" % (params[-1], spacer, token)
appended = True
# if the number of paired block tags is not the same, the depth has changed, so we calculate that here
# and may append the current token to the params (if we haven't previously done so)
prev_print_depth = print_depth
print_depth = _count_jinja2_blocks(token, print_depth, "{{", "}}")
if print_depth != prev_print_depth and not appended:
params.append(token)
appended = True
prev_block_depth = block_depth
block_depth = _count_jinja2_blocks(token, block_depth, "{%", "%}")
if block_depth != prev_block_depth and not appended:
params.append(token)
appended = True
prev_comment_depth = comment_depth
comment_depth = _count_jinja2_blocks(token, comment_depth, "{#", "#}")
if comment_depth != prev_comment_depth and not appended:
params.append(token)
appended = True
# finally, if we're at zero depth for all blocks and not inside quotes, and have not
# yet appended anything to the list of params, we do so now
if not (print_depth or block_depth or comment_depth) and not inside_quotes and not appended and token != '':
params.append(token)
# if this was the last token in the list, and we have more than
# one item (meaning we split on newlines), add a newline back here
# to preserve the original structure
if len(items) > 1 and itemidx != len(items) - 1 and not line_continuation:
if not params[-1].endswith('\n') or item == '':
params[-1] += '\n'
# always clear the line continuation flag
line_continuation = False
# If we're done and things are not at zero depth or we're still inside quotes,
# raise an error to indicate that the args were unbalanced
if print_depth or block_depth or comment_depth or inside_quotes:
raise Exception("error while splitting arguments, either an unbalanced jinja2 block or quotes")
# finally, we decode each param back to the unicode it was in the arg string
if do_decode:
params = [x.decode('utf-8') for x in params]
return params
def is_quoted(data):
return len(data) > 0 and (data[0] == '"' and data[-1] == '"' or data[0] == "'" and data[-1] == "'")
def unquote(data):
''' removes first and last quotes from a string, if the string starts and ends with the same quotes '''
if is_quoted(data):
return data[1:-1]
return data
Subproject commit cb69744bcee4b4217d83b4a30006635ba69e2aa0
Subproject commit c16601fffac87c941eb15263f24552e91641963d
......@@ -18,3 +18,4 @@
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
class Host:
def __init__(self, name):
self._name = name
self._connection = None
self._ipv4_address = ''
self._ipv6_address = ''
self._port = 22
self._vars = dict()
def __repr__(self):
return self.get_name()
def get_name(self):
return self._name
def get_groups(self):
return []
def set_variable(self, name, value):
''' sets a variable for this host '''
self._vars[name] = value
def get_vars(self):
''' returns all variables for this host '''
all_vars = self._vars.copy()
all_vars.update(dict(inventory_hostname=self._name))
return all_vars
......@@ -19,3 +19,203 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import json
import os
from yaml import load, YAMLError
from ansible.errors import AnsibleParserError
from ansible.errors.yaml_strings import YAML_SYNTAX_ERROR
from ansible.parsing.vault import VaultLib
from ansible.parsing.splitter import unquote
from ansible.parsing.yaml.loader import AnsibleLoader
from ansible.parsing.yaml.objects import AnsibleBaseYAMLObject
class DataLoader():
'''
The DataLoader class is used to load and parse YAML or JSON content,
either from a given file name or from a string that was previously
read in through other means. A Vault password can be specified, and
any vault-encrypted files will be decrypted.
Data read from files will also be cached, so the file will never be
read from disk more than once.
Usage:
dl = DataLoader()
(or)
dl = DataLoader(vault_password='foo')
ds = dl.load('...')
ds = dl.load_from_file('/path/to/file')
'''
def __init__(self, vault_password=None):
self._basedir = '.'
self._vault_password = vault_password
self._FILE_CACHE = dict()
self._vault = VaultLib(password=vault_password)
def load(self, data, file_name='<string>', show_content=True):
'''
Creates a python datastructure from the given data, which can be either
a JSON or YAML string.
'''
try:
# we first try to load this data as JSON
return json.loads(data)
except:
try:
# if loading JSON failed for any reason, we go ahead
# and try to parse it as YAML instead
return self._safe_load(data, file_name=file_name)
except YAMLError as yaml_exc:
self._handle_error(yaml_exc, file_name, show_content)
def load_from_file(self, file_name):
''' Loads data from a file, which can contain either JSON or YAML. '''
file_name = self.path_dwim(file_name)
# if the file has already been read in and cached, we'll
# return those results to avoid more file/vault operations
if file_name in self._FILE_CACHE:
return self._FILE_CACHE[file_name]
# read the file contents and load the data structure from them
(file_data, show_content) = self._get_file_contents(file_name)
parsed_data = self.load(data=file_data, file_name=file_name, show_content=show_content)
# cache the file contents for next time
self._FILE_CACHE[file_name] = parsed_data
return parsed_data
def path_exists(self, path):
return os.path.exists(path)
def is_directory(self, path):
return os.path.isdir(path)
def is_file(self, path):
return os.path.isfile(path)
def _safe_load(self, stream, file_name=None):
''' Implements yaml.safe_load(), except using our custom loader class. '''
loader = AnsibleLoader(stream, file_name)
try:
return loader.get_single_data()
finally:
loader.dispose()
def _get_file_contents(self, file_name):
'''
Reads the file contents from the given file name, and will decrypt them
if they are found to be vault-encrypted.
'''
if not self.path_exists(file_name) or not self.is_file(file_name):
raise AnsibleParserError("the file_name '%s' does not exist, or is not readable" % file_name)
show_content = True
try:
with open(file_name, 'r') as f:
data = f.read()
if self._vault.is_encrypted(data):
data = self._vault.decrypt(data)
show_content = False
return (data, show_content)
except (IOError, OSError) as e:
raise AnsibleParserError("an error occured while trying to read the file '%s': %s" % (file_name, str(e)))
def _handle_error(self, yaml_exc, file_name, show_content):
'''
Optionally constructs an object (AnsibleBaseYAMLObject) to encapsulate the
file name/position where a YAML exception occured, and raises an AnsibleParserError
to display the syntax exception information.
'''
# if the YAML exception contains a problem mark, use it to construct
# an object the error class can use to display the faulty line
err_obj = None
if hasattr(yaml_exc, 'problem_mark'):
err_obj = AnsibleBaseYAMLObject()
err_obj.set_position_info(file_name, yaml_exc.problem_mark.line + 1, yaml_exc.problem_mark.column + 1)
raise AnsibleParserError(YAML_SYNTAX_ERROR, obj=err_obj, show_content=show_content)
def get_basedir(self):
''' returns the current basedir '''
return self._basedir
def set_basedir(self, basedir):
''' sets the base directory, used to find files when a relative path is given '''
if basedir is not None:
self._basedir = basedir
def path_dwim(self, given):
'''
make relative paths work like folks expect.
'''
given = unquote(given)
if given.startswith("/"):
return os.path.abspath(given)
elif given.startswith("~"):
return os.path.abspath(os.path.expanduser(given))
else:
return os.path.abspath(os.path.join(self._basedir, given))
def path_dwim_relative(self, role_path, dirname, source):
''' find one file in a directory one level up in a dir named dirname relative to current '''
basedir = os.path.dirname(role_path)
if os.path.islink(basedir):
# FIXME:
#basedir = unfrackpath(basedir)
template2 = os.path.join(basedir, dirname, source)
else:
template2 = os.path.join(basedir, '..', dirname, source)
source1 = os.path.join(role_path, dirname, source)
if os.path.exists(source1):
return source1
cur_basedir = self._basedir
self.set_basedir(basedir)
source2 = self.path_dwim(template2)
if os.path.exists(source2):
self.set_basedir(cur_basedir)
return source2
obvious_local_path = self.path_dwim(source)
if os.path.exists(obvious_local_path):
self.set_basedir(cur_basedir)
return obvious_local_path
self.set_basedir(cur_basedir)
return source2 # which does not exist
#def __getstate__(self):
# data = dict(
# basedir = self._basedir,
# vault_password = self._vault_password,
# FILE_CACHE = self._FILE_CACHE,
# )
# return data
#def __setstate__(self, data):
# self._basedir = data.get('basedir', '.')
# self._FILE_CACHE = data.get('FILE_CACHE', dict())
# self._vault_password = data.get('vault_password', '')
#
# self._vault = VaultLib(password=self._vault_password)
......@@ -20,9 +20,10 @@ from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from six import iteritems, string_types
from types import NoneType
from ansible.errors import AnsibleParserError
from ansible.plugins import module_finder
from ansible.plugins import module_loader
from ansible.parsing.splitter import parse_kv
class ModuleArgsParser:
......@@ -120,7 +121,7 @@ class ModuleArgsParser:
(action, args) = self._normalize_new_style_args(thing)
# this can occasionally happen, simplify
if 'args' in args:
if args and 'args' in args:
args = args['args']
return (action, args)
......@@ -144,8 +145,11 @@ class ModuleArgsParser:
elif isinstance(thing, string_types):
# form is like: local_action: copy src=a dest=b ... pretty common
args = parse_kv(thing)
elif isinstance(thing, NoneType):
# this can happen with modules which take no params, like ping:
args = None
else:
raise AnsibleParsingError("unexpected parameter type in action: %s" % type(thing), obj=self._task_ds)
raise AnsibleParserError("unexpected parameter type in action: %s" % type(thing), obj=self._task_ds)
return args
def _normalize_new_style_args(self, thing):
......@@ -180,7 +184,7 @@ class ModuleArgsParser:
else:
# need a dict or a string, so giving up
raise AnsibleParsingError("unexpected parameter type in action: %s" % type(thing), obj=self._task_ds)
raise AnsibleParserError("unexpected parameter type in action: %s" % type(thing), obj=self._task_ds)
return (action, args)
......@@ -224,7 +228,7 @@ class ModuleArgsParser:
# walk the input dictionary to see we recognize a module name
for (item, value) in iteritems(self._task_ds):
if item in module_finder:
if item in module_loader:
# finding more than one module name is a problem
if action is not None:
raise AnsibleParserError("conflicting action statements", obj=self._task_ds)
......
# FIXME: header
try:
import json
except ImportError:
import simplejson as json
def jsonify(result, format=False):
''' format JSON output (uncompressed or uncompressed) '''
if result is None:
return "{}"
result2 = result.copy()
for key, value in result2.items():
if type(value) is str:
result2[key] = value.decode('utf-8', 'ignore')
indent = None
if format:
indent = 4
try:
return json.dumps(result2, sort_keys=True, indent=indent, ensure_ascii=False)
except UnicodeDecodeError:
return json.dumps(result2, sort_keys=True, indent=indent)
......@@ -19,156 +19,3 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import json
import os
from yaml import load, YAMLError
from ansible.errors import AnsibleParserError
from ansible.parsing.vault import VaultLib
from ansible.parsing.splitter import unquote
from ansible.parsing.yaml.loader import AnsibleLoader
from ansible.parsing.yaml.objects import AnsibleBaseYAMLObject
from ansible.parsing.yaml.strings import YAML_SYNTAX_ERROR
class DataLoader():
'''
The DataLoader class is used to load and parse YAML or JSON content,
either from a given file name or from a string that was previously
read in through other means. A Vault password can be specified, and
any vault-encrypted files will be decrypted.
Data read from files will also be cached, so the file will never be
read from disk more than once.
Usage:
dl = DataLoader()
(or)
dl = DataLoader(vault_password='foo')
ds = dl.load('...')
ds = dl.load_from_file('/path/to/file')
'''
_FILE_CACHE = dict()
def __init__(self, vault_password=None):
self._basedir = '.'
self._vault = VaultLib(password=vault_password)
def load(self, data, file_name='<string>', show_content=True):
'''
Creates a python datastructure from the given data, which can be either
a JSON or YAML string.
'''
try:
# we first try to load this data as JSON
return json.loads(data)
except:
try:
# if loading JSON failed for any reason, we go ahead
# and try to parse it as YAML instead
return self._safe_load(data, file_name=file_name)
except YAMLError as yaml_exc:
self._handle_error(yaml_exc, file_name, show_content)
def load_from_file(self, file_name):
''' Loads data from a file, which can contain either JSON or YAML. '''
file_name = self.path_dwim(file_name)
# if the file has already been read in and cached, we'll
# return those results to avoid more file/vault operations
if file_name in self._FILE_CACHE:
return self._FILE_CACHE[file_name]
# read the file contents and load the data structure from them
(file_data, show_content) = self._get_file_contents(file_name)
parsed_data = self.load(data=file_data, file_name=file_name, show_content=show_content)
# cache the file contents for next time
self._FILE_CACHE[file_name] = parsed_data
return parsed_data
def path_exists(self, path):
return os.path.exists(path)
def is_directory(self, path):
return os.path.isdir(path)
def is_file(self, path):
return os.path.isfile(path)
def _safe_load(self, stream, file_name=None):
''' Implements yaml.safe_load(), except using our custom loader class. '''
loader = AnsibleLoader(stream, file_name)
try:
return loader.get_single_data()
finally:
loader.dispose()
def _get_file_contents(self, file_name):
'''
Reads the file contents from the given file name, and will decrypt them
if they are found to be vault-encrypted.
'''
if not self.path_exists(file_name) or not self.is_file(file_name):
raise AnsibleParserError("the file_name '%s' does not exist, or is not readable" % file_name)
show_content = True
try:
with open(file_name, 'r') as f:
data = f.read()
if self._vault.is_encrypted(data):
data = self._vault.decrypt(data)
show_content = False
return (data, show_content)
except (IOError, OSError) as e:
raise AnsibleParserError("an error occurred while trying to read the file '%s': %s" % (file_name, str(e)))
def _handle_error(self, yaml_exc, file_name, show_content):
'''
Optionally constructs an object (AnsibleBaseYAMLObject) to encapsulate the
file name/position where a YAML exception occurred, and raises an AnsibleParserError
to display the syntax exception information.
'''
# if the YAML exception contains a problem mark, use it to construct
# an object the error class can use to display the faulty line
err_obj = None
if hasattr(yaml_exc, 'problem_mark'):
err_obj = AnsibleBaseYAMLObject()
err_obj.set_position_info(file_name, yaml_exc.problem_mark.line + 1, yaml_exc.problem_mark.column + 1)
raise AnsibleParserError(YAML_SYNTAX_ERROR, obj=err_obj, show_content=show_content)
def get_basedir(self):
''' returns the current basedir '''
return self._basedir
def set_basedir(self, basedir):
''' sets the base directory, used to find files when a relative path is given '''
if basedir is not None:
self._basedir = basedir
def path_dwim(self, given):
'''
make relative paths work like folks expect.
'''
given = unquote(given)
if given.startswith("/"):
return os.path.abspath(given)
elif given.startswith("~"):
return os.path.abspath(os.path.expanduser(given))
else:
return os.path.abspath(os.path.join(self._basedir, given))
......@@ -22,7 +22,7 @@ __metaclass__ = type
import os
from ansible.errors import AnsibleError, AnsibleParserError
from ansible.parsing.yaml import DataLoader
from ansible.parsing import DataLoader
from ansible.playbook.attribute import Attribute, FieldAttribute
from ansible.playbook.play import Play
from ansible.plugins import push_basedir
......@@ -33,34 +33,33 @@ __all__ = ['Playbook']
class Playbook:
def __init__(self, loader=None):
def __init__(self, loader):
# Entries in the datastructure of a playbook may
# be either a play or an include statement
self._entries = []
self._basedir = '.'
if loader:
self._loader = loader
else:
self._loader = DataLoader()
self._basedir = os.getcwd()
self._loader = loader
@staticmethod
def load(file_name, loader=None):
def load(file_name, variable_manager=None, loader=None):
pb = Playbook(loader=loader)
pb._load_playbook_data(file_name)
pb._load_playbook_data(file_name=file_name, variable_manager=variable_manager)
return pb
def _load_playbook_data(self, file_name):
def _load_playbook_data(self, file_name, variable_manager):
# add the base directory of the file to the data loader,
# so that it knows where to find relatively pathed files
basedir = os.path.dirname(file_name)
self._loader.set_basedir(basedir)
if os.path.isabs(file_name):
self._basedir = os.path.dirname(file_name)
else:
self._basedir = os.path.normpath(os.path.join(self._basedir, os.path.dirname(file_name)))
# set the loaders basedir
self._loader.set_basedir(self._basedir)
# also add the basedir to the list of module directories
push_basedir(basedir)
push_basedir(self._basedir)
ds = self._loader.load_from_file(file_name)
ds = self._loader.load_from_file(os.path.basename(file_name))
if not isinstance(ds, list):
raise AnsibleParserError("playbooks must be a list of plays", obj=ds)
......@@ -72,11 +71,14 @@ class Playbook:
raise AnsibleParserError("playbook entries must be either a valid play or an include statement", obj=entry)
if 'include' in entry:
entry_obj = PlaybookInclude.load(entry, loader=self._loader)
entry_obj = PlaybookInclude.load(entry, variable_manager=variable_manager, loader=self._loader)
else:
entry_obj = Play.load(entry, loader=self._loader)
entry_obj = Play.load(entry, variable_manager=variable_manager, loader=self._loader)
self._entries.append(entry_obj)
def get_loader(self):
return self._loader
def get_entries(self):
return self._entries[:]
......@@ -21,11 +21,12 @@ __metaclass__ = type
class Attribute:
def __init__(self, isa=None, private=False, default=None):
def __init__(self, isa=None, private=False, default=None, required=False):
self.isa = isa
self.private = private
self.default = default
self.required = required
class FieldAttribute(Attribute):
pass
......@@ -19,22 +19,36 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import uuid
from inspect import getmembers
from io import FileIO
from six import iteritems, string_types
from jinja2.exceptions import UndefinedError
from ansible.errors import AnsibleParserError
from ansible.parsing import DataLoader
from ansible.playbook.attribute import Attribute, FieldAttribute
from ansible.parsing.yaml import DataLoader
from ansible.template import Templar
from ansible.utils.boolean import boolean
from ansible.utils.debug import debug
from ansible.template import template
class Base:
def __init__(self):
# initialize the data loader, this will be provided later
# when the object is actually loaded
# initialize the data loader and variable manager, which will be provided
# later when the object is actually loaded
self._loader = None
self._variable_manager = None
# every object gets a random uuid:
self._uuid = uuid.uuid4()
# each class knows attributes set upon it, see Task.py for example
self._attributes = dict()
......@@ -60,11 +74,15 @@ class Base:
return ds
def load_data(self, ds, loader=None):
def load_data(self, ds, variable_manager=None, loader=None):
''' walk the input datastructure and assign any values '''
assert ds is not None
# the variable manager class is used to manage and merge variables
# down to a single dictionary for reference in templating, etc.
self._variable_manager = variable_manager
# the data loader class is used to parse data from strings and files
if loader is not None:
self._loader = loader
......@@ -94,13 +112,24 @@ class Base:
else:
self._attributes[name] = ds[name]
# return the constructed object
# run early, non-critical validation
self.validate()
# cache the datastructure internally
self._ds = ds
# return the constructed object
return self
def get_ds(self):
return self._ds
def get_loader(self):
return self._loader
def get_variable_manager(self):
return self._variable_manager
def _validate_attributes(self, ds):
'''
Ensures that there are no keys in the datastructure which do
......@@ -112,7 +141,7 @@ class Base:
if key not in valid_attrs:
raise AnsibleParserError("'%s' is not a valid attribute for a %s" % (key, self.__class__.__name__), obj=ds)
def validate(self):
def validate(self, all_vars=dict()):
''' validation that is done at parse time, not load time '''
# walk all fields in the object
......@@ -121,16 +150,111 @@ class Base:
# run validator only if present
method = getattr(self, '_validate_%s' % name, None)
if method:
method(self, attribute)
method(attribute, name, getattr(self, name))
def copy(self):
'''
Create a copy of this object and return it.
'''
new_me = self.__class__()
def post_validate(self, runner_context):
for (name, attribute) in iteritems(self._get_base_attributes()):
setattr(new_me, name, getattr(self, name))
new_me._loader = self._loader
new_me._variable_manager = self._variable_manager
return new_me
def post_validate(self, all_vars=dict(), ignore_undefined=False):
'''
we can't tell that everything is of the right type until we have
all the variables. Run basic types (from isa) as well as
any _post_validate_<foo> functions.
'''
raise exception.NotImplementedError
basedir = None
if self._loader is not None:
basedir = self._loader.get_basedir()
templar = Templar(basedir=basedir, variables=all_vars)
for (name, attribute) in iteritems(self._get_base_attributes()):
if getattr(self, name) is None:
if not attribute.required:
continue
else:
raise AnsibleParserError("the field '%s' is required but was not set" % name)
try:
# if the attribute contains a variable, template it now
value = templar.template(getattr(self, name))
# run the post-validator if present
method = getattr(self, '_post_validate_%s' % name, None)
if method:
method(self, attribute, value)
else:
# otherwise, just make sure the attribute is of the type it should be
if attribute.isa == 'string':
value = unicode(value)
elif attribute.isa == 'int':
value = int(value)
elif attribute.isa == 'bool':
value = boolean(value)
elif attribute.isa == 'list':
if not isinstance(value, list):
value = [ value ]
elif attribute.isa == 'dict' and not isinstance(value, dict):
raise TypeError()
# and assign the massaged value back to the attribute field
setattr(self, name, value)
except (TypeError, ValueError), e:
#raise AnsibleParserError("the field '%s' has an invalid value, and could not be converted to an %s" % (name, attribute.isa), obj=self.get_ds())
raise AnsibleParserError("the field '%s' has an invalid value (%s), and could not be converted to an %s. Error was: %s" % (name, value, attribute.isa, e))
except UndefinedError:
if not ignore_undefined:
raise AnsibleParserError("the field '%s' has an invalid value, which appears to include a variable that is undefined" % (name,))
def serialize(self):
'''
Serializes the object derived from the base object into
a dictionary of values. This only serializes the field
attributes for the object, so this may need to be overridden
for any classes which wish to add additional items not stored
as field attributes.
'''
debug("starting serialization of %s" % self.__class__.__name__)
repr = dict()
for (name, attribute) in iteritems(self._get_base_attributes()):
repr[name] = getattr(self, name)
debug("done serializing %s" % self.__class__.__name__)
return repr
def deserialize(self, data):
'''
Given a dictionary of values, load up the field attributes for
this object. As with serialize(), if there are any non-field
attribute data members, this method will need to be overridden
and extended.
'''
debug("starting deserialization of %s" % self.__class__.__name__)
assert isinstance(data, dict)
for (name, attribute) in iteritems(self._get_base_attributes()):
if name in data:
setattr(self, name, data[name])
else:
setattr(self, name, attribute.default)
debug("done deserializing %s" % self.__class__.__name__)
def __getattr__(self, needle):
......@@ -146,3 +270,11 @@ class Base:
return self._attributes[needle]
raise AttributeError("attribute not found: %s" % needle)
def __getstate__(self):
return self.serialize()
def __setstate__(self, data):
self.__init__()
self.deserialize(data)
......@@ -21,25 +21,28 @@ __metaclass__ = type
from ansible.playbook.attribute import Attribute, FieldAttribute
from ansible.playbook.base import Base
from ansible.playbook.conditional import Conditional
from ansible.playbook.helpers import load_list_of_tasks
from ansible.playbook.role import Role
from ansible.playbook.taggable import Taggable
from ansible.playbook.task_include import TaskInclude
class Block(Base):
class Block(Base, Conditional, Taggable):
_block = FieldAttribute(isa='list')
_rescue = FieldAttribute(isa='list')
_always = FieldAttribute(isa='list')
_tags = FieldAttribute(isa='list', default=[])
_when = FieldAttribute(isa='list', default=[])
_block = FieldAttribute(isa='list')
_rescue = FieldAttribute(isa='list')
_always = FieldAttribute(isa='list')
# for future consideration? this would be functionally
# similar to the 'else' clause for exceptions
#_otherwise = FieldAttribute(isa='list')
def __init__(self, parent_block=None, role=None, task_include=None):
def __init__(self, parent_block=None, role=None, task_include=None, use_handlers=False):
self._parent_block = parent_block
self._role = role
self._role = role
self._task_include = task_include
self._use_handlers = use_handlers
super(Block, self).__init__()
def get_variables(self):
......@@ -48,9 +51,9 @@ class Block(Base):
return dict()
@staticmethod
def load(data, parent_block=None, role=None, task_include=None, loader=None):
b = Block(parent_block=parent_block, role=role, task_include=task_include)
return b.load_data(data, loader=loader)
def load(data, parent_block=None, role=None, task_include=None, use_handlers=False, variable_manager=None, loader=None):
b = Block(parent_block=parent_block, role=role, task_include=task_include, use_handlers=use_handlers)
return b.load_data(data, variable_manager=variable_manager, loader=loader)
def munge(self, ds):
'''
......@@ -70,17 +73,17 @@ class Block(Base):
return ds
def _load_block(self, attr, ds):
return load_list_of_tasks(ds, block=self, loader=self._loader)
return load_list_of_tasks(ds, block=self, role=self._role, variable_manager=self._variable_manager, loader=self._loader, use_handlers=self._use_handlers)
def _load_rescue(self, attr, ds):
return load_list_of_tasks(ds, block=self, loader=self._loader)
return load_list_of_tasks(ds, block=self, role=self._role, variable_manager=self._variable_manager, loader=self._loader, use_handlers=self._use_handlers)
def _load_always(self, attr, ds):
return load_list_of_tasks(ds, block=self, loader=self._loader)
return load_list_of_tasks(ds, block=self, role=self._role, variable_manager=self._variable_manager, loader=self._loader, use_handlers=self._use_handlers)
# not currently used
#def _load_otherwise(self, attr, ds):
# return self._load_list_of_tasks(ds, block=self, loader=self._loader)
# return self._load_list_of_tasks(ds, block=self, role=self._role, variable_manager=self._variable_manager, loader=self._loader, use_handlers=self._use_handlers)
def compile(self):
'''
......@@ -93,3 +96,75 @@ class Block(Base):
task_list.extend(task.compile())
return task_list
def copy(self):
new_me = super(Block, self).copy()
new_me._use_handlers = self._use_handlers
new_me._parent_block = None
if self._parent_block:
new_me._parent_block = self._parent_block.copy()
new_me._role = None
if self._role:
new_me._role = self._role
new_me._task_include = None
if self._task_include:
new_me._task_include = self._task_include.copy()
return new_me
def serialize(self):
'''
Override of the default serialize method, since when we're serializing
a task we don't want to include the attribute list of tasks.
'''
data = dict(when=self.when)
if self._role is not None:
data['role'] = self._role.serialize()
return data
def deserialize(self, data):
'''
Override of the default deserialize method, to match the above overridden
serialize method
'''
# unpack the when attribute, which is the only one we want
self.when = data.get('when')
# if there was a serialized role, unpack it too
role_data = data.get('role')
if role_data:
r = Role()
r.deserialize(role_data)
self._role = r
def evaluate_conditional(self, all_vars):
if self._parent_block is not None:
if not self._parent_block.evaluate_conditional(all_vars):
return False
if self._role is not None:
if not self._role.evaluate_conditional(all_vars):
return False
return super(Block, self).evaluate_conditional(all_vars)
def get_tags(self):
tags = set(self.tags[:])
if self._parent_block:
tags.update(self._parent_block.get_tags())
if self._role:
tags.update(self._role.get_tags())
return tags
#def get_conditionals(self):
# conditionals = set(self.when[:])
# if self._parent_block:
# conditionals.update(self._parent_block.get_conditionals())
# if self._role:
# conditionals.update(self._role.get_conditionals())
# return conditionals
......@@ -19,16 +19,79 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from ansible.errors import *
from ansible.playbook.attribute import FieldAttribute
from ansible.template import Templar
class Conditional:
def __init__(self, task):
self._task = task
self._conditionals = []
'''
This is a mix-in class, to be used with Base to allow the object
to be run conditionally when a condition is met or skipped.
'''
_when = FieldAttribute(isa='list', default=[])
def __init__(self):
super(Conditional, self).__init__()
def _validate_when(self, attr, name, value):
if not isinstance(value, list):
setattr(self, name, [ value ])
def evaluate_conditional(self, all_vars):
'''
Loops through the conditionals set on this object, returning
False if any of them evaluate as such.
'''
templar = Templar(variables=all_vars)
for conditional in self.when:
if not self._check_conditional(conditional, templar):
return False
return True
def _check_conditional(self, conditional, templar):
'''
This method does the low-level evaluation of each conditional
set on this object, using jinja2 to wrap the conditionals for
evaluation.
'''
if conditional is None or conditional == '':
return True
elif not isinstance(conditional, basestring):
return conditional
conditional = conditional.replace("jinja2_compare ","")
# allow variable names
#if conditional in inject and '-' not in str(inject[conditional]):
# conditional = inject[conditional]
conditional = templar.template(conditional, convert_bare=True)
original = str(conditional).replace("jinja2_compare ","")
# a Jinja2 evaluation that results in something Python can eval!
presented = "{%% if %s %%} True {%% else %%} False {%% endif %%}" % conditional
conditional = templar.template(presented)
def evaluate(self, context):
pass
val = conditional.strip()
if val == presented:
# the templating failed, meaning most likely a
# variable was undefined. If we happened to be
# looking for an undefined variable, return True,
# otherwise fail
if "is undefined" in conditional:
return True
elif "is defined" in conditional:
return False
else:
raise AnsibleError("error while evaluating conditional: %s" % original)
elif val == "True":
return True
elif val == "False":
return False
else:
raise AnsibleError("unable to evaluate conditional: %s" % original)
def push(self, conditionals):
if not isinstance(conditionals, list):
conditionals = [ conditionals ]
self._conditionals.extend(conditionals)
......@@ -19,22 +19,35 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from v2.errors import AnsibleError
from v2.inventory import Host
from v2.playbook import Task
from ansible.errors import AnsibleError
#from ansible.inventory.host import Host
from ansible.playbook.task import Task
class Handler(Task):
def __init__(self):
pass
def __init__(self, block=None, role=None, task_include=None):
self._flagged_hosts = []
super(Handler, self).__init__(block=block, role=role, task_include=task_include)
def __repr__(self):
''' returns a human readable representation of the handler '''
return "HANDLER: %s" % self.get_name()
@staticmethod
def load(data, block=None, role=None, task_include=None, variable_manager=None, loader=None):
t = Handler(block=block, role=role, task_include=task_include)
return t.load_data(data, variable_manager=variable_manager, loader=loader)
def flag_for_host(self, host):
assert instanceof(host, Host)
pass
#assert instanceof(host, Host)
if host not in self._flagged_hosts:
self._flagged_hosts.append(host)
def has_triggered(self):
return self._triggered
def has_triggered(self, host):
return host in self._flagged_hosts
def set_triggered(self, triggered):
assert instanceof(triggered, bool)
self._triggered = triggered
def serialize(self):
result = super(Handler, self).serialize()
result['is_handler'] = True
return result
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment