Commit 8ef78b1c by James Cammarata

Fixing accelerated connection plugin

parent 00b8a242
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import base64
import json import json
import pipes import pipes
import subprocess import subprocess
...@@ -33,6 +34,7 @@ from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVar ...@@ -33,6 +34,7 @@ from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVar
from ansible.playbook.conditional import Conditional from ansible.playbook.conditional import Conditional
from ansible.playbook.task import Task from ansible.playbook.task import Task
from ansible.template import Templar from ansible.template import Templar
from ansible.utils.encrypt import key_for_hostname
from ansible.utils.listify import listify_lookup_plugin_terms from ansible.utils.listify import listify_lookup_plugin_terms
from ansible.utils.unicode import to_unicode from ansible.utils.unicode import to_unicode
from ansible.vars.unsafe_proxy import UnsafeProxy from ansible.vars.unsafe_proxy import UnsafeProxy
...@@ -309,7 +311,7 @@ class TaskExecutor: ...@@ -309,7 +311,7 @@ class TaskExecutor:
return dict(include=include_file, include_variables=include_variables) return dict(include=include_file, include_variables=include_variables)
# get the connection and the handler for this execution # get the connection and the handler for this execution
self._connection = self._get_connection(variables) self._connection = self._get_connection(variables=variables, templar=templar)
self._connection.set_host_overrides(host=self._host) self._connection.set_host_overrides(host=self._host)
self._handler = self._get_action_handler(connection=self._connection, templar=templar) self._handler = self._get_action_handler(connection=self._connection, templar=templar)
...@@ -466,7 +468,7 @@ class TaskExecutor: ...@@ -466,7 +468,7 @@ class TaskExecutor:
else: else:
return async_result return async_result
def _get_connection(self, variables): def _get_connection(self, variables, templar):
''' '''
Reads the connection property for the host, and returns the Reads the connection property for the host, and returns the
correct connection object from the list of connection plugins correct connection object from the list of connection plugins
...@@ -513,6 +515,38 @@ class TaskExecutor: ...@@ -513,6 +515,38 @@ class TaskExecutor:
if not connection: if not connection:
raise AnsibleError("the connection plugin '%s' was not found" % conn_type) raise AnsibleError("the connection plugin '%s' was not found" % conn_type)
if self._play_context.accelerate:
# launch the accelerated daemon here
ssh_connection = connection
handler = self._shared_loader_obj.action_loader.get(
'normal',
task=self._task,
connection=ssh_connection,
play_context=self._play_context,
loader=self._loader,
templar=templar,
shared_loader_obj=self._shared_loader_obj,
)
key = key_for_hostname(self._play_context.remote_addr)
accelerate_args = dict(
password=base64.b64encode(key.__str__()),
port=self._play_context.accelerate_port,
minutes=C.ACCELERATE_DAEMON_TIMEOUT,
ipv6=self._play_context.accelerate_ipv6,
debug=self._play_context.verbosity,
)
connection = self._shared_loader_obj.connection_loader.get('accelerate', self._play_context, self._new_stdin)
if not connection:
raise AnsibleError("the connection plugin '%s' was not found" % conn_type)
try:
connection._connect()
except AnsibleConnectionFailure:
res = handler._execute_module(module_name='accelerate', module_args=accelerate_args, task_vars=variables, delete_remote_tmp=False)
connection._connect()
return connection return connection
def _get_action_handler(self, connection, templar): def _get_action_handler(self, connection, templar):
......
...@@ -56,6 +56,7 @@ MAGIC_VARIABLE_MAPPING = dict( ...@@ -56,6 +56,7 @@ MAGIC_VARIABLE_MAPPING = dict(
remote_addr = ('ansible_ssh_host', 'ansible_host'), remote_addr = ('ansible_ssh_host', 'ansible_host'),
remote_user = ('ansible_ssh_user', 'ansible_user'), remote_user = ('ansible_ssh_user', 'ansible_user'),
port = ('ansible_ssh_port', 'ansible_port'), port = ('ansible_ssh_port', 'ansible_port'),
accelerate_port = ('ansible_accelerate_port',),
password = ('ansible_ssh_pass', 'ansible_password'), password = ('ansible_ssh_pass', 'ansible_password'),
private_key_file = ('ansible_ssh_private_key_file', 'ansible_private_key_file'), private_key_file = ('ansible_ssh_private_key_file', 'ansible_private_key_file'),
pipelining = ('ansible_ssh_pipelining', 'ansible_pipelining'), pipelining = ('ansible_ssh_pipelining', 'ansible_pipelining'),
...@@ -142,6 +143,9 @@ class PlayContext(Base): ...@@ -142,6 +143,9 @@ class PlayContext(Base):
_ssh_extra_args = FieldAttribute(isa='string') _ssh_extra_args = FieldAttribute(isa='string')
_connection_lockfd= FieldAttribute(isa='int') _connection_lockfd= FieldAttribute(isa='int')
_pipelining = FieldAttribute(isa='bool', default=C.ANSIBLE_SSH_PIPELINING) _pipelining = FieldAttribute(isa='bool', default=C.ANSIBLE_SSH_PIPELINING)
_accelerate = FieldAttribute(isa='bool', default=False)
_accelerate_ipv6 = FieldAttribute(isa='bool', default=False, always_post_validate=True)
_accelerate_port = FieldAttribute(isa='int', default=C.ACCELERATE_PORT, always_post_validate=True)
# privilege escalation fields # privilege escalation fields
_become = FieldAttribute(isa='bool') _become = FieldAttribute(isa='bool')
...@@ -199,6 +203,12 @@ class PlayContext(Base): ...@@ -199,6 +203,12 @@ class PlayContext(Base):
the play class. the play class.
''' '''
# special handling for accelerated mode, as it is set in a separate
# play option from the connection parameter
self.accelerate = play.accelerate
self.accelerate_ipv6 = play.accelerate_ipv6
self.accelerate_port = play.accelerate_port
if play.connection: if play.connection:
self.connection = play.connection self.connection = play.connection
......
...@@ -18,6 +18,11 @@ from __future__ import (absolute_import, division, print_function) ...@@ -18,6 +18,11 @@ from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import os
import stat
import time
import warnings
PASSLIB_AVAILABLE = False PASSLIB_AVAILABLE = False
try: try:
import passlib.hash import passlib.hash
...@@ -25,6 +30,34 @@ try: ...@@ -25,6 +30,34 @@ try:
except: except:
pass pass
KEYCZAR_AVAILABLE=False
try:
try:
# some versions of pycrypto may not have this?
from Crypto.pct_warnings import PowmInsecureWarning
except ImportError:
PowmInsecureWarning = RuntimeWarning
with warnings.catch_warnings(record=True) as warning_handler:
warnings.simplefilter("error", PowmInsecureWarning)
try:
import keyczar.errors as key_errors
from keyczar.keys import AesKey
except PowmInsecureWarning:
system_warning(
"The version of gmp you have installed has a known issue regarding " + \
"timing vulnerabilities when used with pycrypto. " + \
"If possible, you should update it (i.e. yum update gmp)."
)
warnings.resetwarnings()
warnings.simplefilter("ignore")
import keyczar.errors as key_errors
from keyczar.keys import AesKey
KEYCZAR_AVAILABLE=True
except ImportError:
pass
from ansible import constants as C
from ansible.errors import AnsibleError from ansible.errors import AnsibleError
__all__ = ['do_encrypt'] __all__ = ['do_encrypt']
...@@ -47,3 +80,47 @@ def do_encrypt(result, encrypt, salt_size=None, salt=None): ...@@ -47,3 +80,47 @@ def do_encrypt(result, encrypt, salt_size=None, salt=None):
return result return result
def key_for_hostname(hostname):
# fireball mode is an implementation of ansible firing up zeromq via SSH
# to use no persistent daemons or key management
if not KEYCZAR_AVAILABLE:
raise AnsibleError("python-keyczar must be installed on the control machine to use accelerated modes")
key_path = os.path.expanduser(C.ACCELERATE_KEYS_DIR)
if not os.path.exists(key_path):
os.makedirs(key_path, mode=0700)
os.chmod(key_path, int(C.ACCELERATE_KEYS_DIR_PERMS, 8))
elif not os.path.isdir(key_path):
raise AnsibleError('ACCELERATE_KEYS_DIR is not a directory.')
if stat.S_IMODE(os.stat(key_path).st_mode) != int(C.ACCELERATE_KEYS_DIR_PERMS, 8):
raise AnsibleError('Incorrect permissions on the private key directory. Use `chmod 0%o %s` to correct this issue, and make sure any of the keys files contained within that directory are set to 0%o' % (int(C.ACCELERATE_KEYS_DIR_PERMS, 8), C.ACCELERATE_KEYS_DIR, int(C.ACCELERATE_KEYS_FILE_PERMS, 8)))
key_path = os.path.join(key_path, hostname)
# use new AES keys every 2 hours, which means fireball must not allow running for longer either
if not os.path.exists(key_path) or (time.time() - os.path.getmtime(key_path) > 60*60*2):
key = AesKey.Generate(size=256)
fd = os.open(key_path, os.O_WRONLY | os.O_CREAT, int(C.ACCELERATE_KEYS_FILE_PERMS, 8))
fh = os.fdopen(fd, 'w')
fh.write(str(key))
fh.close()
return key
else:
if stat.S_IMODE(os.stat(key_path).st_mode) != int(C.ACCELERATE_KEYS_FILE_PERMS, 8):
raise AnsibleError('Incorrect permissions on the key file for this host. Use `chmod 0%o %s` to correct this issue.' % (int(C.ACCELERATE_KEYS_FILE_PERMS, 8), key_path))
fh = open(key_path)
key = AesKey.Read(fh.read())
fh.close()
return key
def keyczar_encrypt(key, msg):
return key.Encrypt(msg.encode('utf-8'))
def keyczar_decrypt(key, msg):
try:
return key.Decrypt(msg)
except key_errors.InvalidSignatureError:
raise AnsibleError("decryption failed")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment