runner.py 12.6 KB
Newer Older
1
# (c) 2012, Michael DeHaan <michael.dehaan@gmail.com>
2
#
3 4 5 6 7 8 9 10 11 12 13 14 15 16
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible.  If not, see <http://www.gnu.org/licenses/>.
17 18
#

19 20 21 22 23
try:
    import json
except ImportError:
    import simplejson as json

24 25
import fnmatch
import multiprocessing
26
import signal
27 28
import os
import traceback
29 30
import paramiko # non-core dependency
import ansible.constants as C 
31
import Queue
32

33
def _executor_hook(job_queue, result_queue):
34
    ''' callback used by multiprocessing pool '''
35 36 37 38 39 40 41 42
    signal.signal(signal.SIGINT, signal.SIG_IGN)
    while not job_queue.empty():
        try:
            job = job_queue.get(block=False)
            runner, host = job
            result_queue.put(runner._executor(host))
        except Queue.Empty:
            pass
43 44 45

class Runner(object):

46 47 48 49 50 51 52 53 54 55 56
    def __init__(self, 
        host_list=C.DEFAULT_HOST_LIST, 
        module_path=C.DEFAULT_MODULE_PATH,
        module_name=C.DEFAULT_MODULE_NAME, 
        module_args=C.DEFAULT_MODULE_ARGS, 
        forks=C.DEFAULT_FORKS, 
        timeout=C.DEFAULT_TIMEOUT, 
        pattern=C.DEFAULT_PATTERN,
        remote_user=C.DEFAULT_REMOTE_USER,
        remote_pass=C.DEFAULT_REMOTE_PASS,
        verbose=False):
57
    
58 59 60 61 62 63 64 65 66 67 68 69
        ''' 
        Constructor
        host_list   -- file on disk listing hosts to manage, or an array of hostnames
        pattern ------ a fnmatch pattern selecting some of the hosts in host_list
        module_path -- location of ansible library on disk
        module_name -- which module to run
        module_args -- arguments to pass to module
        forks -------- how parallel should we be? 1 is extra debuggable.
        remote_user -- who to login as (default root)
        remote_pass -- provide only if you don't want to use keys or ssh-agent
        '''

70
        # save input values
71
        self.host_list   = self.parse_hosts(host_list)
72 73 74 75 76 77 78 79 80
        self.module_path = module_path
        self.module_name = module_name
        self.forks       = forks
        self.pattern     = pattern
        self.module_args = module_args
        self.timeout     = timeout
        self.verbose     = verbose
        self.remote_user = remote_user
        self.remote_pass = remote_pass
81
        self._tmp_paths = {}
82

83 84
    @classmethod
    def parse_hosts(cls, host_list):
85
        ''' parse the host inventory file if not sent as an array '''
86 87 88

        # if the host list is given as a string load the host list
        # from a file, one host per line
89 90 91
        if type(host_list) != list:
            host_list = os.path.expanduser(host_list)
            return file(host_list).read().split("\n")
92

93 94 95
        return host_list


96 97
    def _matches(self, host_name, pattern=None):
        ''' returns if a hostname is matched by the pattern '''
98 99 100
        # a pattern is in fnmatch format but more than one pattern
        # can be strung together with semicolons. ex:
        #   atlanta-web*.example.com;dc-web*.example.com
101 102
        if host_name == '':
            return False
103 104 105 106
        subpatterns = pattern.split(";")
        for subpattern in subpatterns:
            if fnmatch.fnmatch(host_name, subpattern):
                return True
107 108 109 110 111 112 113 114 115 116 117
        return False

    def _connect(self, host):
        ''' 
        obtains a paramiko connection to the host.
        on success, returns (True, connection) 
        on failure, returns (False, traceback str)
        '''
        ssh = paramiko.SSHClient()
        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        try:
118
            # try paramiko
119
            ssh.connect(host, username=self.remote_user, allow_agent=True, 
120
              look_for_keys=True, password=self.remote_pass)
121
            return [ True, ssh ]
122
        except Exception, e:
123
            # it failed somehow, return the failure string
124
            return [ False, str(e) ]
125 126

    def _return_from_module(self, conn, host, result):
127 128
        ''' helper function to handle JSON parsing of results '''
        # disconnect from paramiko/SSH
129 130
        conn.close()
        try:
131
            # try to parse the JSON response
132 133
            return [ host, True, json.loads(result) ]
        except:
134
            # it failed, say so, but return the string anyway
135 136 137
            return [ host, False, result ]

    def _delete_remote_files(self, conn, files):
138
        ''' deletes one or more remote files '''
139 140 141 142
        for filename in files:
            self._exec_command(conn, "rm -f %s" % filename)

    def _transfer_file(self, conn, source, dest):
143
        ''' transfers a remote file '''
144
        self.remote_log(conn, 'COPY remote:%s local:%s' % (source, dest))
145 146 147
        sftp = conn.open_sftp()
        sftp.put(source, dest)
        sftp.close()
148 149

    def _transfer_module(self, conn):
150 151 152 153
        ''' 
        transfers a module file to the remote side to execute it,
        but does not execute it yet
        '''
154 155 156 157 158
        outpath = self._copy_module(conn)
        self._exec_command(conn, "chmod +x %s" % outpath)
        return outpath

    def _execute_module(self, conn, outpath):
159 160 161
        ''' 
        runs a module that has already been transferred
        '''
162 163 164 165 166 167
        cmd = self._command(outpath)
        result = self._exec_command(conn, cmd)
        self._delete_remote_files(conn, [ outpath ])
        return result

    def _execute_normal_module(self, conn, host):
168 169 170 171
        ''' 
        transfer & execute a module that is not 'copy' or 'template'
        because those require extra work.
        '''
172 173 174 175
        module = self._transfer_module(conn)
        result = self._execute_module(conn, module)
        return self._return_from_module(conn, host, result)

176
    def _parse_kv(self, args):
177
        ''' helper function to convert a string of key/value items to a dict '''
178 179 180 181 182 183 184
        options = {}
        for x in args:
            if x.find("=") != -1:
               k, v = x.split("=")
               options[k]=v
        return options

185 186 187
    def _execute_copy(self, conn, host):
        ''' handler for file transfer operations '''

188
        # load up options
189 190 191
        options = self._parse_kv(self.module_args)
        source = options['src']
        dest   = options['dest']
192 193
        
        # transfer the file to a remote tmp location
194 195
        tmp_path = self._get_tmp_path(conn)
        tmp_src = tmp_path + source.split('/')[-1]
196
        self._transfer_file(conn, source, tmp_src)
197 198 199 200 201 202

        # install the copy  module
        self.module_name = 'copy'
        module = self._transfer_module(conn)

        # run the copy module
203
        self.module_args = [ "src=%s" % tmp_src, "dest=%s" % dest ]
204
        result = self._execute_module(conn, module)
205
        self._delete_remote_files(conn, tmp_src)
206 207 208 209 210
        return self._return_from_module(conn, host, result)

    def _execute_template(self, conn, host):
        ''' handler for template operations '''

211
        # load up options
212 213 214 215
        options  = self._parse_kv(self.module_args)
        source   = options['src']
        dest     = options['dest']
        metadata = options.get('metadata', '/etc/ansible/setup')
216 217 218

        # first copy the source template over
        tempname = os.path.split(source)[-1]
219
        temppath = self._get_tmp_path(conn) + tempname
220 221 222 223 224 225 226
        self._transfer_file(conn, source, temppath)

        # install the template module
        self.module_name = 'template'
        module = self._transfer_module(conn)

        # run the template module
227
        self.module_args = [ "src=%s" % temppath, "dest=%s" % dest, "metadata=%s" % metadata ]
228 229 230 231 232 233 234 235 236 237 238 239 240
        result = self._execute_module(conn, module)
        self._delete_remote_files(conn, [ temppath ])
        return self._return_from_module(conn, host, result)


    def _executor(self, host):
        ''' 
        callback executed in parallel for each host.
        returns (hostname, connected_ok, extra)
        where extra is the result of a successful connect
        or a traceback string
        '''

241 242 243 244
        # depending on whether it's a normal module,
        # or a request to use the copy or template
        # module, call the appropriate executor function

245 246 247 248 249 250 251 252 253 254
        ok, conn = self._connect(host)
        if not ok:
            return [ host, False, conn ]
        if self.module_name not in [ 'copy', 'template' ]:
            return self._execute_normal_module(conn, host)
        elif self.module_name == 'copy':
            return self._execute_copy(conn, host)
        elif self.module_name == 'template':
            return self._execute_template(conn, host)
        else:
255 256
            # this would be a coding error in THIS module
            # shouldn't occur
257 258 259
            raise Exception("???")

    def _command(self, outpath):
260
        ''' form up a command string for running over SSH '''
261 262
        cmd = "%s %s" % (outpath, " ".join(self.module_args))
        return cmd
263

264
    def remote_log(self, conn, msg):
265
        ''' this is the function we use to log things '''
266
        stdin, stdout, stderr = conn.exec_command('/usr/bin/logger -t ansible -p auth.info %r' % msg)
267
        # TODO: maybe make that optional
268 269

    def _exec_command(self, conn, cmd):
270
        ''' execute a command string over SSH, return the output '''
271 272 273 274 275 276
        msg = '%s: %s' % (self.module_name, cmd)
        self.remote_log(conn, msg)
        stdin, stdout, stderr = conn.exec_command(cmd)
        results = "\n".join(stdout.readlines())
        return results

277
    def _get_tmp_path(self, conn):
278
        ''' gets a temporary path on a remote box '''
279 280 281 282 283 284

        if conn not in self._tmp_paths:
            output = self._exec_command(conn, "mktemp -d /tmp/ansible.XXXXXX")
            self._tmp_paths[conn] = output.split("\n")[0] + '/'
            
        return self._tmp_paths[conn]
285 286

    def _copy_module(self, conn):
287
        ''' transfer a module over SFTP, does not run it '''
288 289 290
        in_path = os.path.expanduser(
            os.path.join(self.module_path, self.module_name)
        )
291
        out_path = self._get_tmp_path(conn) + self.module_name
292 293 294 295 296
        sftp = conn.open_sftp()
        sftp.put(in_path, out_path)
        sftp.close()
        return out_path

297 298
    def match_hosts(self, pattern):
        ''' return all matched hosts fitting a pattern '''
299 300 301 302
        return [ h for h in self.host_list if self._matches(h, pattern) ]

    def run(self):
        ''' xfer & run module on all matched hosts '''
303
        
304
        # find hosts that match the pattern
305
        hosts = self.match_hosts(self.pattern)
306 307

        # attack pool of hosts in N forks
308
        # _executor_hook does all of the work
309 310
        hosts = [ (self,x) for x in hosts ]
        if self.forks > 1:
311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336
            job_queue = multiprocessing.Queue()
            result_queue = multiprocessing.Queue()
 
            for i in hosts:
                job_queue.put(i)
 
            workers = []
            for i in range(self.forks):
                tmp = multiprocessing.Process(target=_executor_hook,
                                      args=(job_queue, result_queue))
                tmp.start()
                workers.append(tmp)
 
            try:
                for worker in workers:
                    worker.join()
            except KeyboardInterrupt:
                print 'parent received ctrl-c'
                for worker in workers:
                    worker.terminate()
                    worker.join()
            
            results = []
            while not result_queue.empty():
                results.append(result_queue.get(block=False))
 
337
        else:
338
            results = [ x._executor(h) for (x,h) in hosts ]
339 340

        # sort hosts by ones we successfully contacted
341 342 343
        # and ones we did not so that we can return a 
        # dictionary containing results of everything

344
        results2 = {
345 346
          "contacted" : {},
          "dark"      : {}
347
        }
348
        hosts_with_results = []
349 350
        for x in results:
            (host, is_ok, result) = x
351
            hosts_with_results.append(host)
352 353 354 355
            if not is_ok:
                results2["dark"][host] = result
            else:
                results2["contacted"][host] = result
356 357 358 359 360 361 362 363
        # hosts which were contacted but never got a chance
        # to return a result before we exited/ctrl-c'd
        # perhaps these shouldn't be 'dark' but I'm not sure if they fit
        # anywhere else.
        for host in self.match_hosts(self.pattern):
            if host not in hosts_with_results:
                results2["dark"][host] = {}
                
364
        return results2
365 366