Commit 1754de33 by Michael DeHaan

Misc code cleanup, mostly whitespace preferences, removing unused imports, plus…

Misc code cleanup, mostly whitespace preferences, removing unused imports, plus a few fixes here and there.
parent 4b739313
...@@ -15,20 +15,23 @@ ...@@ -15,20 +15,23 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>. # along with Ansible. If not, see <http://www.gnu.org/licenses/>.
#######################################################
import utils import utils
import sys import sys
import getpass import getpass
import os import os
import subprocess import subprocess
####################################################### cowsay = None
if os.path.exists("/usr/bin/cowsay"):
cowsay = "/usr/bin/cowsay"
elif os.path.exists("/usr/games/cowsay"):
cowsay = "/usr/games/cowsay"
class AggregateStats(object): class AggregateStats(object):
''' holds stats about per-host activity during playbook runs ''' ''' holds stats about per-host activity during playbook runs '''
def __init__(self): def __init__(self):
self.processed = {} self.processed = {}
self.failures = {} self.failures = {}
self.ok = {} self.ok = {}
...@@ -78,6 +81,7 @@ class AggregateStats(object): ...@@ -78,6 +81,7 @@ class AggregateStats(object):
def regular_generic_msg(hostname, result, oneline, caption): def regular_generic_msg(hostname, result, oneline, caption):
''' output on the result of a module run that is not command ''' ''' output on the result of a module run that is not command '''
if not oneline: if not oneline:
return "%s | %s >> %s\n" % (hostname, caption, utils.jsonify(result,format=True)) return "%s | %s >> %s\n" % (hostname, caption, utils.jsonify(result,format=True))
else: else:
...@@ -85,30 +89,23 @@ def regular_generic_msg(hostname, result, oneline, caption): ...@@ -85,30 +89,23 @@ def regular_generic_msg(hostname, result, oneline, caption):
def banner(msg): def banner(msg):
res = ""
global COWSAY
if os.path.exists("/usr/bin/cowsay"):
COWSAY = "/usr/bin/cowsay"
elif os.path.exists("/usr/games/cowsay"):
COWSAY = "/usr/games/cowsay"
else:
COWSAY = None
if COWSAY != None: if cowsay != None:
cmd = subprocess.Popen("%s -W 60 \"%s\"" % (COWSAY, msg), cmd = subprocess.Popen("%s -W 60 \"%s\"" % (cowsay, msg),
stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
(out, err) = cmd.communicate() (out, err) = cmd.communicate()
res = "%s\n" % out return "%s\n" % out
else: else:
res = "\n%s ********************* " % msg return "\n%s ********************* " % msg
return res
def command_generic_msg(hostname, result, oneline, caption): def command_generic_msg(hostname, result, oneline, caption):
''' output the result of a command run ''' ''' output the result of a command run '''
rc = result.get('rc', '0') rc = result.get('rc', '0')
stdout = result.get('stdout','') stdout = result.get('stdout','')
stderr = result.get('stderr', '') stderr = result.get('stderr', '')
msg = result.get('msg', '') msg = result.get('msg', '')
if not oneline: if not oneline:
buf = "%s | %s | rc=%s >>\n" % (hostname, caption, result.get('rc',0)) buf = "%s | %s | rc=%s >>\n" % (hostname, caption, result.get('rc',0))
if stdout: if stdout:
...@@ -117,8 +114,7 @@ def command_generic_msg(hostname, result, oneline, caption): ...@@ -117,8 +114,7 @@ def command_generic_msg(hostname, result, oneline, caption):
buf += stderr buf += stderr
if msg: if msg:
buf += msg buf += msg
buf += "\n" return buf + "\n"
return buf
else: else:
if stderr: if stderr:
return "%s | %s | rc=%s | (stdout) %s (stderr) %s\n" % (hostname, caption, rc, stdout, stderr) return "%s | %s | rc=%s | (stdout) %s (stderr) %s\n" % (hostname, caption, rc, stdout, stderr)
...@@ -127,6 +123,7 @@ def command_generic_msg(hostname, result, oneline, caption): ...@@ -127,6 +123,7 @@ def command_generic_msg(hostname, result, oneline, caption):
def host_report_msg(hostname, module_name, result, oneline): def host_report_msg(hostname, module_name, result, oneline):
''' summarize the JSON results for a particular host ''' ''' summarize the JSON results for a particular host '''
failed = utils.is_failed(result) failed = utils.is_failed(result)
if module_name in [ 'command', 'shell', 'raw' ] and 'ansible_job_id' not in result: if module_name in [ 'command', 'shell', 'raw' ] and 'ansible_job_id' not in result:
if not failed: if not failed:
...@@ -181,17 +178,21 @@ class CliRunnerCallbacks(DefaultRunnerCallbacks): ...@@ -181,17 +178,21 @@ class CliRunnerCallbacks(DefaultRunnerCallbacks):
''' callbacks for use by /usr/bin/ansible ''' ''' callbacks for use by /usr/bin/ansible '''
def __init__(self): def __init__(self):
# set by /usr/bin/ansible later # set by /usr/bin/ansible later
self.options = None self.options = None
self._async_notified = {} self._async_notified = {}
def on_failed(self, host, res): def on_failed(self, host, res):
self._on_any(host,res) self._on_any(host,res)
def on_ok(self, host, res): def on_ok(self, host, res):
self._on_any(host,res) self._on_any(host,res)
def on_unreachable(self, host, res): def on_unreachable(self, host, res):
if type(res) == dict: if type(res) == dict:
res = res.get('msg','') res = res.get('msg','')
print "%s | FAILED => %s" % (host, res) print "%s | FAILED => %s" % (host, res)
...@@ -205,12 +206,15 @@ class CliRunnerCallbacks(DefaultRunnerCallbacks): ...@@ -205,12 +206,15 @@ class CliRunnerCallbacks(DefaultRunnerCallbacks):
pass pass
def on_error(self, host, err): def on_error(self, host, err):
print >>sys.stderr, "err: [%s] => %s\n" % (host, err) print >>sys.stderr, "err: [%s] => %s\n" % (host, err)
def on_no_hosts(self): def on_no_hosts(self):
print >>sys.stderr, "no hosts matched\n" print >>sys.stderr, "no hosts matched\n"
def on_async_poll(self, host, res, jid, clock): def on_async_poll(self, host, res, jid, clock):
if jid not in self._async_notified: if jid not in self._async_notified:
self._async_notified[jid] = clock + 1 self._async_notified[jid] = clock + 1
if self._async_notified[jid] > clock: if self._async_notified[jid] > clock:
...@@ -218,12 +222,15 @@ class CliRunnerCallbacks(DefaultRunnerCallbacks): ...@@ -218,12 +222,15 @@ class CliRunnerCallbacks(DefaultRunnerCallbacks):
print "<job %s> polling, %ss remaining"%(jid, clock) print "<job %s> polling, %ss remaining"%(jid, clock)
def on_async_ok(self, host, res, jid): def on_async_ok(self, host, res, jid):
print "<job %s> finished on %s => %s"%(jid, host, utils.jsonify(res,format=True)) print "<job %s> finished on %s => %s"%(jid, host, utils.jsonify(res,format=True))
def on_async_failed(self, host, res, jid): def on_async_failed(self, host, res, jid):
print "<job %s> FAILED on %s => %s"%(jid, host, utils.jsonify(res,format=True)) print "<job %s> FAILED on %s => %s"%(jid, host, utils.jsonify(res,format=True))
def _on_any(self, host, result): def _on_any(self, host, result):
print host_report_msg(host, self.options.module_name, result, self.options.one_line) print host_report_msg(host, self.options.module_name, result, self.options.one_line)
if self.options.tree: if self.options.tree:
utils.write_tree_file(self.options.tree, host, utils.json(result,format=True)) utils.write_tree_file(self.options.tree, host, utils.json(result,format=True))
...@@ -234,17 +241,21 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks): ...@@ -234,17 +241,21 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks):
''' callbacks used for Runner() from /usr/bin/ansible-playbook ''' ''' callbacks used for Runner() from /usr/bin/ansible-playbook '''
def __init__(self, stats, verbose=False): def __init__(self, stats, verbose=False):
self.stats = stats self.stats = stats
self._async_notified = {} self._async_notified = {}
self.verbose = verbose self.verbose = verbose
def on_unreachable(self, host, msg): def on_unreachable(self, host, msg):
print "fatal: [%s] => %s" % (host, msg) print "fatal: [%s] => %s" % (host, msg)
def on_failed(self, host, results): def on_failed(self, host, results):
print "failed: [%s] => %s\n" % (host, utils.jsonify(results)) print "failed: [%s] => %s\n" % (host, utils.jsonify(results))
def on_ok(self, host, host_result): def on_ok(self, host, host_result):
# show verbose output for non-setup module results if --verbose is used # show verbose output for non-setup module results if --verbose is used
if not self.verbose or host_result.get("verbose_override",None) is not None: if not self.verbose or host_result.get("verbose_override",None) is not None:
print "ok: [%s]" % (host) print "ok: [%s]" % (host)
...@@ -252,15 +263,19 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks): ...@@ -252,15 +263,19 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks):
print "ok: [%s] => %s" % (host, utils.jsonify(host_result)) print "ok: [%s] => %s" % (host, utils.jsonify(host_result))
def on_error(self, host, err): def on_error(self, host, err):
print >>sys.stderr, "err: [%s] => %s\n" % (host, err) print >>sys.stderr, "err: [%s] => %s\n" % (host, err)
def on_skipped(self, host): def on_skipped(self, host):
print "skipping: [%s]\n" % host print "skipping: [%s]\n" % host
def on_no_hosts(self): def on_no_hosts(self):
print "no hosts matched or remaining\n" print "no hosts matched or remaining\n"
def on_async_poll(self, host, res, jid, clock): def on_async_poll(self, host, res, jid, clock):
if jid not in self._async_notified: if jid not in self._async_notified:
self._async_notified[jid] = clock + 1 self._async_notified[jid] = clock + 1
if self._async_notified[jid] > clock: if self._async_notified[jid] > clock:
...@@ -268,9 +283,11 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks): ...@@ -268,9 +283,11 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks):
print "<job %s> polling, %ss remaining"%(jid, clock) print "<job %s> polling, %ss remaining"%(jid, clock)
def on_async_ok(self, host, res, jid): def on_async_ok(self, host, res, jid):
print "<job %s> finished on %s"%(jid, host) print "<job %s> finished on %s"%(jid, host)
def on_async_failed(self, host, res, jid): def on_async_failed(self, host, res, jid):
print "<job %s> FAILED on %s"%(jid, host) print "<job %s> FAILED on %s"%(jid, host)
######################################################################## ########################################################################
...@@ -279,34 +296,45 @@ class PlaybookCallbacks(object): ...@@ -279,34 +296,45 @@ class PlaybookCallbacks(object):
''' playbook.py callbacks used by /usr/bin/ansible-playbook ''' ''' playbook.py callbacks used by /usr/bin/ansible-playbook '''
def __init__(self, verbose=False): def __init__(self, verbose=False):
self.verbose = verbose self.verbose = verbose
def on_start(self): def on_start(self):
pass pass
def on_notify(self, host, handler): def on_notify(self, host, handler):
pass pass
def on_task_start(self, name, is_conditional): def on_task_start(self, name, is_conditional):
msg = "TASK: [%s]" % name msg = "TASK: [%s]" % name
if is_conditional: if is_conditional:
msg = "NOTIFIED: [%s]" % name msg = "NOTIFIED: [%s]" % name
print banner(msg) print banner(msg)
def on_vars_prompt(self, varname, private=True): def on_vars_prompt(self, varname, private=True):
msg = 'input for %s: ' % varname msg = 'input for %s: ' % varname
if private: if private:
return getpass.getpass(msg) return getpass.getpass(msg)
return raw_input(msg) return raw_input(msg)
def on_setup(self): def on_setup(self):
print banner("GATHERING FACTS") print banner("GATHERING FACTS")
def on_import_for_host(self, host, imported_file): def on_import_for_host(self, host, imported_file):
print "%s: importing %s" % (host, imported_file) print "%s: importing %s" % (host, imported_file)
def on_not_import_for_host(self, host, missing_file): def on_not_import_for_host(self, host, missing_file):
print "%s: not importing file: %s" % (host, missing_file) print "%s: not importing file: %s" % (host, missing_file)
def on_play_start(self, pattern): def on_play_start(self, pattern):
print banner("PLAY [%s]" % pattern) print banner("PLAY [%s]" % pattern)
...@@ -14,17 +14,12 @@ ...@@ -14,17 +14,12 @@
# #
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>. # along with Ansible. If not, see <http://www.gnu.org/licenses/>.
#
import os import os
DEFAULT_HOST_LIST = os.environ.get('ANSIBLE_HOSTS', '/etc/ansible/hosts')
DEFAULT_HOST_LIST = os.environ.get('ANSIBLE_HOSTS', DEFAULT_MODULE_PATH = os.environ.get('ANSIBLE_LIBRARY', '/usr/share/ansible')
'/etc/ansible/hosts') DEFAULT_REMOTE_TMP = os.environ.get('ANSIBLE_REMOTE_TMP', '$HOME/.ansible/tmp')
DEFAULT_MODULE_PATH = os.environ.get('ANSIBLE_LIBRARY',
'/usr/share/ansible')
DEFAULT_REMOTE_TMP = os.environ.get('ANSIBLE_REMOTE_TMP',
'$HOME/.ansible/tmp')
DEFAULT_MODULE_NAME = 'command' DEFAULT_MODULE_NAME = 'command'
DEFAULT_PATTERN = '*' DEFAULT_PATTERN = '*'
......
...@@ -15,11 +15,8 @@ ...@@ -15,11 +15,8 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>. # along with Ansible. If not, see <http://www.gnu.org/licenses/>.
class AnsibleError(Exception): class AnsibleError(Exception):
""" ''' The base Ansible exception from which all others should subclass '''
The base Ansible exception from which all others should subclass.
"""
def __init__(self, msg): def __init__(self, msg):
self.msg = msg self.msg = msg
...@@ -27,11 +24,9 @@ class AnsibleError(Exception): ...@@ -27,11 +24,9 @@ class AnsibleError(Exception):
def __str__(self): def __str__(self):
return self.msg return self.msg
class AnsibleFileNotFound(AnsibleError): class AnsibleFileNotFound(AnsibleError):
pass pass
class AnsibleConnectionFailed(AnsibleError): class AnsibleConnectionFailed(AnsibleError):
pass pass
...@@ -15,16 +15,11 @@ ...@@ -15,16 +15,11 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>. # along with Ansible. If not, see <http://www.gnu.org/licenses/>.
#############################################
# from ansible import errors
class Group(object): class Group(object):
""" ''' a group of ansible hosts '''
Group of ansible hosts
"""
def __init__(self, name=None): def __init__(self, name=None):
self.name = name self.name = name
self.hosts = [] self.hosts = []
self.vars = {} self.vars = {}
...@@ -34,19 +29,23 @@ class Group(object): ...@@ -34,19 +29,23 @@ class Group(object):
raise Exception("group name is required") raise Exception("group name is required")
def add_child_group(self, group): def add_child_group(self, group):
if self == group: if self == group:
raise Exception("can't add group to itself") raise Exception("can't add group to itself")
self.child_groups.append(group) self.child_groups.append(group)
group.parent_groups.append(self) group.parent_groups.append(self)
def add_host(self, host): def add_host(self, host):
self.hosts.append(host) self.hosts.append(host)
host.add_group(self) host.add_group(self)
def set_variable(self, key, value): def set_variable(self, key, value):
self.vars[key] = value self.vars[key] = value
def get_hosts(self): def get_hosts(self):
hosts = [] hosts = []
for kid in self.child_groups: for kid in self.child_groups:
hosts.extend(kid.get_hosts()) hosts.extend(kid.get_hosts())
...@@ -54,6 +53,7 @@ class Group(object): ...@@ -54,6 +53,7 @@ class Group(object):
return hosts return hosts
def get_variables(self): def get_variables(self):
vars = {} vars = {}
# FIXME: verify this variable override order is what we want # FIXME: verify this variable override order is what we want
for ancestor in self.get_ancestors(): for ancestor in self.get_ancestors():
...@@ -62,6 +62,7 @@ class Group(object): ...@@ -62,6 +62,7 @@ class Group(object):
return vars return vars
def _get_ancestors(self): def _get_ancestors(self):
results = {} results = {}
for g in self.parent_groups: for g in self.parent_groups:
results[g.name] = g results[g.name] = g
...@@ -69,8 +70,6 @@ class Group(object): ...@@ -69,8 +70,6 @@ class Group(object):
return results return results
def get_ancestors(self): def get_ancestors(self):
return self._get_ancestors().values()
return self._get_ancestors().values()
...@@ -15,17 +15,14 @@ ...@@ -15,17 +15,14 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>. # along with Ansible. If not, see <http://www.gnu.org/licenses/>.
#############################################
from ansible import errors from ansible import errors
import ansible.constants as C import ansible.constants as C
class Host(object): class Host(object):
""" ''' a single ansible host '''
Group of ansible hosts
"""
def __init__(self, name=None, port=None): def __init__(self, name=None, port=None):
self.name = name self.name = name
self.vars = {} self.vars = {}
self.groups = [] self.groups = []
...@@ -36,12 +33,15 @@ class Host(object): ...@@ -36,12 +33,15 @@ class Host(object):
raise Exception("host name is required") raise Exception("host name is required")
def add_group(self, group): def add_group(self, group):
self.groups.append(group) self.groups.append(group)
def set_variable(self, key, value): def set_variable(self, key, value):
self.vars[key]=value self.vars[key]=value
def get_groups(self): def get_groups(self):
groups = {} groups = {}
for g in self.groups: for g in self.groups:
groups[g.name] = g groups[g.name] = g
...@@ -51,6 +51,7 @@ class Host(object): ...@@ -51,6 +51,7 @@ class Host(object):
return groups.values() return groups.values()
def get_variables(self): def get_variables(self):
results = {} results = {}
for group in self.groups: for group in self.groups:
results.update(group.get_variables()) results.update(group.get_variables())
......
...@@ -54,6 +54,7 @@ class InventoryParser(object): ...@@ -54,6 +54,7 @@ class InventoryParser(object):
# delta asdf=jkl favcolor=red # delta asdf=jkl favcolor=red
def _parse_base_groups(self): def _parse_base_groups(self):
# FIXME: refactor
ungrouped = Group(name='ungrouped') ungrouped = Group(name='ungrouped')
all = Group(name='all') all = Group(name='all')
......
...@@ -26,9 +26,7 @@ from ansible import errors ...@@ -26,9 +26,7 @@ from ansible import errors
from ansible import utils from ansible import utils
class InventoryScript(object): class InventoryScript(object):
""" ''' Host inventory parser for ansible using external inventory scripts. '''
Host inventory parser for ansible using external inventory scripts.
"""
def __init__(self, filename=C.DEFAULT_HOST_LIST): def __init__(self, filename=C.DEFAULT_HOST_LIST):
...@@ -39,6 +37,7 @@ class InventoryScript(object): ...@@ -39,6 +37,7 @@ class InventoryScript(object):
self.groups = self._parse() self.groups = self._parse()
def _parse(self): def _parse(self):
groups = {} groups = {}
self.raw = utils.parse_json(self.data) self.raw = utils.parse_json(self.data)
all=Group('all') all=Group('all')
...@@ -55,4 +54,3 @@ class InventoryScript(object): ...@@ -55,4 +54,3 @@ class InventoryScript(object):
all.add_child_group(group) all.add_child_group(group)
return groups return groups
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>. # along with Ansible. If not, see <http://www.gnu.org/licenses/>.
#############################################
import ansible.constants as C import ansible.constants as C
from ansible.inventory.host import Host from ansible.inventory.host import Host
from ansible.inventory.group import Group from ansible.inventory.group import Group
...@@ -24,9 +22,7 @@ from ansible import errors ...@@ -24,9 +22,7 @@ from ansible import errors
from ansible import utils from ansible import utils
class InventoryParserYaml(object): class InventoryParserYaml(object):
""" ''' Host inventory parser for ansible '''
Host inventory for ansible.
"""
def __init__(self, filename=C.DEFAULT_HOST_LIST): def __init__(self, filename=C.DEFAULT_HOST_LIST):
...@@ -37,6 +33,7 @@ class InventoryParserYaml(object): ...@@ -37,6 +33,7 @@ class InventoryParserYaml(object):
self._parse(data) self._parse(data)
def _make_host(self, hostname): def _make_host(self, hostname):
if hostname in self._hosts: if hostname in self._hosts:
return self._hosts[hostname] return self._hosts[hostname]
else: else:
...@@ -47,6 +44,7 @@ class InventoryParserYaml(object): ...@@ -47,6 +44,7 @@ class InventoryParserYaml(object):
# see file 'test/yaml_hosts' for syntax # see file 'test/yaml_hosts' for syntax
def _parse(self, data): def _parse(self, data):
# FIXME: refactor into subfunctions
all = Group('all') all = Group('all')
ungrouped = Group('ungrouped') ungrouped = Group('ungrouped')
...@@ -134,3 +132,5 @@ class InventoryParserYaml(object): ...@@ -134,3 +132,5 @@ class InventoryParserYaml(object):
# make sure ungrouped.hosts is the complement of grouped_hosts # make sure ungrouped.hosts is the complement of grouped_hosts
ungrouped_hosts = [host for host in ungrouped.hosts if host not in grouped_hosts] ungrouped_hosts = [host for host in ungrouped.hosts if host not in grouped_hosts]
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>. # along with Ansible. If not, see <http://www.gnu.org/licenses/>.
#############################################
import ansible.inventory import ansible.inventory
import ansible.runner import ansible.runner
import ansible.constants as C import ansible.constants as C
...@@ -26,8 +24,6 @@ import os ...@@ -26,8 +24,6 @@ import os
import collections import collections
from play import Play from play import Play
#############################################
class PlayBook(object): class PlayBook(object):
''' '''
runs an ansible playbook, given as a datastructure or YAML filename. runs an ansible playbook, given as a datastructure or YAML filename.
......
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>. # along with Ansible. If not, see <http://www.gnu.org/licenses/>.
#############################################
from ansible import errors from ansible import errors
from ansible import utils from ansible import utils
...@@ -93,9 +91,7 @@ class Task(object): ...@@ -93,9 +91,7 @@ class Task(object):
raise errors.AnsibleError("with_items must be a list, got: %s" % self.with_items) raise errors.AnsibleError("with_items must be a list, got: %s" % self.with_items)
self.module_vars['items'] = self.with_items self.module_vars['items'] = self.with_items
# tags allow certain parts of a playbook to be run without running the whole playbook
# tags allow certain parts of a playbook to be run without
# running the whole playbook
apply_tags = ds.get('tags', None) apply_tags = ds.get('tags', None)
if apply_tags is not None: if apply_tags is not None:
if type(apply_tags) in [ str, unicode ]: if type(apply_tags) in [ str, unicode ]:
...@@ -104,5 +100,3 @@ class Task(object): ...@@ -104,5 +100,3 @@ class Task(object):
self.tags.extend(apply_tags) self.tags.extend(apply_tags)
self.tags.extend(import_tags) self.tags.extend(import_tags)
...@@ -228,7 +228,9 @@ class Runner(object): ...@@ -228,7 +228,9 @@ class Runner(object):
afo.close() afo.close()
remote = os.path.join(tmp, name) remote = os.path.join(tmp, name)
try:
conn.put_file(afile, remote) conn.put_file(afile, remote)
finally:
os.unlink(afile) os.unlink(afile)
return remote return remote
...@@ -632,6 +634,7 @@ class Runner(object): ...@@ -632,6 +634,7 @@ class Runner(object):
result = None result = None
handler = getattr(self, "_execute_%s" % self.module_name, None) handler = getattr(self, "_execute_%s" % self.module_name, None)
if handler: if handler:
result = handler(conn, tmp) result = handler(conn, tmp)
else: else:
......
...@@ -14,21 +14,11 @@ ...@@ -14,21 +14,11 @@
# #
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>. # along with Ansible. If not, see <http://www.gnu.org/licenses/>.
#
################################################
import warnings
import traceback import traceback
import os import os
import time
import re
import shutil import shutil
import subprocess import subprocess
import pipes
import socket
import random
from ansible import errors from ansible import errors
class LocalConnection(object): class LocalConnection(object):
...@@ -45,6 +35,7 @@ class LocalConnection(object): ...@@ -45,6 +35,7 @@ class LocalConnection(object):
def exec_command(self, cmd, tmp_path, sudo_user, sudoable=False): def exec_command(self, cmd, tmp_path, sudo_user, sudoable=False):
''' run a command on the local host ''' ''' run a command on the local host '''
if self.runner.sudo and sudoable: if self.runner.sudo and sudoable:
cmd = "sudo -s %s" % cmd cmd = "sudo -s %s" % cmd
if self.runner.sudo_pass: if self.runner.sudo_pass:
...@@ -60,6 +51,7 @@ class LocalConnection(object): ...@@ -60,6 +51,7 @@ class LocalConnection(object):
def put_file(self, in_path, out_path): def put_file(self, in_path, out_path):
''' transfer a file from local to local ''' ''' transfer a file from local to local '''
if not os.path.exists(in_path): if not os.path.exists(in_path):
raise errors.AnsibleFileNotFound("file or module does not exist: %s" % in_path) raise errors.AnsibleFileNotFound("file or module does not exist: %s" % in_path)
try: try:
...@@ -77,5 +69,4 @@ class LocalConnection(object): ...@@ -77,5 +69,4 @@ class LocalConnection(object):
def close(self): def close(self):
''' terminate the connection; nothing to do here ''' ''' terminate the connection; nothing to do here '''
pass pass
...@@ -14,24 +14,19 @@ ...@@ -14,24 +14,19 @@
# #
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>. # along with Ansible. If not, see <http://www.gnu.org/licenses/>.
#
################################################
import warnings import warnings
import traceback import traceback
import os import os
import time
import re import re
import shutil import shutil
import subprocess import subprocess
import pipes import pipes
import socket import socket
import random import random
from ansible import errors from ansible import errors
# prevent paramiko warning noise
# see http://stackoverflow.com/questions/3920502/ # prevent paramiko warning noise -- see http://stackoverflow.com/questions/3920502/
HAVE_PARAMIKO=False HAVE_PARAMIKO=False
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
...@@ -52,27 +47,20 @@ class ParamikoConnection(object): ...@@ -52,27 +47,20 @@ class ParamikoConnection(object):
if port is None: if port is None:
self.port = self.runner.remote_port self.port = self.runner.remote_port
def _get_conn(self): def connect(self):
''' activates the connection object '''
if not HAVE_PARAMIKO: if not HAVE_PARAMIKO:
raise errors.AnsibleError("paramiko is not installed") raise errors.AnsibleError("paramiko is not installed")
user = self.runner.remote_user user = self.runner.remote_user
ssh = paramiko.SSHClient() ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
try: try:
ssh.connect( ssh.connect(self.host, username=user, allow_agent=True, look_for_keys=True,
self.host, key_filename=self.runner.private_key_file, password=self.runner.remote_pass,
username=user, timeout=self.runner.timeout, port=self.port)
allow_agent=True,
look_for_keys=True,
key_filename=self.runner.private_key_file,
password=self.runner.remote_pass,
timeout=self.runner.timeout,
port=self.port
)
except Exception, e: except Exception, e:
msg = str(e) msg = str(e)
if "PID check failed" in msg: if "PID check failed" in msg:
...@@ -84,17 +72,12 @@ class ParamikoConnection(object): ...@@ -84,17 +72,12 @@ class ParamikoConnection(object):
else: else:
raise errors.AnsibleConnectionFailed(msg) raise errors.AnsibleConnectionFailed(msg)
return ssh self.ssh = ssh
def connect(self):
''' connect to the remote host '''
self.ssh = self._get_conn()
return self return self
def exec_command(self, cmd, tmp_path, sudo_user, sudoable=False): def exec_command(self, cmd, tmp_path, sudo_user, sudoable=False):
''' run a command on the remote host ''' ''' run a command on the remote host '''
bufsize = 4096 bufsize = 4096
chan = self.ssh.get_transport().open_session() chan = self.ssh.get_transport().open_session()
chan.get_pty() chan.get_pty()
...@@ -128,10 +111,7 @@ class ParamikoConnection(object): ...@@ -128,10 +111,7 @@ class ParamikoConnection(object):
except socket.timeout: except socket.timeout:
raise errors.AnsibleError('ssh timed out waiting for sudo.\n' + sudo_output) raise errors.AnsibleError('ssh timed out waiting for sudo.\n' + sudo_output)
stdin = chan.makefile('wb', bufsize) return (chan.makefile('wb', bufsize), chan.makefile('rb', bufsize), '')
stdout = chan.makefile('rb', bufsize)
stderr = '' # stderr goes to stdout when using a pty, so this will never output anything.
return stdin, stdout, stderr
def put_file(self, in_path, out_path): def put_file(self, in_path, out_path):
''' transfer a file from local to remote ''' ''' transfer a file from local to remote '''
...@@ -141,21 +121,19 @@ class ParamikoConnection(object): ...@@ -141,21 +121,19 @@ class ParamikoConnection(object):
try: try:
sftp.put(in_path, out_path) sftp.put(in_path, out_path)
except IOError: except IOError:
traceback.print_exc()
raise errors.AnsibleError("failed to transfer file to %s" % out_path) raise errors.AnsibleError("failed to transfer file to %s" % out_path)
sftp.close() sftp.close()
def fetch_file(self, in_path, out_path): def fetch_file(self, in_path, out_path):
''' save a remote file to the specified path '''
sftp = self.ssh.open_sftp() sftp = self.ssh.open_sftp()
try: try:
sftp.get(in_path, out_path) sftp.get(in_path, out_path)
except IOError: except IOError:
traceback.print_exc()
raise errors.AnsibleError("failed to transfer file from %s" % in_path) raise errors.AnsibleError("failed to transfer file from %s" % in_path)
sftp.close() sftp.close()
def close(self): def close(self):
''' terminate the connection ''' ''' terminate the connection '''
self.ssh.close() self.ssh.close()
...@@ -16,17 +16,13 @@ ...@@ -16,17 +16,13 @@
# along with Ansible. If not, see <http://www.gnu.org/licenses/>. # along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# #
################################################
import os import os
import time
import subprocess import subprocess
import shlex import shlex
import pipes import pipes
import random import random
import select import select
import fcntl import fcntl
from ansible import errors from ansible import errors
class SSHConnection(object): class SSHConnection(object):
...@@ -39,6 +35,7 @@ class SSHConnection(object): ...@@ -39,6 +35,7 @@ class SSHConnection(object):
def connect(self): def connect(self):
''' connect to the remote host ''' ''' connect to the remote host '''
self.common_args = [] self.common_args = []
extra_args = os.getenv("ANSIBLE_SSH_ARGS", None) extra_args = os.getenv("ANSIBLE_SSH_ARGS", None)
if extra_args is not None: if extra_args is not None:
...@@ -134,5 +131,6 @@ class SSHConnection(object): ...@@ -134,5 +131,6 @@ class SSHConnection(object):
raise errors.AnsibleError("failed to transfer file from %s:\n%s\n%s" % (in_path, stdout, stderr)) raise errors.AnsibleError("failed to transfer file from %s:\n%s\n%s" % (in_path, stdout, stderr))
def close(self): def close(self):
''' terminate the connection ''' ''' not applicable since we're executing openssh binaries '''
pass pass
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>. # along with Ansible. If not, see <http://www.gnu.org/licenses/>.
###############################################################
import sys import sys
import os import os
import shlex import shlex
...@@ -45,15 +43,18 @@ except ImportError: ...@@ -45,15 +43,18 @@ except ImportError:
def err(msg): def err(msg):
''' print an error message to stderr ''' ''' print an error message to stderr '''
print >> sys.stderr, msg print >> sys.stderr, msg
def exit(msg, rc=1): def exit(msg, rc=1):
''' quit with an error to stdout and a failure code ''' ''' quit with an error to stdout and a failure code '''
err(msg) err(msg)
sys.exit(rc) sys.exit(rc)
def jsonify(result, format=False): def jsonify(result, format=False):
''' format JSON output (uncompressed or uncompressed) ''' ''' format JSON output (uncompressed or uncompressed) '''
result2 = result.copy() result2 = result.copy()
if format: if format:
return json.dumps(result2, sort_keys=True, indent=4) return json.dumps(result2, sort_keys=True, indent=4)
...@@ -62,6 +63,7 @@ def jsonify(result, format=False): ...@@ -62,6 +63,7 @@ def jsonify(result, format=False):
def write_tree_file(tree, hostname, buf): def write_tree_file(tree, hostname, buf):
''' write something into treedir/hostname ''' ''' write something into treedir/hostname '''
# TODO: might be nice to append playbook runs per host in a similar way # TODO: might be nice to append playbook runs per host in a similar way
# in which case, we'd want append mode. # in which case, we'd want append mode.
path = os.path.join(tree, hostname) path = os.path.join(tree, hostname)
...@@ -71,10 +73,12 @@ def write_tree_file(tree, hostname, buf): ...@@ -71,10 +73,12 @@ def write_tree_file(tree, hostname, buf):
def is_failed(result): def is_failed(result):
''' is a given JSON result a failed result? ''' ''' is a given JSON result a failed result? '''
return ((result.get('rc', 0) != 0) or (result.get('failed', False) in [ True, 'True', 'true'])) return ((result.get('rc', 0) != 0) or (result.get('failed', False) in [ True, 'True', 'true']))
def prepare_writeable_dir(tree): def prepare_writeable_dir(tree):
''' make sure a directory exists and is writeable ''' ''' make sure a directory exists and is writeable '''
if tree != '/': if tree != '/':
tree = os.path.realpath(os.path.expanduser(tree)) tree = os.path.realpath(os.path.expanduser(tree))
if not os.path.exists(tree): if not os.path.exists(tree):
...@@ -87,6 +91,7 @@ def prepare_writeable_dir(tree): ...@@ -87,6 +91,7 @@ def prepare_writeable_dir(tree):
def path_dwim(basedir, given): def path_dwim(basedir, given):
''' make relative paths work like folks expect ''' ''' make relative paths work like folks expect '''
if given.startswith("/"): if given.startswith("/"):
return given return given
elif given.startswith("~/"): elif given.startswith("~/"):
...@@ -96,10 +101,12 @@ def path_dwim(basedir, given): ...@@ -96,10 +101,12 @@ def path_dwim(basedir, given):
def json_loads(data): def json_loads(data):
''' parse a JSON string and return a data structure ''' ''' parse a JSON string and return a data structure '''
return json.loads(data) return json.loads(data)
def parse_json(data): def parse_json(data):
''' this version for module return data only ''' ''' this version for module return data only '''
try: try:
return json.loads(data) return json.loads(data)
except: except:
...@@ -132,6 +139,7 @@ _LISTRE = re.compile(r"(\w+)\[(\d+)\]") ...@@ -132,6 +139,7 @@ _LISTRE = re.compile(r"(\w+)\[(\d+)\]")
def _varLookup(name, vars): def _varLookup(name, vars):
''' find the contents of a possibly complex variable in vars. ''' ''' find the contents of a possibly complex variable in vars. '''
path = name.split('.') path = name.split('.')
space = vars space = vars
for part in path: for part in path:
...@@ -153,19 +161,14 @@ _KEYCRE = re.compile(r"\$(?P<complex>\{){0,1}((?(complex)[\w\.\[\]]+|\w+))(?(com ...@@ -153,19 +161,14 @@ _KEYCRE = re.compile(r"\$(?P<complex>\{){0,1}((?(complex)[\w\.\[\]]+|\w+))(?(com
def varLookup(varname, vars): def varLookup(varname, vars):
''' helper function used by varReplace ''' ''' helper function used by varReplace '''
m = _KEYCRE.search(varname) m = _KEYCRE.search(varname)
if not m: if not m:
return None return None
return _varLookup(m.group(2), vars) return _varLookup(m.group(2), vars)
def varReplace(raw, vars): def varReplace(raw, vars):
'''Perform variable replacement of $vars ''' Perform variable replacement of $variables in string raw using vars dictionary '''
@param raw: String to perform substitution on.
@param vars: Dictionary of variables to replace. Key is variable name
(without $ prefix). Value is replacement string.
@return: Input raw string with substituted values.
'''
# this code originally from yum # this code originally from yum
done = [] # Completed chunks to return done = [] # Completed chunks to return
...@@ -191,13 +194,14 @@ def varReplace(raw, vars): ...@@ -191,13 +194,14 @@ def varReplace(raw, vars):
def _template(text, vars, setup_cache=None): def _template(text, vars, setup_cache=None):
''' run a text buffer through the templating engine ''' ''' run a text buffer through the templating engine '''
vars = vars.copy() vars = vars.copy()
vars['hostvars'] = setup_cache vars['hostvars'] = setup_cache
return varReplace(unicode(text), vars) return varReplace(unicode(text), vars)
def template(text, vars, setup_cache=None): def template(text, vars, setup_cache=None):
''' run a text buffer through the templating engine ''' run a text buffer through the templating engine until it no longer changes '''
until it no longer changes '''
prev_text = '' prev_text = ''
depth = 0 depth = 0
while prev_text != text: while prev_text != text:
...@@ -210,6 +214,7 @@ def template(text, vars, setup_cache=None): ...@@ -210,6 +214,7 @@ def template(text, vars, setup_cache=None):
def template_from_file(basedir, path, vars, setup_cache): def template_from_file(basedir, path, vars, setup_cache):
''' run a file through the templating engine ''' ''' run a file through the templating engine '''
environment = jinja2.Environment(loader=jinja2.FileSystemLoader(basedir), trim_blocks=False) environment = jinja2.Environment(loader=jinja2.FileSystemLoader(basedir), trim_blocks=False)
data = codecs.open(path_dwim(basedir, path), encoding="utf8").read() data = codecs.open(path_dwim(basedir, path), encoding="utf8").read()
t = environment.from_string(data) t = environment.from_string(data)
...@@ -222,10 +227,12 @@ def template_from_file(basedir, path, vars, setup_cache): ...@@ -222,10 +227,12 @@ def template_from_file(basedir, path, vars, setup_cache):
def parse_yaml(data): def parse_yaml(data):
''' convert a yaml string to a data structure ''' ''' convert a yaml string to a data structure '''
return yaml.load(data) return yaml.load(data)
def parse_yaml_from_file(path): def parse_yaml_from_file(path):
''' convert a yaml file to a data structure ''' ''' convert a yaml file to a data structure '''
try: try:
data = file(path).read() data = file(path).read()
except IOError: except IOError:
...@@ -234,6 +241,7 @@ def parse_yaml_from_file(path): ...@@ -234,6 +241,7 @@ def parse_yaml_from_file(path):
def parse_kv(args): def parse_kv(args):
''' convert a string of key/value items to a dict ''' ''' convert a string of key/value items to a dict '''
options = {} options = {}
if args is not None: if args is not None:
vargs = shlex.split(args, posix=True) vargs = shlex.split(args, posix=True)
...@@ -245,6 +253,7 @@ def parse_kv(args): ...@@ -245,6 +253,7 @@ def parse_kv(args):
def md5(filename): def md5(filename):
''' Return MD5 hex digest of local file, or None if file is not present. ''' ''' Return MD5 hex digest of local file, or None if file is not present. '''
if not os.path.exists(filename): if not os.path.exists(filename):
return None return None
digest = _md5() digest = _md5()
...@@ -263,6 +272,7 @@ def md5(filename): ...@@ -263,6 +272,7 @@ def md5(filename):
class SortedOptParser(optparse.OptionParser): class SortedOptParser(optparse.OptionParser):
'''Optparser which sorts the options by opt before outputting --help''' '''Optparser which sorts the options by opt before outputting --help'''
def format_help(self, formatter=None): def format_help(self, formatter=None):
self.option_list.sort(key=operator.methodcaller('get_opt_string')) self.option_list.sort(key=operator.methodcaller('get_opt_string'))
return optparse.OptionParser.format_help(self, formatter=None) return optparse.OptionParser.format_help(self, formatter=None)
...@@ -321,4 +331,3 @@ def base_parser(constants=C, usage="", output_opts=False, runas_opts=False, asyn ...@@ -321,4 +331,3 @@ def base_parser(constants=C, usage="", output_opts=False, runas_opts=False, asyn
return parser return parser
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment