Commit 8ef78b1c by James Cammarata

Fixing accelerated connection plugin

parent 00b8a242
......@@ -19,6 +19,7 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import base64
import json
import pipes
import subprocess
......@@ -33,6 +34,7 @@ from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVar
from ansible.playbook.conditional import Conditional
from ansible.playbook.task import Task
from ansible.template import Templar
from ansible.utils.encrypt import key_for_hostname
from ansible.utils.listify import listify_lookup_plugin_terms
from ansible.utils.unicode import to_unicode
from ansible.vars.unsafe_proxy import UnsafeProxy
......@@ -309,7 +311,7 @@ class TaskExecutor:
return dict(include=include_file, include_variables=include_variables)
# get the connection and the handler for this execution
self._connection = self._get_connection(variables)
self._connection = self._get_connection(variables=variables, templar=templar)
self._connection.set_host_overrides(host=self._host)
self._handler = self._get_action_handler(connection=self._connection, templar=templar)
......@@ -466,7 +468,7 @@ class TaskExecutor:
else:
return async_result
def _get_connection(self, variables):
def _get_connection(self, variables, templar):
'''
Reads the connection property for the host, and returns the
correct connection object from the list of connection plugins
......@@ -513,6 +515,38 @@ class TaskExecutor:
if not connection:
raise AnsibleError("the connection plugin '%s' was not found" % conn_type)
if self._play_context.accelerate:
# launch the accelerated daemon here
ssh_connection = connection
handler = self._shared_loader_obj.action_loader.get(
'normal',
task=self._task,
connection=ssh_connection,
play_context=self._play_context,
loader=self._loader,
templar=templar,
shared_loader_obj=self._shared_loader_obj,
)
key = key_for_hostname(self._play_context.remote_addr)
accelerate_args = dict(
password=base64.b64encode(key.__str__()),
port=self._play_context.accelerate_port,
minutes=C.ACCELERATE_DAEMON_TIMEOUT,
ipv6=self._play_context.accelerate_ipv6,
debug=self._play_context.verbosity,
)
connection = self._shared_loader_obj.connection_loader.get('accelerate', self._play_context, self._new_stdin)
if not connection:
raise AnsibleError("the connection plugin '%s' was not found" % conn_type)
try:
connection._connect()
except AnsibleConnectionFailure:
res = handler._execute_module(module_name='accelerate', module_args=accelerate_args, task_vars=variables, delete_remote_tmp=False)
connection._connect()
return connection
def _get_action_handler(self, connection, templar):
......
......@@ -56,6 +56,7 @@ MAGIC_VARIABLE_MAPPING = dict(
remote_addr = ('ansible_ssh_host', 'ansible_host'),
remote_user = ('ansible_ssh_user', 'ansible_user'),
port = ('ansible_ssh_port', 'ansible_port'),
accelerate_port = ('ansible_accelerate_port',),
password = ('ansible_ssh_pass', 'ansible_password'),
private_key_file = ('ansible_ssh_private_key_file', 'ansible_private_key_file'),
pipelining = ('ansible_ssh_pipelining', 'ansible_pipelining'),
......@@ -142,6 +143,9 @@ class PlayContext(Base):
_ssh_extra_args = FieldAttribute(isa='string')
_connection_lockfd= FieldAttribute(isa='int')
_pipelining = FieldAttribute(isa='bool', default=C.ANSIBLE_SSH_PIPELINING)
_accelerate = FieldAttribute(isa='bool', default=False)
_accelerate_ipv6 = FieldAttribute(isa='bool', default=False, always_post_validate=True)
_accelerate_port = FieldAttribute(isa='int', default=C.ACCELERATE_PORT, always_post_validate=True)
# privilege escalation fields
_become = FieldAttribute(isa='bool')
......@@ -199,6 +203,12 @@ class PlayContext(Base):
the play class.
'''
# special handling for accelerated mode, as it is set in a separate
# play option from the connection parameter
self.accelerate = play.accelerate
self.accelerate_ipv6 = play.accelerate_ipv6
self.accelerate_port = play.accelerate_port
if play.connection:
self.connection = play.connection
......
......@@ -18,6 +18,11 @@ from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import os
import stat
import time
import warnings
PASSLIB_AVAILABLE = False
try:
import passlib.hash
......@@ -25,6 +30,34 @@ try:
except:
pass
KEYCZAR_AVAILABLE=False
try:
try:
# some versions of pycrypto may not have this?
from Crypto.pct_warnings import PowmInsecureWarning
except ImportError:
PowmInsecureWarning = RuntimeWarning
with warnings.catch_warnings(record=True) as warning_handler:
warnings.simplefilter("error", PowmInsecureWarning)
try:
import keyczar.errors as key_errors
from keyczar.keys import AesKey
except PowmInsecureWarning:
system_warning(
"The version of gmp you have installed has a known issue regarding " + \
"timing vulnerabilities when used with pycrypto. " + \
"If possible, you should update it (i.e. yum update gmp)."
)
warnings.resetwarnings()
warnings.simplefilter("ignore")
import keyczar.errors as key_errors
from keyczar.keys import AesKey
KEYCZAR_AVAILABLE=True
except ImportError:
pass
from ansible import constants as C
from ansible.errors import AnsibleError
__all__ = ['do_encrypt']
......@@ -47,3 +80,47 @@ def do_encrypt(result, encrypt, salt_size=None, salt=None):
return result
def key_for_hostname(hostname):
# fireball mode is an implementation of ansible firing up zeromq via SSH
# to use no persistent daemons or key management
if not KEYCZAR_AVAILABLE:
raise AnsibleError("python-keyczar must be installed on the control machine to use accelerated modes")
key_path = os.path.expanduser(C.ACCELERATE_KEYS_DIR)
if not os.path.exists(key_path):
os.makedirs(key_path, mode=0700)
os.chmod(key_path, int(C.ACCELERATE_KEYS_DIR_PERMS, 8))
elif not os.path.isdir(key_path):
raise AnsibleError('ACCELERATE_KEYS_DIR is not a directory.')
if stat.S_IMODE(os.stat(key_path).st_mode) != int(C.ACCELERATE_KEYS_DIR_PERMS, 8):
raise AnsibleError('Incorrect permissions on the private key directory. Use `chmod 0%o %s` to correct this issue, and make sure any of the keys files contained within that directory are set to 0%o' % (int(C.ACCELERATE_KEYS_DIR_PERMS, 8), C.ACCELERATE_KEYS_DIR, int(C.ACCELERATE_KEYS_FILE_PERMS, 8)))
key_path = os.path.join(key_path, hostname)
# use new AES keys every 2 hours, which means fireball must not allow running for longer either
if not os.path.exists(key_path) or (time.time() - os.path.getmtime(key_path) > 60*60*2):
key = AesKey.Generate(size=256)
fd = os.open(key_path, os.O_WRONLY | os.O_CREAT, int(C.ACCELERATE_KEYS_FILE_PERMS, 8))
fh = os.fdopen(fd, 'w')
fh.write(str(key))
fh.close()
return key
else:
if stat.S_IMODE(os.stat(key_path).st_mode) != int(C.ACCELERATE_KEYS_FILE_PERMS, 8):
raise AnsibleError('Incorrect permissions on the key file for this host. Use `chmod 0%o %s` to correct this issue.' % (int(C.ACCELERATE_KEYS_FILE_PERMS, 8), key_path))
fh = open(key_path)
key = AesKey.Read(fh.read())
fh.close()
return key
def keyczar_encrypt(key, msg):
return key.Encrypt(msg.encode('utf-8'))
def keyczar_decrypt(key, msg):
try:
return key.Decrypt(msg)
except key_errors.InvalidSignatureError:
raise AnsibleError("decryption failed")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment