Commit 07888d97 by willmcgugan@gmail.com

Applied patch to sftp that searches for ssh keys in default locations

parent dccbc1f3
...@@ -14,7 +14,6 @@ import threading ...@@ -14,7 +14,6 @@ import threading
import os import os
import paramiko import paramiko
from getpass import getuser from getpass import getuser
from binascii import hexlify
from fs.base import * from fs.base import *
from fs.path import * from fs.path import *
...@@ -34,17 +33,18 @@ else: ...@@ -34,17 +33,18 @@ else:
class thread_local(object): class thread_local(object):
def __init__(self): def __init__(self):
self._map = {} self._map = {}
def __getattr__(self,attr):
def __getattr__(self, attr):
try: try:
return self._map[(threading.currentThread().ident,attr)] return self._map[(threading.currentThread().ident, attr)]
except KeyError: except KeyError:
raise AttributeError, attr raise AttributeError(attr)
def __setattr__(self,attr,value):
self._map[(threading.currentThread().ident,attr)] = value
def __setattr__(self, attr, value):
self._map[(threading.currentThread().ident, attr)] = value
if not hasattr(paramiko.SFTPFile,"__enter__"): if not hasattr(paramiko.SFTPFile, "__enter__"):
paramiko.SFTPFile.__enter__ = lambda self: self paramiko.SFTPFile.__enter__ = lambda self: self
paramiko.SFTPFile.__exit__ = lambda self,et,ev,tb: self.close() and False paramiko.SFTPFile.__exit__ = lambda self,et,ev,tb: self.close() and False
...@@ -54,7 +54,7 @@ class SFTPFS(FS): ...@@ -54,7 +54,7 @@ class SFTPFS(FS):
This is basically a compatibility wrapper for the excellent SFTPClient This is basically a compatibility wrapper for the excellent SFTPClient
class in the paramiko module. class in the paramiko module.
""" """
_meta = { 'thread_safe' : True, _meta = { 'thread_safe' : True,
...@@ -70,8 +70,17 @@ class SFTPFS(FS): ...@@ -70,8 +70,17 @@ class SFTPFS(FS):
'atomic.setcontents' : False 'atomic.setcontents' : False
} }
def __init__(self,
def __init__(self, connection, root_path="/", encoding=None, hostkey=None, username='', password=None, pkey=None, agent_auth=True, no_auth=False): connection,
root_path="/",
encoding=None,
hostkey=None,
username='',
password=None,
pkey=None,
agent_auth=True,
no_auth=False,
look_for_keys=True):
"""SFTPFS constructor. """SFTPFS constructor.
The only required argument is 'connection', which must be something The only required argument is 'connection', which must be something
...@@ -84,8 +93,8 @@ class SFTPFS(FS): ...@@ -84,8 +93,8 @@ class SFTPFS(FS):
* a paramiko.Channel instance in "sftp" mode * a paramiko.Channel instance in "sftp" mode
The keyword argument 'root_path' specifies the root directory on the The keyword argument 'root_path' specifies the root directory on the
remote machine - access to files outside this root will be prevented. remote machine - access to files outside this root will be prevented.
:param connection: a connection string :param connection: a connection string
:param root_path: The root path to open :param root_path: The root path to open
:param encoding: String encoding of paths (defaults to UTF-8) :param encoding: String encoding of paths (defaults to UTF-8)
...@@ -95,14 +104,16 @@ class SFTPFS(FS): ...@@ -95,14 +104,16 @@ class SFTPFS(FS):
:param pkey: Public key :param pkey: Public key
:param agent_auth: attempt to authorize with the user's public keys :param agent_auth: attempt to authorize with the user's public keys
:param no_auth: attempt to log in without any kind of authorization :param no_auth: attempt to log in without any kind of authorization
:param look_for_keys: Look for keys in the same locations as ssh,
if other authentication is not succesful
""" """
credentials = dict(username=username, credentials = dict(username=username,
password=password, password=password,
pkey=pkey) pkey=pkey)
self.credentials = credentials self.credentials = credentials
if encoding is None: if encoding is None:
encoding = "utf8" encoding = "utf8"
self.encoding = encoding self.encoding = encoding
...@@ -112,16 +123,16 @@ class SFTPFS(FS): ...@@ -112,16 +123,16 @@ class SFTPFS(FS):
self._tlocal = thread_local() self._tlocal = thread_local()
self._transport = None self._transport = None
self._client = None self._client = None
self.hostname = None self.hostname = None
if isinstance(connection, basestring): if isinstance(connection, basestring):
self.hostname = connection self.hostname = connection
elif isinstance(connection, tuple): elif isinstance(connection, tuple):
self.hostname = '%s:%s' % connection self.hostname = '%s:%s' % connection
super(SFTPFS, self).__init__() super(SFTPFS, self).__init__()
self.root_path = abspath(normpath(root_path)) self.root_path = abspath(normpath(root_path))
if isinstance(connection,paramiko.Channel): if isinstance(connection,paramiko.Channel):
self._transport = None self._transport = None
self._client = paramiko.SFTPClient(connection) self._client = paramiko.SFTPClient(connection)
...@@ -130,75 +141,111 @@ class SFTPFS(FS): ...@@ -130,75 +141,111 @@ class SFTPFS(FS):
connection = paramiko.Transport(connection) connection = paramiko.Transport(connection)
connection.daemon = True connection.daemon = True
self._owns_transport = True self._owns_transport = True
if hostkey is not None: if hostkey is not None:
key = self.get_remote_server_key() key = self.get_remote_server_key()
if hostkey != key: if hostkey != key:
raise WrongHostKeyError('Host keys do not match') raise WrongHostKeyError('Host keys do not match')
connection.start_client() connection.start_client()
if not connection.is_active(): if not connection.is_active():
raise RemoteConnectionError(msg='Unable to connect') raise RemoteConnectionError(msg='Unable to connect')
if no_auth: if no_auth:
try: try:
connection.auth_none('') connection.auth_none('')
except paramiko.SSHException: except paramiko.SSHException:
pass pass
elif not connection.is_authenticated(): elif not connection.is_authenticated():
if not username: if not username:
username = getuser() username = getuser()
try: try:
if pkey: if pkey:
connection.auth_publickey(username, pkey) connection.auth_publickey(username, pkey)
if not connection.is_authenticated() and password: if not connection.is_authenticated() and password:
connection.auth_password(username, password) connection.auth_password(username, password)
if agent_auth and not connection.is_authenticated(): if agent_auth and not connection.is_authenticated():
self._agent_auth(connection, username) self._agent_auth(connection, username)
if not connection.is_authenticated(): if look_for_keys and not connection.is_authenticated():
try: self._userkeys_auth(connection, username, password)
if not connection.is_authenticated():
try:
connection.auth_none(username) connection.auth_none(username)
except paramiko.BadAuthenticationType, e: except paramiko.BadAuthenticationType, e:
self.close() self.close()
allowed = ', '.join(e.allowed_types) allowed = ', '.join(e.allowed_types)
raise RemoteConnectionError(msg='no auth - server requires one of the following: %s' % allowed, details=e) raise RemoteConnectionError(msg='no auth - server requires one of the following: %s' % allowed, details=e)
if not connection.is_authenticated(): if not connection.is_authenticated():
self.close() self.close()
raise RemoteConnectionError(msg='no auth') raise RemoteConnectionError(msg='no auth')
except paramiko.SSHException, e: except paramiko.SSHException, e:
self.close() self.close()
raise RemoteConnectionError(msg='SSH exception (%s)' % str(e), details=e) raise RemoteConnectionError(msg='SSH exception (%s)' % str(e), details=e)
self._transport = connection self._transport = connection
def __unicode__(self): def __unicode__(self):
return u'<SFTPFS: %s>' % self.desc('/') return u'<SFTPFS: %s>' % self.desc('/')
@classmethod @classmethod
def _agent_auth(cls, transport, username): def _agent_auth(cls, transport, username):
""" """
Attempt to authenticate to the given transport using any of the private Attempt to authenticate to the given transport using any of the private
keys available from an SSH agent. keys available from an SSH agent.
""" """
agent = paramiko.Agent() agent = paramiko.Agent()
agent_keys = agent.get_keys() agent_keys = agent.get_keys()
if not agent_keys: if not agent_keys:
return None return None
for key in agent_keys: for key in agent_keys:
try: try:
transport.auth_publickey(username, key) transport.auth_publickey(username, key)
return key return key
except paramiko.SSHException: except paramiko.SSHException:
pass pass
return None return None
@classmethod
def _userkeys_auth(cls, transport, username, password):
"""
Attempt to authenticate to the given transport using any of the private
keys in the users ~/.ssh and ~/ssh dirs
Derived from http://www.lag.net/paramiko/docs/paramiko.client-pysrc.html
"""
keyfiles = []
rsa_key = os.path.expanduser('~/.ssh/id_rsa')
dsa_key = os.path.expanduser('~/.ssh/id_dsa')
if os.path.isfile(rsa_key):
keyfiles.append((paramiko.rsakey.RSAKey, rsa_key))
if os.path.isfile(dsa_key):
keyfiles.append((paramiko.dsskey.DSSKey, dsa_key))
# look in ~/ssh/ for windows users:
rsa_key = os.path.expanduser('~/ssh/id_rsa')
dsa_key = os.path.expanduser('~/ssh/id_dsa')
if os.path.isfile(rsa_key):
keyfiles.append((paramiko.rsakey.RSAKey, rsa_key))
if os.path.isfile(dsa_key):
keyfiles.append((paramiko.dsskey.DSSKey, dsa_key))
for pkey_class, filename in keyfiles:
key = pkey_class.from_private_key_file(filename, password)
try:
transport.auth_publickey(username, key)
return key
except paramike.SSHException:
pass
return None
def __del__(self): def __del__(self):
self.close() self.close()
...@@ -210,9 +257,9 @@ class SFTPFS(FS): ...@@ -210,9 +257,9 @@ class SFTPFS(FS):
if self._owns_transport: if self._owns_transport:
state['_transport'] = self._transport.getpeername() state['_transport'] = self._transport.getpeername()
return state return state
def __setstate__(self,state): def __setstate__(self,state):
super(SFTPFS, self).__setstate__(state) super(SFTPFS, self).__setstate__(state)
#for (k,v) in state.iteritems(): #for (k,v) in state.iteritems():
# self.__dict__[k] = v # self.__dict__[k] = v
#self._lock = threading.RLock() #self._lock = threading.RLock()
...@@ -261,20 +308,20 @@ class SFTPFS(FS): ...@@ -261,20 +308,20 @@ class SFTPFS(FS):
raise PathError(path,msg="Path is outside root: %(path)s") raise PathError(path,msg="Path is outside root: %(path)s")
return npath return npath
def getpathurl(self, path, allow_none=False): def getpathurl(self, path, allow_none=False):
path = self._normpath(path) path = self._normpath(path)
if self.hostname is None: if self.hostname is None:
if allow_none: if allow_none:
return None return None
raise NoPathURLError(path=path) raise NoPathURLError(path=path)
username = self.credentials.get('username', '') or '' username = self.credentials.get('username', '') or ''
password = self.credentials.get('password', '') or '' password = self.credentials.get('password', '') or ''
credentials = ('%s:%s' % (username, password)).rstrip(':') credentials = ('%s:%s' % (username, password)).rstrip(':')
if credentials: if credentials:
url = 'sftp://%s@%s%s' % (credentials, self.hostname.rstrip('/'), abspath(path)) url = 'sftp://%s@%s%s' % (credentials, self.hostname.rstrip('/'), abspath(path))
else: else:
url = 'sftp://%s%s' % (self.hostname.rstrip('/'), abspath(path)) url = 'sftp://%s%s' % (self.hostname.rstrip('/'), abspath(path))
return url return url
@synchronize @synchronize
@convert_os_errors @convert_os_errors
...@@ -349,14 +396,14 @@ class SFTPFS(FS): ...@@ -349,14 +396,14 @@ class SFTPFS(FS):
@convert_os_errors @convert_os_errors
def listdir(self,path="./",wildcard=None,full=False,absolute=False,dirs_only=False,files_only=False): def listdir(self,path="./",wildcard=None,full=False,absolute=False,dirs_only=False,files_only=False):
npath = self._normpath(path) npath = self._normpath(path)
try: try:
attrs_map = None attrs_map = None
if dirs_only or files_only: if dirs_only or files_only:
attrs = self.client.listdir_attr(npath) attrs = self.client.listdir_attr(npath)
attrs_map = dict((a.filename, a) for a in attrs) attrs_map = dict((a.filename, a) for a in attrs)
paths = list(attrs_map.iterkeys()) paths = list(attrs_map.iterkeys())
else: else:
paths = self.client.listdir(npath) paths = self.client.listdir(npath)
except IOError, e: except IOError, e:
if getattr(e,"errno",None) == 2: if getattr(e,"errno",None) == 2:
if self.isfile(path): if self.isfile(path):
...@@ -364,8 +411,8 @@ class SFTPFS(FS): ...@@ -364,8 +411,8 @@ class SFTPFS(FS):
raise ResourceNotFoundError(path) raise ResourceNotFoundError(path)
elif self.isfile(path): elif self.isfile(path):
raise ResourceInvalidError(path,msg="Can't list directory contents of a file: %(path)s") raise ResourceInvalidError(path,msg="Can't list directory contents of a file: %(path)s")
raise raise
if attrs_map: if attrs_map:
if dirs_only: if dirs_only:
filter_paths = [] filter_paths = []
...@@ -378,22 +425,22 @@ class SFTPFS(FS): ...@@ -378,22 +425,22 @@ class SFTPFS(FS):
for apath, attr in attrs_map.iteritems(): for apath, attr in attrs_map.iteritems():
if isfile(self, apath, attr.__dict__): if isfile(self, apath, attr.__dict__):
filter_paths.append(apath) filter_paths.append(apath)
paths = filter_paths paths = filter_paths
for (i,p) in enumerate(paths): for (i,p) in enumerate(paths):
if not isinstance(p,unicode): if not isinstance(p,unicode):
paths[i] = p.decode(self.encoding) paths[i] = p.decode(self.encoding)
return self._listdir_helper(path, paths, wildcard, full, absolute, False, False) return self._listdir_helper(path, paths, wildcard, full, absolute, False, False)
@synchronize @synchronize
@convert_os_errors @convert_os_errors
def listdirinfo(self,path="./",wildcard=None,full=False,absolute=False,dirs_only=False,files_only=False): def listdirinfo(self,path="./",wildcard=None,full=False,absolute=False,dirs_only=False,files_only=False):
npath = self._normpath(path) npath = self._normpath(path)
try: try:
attrs = self.client.listdir_attr(npath) attrs = self.client.listdir_attr(npath)
attrs_map = dict((a.filename, a) for a in attrs) attrs_map = dict((a.filename, a) for a in attrs)
paths = attrs_map.keys() paths = attrs_map.keys()
except IOError, e: except IOError, e:
if getattr(e,"errno",None) == 2: if getattr(e,"errno",None) == 2:
if self.isfile(path): if self.isfile(path):
...@@ -402,7 +449,7 @@ class SFTPFS(FS): ...@@ -402,7 +449,7 @@ class SFTPFS(FS):
elif self.isfile(path): elif self.isfile(path):
raise ResourceInvalidError(path,msg="Can't list directory contents of a file: %(path)s") raise ResourceInvalidError(path,msg="Can't list directory contents of a file: %(path)s")
raise raise
if dirs_only: if dirs_only:
filter_paths = [] filter_paths = []
for path, attr in attrs_map.iteritems(): for path, attr in attrs_map.iteritems():
...@@ -415,19 +462,19 @@ class SFTPFS(FS): ...@@ -415,19 +462,19 @@ class SFTPFS(FS):
if isfile(self, path, attr.__dict__): if isfile(self, path, attr.__dict__):
filter_paths.append(path) filter_paths.append(path)
paths = filter_paths paths = filter_paths
for (i, p) in enumerate(paths): for (i, p) in enumerate(paths):
if not isinstance(p, unicode): if not isinstance(p, unicode):
paths[i] = p.decode(self.encoding) paths[i] = p.decode(self.encoding)
def getinfo(p): def getinfo(p):
resourcename = basename(p) resourcename = basename(p)
info = attrs_map.get(resourcename) info = attrs_map.get(resourcename)
if info is None: if info is None:
return self.getinfo(pathjoin(path, p)) return self.getinfo(pathjoin(path, p))
return self._extract_info(info.__dict__) return self._extract_info(info.__dict__)
return [(p, getinfo(p)) for p in return [(p, getinfo(p)) for p in
self._listdir_helper(path, paths, wildcard, full, absolute, False, False)] self._listdir_helper(path, paths, wildcard, full, absolute, False, False)]
@synchronize @synchronize
...@@ -492,7 +539,7 @@ class SFTPFS(FS): ...@@ -492,7 +539,7 @@ class SFTPFS(FS):
if self.isfile(path): if self.isfile(path):
raise ResourceInvalidError(path,msg="Can't use removedir() on a file: %(path)s") raise ResourceInvalidError(path,msg="Can't use removedir() on a file: %(path)s")
raise ResourceNotFoundError(path) raise ResourceNotFoundError(path)
elif self.listdir(path): elif self.listdir(path):
raise DirectoryNotEmptyError(path) raise DirectoryNotEmptyError(path)
raise raise
...@@ -556,7 +603,7 @@ class SFTPFS(FS): ...@@ -556,7 +603,7 @@ class SFTPFS(FS):
@classmethod @classmethod
def _extract_info(cls, stats): def _extract_info(cls, stats):
fromtimestamp = datetime.datetime.fromtimestamp fromtimestamp = datetime.datetime.fromtimestamp
info = dict((k, v) for k, v in stats.iteritems() if k in cls._info_vars and not k.startswith('_')) info = dict((k, v) for k, v in stats.iteritems() if k in cls._info_vars and not k.startswith('_'))
info['size'] = info['st_size'] info['size'] = info['st_size']
ct = info.get('st_ctime') ct = info.get('st_ctime')
if ct is not None: if ct is not None:
...@@ -571,10 +618,10 @@ class SFTPFS(FS): ...@@ -571,10 +618,10 @@ class SFTPFS(FS):
@synchronize @synchronize
@convert_os_errors @convert_os_errors
def getinfo(self, path): def getinfo(self, path):
npath = self._normpath(path) npath = self._normpath(path)
stats = self.client.stat(npath) stats = self.client.stat(npath)
info = dict((k, getattr(stats, k)) for k in dir(stats) if not k.startswith('_')) info = dict((k, getattr(stats, k)) for k in dir(stats) if not k.startswith('_'))
info['size'] = info['st_size'] info['size'] = info['st_size']
ct = info.get('st_ctime', None) ct = info.get('st_ctime', None)
if ct is not None: if ct is not None:
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
The `utils` module provides a number of utility functions that don't belong in the Filesystem interface. Generally the functions in this module work with multiple Filesystems, for instance moving and copying between non-similar Filesystems. The `utils` module provides a number of utility functions that don't belong
in the Filesystem interface. Generally the functions in this module work with
multiple Filesystems, for instance moving and copying between non-similar Filesystems.
""" """
...@@ -36,15 +38,15 @@ def copyfile(src_fs, src_path, dst_fs, dst_path, overwrite=True, chunk_size=64*1 ...@@ -36,15 +38,15 @@ def copyfile(src_fs, src_path, dst_fs, dst_path, overwrite=True, chunk_size=64*1
:param chunk_size: Size of chunks to move if system copyfile is not available (default 64K) :param chunk_size: Size of chunks to move if system copyfile is not available (default 64K)
""" """
# If the src and dst fs objects are the same, then use a direct copy # If the src and dst fs objects are the same, then use a direct copy
if src_fs is dst_fs: if src_fs is dst_fs:
src_fs.copy(src_path, dst_path, overwrite=overwrite) src_fs.copy(src_path, dst_path, overwrite=overwrite)
return return
src_syspath = src_fs.getsyspath(src_path, allow_none=True) src_syspath = src_fs.getsyspath(src_path, allow_none=True)
dst_syspath = dst_fs.getsyspath(dst_path, allow_none=True) dst_syspath = dst_fs.getsyspath(dst_path, allow_none=True)
if not overwrite and dst_fs.exists(dst_path): if not overwrite and dst_fs.exists(dst_path):
raise DestinationExistsError(dst_path) raise DestinationExistsError(dst_path)
...@@ -57,44 +59,44 @@ def copyfile(src_fs, src_path, dst_fs, dst_path, overwrite=True, chunk_size=64*1 ...@@ -57,44 +59,44 @@ def copyfile(src_fs, src_path, dst_fs, dst_path, overwrite=True, chunk_size=64*1
if src_lock is not None: if src_lock is not None:
src_lock.acquire() src_lock.acquire()
try: try:
src = None src = None
try: try:
src = src_fs.open(src_path, 'rb') src = src_fs.open(src_path, 'rb')
dst_fs.setcontents(dst_path, src, chunk_size=chunk_size) dst_fs.setcontents(dst_path, src, chunk_size=chunk_size)
finally: finally:
if src is not None: if src is not None:
src.close() src.close()
finally: finally:
if src_lock is not None: if src_lock is not None:
src_lock.release() src_lock.release()
def copyfile_non_atomic(src_fs, src_path, dst_fs, dst_path, overwrite=True, chunk_size=64*1024): def copyfile_non_atomic(src_fs, src_path, dst_fs, dst_path, overwrite=True, chunk_size=64*1024):
"""A non atomic version of copyfile (will not block other threads using src_fs or dst_fst) """A non atomic version of copyfile (will not block other threads using src_fs or dst_fst)
:param src_fs: Source filesystem object :param src_fs: Source filesystem object
:param src_path: -- Source path :param src_path: -- Source path
:param dst_fs: Destination filesystem object :param dst_fs: Destination filesystem object
:param dst_path: Destination filesystem object :param dst_path: Destination filesystem object
:param chunk_size: Size of chunks to move if system copyfile is not available (default 64K) :param chunk_size: Size of chunks to move if system copyfile is not available (default 64K)
""" """
if not overwrite and dst_fs.exists(dst_path): if not overwrite and dst_fs.exists(dst_path):
raise DestinationExistsError(dst_path) raise DestinationExistsError(dst_path)
src = None src = None
dst = None dst = None
try: try:
src = src_fs.open(src_path, 'rb') src = src_fs.open(src_path, 'rb')
dst = dst_fs.open(dst_path, 'wb') dst = dst_fs.open(dst_path, 'wb')
write = dst.write write = dst.write
read = src.read read = src.read
chunk = read(chunk_size) chunk = read(chunk_size)
while chunk: while chunk:
write(chunk) write(chunk)
chunk = read(chunk_size) chunk = read(chunk_size)
finally: finally:
if src is not None: if src is not None:
src.close() src.close()
...@@ -118,14 +120,14 @@ def movefile(src_fs, src_path, dst_fs, dst_path, overwrite=True, chunk_size=64*1 ...@@ -118,14 +120,14 @@ def movefile(src_fs, src_path, dst_fs, dst_path, overwrite=True, chunk_size=64*1
if not overwrite and dst_fs.exists(dst_path): if not overwrite and dst_fs.exists(dst_path):
raise DestinationExistsError(dst_path) raise DestinationExistsError(dst_path)
if src_fs is dst_fs: if src_fs is dst_fs:
src_fs.move(src_path, dst_path, overwrite=overwrite) src_fs.move(src_path, dst_path, overwrite=overwrite)
return return
# System copy if there are two sys paths # System copy if there are two sys paths
if src_syspath is not None and dst_syspath is not None: if src_syspath is not None and dst_syspath is not None:
FS._shutil_movefile(src_syspath, dst_syspath) FS._shutil_movefile(src_syspath, dst_syspath)
return return
src_lock = getattr(src_fs, '_lock', None) src_lock = getattr(src_fs, '_lock', None)
...@@ -134,18 +136,18 @@ def movefile(src_fs, src_path, dst_fs, dst_path, overwrite=True, chunk_size=64*1 ...@@ -134,18 +136,18 @@ def movefile(src_fs, src_path, dst_fs, dst_path, overwrite=True, chunk_size=64*1
src_lock.acquire() src_lock.acquire()
try: try:
src = None src = None
try: try:
# Chunk copy # Chunk copy
src = src_fs.open(src_path, 'rb') src = src_fs.open(src_path, 'rb')
dst_fs.setcontents(dst_path, src, chunk_size=chunk_size) dst_fs.setcontents(dst_path, src, chunk_size=chunk_size)
except: except:
raise raise
else: else:
src_fs.remove(src_path) src_fs.remove(src_path)
finally: finally:
if src is not None: if src is not None:
src.close() src.close()
finally: finally:
if src_lock is not None: if src_lock is not None:
src_lock.release() src_lock.release()
...@@ -160,32 +162,32 @@ def movefile_non_atomic(src_fs, src_path, dst_fs, dst_path, overwrite=True, chun ...@@ -160,32 +162,32 @@ def movefile_non_atomic(src_fs, src_path, dst_fs, dst_path, overwrite=True, chun
:param dst_path: Destination filesystem object :param dst_path: Destination filesystem object
:param chunk_size: Size of chunks to move if system copyfile is not available (default 64K) :param chunk_size: Size of chunks to move if system copyfile is not available (default 64K)
""" """
if not overwrite and dst_fs.exists(dst_path): if not overwrite and dst_fs.exists(dst_path):
raise DestinationExistsError(dst_path) raise DestinationExistsError(dst_path)
src = None src = None
dst = None dst = None
try: try:
# Chunk copy # Chunk copy
src = src_fs.open(src_path, 'rb') src = src_fs.open(src_path, 'rb')
dst = dst_fs.open(dst_path, 'wb') dst = dst_fs.open(dst_path, 'wb')
write = dst.write write = dst.write
read = src.read read = src.read
chunk = read(chunk_size) chunk = read(chunk_size)
while chunk: while chunk:
write(chunk) write(chunk)
chunk = read(chunk_size) chunk = read(chunk_size)
except: except:
raise raise
else: else:
src_fs.remove(src_path) src_fs.remove(src_path)
finally: finally:
if src is not None: if src is not None:
src.close() src.close()
if dst is not None: if dst is not None:
dst.close() dst.close()
def movedir(fs1, fs2, create_destination=True, ignore_errors=False, chunk_size=64*1024): def movedir(fs1, fs2, create_destination=True, ignore_errors=False, chunk_size=64*1024):
...@@ -193,28 +195,27 @@ def movedir(fs1, fs2, create_destination=True, ignore_errors=False, chunk_size=6 ...@@ -193,28 +195,27 @@ def movedir(fs1, fs2, create_destination=True, ignore_errors=False, chunk_size=6
:param fs1: A tuple of (<filesystem>, <directory path>) :param fs1: A tuple of (<filesystem>, <directory path>)
:param fs2: Destination filesystem, or a tuple of (<filesystem>, <directory path>) :param fs2: Destination filesystem, or a tuple of (<filesystem>, <directory path>)
:param create_destination: If True, the destination will be created if it doesn't exist :param create_destination: If True, the destination will be created if it doesn't exist
:param ignore_errors: If True, exceptions from file moves are ignored :param ignore_errors: If True, exceptions from file moves are ignored
:param chunk_size: Size of chunks to move if a simple copy is used :param chunk_size: Size of chunks to move if a simple copy is used
""" """
if not isinstance(fs1, tuple): if not isinstance(fs1, tuple):
raise ValueError("first argument must be a tuple of (<filesystem>, <path>)") raise ValueError("first argument must be a tuple of (<filesystem>, <path>)")
fs1, dir1 = fs1 fs1, dir1 = fs1
parent_fs1 = fs1 parent_fs1 = fs1
parent_dir1 = dir1 parent_dir1 = dir1
fs1 = fs1.opendir(dir1) fs1 = fs1.opendir(dir1)
print fs1
if parent_dir1 in ('', '/'): if parent_dir1 in ('', '/'):
raise RemoveRootError(dir1) raise RemoveRootError(dir1)
if isinstance(fs2, tuple): if isinstance(fs2, tuple):
fs2, dir2 = fs2 fs2, dir2 = fs2
if create_destination: if create_destination:
fs2.makedir(dir2, allow_recreate=True, recursive=True) fs2.makedir(dir2, allow_recreate=True, recursive=True)
fs2 = fs2.opendir(dir2) fs2 = fs2.opendir(dir2)
mount_fs = MountFS(auto_close=False) mount_fs = MountFS(auto_close=False)
mount_fs.mount('src', fs1) mount_fs.mount('src', fs1)
...@@ -223,8 +224,8 @@ def movedir(fs1, fs2, create_destination=True, ignore_errors=False, chunk_size=6 ...@@ -223,8 +224,8 @@ def movedir(fs1, fs2, create_destination=True, ignore_errors=False, chunk_size=6
mount_fs.copydir('src', 'dst', mount_fs.copydir('src', 'dst',
overwrite=True, overwrite=True,
ignore_errors=ignore_errors, ignore_errors=ignore_errors,
chunk_size=chunk_size) chunk_size=chunk_size)
parent_fs1.removedir(parent_dir1, force=True) parent_fs1.removedir(parent_dir1, force=True)
def copydir(fs1, fs2, create_destination=True, ignore_errors=False, chunk_size=64*1024): def copydir(fs1, fs2, create_destination=True, ignore_errors=False, chunk_size=64*1024):
...@@ -244,12 +245,12 @@ def copydir(fs1, fs2, create_destination=True, ignore_errors=False, chunk_size=6 ...@@ -244,12 +245,12 @@ def copydir(fs1, fs2, create_destination=True, ignore_errors=False, chunk_size=6
fs2, dir2 = fs2 fs2, dir2 = fs2
if create_destination: if create_destination:
fs2.makedir(dir2, allow_recreate=True, recursive=True) fs2.makedir(dir2, allow_recreate=True, recursive=True)
fs2 = fs2.opendir(dir2) fs2 = fs2.opendir(dir2)
mount_fs = MountFS(auto_close=False) mount_fs = MountFS(auto_close=False)
mount_fs.mount('src', fs1) mount_fs.mount('src', fs1)
mount_fs.mount('dst', fs2) mount_fs.mount('dst', fs2)
mount_fs.copydir('src', 'dst', mount_fs.copydir('src', 'dst',
overwrite=True, overwrite=True,
ignore_errors=ignore_errors, ignore_errors=ignore_errors,
chunk_size=chunk_size) chunk_size=chunk_size)
...@@ -257,29 +258,29 @@ def copydir(fs1, fs2, create_destination=True, ignore_errors=False, chunk_size=6 ...@@ -257,29 +258,29 @@ def copydir(fs1, fs2, create_destination=True, ignore_errors=False, chunk_size=6
def remove_all(fs, path): def remove_all(fs, path):
"""Remove everything in a directory. Returns True if successful. """Remove everything in a directory. Returns True if successful.
:param fs: A filesystem :param fs: A filesystem
:param path: Path to a directory :param path: Path to a directory
""" """
sub_fs = fs.opendir(path) sub_fs = fs.opendir(path)
for sub_path in sub_fs.listdir(): for sub_path in sub_fs.listdir():
if sub_fs.isdir(sub_path): if sub_fs.isdir(sub_path):
sub_fs.removedir(sub_path, force=True) sub_fs.removedir(sub_path, force=True)
else: else:
sub_fs.remove(sub_path) sub_fs.remove(sub_path)
return fs.isdirempty(path) return fs.isdirempty(path)
def copystructure(src_fs, dst_fs): def copystructure(src_fs, dst_fs):
"""Copies the directory structure from one filesystem to another, so that """Copies the directory structure from one filesystem to another, so that
all directories in `src_fs` will have a corresponding directory in `dst_fs` all directories in `src_fs` will have a corresponding directory in `dst_fs`
:param src_fs: Filesystem to copy structure from :param src_fs: Filesystem to copy structure from
:param dst_fs: Filesystem to copy structure to :param dst_fs: Filesystem to copy structure to
""" """
for path in src_fs.walkdirs(): for path in src_fs.walkdirs():
dst_fs.makedir(path, allow_recreate=True) dst_fs.makedir(path, allow_recreate=True)
...@@ -349,7 +350,7 @@ def find_duplicates(fs, ...@@ -349,7 +350,7 @@ def find_duplicates(fs,
:param signature_size: The total number of bytes read to generate a signature :param signature_size: The total number of bytes read to generate a signature
For example, the following will list all the duplicate .jpg files in "~/Pictures":: For example, the following will list all the duplicate .jpg files in "~/Pictures"::
>>> from fs.utils import find_duplicates >>> from fs.utils import find_duplicates
>>> from fs.osfs import OSFS >>> from fs.osfs import OSFS
>>> fs = OSFS('~/Pictures') >>> fs = OSFS('~/Pictures')
...@@ -457,9 +458,9 @@ def print_fs(fs, ...@@ -457,9 +458,9 @@ def print_fs(fs,
This mostly useful as a debugging aid. This mostly useful as a debugging aid.
Be careful about printing a OSFS, or any other large filesystem. Be careful about printing a OSFS, or any other large filesystem.
Without max_levels set, this function will traverse the entire directory tree. Without max_levels set, this function will traverse the entire directory tree.
For example, the following will print a tree of the files under the current working directory:: For example, the following will print a tree of the files under the current working directory::
>>> from fs.osfs import * >>> from fs.osfs import *
>>> from fs.utils import * >>> from fs.utils import *
>>> fs = OSFS('.') >>> fs = OSFS('.')
...@@ -487,28 +488,28 @@ def print_fs(fs, ...@@ -487,28 +488,28 @@ def print_fs(fs,
terminal_colors = False terminal_colors = False
else: else:
terminal_colors = hasattr(file_out, 'isatty') and file_out.isatty() terminal_colors = hasattr(file_out, 'isatty') and file_out.isatty()
def write(line): def write(line):
file_out.write(line.encode(file_encoding, 'replace')+'\n') file_out.write(line.encode(file_encoding, 'replace')+'\n')
def wrap_prefix(prefix): def wrap_prefix(prefix):
if not terminal_colors: if not terminal_colors:
return prefix return prefix
return '\x1b[32m%s\x1b[0m' % prefix return '\x1b[32m%s\x1b[0m' % prefix
def wrap_dirname(dirname): def wrap_dirname(dirname):
if not terminal_colors: if not terminal_colors:
return dirname return dirname
return '\x1b[1;34m%s\x1b[0m' % dirname return '\x1b[1;34m%s\x1b[0m' % dirname
def wrap_error(msg): def wrap_error(msg):
if not terminal_colors: if not terminal_colors:
return msg return msg
return '\x1b[31m%s\x1b[0m' % msg return '\x1b[31m%s\x1b[0m' % msg
def wrap_filename(fname): def wrap_filename(fname):
if not terminal_colors: if not terminal_colors:
return fname return fname
# if '.' in fname: # if '.' in fname:
# name, ext = os.path.splitext(fname) # name, ext = os.path.splitext(fname)
# fname = '%s\x1b[36m%s\x1b[0m' % (name, ext) # fname = '%s\x1b[36m%s\x1b[0m' % (name, ext)
...@@ -518,7 +519,7 @@ def print_fs(fs, ...@@ -518,7 +519,7 @@ def print_fs(fs,
return fname return fname
dircount = [0] dircount = [0]
filecount = [0] filecount = [0]
def print_dir(fs, path, levels=[]): def print_dir(fs, path, levels=[]):
if file_encoding == 'UTF-8' and terminal_colors: if file_encoding == 'UTF-8' and terminal_colors:
char_vertline = u'│' char_vertline = u'│'
char_newnode = u'├' char_newnode = u'├'
...@@ -529,52 +530,52 @@ def print_fs(fs, ...@@ -529,52 +530,52 @@ def print_fs(fs,
char_newnode = '|' char_newnode = '|'
char_line = '--' char_line = '--'
char_corner = '`' char_corner = '`'
try: try:
dirs = fs.listdir(path, dirs_only=True) dirs = fs.listdir(path, dirs_only=True)
if dirs_only: if dirs_only:
files = [] files = []
else: else:
files = fs.listdir(path, files_only=True, wildcard=files_wildcard) files = fs.listdir(path, files_only=True, wildcard=files_wildcard)
dir_listing = ( [(True, p) for p in dirs] + dir_listing = ( [(True, p) for p in dirs] +
[(False, p) for p in files] ) [(False, p) for p in files] )
except Exception, e: except Exception, e:
prefix = ''.join([(char_vertline + ' ', ' ')[last] for last in levels]) + ' ' prefix = ''.join([(char_vertline + ' ', ' ')[last] for last in levels]) + ' '
write(wrap_prefix(prefix[:-1] + ' ') + wrap_error("unabled to retrieve directory list (%s) ..." % str(e))) write(wrap_prefix(prefix[:-1] + ' ') + wrap_error("unabled to retrieve directory list (%s) ..." % str(e)))
return 0 return 0
if hide_dotfiles: if hide_dotfiles:
dir_listing = [(isdir, p) for isdir, p in dir_listing if not p.startswith('.')] dir_listing = [(isdir, p) for isdir, p in dir_listing if not p.startswith('.')]
if dirs_first: if dirs_first:
dir_listing.sort(key = lambda (isdir, p):(not isdir, p.lower())) dir_listing.sort(key = lambda (isdir, p):(not isdir, p.lower()))
else: else:
dir_listing.sort(key = lambda (isdir, p):p.lower()) dir_listing.sort(key = lambda (isdir, p):p.lower())
for i, (is_dir, item) in enumerate(dir_listing): for i, (is_dir, item) in enumerate(dir_listing):
if is_dir: if is_dir:
dircount[0] += 1 dircount[0] += 1
else: else:
filecount[0] += 1 filecount[0] += 1
is_last_item = (i == len(dir_listing) - 1) is_last_item = (i == len(dir_listing) - 1)
prefix = ''.join([(char_vertline + ' ', ' ')[last] for last in levels]) prefix = ''.join([(char_vertline + ' ', ' ')[last] for last in levels])
if is_last_item: if is_last_item:
prefix += char_corner prefix += char_corner
else: else:
prefix += char_newnode prefix += char_newnode
if is_dir: if is_dir:
write('%s %s' % (wrap_prefix(prefix + char_line), wrap_dirname(item))) write('%s %s' % (wrap_prefix(prefix + char_line), wrap_dirname(item)))
if max_levels is not None and len(levels) + 1 >= max_levels: if max_levels is not None and len(levels) + 1 >= max_levels:
pass pass
#write(wrap_prefix(prefix[:-1] + ' ') + wrap_error('max recursion levels reached')) #write(wrap_prefix(prefix[:-1] + ' ') + wrap_error('max recursion levels reached'))
else: else:
print_dir(fs, pathjoin(path, item), levels[:] + [is_last_item]) print_dir(fs, pathjoin(path, item), levels[:] + [is_last_item])
else: else:
write('%s %s' % (wrap_prefix(prefix + char_line), wrap_filename(item))) write('%s %s' % (wrap_prefix(prefix + char_line), wrap_filename(item)))
return len(dir_listing) return len(dir_listing)
print_dir(fs, path) print_dir(fs, path)
return dircount[0], filecount[0] return dircount[0], filecount[0]
...@@ -585,15 +586,14 @@ if __name__ == "__main__": ...@@ -585,15 +586,14 @@ if __name__ == "__main__":
t1.setcontents("foo", "test") t1.setcontents("foo", "test")
t1.makedir("bar") t1.makedir("bar")
t1.setcontents("bar/baz", "another test") t1.setcontents("bar/baz", "another test")
t1.tree() t1.tree()
t2 = TempFS() t2 = TempFS()
print t2.listdir() print t2.listdir()
movedir(t1, t2) movedir(t1, t2)
print t2.listdir() print t2.listdir()
t1.tree() t1.tree()
t2.tree() t2.tree()
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
fs.wrapfs.hidefs fs.wrapfs.hidefs
================ ================
Removes resources from a directory listing if they match a given set of wildcards Removes resources from a directory listing if they match a given set of wildcards
""" """
...@@ -12,23 +12,24 @@ from fs.errors import ResourceNotFoundError ...@@ -12,23 +12,24 @@ from fs.errors import ResourceNotFoundError
import re import re
import fnmatch import fnmatch
class HideFS(WrapFS): class HideFS(WrapFS):
"""FS wrapper that hides resources if they match a wildcard(s). """FS wrapper that hides resources if they match a wildcard(s).
For example, to hide all pyc file and subversion directories from a filesystem:: For example, to hide all pyc file and subversion directories from a filesystem::
hide_fs = HideFS(my_fs, "*.pyc", ".svn") hide_fs = HideFS(my_fs, "*.pyc", ".svn")
""" """
def __init__(self, wrapped_fs, *hide_wildcards): def __init__(self, wrapped_fs, *hide_wildcards):
self._hide_wildcards = [re.compile(fnmatch.translate(wildcard)) for wildcard in hide_wildcards] self._hide_wildcards = [re.compile(fnmatch.translate(wildcard)) for wildcard in hide_wildcards]
super(HideFS, self).__init__(wrapped_fs) super(HideFS, self).__init__(wrapped_fs)
def _should_hide(self, path): def _should_hide(self, path):
return any(any(wildcard.match(part) for wildcard in self._hide_wildcards) return any(any(wildcard.match(part) for wildcard in self._hide_wildcards)
for part in iteratepath(path)) for part in iteratepath(path))
def _encode(self, path): def _encode(self, path):
if self._should_hide(path): if self._should_hide(path):
raise ResourceNotFoundError(path) raise ResourceNotFoundError(path)
...@@ -42,12 +43,12 @@ class HideFS(WrapFS): ...@@ -42,12 +43,12 @@ class HideFS(WrapFS):
return False return False
return super(HideFS, self).exists(path) return super(HideFS, self).exists(path)
def listdir(self, path="", *args, **kwargs): def listdir(self, path="", *args, **kwargs):
entries = super(HideFS, self).listdir(path, *args, **kwargs) entries = super(HideFS, self).listdir(path, *args, **kwargs)
entries = [entry for entry in entries if not self._should_hide(entry)] entries = [entry for entry in entries if not self._should_hide(entry)]
return entries return entries
if __name__ == "__main__": if __name__ == "__main__":
from fs.osfs import OSFS from fs.osfs import OSFS
hfs = HideFS(OSFS('~/projects/pyfilesystem'), "*.pyc", ".svn") hfs = HideFS(OSFS('~/projects/pyfilesystem'), "*.pyc", ".svn")
hfs.tree() hfs.tree()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment