Commit 995aa8e2 by James Cammarata Committed by vagrant

Making task includes dynamic and fixing many other bugs

Dynamic task includes still need some work, this is a rough first version.
* doesn't work with handler sections of playbooks yet
* when using include + with*, the insertion order is backwards
* fix potential for task lists to be unsynchronized when using the linear
  strategy, as the include conditional could be predicated on an inventory
  variable
parent 62a6378c
...@@ -19,4 +19,4 @@ ...@@ -19,4 +19,4 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
__version__ = '1.v2' __version__ = '2.0'
...@@ -63,8 +63,9 @@ class PlayState: ...@@ -63,8 +63,9 @@ class PlayState:
self._parent_iterator = parent_iterator self._parent_iterator = parent_iterator
self._run_state = ITERATING_SETUP self._run_state = ITERATING_SETUP
self._failed_state = FAILED_NONE self._failed_state = FAILED_NONE
self._task_list = parent_iterator._play.compile()
self._gather_facts = parent_iterator._play.gather_facts self._gather_facts = parent_iterator._play.gather_facts
#self._task_list = parent_iterator._play.compile()
self._task_list = parent_iterator._task_list[:]
self._host = host self._host = host
self._cur_block = None self._cur_block = None
...@@ -209,6 +210,19 @@ class PlayState: ...@@ -209,6 +210,19 @@ class PlayState:
elif self._run_state == ITERATING_ALWAYS: elif self._run_state == ITERATING_ALWAYS:
self._failed_state = FAILED_ALWAYS self._failed_state = FAILED_ALWAYS
def add_tasks(self, task_list):
if self._run_state == ITERATING_TASKS:
before = self._task_list[:self._cur_task_pos]
after = self._task_list[self._cur_task_pos:]
self._task_list = before + task_list + after
elif self._run_state == ITERATING_RESCUE:
before = self._cur_block.rescue[:self._cur_rescue_pos]
after = self._cur_block.rescue[self._cur_rescue_pos:]
self._cur_block.rescue = before + task_list + after
elif self._run_state == ITERATING_ALWAYS:
before = self._cur_block.always[:self._cur_always_pos]
after = self._cur_block.always[self._cur_always_pos:]
self._cur_block.always = before + task_list + after
class PlayIterator: class PlayIterator:
...@@ -235,6 +249,7 @@ class PlayIterator: ...@@ -235,6 +249,7 @@ class PlayIterator:
new_play = play.copy() new_play = play.copy()
new_play.post_validate(all_vars, fail_on_undefined=False) new_play.post_validate(all_vars, fail_on_undefined=False)
self._task_list = new_play.compile()
for host in inventory.get_hosts(new_play.hosts): for host in inventory.get_hosts(new_play.hosts):
if self._first_host is None: if self._first_host is None:
self._first_host = host self._first_host = host
...@@ -267,3 +282,22 @@ class PlayIterator: ...@@ -267,3 +282,22 @@ class PlayIterator:
self._host_entries[host.get_name()].mark_failed() self._host_entries[host.get_name()].mark_failed()
def get_original_task(self, task):
'''
Finds the task in the task list which matches the UUID of the given task.
The executor engine serializes/deserializes objects as they are passed through
the different processes, and not all data structures are preserved. This method
allows us to find the original task passed into the executor engine.
'''
for t in self._task_list:
if t._uuid == task._uuid:
return t
return None
def add_tasks(self, host, task_list):
if host.name not in self._host_entries:
raise AnsibleError("invalid host (%s) specified for playbook iteration (expanding task list)" % host)
self._host_entries[host.name].add_tasks(task_list)
...@@ -137,7 +137,12 @@ class ResultProcess(multiprocessing.Process): ...@@ -137,7 +137,12 @@ class ResultProcess(multiprocessing.Process):
result_items = [ result._result ] result_items = [ result._result ]
for result_item in result_items: for result_item in result_items:
if 'add_host' in result_item: if 'include' in result_item:
include_variables = result_item.get('include_variables', dict())
if 'item' in result_item:
include_variables['item'] = result_item['item']
self._send_result(('include', result._host, result._task, result_item['include'], include_variables))
elif 'add_host' in result_item:
# this task added a new host (add_host module) # this task added a new host (add_host module)
self._send_result(('add_host', result_item)) self._send_result(('add_host', result_item))
elif 'add_group' in result_item: elif 'add_group' in result_item:
......
...@@ -186,6 +186,14 @@ class TaskExecutor: ...@@ -186,6 +186,14 @@ class TaskExecutor:
# Now we do final validation on the task, which sets all fields to their final values # Now we do final validation on the task, which sets all fields to their final values
self._task.post_validate(variables) self._task.post_validate(variables)
# if this task is a TaskInclude, we just return now with a success code so the
# main thread can expand the task list for the given host
if self._task.action == 'include':
include_variables = self._task.args.copy()
include_file = include_variables.get('_raw_params')
del include_variables['_raw_params']
return dict(changed=True, include=include_file, include_variables=include_variables)
# And filter out any fields which were set to default(omit), and got the omit token value # And filter out any fields which were set to default(omit), and got the omit token value
omit_token = variables.get('omit') omit_token = variables.get('omit')
if omit_token is not None: if omit_token is not None:
...@@ -204,7 +212,6 @@ class TaskExecutor: ...@@ -204,7 +212,6 @@ class TaskExecutor:
# with the registered variable value later on when testing conditions # with the registered variable value later on when testing conditions
vars_copy = variables.copy() vars_copy = variables.copy()
debug("starting attempt loop") debug("starting attempt loop")
result = None result = None
for attempt in range(retries): for attempt in range(retries):
......
Subproject commit 095f8681dbdfd2e9247446822e953287c9bca66c Subproject commit 34784b7a617aa35d3b994c9f0795567afc6fb0b0
...@@ -253,7 +253,7 @@ class ModuleArgsParser: ...@@ -253,7 +253,7 @@ class ModuleArgsParser:
# walk the input dictionary to see we recognize a module name # walk the input dictionary to see we recognize a module name
for (item, value) in iteritems(self._task_ds): for (item, value) in iteritems(self._task_ds):
if item in module_loader or item == 'meta': if item in module_loader or item == 'meta' or item == 'include':
# finding more than one module name is a problem # finding more than one module name is a problem
if action is not None: if action is not None:
raise AnsibleParserError("conflicting action statements", obj=self._task_ds) raise AnsibleParserError("conflicting action statements", obj=self._task_ds)
......
...@@ -116,7 +116,7 @@ class Base: ...@@ -116,7 +116,7 @@ class Base:
self.validate() self.validate()
# cache the datastructure internally # cache the datastructure internally
self._ds = ds setattr(self, '_ds', ds)
# return the constructed object # return the constructed object
return self return self
...@@ -231,13 +231,14 @@ class Base: ...@@ -231,13 +231,14 @@ class Base:
as field attributes. as field attributes.
''' '''
#debug("starting serialization of %s" % self.__class__.__name__)
repr = dict() repr = dict()
for (name, attribute) in iteritems(self._get_base_attributes()): for (name, attribute) in iteritems(self._get_base_attributes()):
repr[name] = getattr(self, name) repr[name] = getattr(self, name)
#debug("done serializing %s" % self.__class__.__name__) # serialize the uuid field
repr['uuid'] = getattr(self, '_uuid')
return repr return repr
def deserialize(self, data): def deserialize(self, data):
...@@ -248,7 +249,6 @@ class Base: ...@@ -248,7 +249,6 @@ class Base:
and extended. and extended.
''' '''
#debug("starting deserialization of %s" % self.__class__.__name__)
assert isinstance(data, dict) assert isinstance(data, dict)
for (name, attribute) in iteritems(self._get_base_attributes()): for (name, attribute) in iteritems(self._get_base_attributes()):
...@@ -256,7 +256,9 @@ class Base: ...@@ -256,7 +256,9 @@ class Base:
setattr(self, name, data[name]) setattr(self, name, data[name])
else: else:
setattr(self, name, attribute.default) setattr(self, name, attribute.default)
#debug("done deserializing %s" % self.__class__.__name__)
# restore the UUID field
setattr(self, '_uuid', data.get('uuid'))
def __getattr__(self, needle): def __getattr__(self, needle):
......
...@@ -25,7 +25,6 @@ from ansible.playbook.conditional import Conditional ...@@ -25,7 +25,6 @@ from ansible.playbook.conditional import Conditional
from ansible.playbook.helpers import load_list_of_tasks from ansible.playbook.helpers import load_list_of_tasks
from ansible.playbook.role import Role from ansible.playbook.role import Role
from ansible.playbook.taggable import Taggable from ansible.playbook.taggable import Taggable
from ansible.playbook.task_include import TaskInclude
class Block(Base, Conditional, Taggable): class Block(Base, Conditional, Taggable):
...@@ -178,7 +177,8 @@ class Block(Base, Conditional, Taggable): ...@@ -178,7 +177,8 @@ class Block(Base, Conditional, Taggable):
serialize method serialize method
''' '''
from ansible.playbook.task_include import TaskInclude #from ansible.playbook.task_include import TaskInclude
from ansible.playbook.task import Task
# unpack the when attribute, which is the only one we want # unpack the when attribute, which is the only one we want
self.when = data.get('when') self.when = data.get('when')
...@@ -193,7 +193,7 @@ class Block(Base, Conditional, Taggable): ...@@ -193,7 +193,7 @@ class Block(Base, Conditional, Taggable):
# if there was a serialized task include, unpack it too # if there was a serialized task include, unpack it too
ti_data = data.get('task_include') ti_data = data.get('task_include')
if ti_data: if ti_data:
ti = TaskInclude() ti = Task()
ti.deserialize(ti_data) ti.deserialize(ti_data)
self._task_include = ti self._task_include = ti
......
...@@ -92,7 +92,7 @@ class Conditional: ...@@ -92,7 +92,7 @@ class Conditional:
elif "is defined" in original: elif "is defined" in original:
return False return False
else: else:
raise AnsibleError("error while evaluating conditional: %s" % original) raise AnsibleError("error while evaluating conditional: %s (%s)" % (original, presented))
elif val == "True": elif val == "True":
return True return True
elif val == "False": elif val == "False":
......
...@@ -62,7 +62,7 @@ def load_list_of_tasks(ds, block=None, role=None, task_include=None, use_handler ...@@ -62,7 +62,7 @@ def load_list_of_tasks(ds, block=None, role=None, task_include=None, use_handler
# we import here to prevent a circular dependency with imports # we import here to prevent a circular dependency with imports
from ansible.playbook.handler import Handler from ansible.playbook.handler import Handler
from ansible.playbook.task import Task from ansible.playbook.task import Task
from ansible.playbook.task_include import TaskInclude #from ansible.playbook.task_include import TaskInclude
assert type(ds) == list assert type(ds) == list
...@@ -71,26 +71,27 @@ def load_list_of_tasks(ds, block=None, role=None, task_include=None, use_handler ...@@ -71,26 +71,27 @@ def load_list_of_tasks(ds, block=None, role=None, task_include=None, use_handler
if not isinstance(task, dict): if not isinstance(task, dict):
raise AnsibleParserError("task/handler entries must be dictionaries (got a %s)" % type(task), obj=ds) raise AnsibleParserError("task/handler entries must be dictionaries (got a %s)" % type(task), obj=ds)
if 'include' in task: #if 'include' in task:
cur_basedir = None # cur_basedir = None
if isinstance(task, AnsibleBaseYAMLObject) and loader: # if isinstance(task, AnsibleBaseYAMLObject) and loader:
pos_info = task.get_position_info() # pos_info = task.get_position_info()
new_basedir = os.path.dirname(pos_info[0]) # new_basedir = os.path.dirname(pos_info[0])
cur_basedir = loader.get_basedir() # cur_basedir = loader.get_basedir()
loader.set_basedir(new_basedir) # loader.set_basedir(new_basedir)
t = TaskInclude.load( # t = TaskInclude.load(
task, # task,
block=block, # block=block,
role=role, # role=role,
task_include=task_include, # task_include=task_include,
use_handlers=use_handlers, # use_handlers=use_handlers,
loader=loader # loader=loader
) # )
if cur_basedir and loader: # if cur_basedir and loader:
loader.set_basedir(cur_basedir) # loader.set_basedir(cur_basedir)
else: #else:
if True:
if use_handlers: if use_handlers:
t = Handler.load(task, block=block, role=role, task_include=task_include, variable_manager=variable_manager, loader=loader) t = Handler.load(task, block=block, role=role, task_include=task_include, variable_manager=variable_manager, loader=loader)
else: else:
......
...@@ -33,7 +33,8 @@ from ansible.playbook.block import Block ...@@ -33,7 +33,8 @@ from ansible.playbook.block import Block
from ansible.playbook.conditional import Conditional from ansible.playbook.conditional import Conditional
from ansible.playbook.role import Role from ansible.playbook.role import Role
from ansible.playbook.taggable import Taggable from ansible.playbook.taggable import Taggable
from ansible.playbook.task_include import TaskInclude
__all__ = ['Task']
class Task(Base, Conditional, Taggable): class Task(Base, Conditional, Taggable):
...@@ -93,6 +94,7 @@ class Task(Base, Conditional, Taggable): ...@@ -93,6 +94,7 @@ class Task(Base, Conditional, Taggable):
_sudo_pass = FieldAttribute(isa='string') _sudo_pass = FieldAttribute(isa='string')
_transport = FieldAttribute(isa='string') _transport = FieldAttribute(isa='string')
_until = FieldAttribute(isa='list') # ? _until = FieldAttribute(isa='list') # ?
_vars = FieldAttribute(isa='dict', default=dict())
def __init__(self, block=None, role=None, task_include=None): def __init__(self, block=None, role=None, task_include=None):
''' constructors a task, without the Task.load classmethod, it will be pretty blank ''' ''' constructors a task, without the Task.load classmethod, it will be pretty blank '''
...@@ -201,7 +203,7 @@ class Task(Base, Conditional, Taggable): ...@@ -201,7 +203,7 @@ class Task(Base, Conditional, Taggable):
super(Task, self).post_validate(all_vars=all_vars, fail_on_undefined=fail_on_undefined) super(Task, self).post_validate(all_vars=all_vars, fail_on_undefined=fail_on_undefined)
def get_vars(self): def get_vars(self):
all_vars = dict() all_vars = self.vars.copy()
if self._task_include: if self._task_include:
all_vars.update(self._task_include.get_vars()) all_vars.update(self._task_include.get_vars())
...@@ -256,6 +258,10 @@ class Task(Base, Conditional, Taggable): ...@@ -256,6 +258,10 @@ class Task(Base, Conditional, Taggable):
return data return data
def deserialize(self, data): def deserialize(self, data):
# import is here to avoid import loops
#from ansible.playbook.task_include import TaskInclude
block_data = data.get('block') block_data = data.get('block')
self._dep_chain = data.get('dep_chain', []) self._dep_chain = data.get('dep_chain', [])
...@@ -274,7 +280,8 @@ class Task(Base, Conditional, Taggable): ...@@ -274,7 +280,8 @@ class Task(Base, Conditional, Taggable):
ti_data = data.get('task_include') ti_data = data.get('task_include')
if ti_data: if ti_data:
ti = TaskInclude() #ti = TaskInclude()
ti = Task()
ti.deserialize(ti_data) ti.deserialize(ti_data)
self._task_include = ti self._task_include = ti
del data['task_include'] del data['task_include']
......
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from ansible.errors import AnsibleParserError
from ansible.parsing.splitter import split_args, parse_kv
from ansible.parsing.yaml.objects import AnsibleBaseYAMLObject, AnsibleMapping
from ansible.playbook.attribute import Attribute, FieldAttribute
from ansible.playbook.base import Base
from ansible.playbook.conditional import Conditional
from ansible.playbook.helpers import load_list_of_blocks, compile_block_list
from ansible.playbook.taggable import Taggable
from ansible.plugins import lookup_loader
__all__ = ['TaskInclude']
class TaskInclude(Base, Conditional, Taggable):
'''
A class used to wrap the use of `include: /some/other/file.yml`
within a task list, which may return a list of Task objects and/or
more TaskInclude objects.
'''
# the description field is used mainly internally to
# show a nice reprsentation of this class, rather than
# simply using __class__.__name__
__desc__ = "task include statement"
#-----------------------------------------------------------------
# Attributes
_name = FieldAttribute(isa='string')
_include = FieldAttribute(isa='string')
_loop = FieldAttribute(isa='string', private=True)
_loop_args = FieldAttribute(isa='list', private=True)
_vars = FieldAttribute(isa='dict', default=dict())
def __init__(self, block=None, role=None, task_include=None, use_handlers=False):
self._block = block
self._role = role
self._task_include = task_include
self._use_handlers = use_handlers
self._task_blocks = []
super(TaskInclude, self).__init__()
@staticmethod
def load(data, block=None, role=None, task_include=None, use_handlers=False, variable_manager=None, loader=None):
ti = TaskInclude(block=block, role=role, task_include=None, use_handlers=use_handlers)
return ti.load_data(data, variable_manager=variable_manager, loader=loader)
def munge(self, ds):
'''
Regorganizes the data for a TaskInclude datastructure to line
up with what we expect the proper attributes to be
'''
assert isinstance(ds, dict)
# the new, cleaned datastructure, which will have legacy
# items reduced to a standard structure
new_ds = AnsibleMapping()
if isinstance(ds, AnsibleBaseYAMLObject):
new_ds.copy_position_info(ds)
for (k,v) in ds.iteritems():
if k == 'include':
self._munge_include(ds, new_ds, k, v)
elif k.replace("with_", "") in lookup_loader:
self._munge_loop(ds, new_ds, k, v)
else:
# some basic error checking, to make sure vars are properly
# formatted and do not conflict with k=v parameters
# FIXME: we could merge these instead, but controlling the order
# in which they're encountered could be difficult
if k == 'vars':
if 'vars' in new_ds:
raise AnsibleParserError("include parameters cannot be mixed with 'vars' entries for include statements", obj=ds)
elif not isinstance(v, dict):
raise AnsibleParserError("vars for include statements must be specified as a dictionary", obj=ds)
new_ds[k] = v
return new_ds
def _munge_include(self, ds, new_ds, k, v):
'''
Splits the include line up into filename and parameters
'''
# The include line must include at least one item, which is the filename
# to include. Anything after that should be regarded as a parameter to the include
items = split_args(v)
if len(items) == 0:
raise AnsibleParserError("include statements must specify the file name to include", obj=ds)
else:
# FIXME/TODO: validate that items[0] is a file, which also
# exists and is readable
new_ds['include'] = items[0]
if len(items) > 1:
# rejoin the parameter portion of the arguments and
# then use parse_kv() to get a dict of params back
params = parse_kv(" ".join(items[1:]))
if 'vars' in new_ds:
# FIXME: see fixme above regarding merging vars
raise AnsibleParserError("include parameters cannot be mixed with 'vars' entries for include statements", obj=ds)
new_ds['vars'] = params
def _munge_loop(self, ds, new_ds, k, v):
''' take a lookup plugin name and store it correctly '''
loop_name = k.replace("with_", "")
if new_ds.get('loop') is not None:
raise AnsibleError("duplicate loop in task: %s" % loop_name)
new_ds['loop'] = loop_name
new_ds['loop_args'] = v
def _load_include(self, attr, ds):
''' loads the file name specified in the ds and returns a list of blocks '''
data = self._loader.load_from_file(ds)
if not isinstance(data, list):
raise AnsibleParsingError("included task files must contain a list of tasks", obj=ds)
self._task_blocks = load_list_of_blocks(
data,
parent_block=self._block,
task_include=self,
role=self._role,
use_handlers=self._use_handlers,
loader=self._loader
)
return ds
def compile(self):
'''
Returns the task list for the included tasks.
'''
task_list = []
task_list.extend(compile_block_list(self._task_blocks))
return task_list
def get_vars(self):
'''
Returns the vars for this task include, but also first merges in
those from any parent task include which may exist.
'''
all_vars = dict()
if self._task_include:
all_vars.update(self._task_include.get_vars())
if self._block:
all_vars.update(self._block.get_vars())
all_vars.update(self.vars)
return all_vars
def serialize(self):
data = super(TaskInclude, self).serialize()
if self._block:
data['block'] = self._block.serialize()
if self._role:
data['role'] = self._role.serialize()
if self._task_include:
data['task_include'] = self._task_include.serialize()
return data
def deserialize(self, data):
# import here to prevent circular importing issues
from ansible.playbook.block import Block
from ansible.playbook.role import Role
block_data = data.get('block')
if block_data:
b = Block()
b.deserialize(block_data)
self._block = b
del data['block']
role_data = data.get('role')
if role_data:
r = Role()
r.deserialize(role_data)
self._role = r
del data['role']
ti_data = data.get('task_include')
if ti_data:
ti = TaskInclude()
ti.deserialize(ti_data)
self._task_include = ti
del data['task_include']
super(TaskInclude, self).deserialize(data)
def evaluate_conditional(self, all_vars):
if self._task_include is not None:
if not self._task_include.evaluate_conditional(all_vars):
return False
if self._block is not None:
if not self._block.evaluate_conditional(all_vars):
return False
elif self._role is not None:
if not self._role.evaluate_conditional(all_vars):
return False
return super(TaskInclude, self).evaluate_conditional(all_vars)
def set_loader(self, loader):
self._loader = loader
if self._block:
self._block.set_loader(loader)
elif self._task_include:
self._task_include.set_loader(loader)
...@@ -50,14 +50,17 @@ class CallbackModule(CallbackBase): ...@@ -50,14 +50,17 @@ class CallbackModule(CallbackBase):
def runner_on_ok(self, task, result): def runner_on_ok(self, task, result):
if result._result.get('changed', False): if result._task.action == 'include':
msg = 'included: %s for %s' % (result._task.args.get('_raw_params'), result._host.name)
color = 'cyan'
elif result._result.get('changed', False):
msg = "changed: [%s]" % result._host.get_name() msg = "changed: [%s]" % result._host.get_name()
color = 'yellow' color = 'yellow'
else: else:
msg = "ok: [%s]" % result._host.get_name() msg = "ok: [%s]" % result._host.get_name()
color = 'green' color = 'green'
if (self._display._verbosity > 0 or 'verbose_always' in result._result) and result._task.action != 'setup': if (self._display._verbosity > 0 or 'verbose_always' in result._result) and result._task.action not in ('setup', 'include'):
indent = None indent = None
if 'verbose_always' in result._result: if 'verbose_always' in result._result:
indent = 4 indent = 4
......
...@@ -27,7 +27,8 @@ from ansible.errors import * ...@@ -27,7 +27,8 @@ from ansible.errors import *
from ansible.inventory.host import Host from ansible.inventory.host import Host
from ansible.inventory.group import Group from ansible.inventory.group import Group
from ansible.playbook.helpers import compile_block_list from ansible.playbook.handler import Handler
from ansible.playbook.helpers import load_list_of_blocks, compile_block_list
from ansible.playbook.role import ROLE_CACHE, hash_params from ansible.playbook.role import ROLE_CACHE, hash_params
from ansible.plugins import module_loader from ansible.plugins import module_loader
from ansible.utils.debug import debug from ansible.utils.debug import debug
...@@ -111,7 +112,7 @@ class StrategyBase: ...@@ -111,7 +112,7 @@ class StrategyBase:
return return
debug("exiting _queue_task() for %s/%s" % (host, task)) debug("exiting _queue_task() for %s/%s" % (host, task))
def _process_pending_results(self): def _process_pending_results(self, iterator):
''' '''
Reads results off the final queue and takes appropriate action Reads results off the final queue and takes appropriate action
based on the result (executing callbacks, updating state, etc.). based on the result (executing callbacks, updating state, etc.).
...@@ -155,6 +156,22 @@ class StrategyBase: ...@@ -155,6 +156,22 @@ class StrategyBase:
if entry == hashed_entry : if entry == hashed_entry :
role_obj._had_task_run = True role_obj._had_task_run = True
elif result[0] == 'include':
host = result[1]
task = result[2]
include_file = result[3]
include_vars = result[4]
if isinstance(task, Handler):
# FIXME: figure out how to make includes work for handlers
pass
else:
original_task = iterator.get_original_task(task)
if original_task._role:
include_file = self._loader.path_dwim_relative(original_task._role._role_path, 'tasks', include_file)
new_tasks = self._load_included_file(original_task, include_file, include_vars)
iterator.add_tasks(host, new_tasks)
elif result[0] == 'add_host': elif result[0] == 'add_host':
task_result = result[1] task_result = result[1]
new_host_info = task_result.get('add_host', dict()) new_host_info = task_result.get('add_host', dict())
...@@ -194,7 +211,7 @@ class StrategyBase: ...@@ -194,7 +211,7 @@ class StrategyBase:
except Queue.Empty: except Queue.Empty:
pass pass
def _wait_on_pending_results(self): def _wait_on_pending_results(self, iterator):
''' '''
Wait for the shared counter to drop to zero, using a short sleep Wait for the shared counter to drop to zero, using a short sleep
between checks to ensure we don't spin lock between checks to ensure we don't spin lock
...@@ -202,7 +219,7 @@ class StrategyBase: ...@@ -202,7 +219,7 @@ class StrategyBase:
while self._pending_results > 0 and not self._tqm._terminated: while self._pending_results > 0 and not self._tqm._terminated:
debug("waiting for pending results (%d left)" % self._pending_results) debug("waiting for pending results (%d left)" % self._pending_results)
self._process_pending_results() self._process_pending_results(iterator)
if self._tqm._terminated: if self._tqm._terminated:
break break
time.sleep(0.01) time.sleep(0.01)
...@@ -275,6 +292,33 @@ class StrategyBase: ...@@ -275,6 +292,33 @@ class StrategyBase:
# and add the host to the group # and add the host to the group
new_group.add_host(actual_host) new_group.add_host(actual_host)
def _load_included_file(self, task, include_file, include_vars):
'''
Loads an included YAML file of tasks, applying the optional set of variables.
'''
data = self._loader.load_from_file(include_file)
if not isinstance(data, list):
raise AnsibleParsingError("included task files must contain a list of tasks", obj=ds)
is_handler = isinstance(task, Handler)
block_list = load_list_of_blocks(
data,
parent_block=task._block,
task_include=task,
role=task._role,
use_handlers=is_handler,
loader=self._loader
)
task_list = compile_block_list(block_list)
for t in task_list:
t.vars = include_vars.copy()
return task_list
def cleanup(self, iterator, connection_info): def cleanup(self, iterator, connection_info):
''' '''
Iterates through failed hosts and runs any outstanding rescue/always blocks Iterates through failed hosts and runs any outstanding rescue/always blocks
...@@ -322,10 +366,10 @@ class StrategyBase: ...@@ -322,10 +366,10 @@ class StrategyBase:
self._callback.playbook_on_cleanup_task_start(task.get_name()) self._callback.playbook_on_cleanup_task_start(task.get_name())
self._queue_task(host, task, task_vars, connection_info) self._queue_task(host, task, task_vars, connection_info)
self._process_pending_results() self._process_pending_results(iterator)
# no more work, wait until the queue is drained # no more work, wait until the queue is drained
self._wait_on_pending_results() self._wait_on_pending_results(iterator)
return result return result
...@@ -346,7 +390,7 @@ class StrategyBase: ...@@ -346,7 +390,7 @@ class StrategyBase:
handler_name = handler.get_name() handler_name = handler.get_name()
if handler_name in self._notified_handlers and len(self._notified_handlers[handler_name]): if handler_name in self._notified_handlers and len(self._notified_handlers[handler_name]):
if not len(self.get_hosts_remaining()): if not len(self.get_hosts_remaining(iterator._play)):
self._callback.playbook_on_no_hosts_remaining() self._callback.playbook_on_no_hosts_remaining()
result = False result = False
break break
...@@ -358,9 +402,9 @@ class StrategyBase: ...@@ -358,9 +402,9 @@ class StrategyBase:
self._queue_task(host, handler, task_vars, connection_info) self._queue_task(host, handler, task_vars, connection_info)
handler.flag_for_host(host) handler.flag_for_host(host)
self._process_pending_results() self._process_pending_results(iterator)
self._wait_on_pending_results() self._wait_on_pending_results(iterator)
# wipe the notification list # wipe the notification list
self._notified_handlers[handler_name] = [] self._notified_handlers[handler_name] = []
......
...@@ -96,10 +96,10 @@ class StrategyModule(StrategyBase): ...@@ -96,10 +96,10 @@ class StrategyModule(StrategyBase):
self._blocked_hosts[host.get_name()] = True self._blocked_hosts[host.get_name()] = True
self._queue_task(host, task, task_vars, connection_info) self._queue_task(host, task, task_vars, connection_info)
self._process_pending_results() self._process_pending_results(iterator)
debug("done queuing things up, now waiting for results queue to drain") debug("done queuing things up, now waiting for results queue to drain")
self._wait_on_pending_results() self._wait_on_pending_results(iterator)
debug("results queue empty") debug("results queue empty")
except (IOError, EOFError), e: except (IOError, EOFError), e:
debug("got IOError/EOFError in task loop: %s" % e) debug("got IOError/EOFError in task loop: %s" % e)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment