Commit 07888d97 by willmcgugan@gmail.com

Applied patch to sftp that searches for ssh keys in default locations

parent dccbc1f3
......@@ -14,7 +14,6 @@ import threading
import os
import paramiko
from getpass import getuser
from binascii import hexlify
from fs.base import *
from fs.path import *
......@@ -34,17 +33,18 @@ else:
class thread_local(object):
def __init__(self):
self._map = {}
def __getattr__(self,attr):
def __getattr__(self, attr):
try:
return self._map[(threading.currentThread().ident,attr)]
return self._map[(threading.currentThread().ident, attr)]
except KeyError:
raise AttributeError, attr
def __setattr__(self,attr,value):
self._map[(threading.currentThread().ident,attr)] = value
raise AttributeError(attr)
def __setattr__(self, attr, value):
self._map[(threading.currentThread().ident, attr)] = value
if not hasattr(paramiko.SFTPFile,"__enter__"):
if not hasattr(paramiko.SFTPFile, "__enter__"):
paramiko.SFTPFile.__enter__ = lambda self: self
paramiko.SFTPFile.__exit__ = lambda self,et,ev,tb: self.close() and False
......@@ -54,7 +54,7 @@ class SFTPFS(FS):
This is basically a compatibility wrapper for the excellent SFTPClient
class in the paramiko module.
"""
_meta = { 'thread_safe' : True,
......@@ -70,8 +70,17 @@ class SFTPFS(FS):
'atomic.setcontents' : False
}
def __init__(self, connection, root_path="/", encoding=None, hostkey=None, username='', password=None, pkey=None, agent_auth=True, no_auth=False):
def __init__(self,
connection,
root_path="/",
encoding=None,
hostkey=None,
username='',
password=None,
pkey=None,
agent_auth=True,
no_auth=False,
look_for_keys=True):
"""SFTPFS constructor.
The only required argument is 'connection', which must be something
......@@ -84,8 +93,8 @@ class SFTPFS(FS):
* a paramiko.Channel instance in "sftp" mode
The keyword argument 'root_path' specifies the root directory on the
remote machine - access to files outside this root will be prevented.
remote machine - access to files outside this root will be prevented.
:param connection: a connection string
:param root_path: The root path to open
:param encoding: String encoding of paths (defaults to UTF-8)
......@@ -95,14 +104,16 @@ class SFTPFS(FS):
:param pkey: Public key
:param agent_auth: attempt to authorize with the user's public keys
:param no_auth: attempt to log in without any kind of authorization
:param look_for_keys: Look for keys in the same locations as ssh,
if other authentication is not succesful
"""
credentials = dict(username=username,
password=password,
pkey=pkey)
self.credentials = credentials
if encoding is None:
encoding = "utf8"
self.encoding = encoding
......@@ -112,16 +123,16 @@ class SFTPFS(FS):
self._tlocal = thread_local()
self._transport = None
self._client = None
self.hostname = None
if isinstance(connection, basestring):
self.hostname = connection
elif isinstance(connection, tuple):
self.hostname = '%s:%s' % connection
super(SFTPFS, self).__init__()
self.hostname = '%s:%s' % connection
super(SFTPFS, self).__init__()
self.root_path = abspath(normpath(root_path))
if isinstance(connection,paramiko.Channel):
self._transport = None
self._client = paramiko.SFTPClient(connection)
......@@ -130,75 +141,111 @@ class SFTPFS(FS):
connection = paramiko.Transport(connection)
connection.daemon = True
self._owns_transport = True
if hostkey is not None:
key = self.get_remote_server_key()
if hostkey != key:
raise WrongHostKeyError('Host keys do not match')
connection.start_client()
if not connection.is_active():
raise RemoteConnectionError(msg='Unable to connect')
if no_auth:
if no_auth:
try:
connection.auth_none('')
except paramiko.SSHException:
except paramiko.SSHException:
pass
elif not connection.is_authenticated():
if not username:
username = getuser()
try:
if pkey:
connection.auth_publickey(username, pkey)
if not connection.is_authenticated() and password:
connection.auth_password(username, password)
username = getuser()
try:
if pkey:
connection.auth_publickey(username, pkey)
if not connection.is_authenticated() and password:
connection.auth_password(username, password)
if agent_auth and not connection.is_authenticated():
self._agent_auth(connection, username)
if not connection.is_authenticated():
try:
self._agent_auth(connection, username)
if look_for_keys and not connection.is_authenticated():
self._userkeys_auth(connection, username, password)
if not connection.is_authenticated():
try:
connection.auth_none(username)
except paramiko.BadAuthenticationType, e:
self.close()
allowed = ', '.join(e.allowed_types)
raise RemoteConnectionError(msg='no auth - server requires one of the following: %s' % allowed, details=e)
raise RemoteConnectionError(msg='no auth - server requires one of the following: %s' % allowed, details=e)
if not connection.is_authenticated():
self.close()
self.close()
raise RemoteConnectionError(msg='no auth')
except paramiko.SSHException, e:
self.close()
self.close()
raise RemoteConnectionError(msg='SSH exception (%s)' % str(e), details=e)
self._transport = connection
def __unicode__(self):
return u'<SFTPFS: %s>' % self.desc('/')
@classmethod
def _agent_auth(cls, transport, username):
"""
Attempt to authenticate to the given transport using any of the private
keys available from an SSH agent.
"""
agent = paramiko.Agent()
agent_keys = agent.get_keys()
if not agent_keys:
return None
for key in agent_keys:
try:
transport.auth_publickey(username, key)
return None
for key in agent_keys:
try:
transport.auth_publickey(username, key)
return key
except paramiko.SSHException:
pass
return None
return None
@classmethod
def _userkeys_auth(cls, transport, username, password):
"""
Attempt to authenticate to the given transport using any of the private
keys in the users ~/.ssh and ~/ssh dirs
Derived from http://www.lag.net/paramiko/docs/paramiko.client-pysrc.html
"""
keyfiles = []
rsa_key = os.path.expanduser('~/.ssh/id_rsa')
dsa_key = os.path.expanduser('~/.ssh/id_dsa')
if os.path.isfile(rsa_key):
keyfiles.append((paramiko.rsakey.RSAKey, rsa_key))
if os.path.isfile(dsa_key):
keyfiles.append((paramiko.dsskey.DSSKey, dsa_key))
# look in ~/ssh/ for windows users:
rsa_key = os.path.expanduser('~/ssh/id_rsa')
dsa_key = os.path.expanduser('~/ssh/id_dsa')
if os.path.isfile(rsa_key):
keyfiles.append((paramiko.rsakey.RSAKey, rsa_key))
if os.path.isfile(dsa_key):
keyfiles.append((paramiko.dsskey.DSSKey, dsa_key))
for pkey_class, filename in keyfiles:
key = pkey_class.from_private_key_file(filename, password)
try:
transport.auth_publickey(username, key)
return key
except paramike.SSHException:
pass
return None
def __del__(self):
self.close()
......@@ -210,9 +257,9 @@ class SFTPFS(FS):
if self._owns_transport:
state['_transport'] = self._transport.getpeername()
return state
def __setstate__(self,state):
super(SFTPFS, self).__setstate__(state)
super(SFTPFS, self).__setstate__(state)
#for (k,v) in state.iteritems():
# self.__dict__[k] = v
#self._lock = threading.RLock()
......@@ -261,20 +308,20 @@ class SFTPFS(FS):
raise PathError(path,msg="Path is outside root: %(path)s")
return npath
def getpathurl(self, path, allow_none=False):
path = self._normpath(path)
def getpathurl(self, path, allow_none=False):
path = self._normpath(path)
if self.hostname is None:
if allow_none:
return None
raise NoPathURLError(path=path)
raise NoPathURLError(path=path)
username = self.credentials.get('username', '') or ''
password = self.credentials.get('password', '') or ''
credentials = ('%s:%s' % (username, password)).rstrip(':')
if credentials:
credentials = ('%s:%s' % (username, password)).rstrip(':')
if credentials:
url = 'sftp://%s@%s%s' % (credentials, self.hostname.rstrip('/'), abspath(path))
else:
url = 'sftp://%s%s' % (self.hostname.rstrip('/'), abspath(path))
return url
url = 'sftp://%s%s' % (self.hostname.rstrip('/'), abspath(path))
return url
@synchronize
@convert_os_errors
......@@ -349,14 +396,14 @@ class SFTPFS(FS):
@convert_os_errors
def listdir(self,path="./",wildcard=None,full=False,absolute=False,dirs_only=False,files_only=False):
npath = self._normpath(path)
try:
try:
attrs_map = None
if dirs_only or files_only:
if dirs_only or files_only:
attrs = self.client.listdir_attr(npath)
attrs_map = dict((a.filename, a) for a in attrs)
paths = list(attrs_map.iterkeys())
paths = list(attrs_map.iterkeys())
else:
paths = self.client.listdir(npath)
paths = self.client.listdir(npath)
except IOError, e:
if getattr(e,"errno",None) == 2:
if self.isfile(path):
......@@ -364,8 +411,8 @@ class SFTPFS(FS):
raise ResourceNotFoundError(path)
elif self.isfile(path):
raise ResourceInvalidError(path,msg="Can't list directory contents of a file: %(path)s")
raise
raise
if attrs_map:
if dirs_only:
filter_paths = []
......@@ -378,22 +425,22 @@ class SFTPFS(FS):
for apath, attr in attrs_map.iteritems():
if isfile(self, apath, attr.__dict__):
filter_paths.append(apath)
paths = filter_paths
paths = filter_paths
for (i,p) in enumerate(paths):
if not isinstance(p,unicode):
paths[i] = p.decode(self.encoding)
paths[i] = p.decode(self.encoding)
return self._listdir_helper(path, paths, wildcard, full, absolute, False, False)
@synchronize
@convert_os_errors
def listdirinfo(self,path="./",wildcard=None,full=False,absolute=False,dirs_only=False,files_only=False):
npath = self._normpath(path)
try:
try:
attrs = self.client.listdir_attr(npath)
attrs_map = dict((a.filename, a) for a in attrs)
paths = attrs_map.keys()
attrs_map = dict((a.filename, a) for a in attrs)
paths = attrs_map.keys()
except IOError, e:
if getattr(e,"errno",None) == 2:
if self.isfile(path):
......@@ -402,7 +449,7 @@ class SFTPFS(FS):
elif self.isfile(path):
raise ResourceInvalidError(path,msg="Can't list directory contents of a file: %(path)s")
raise
if dirs_only:
filter_paths = []
for path, attr in attrs_map.iteritems():
......@@ -415,19 +462,19 @@ class SFTPFS(FS):
if isfile(self, path, attr.__dict__):
filter_paths.append(path)
paths = filter_paths
for (i, p) in enumerate(paths):
if not isinstance(p, unicode):
paths[i] = p.decode(self.encoding)
def getinfo(p):
resourcename = basename(p)
info = attrs_map.get(resourcename)
if info is None:
return self.getinfo(pathjoin(path, p))
return self._extract_info(info.__dict__)
return [(p, getinfo(p)) for p in
return [(p, getinfo(p)) for p in
self._listdir_helper(path, paths, wildcard, full, absolute, False, False)]
@synchronize
......@@ -492,7 +539,7 @@ class SFTPFS(FS):
if self.isfile(path):
raise ResourceInvalidError(path,msg="Can't use removedir() on a file: %(path)s")
raise ResourceNotFoundError(path)
elif self.listdir(path):
raise DirectoryNotEmptyError(path)
raise
......@@ -556,7 +603,7 @@ class SFTPFS(FS):
@classmethod
def _extract_info(cls, stats):
fromtimestamp = datetime.datetime.fromtimestamp
info = dict((k, v) for k, v in stats.iteritems() if k in cls._info_vars and not k.startswith('_'))
info = dict((k, v) for k, v in stats.iteritems() if k in cls._info_vars and not k.startswith('_'))
info['size'] = info['st_size']
ct = info.get('st_ctime')
if ct is not None:
......@@ -571,10 +618,10 @@ class SFTPFS(FS):
@synchronize
@convert_os_errors
def getinfo(self, path):
def getinfo(self, path):
npath = self._normpath(path)
stats = self.client.stat(npath)
info = dict((k, getattr(stats, k)) for k in dir(stats) if not k.startswith('_'))
info = dict((k, getattr(stats, k)) for k in dir(stats) if not k.startswith('_'))
info['size'] = info['st_size']
ct = info.get('st_ctime', None)
if ct is not None:
......
# -*- coding: utf-8 -*-
"""
The `utils` module provides a number of utility functions that don't belong in the Filesystem interface. Generally the functions in this module work with multiple Filesystems, for instance moving and copying between non-similar Filesystems.
The `utils` module provides a number of utility functions that don't belong
in the Filesystem interface. Generally the functions in this module work with
multiple Filesystems, for instance moving and copying between non-similar Filesystems.
"""
......@@ -36,15 +38,15 @@ def copyfile(src_fs, src_path, dst_fs, dst_path, overwrite=True, chunk_size=64*1
:param chunk_size: Size of chunks to move if system copyfile is not available (default 64K)
"""
# If the src and dst fs objects are the same, then use a direct copy
if src_fs is dst_fs:
src_fs.copy(src_path, dst_path, overwrite=overwrite)
return
src_syspath = src_fs.getsyspath(src_path, allow_none=True)
dst_syspath = dst_fs.getsyspath(dst_path, allow_none=True)
if not overwrite and dst_fs.exists(dst_path):
raise DestinationExistsError(dst_path)
......@@ -57,44 +59,44 @@ def copyfile(src_fs, src_path, dst_fs, dst_path, overwrite=True, chunk_size=64*1
if src_lock is not None:
src_lock.acquire()
try:
src = None
try:
src = None
try:
src = src_fs.open(src_path, 'rb')
dst_fs.setcontents(dst_path, src, chunk_size=chunk_size)
dst_fs.setcontents(dst_path, src, chunk_size=chunk_size)
finally:
if src is not None:
src.close()
src.close()
finally:
if src_lock is not None:
src_lock.release()
def copyfile_non_atomic(src_fs, src_path, dst_fs, dst_path, overwrite=True, chunk_size=64*1024):
"""A non atomic version of copyfile (will not block other threads using src_fs or dst_fst)
:param src_fs: Source filesystem object
:param src_path: -- Source path
:param dst_fs: Destination filesystem object
:param dst_path: Destination filesystem object
:param chunk_size: Size of chunks to move if system copyfile is not available (default 64K)
"""
if not overwrite and dst_fs.exists(dst_path):
raise DestinationExistsError(dst_path)
src = None
dst = None
try:
src = src_fs.open(src_path, 'rb')
dst = None
try:
src = src_fs.open(src_path, 'rb')
dst = dst_fs.open(dst_path, 'wb')
write = dst.write
read = src.read
chunk = read(chunk_size)
while chunk:
while chunk:
write(chunk)
chunk = read(chunk_size)
chunk = read(chunk_size)
finally:
if src is not None:
src.close()
......@@ -118,14 +120,14 @@ def movefile(src_fs, src_path, dst_fs, dst_path, overwrite=True, chunk_size=64*1
if not overwrite and dst_fs.exists(dst_path):
raise DestinationExistsError(dst_path)
if src_fs is dst_fs:
src_fs.move(src_path, dst_path, overwrite=overwrite)
return
# System copy if there are two sys paths
if src_syspath is not None and dst_syspath is not None:
FS._shutil_movefile(src_syspath, dst_syspath)
if src_syspath is not None and dst_syspath is not None:
FS._shutil_movefile(src_syspath, dst_syspath)
return
src_lock = getattr(src_fs, '_lock', None)
......@@ -134,18 +136,18 @@ def movefile(src_fs, src_path, dst_fs, dst_path, overwrite=True, chunk_size=64*1
src_lock.acquire()
try:
src = None
src = None
try:
# Chunk copy
src = src_fs.open(src_path, 'rb')
dst_fs.setcontents(dst_path, src, chunk_size=chunk_size)
except:
raise
raise
else:
src_fs.remove(src_path)
src_fs.remove(src_path)
finally:
if src is not None:
src.close()
src.close()
finally:
if src_lock is not None:
src_lock.release()
......@@ -160,32 +162,32 @@ def movefile_non_atomic(src_fs, src_path, dst_fs, dst_path, overwrite=True, chun
:param dst_path: Destination filesystem object
:param chunk_size: Size of chunks to move if system copyfile is not available (default 64K)
"""
"""
if not overwrite and dst_fs.exists(dst_path):
raise DestinationExistsError(dst_path)
src = None
dst = None
dst = None
try:
# Chunk copy
src = src_fs.open(src_path, 'rb')
src = src_fs.open(src_path, 'rb')
dst = dst_fs.open(dst_path, 'wb')
write = dst.write
read = src.read
chunk = read(chunk_size)
while chunk:
while chunk:
write(chunk)
chunk = read(chunk_size)
except:
raise
else:
src_fs.remove(src_path)
src_fs.remove(src_path)
finally:
if src is not None:
src.close()
if dst is not None:
dst.close()
dst.close()
def movedir(fs1, fs2, create_destination=True, ignore_errors=False, chunk_size=64*1024):
......@@ -193,28 +195,27 @@ def movedir(fs1, fs2, create_destination=True, ignore_errors=False, chunk_size=6
:param fs1: A tuple of (<filesystem>, <directory path>)
:param fs2: Destination filesystem, or a tuple of (<filesystem>, <directory path>)
:param create_destination: If True, the destination will be created if it doesn't exist
:param create_destination: If True, the destination will be created if it doesn't exist
:param ignore_errors: If True, exceptions from file moves are ignored
:param chunk_size: Size of chunks to move if a simple copy is used
"""
if not isinstance(fs1, tuple):
raise ValueError("first argument must be a tuple of (<filesystem>, <path>)")
fs1, dir1 = fs1
parent_fs1 = fs1
parent_dir1 = dir1
parent_dir1 = dir1
fs1 = fs1.opendir(dir1)
print fs1
if parent_dir1 in ('', '/'):
raise RemoveRootError(dir1)
if isinstance(fs2, tuple):
fs2, dir2 = fs2
if create_destination:
if create_destination:
fs2.makedir(dir2, allow_recreate=True, recursive=True)
fs2 = fs2.opendir(dir2)
fs2 = fs2.opendir(dir2)
mount_fs = MountFS(auto_close=False)
mount_fs.mount('src', fs1)
......@@ -223,8 +224,8 @@ def movedir(fs1, fs2, create_destination=True, ignore_errors=False, chunk_size=6
mount_fs.copydir('src', 'dst',
overwrite=True,
ignore_errors=ignore_errors,
chunk_size=chunk_size)
parent_fs1.removedir(parent_dir1, force=True)
chunk_size=chunk_size)
parent_fs1.removedir(parent_dir1, force=True)
def copydir(fs1, fs2, create_destination=True, ignore_errors=False, chunk_size=64*1024):
......@@ -244,12 +245,12 @@ def copydir(fs1, fs2, create_destination=True, ignore_errors=False, chunk_size=6
fs2, dir2 = fs2
if create_destination:
fs2.makedir(dir2, allow_recreate=True, recursive=True)
fs2 = fs2.opendir(dir2)
fs2 = fs2.opendir(dir2)
mount_fs = MountFS(auto_close=False)
mount_fs.mount('src', fs1)
mount_fs.mount('dst', fs2)
mount_fs.copydir('src', 'dst',
mount_fs.copydir('src', 'dst',
overwrite=True,
ignore_errors=ignore_errors,
chunk_size=chunk_size)
......@@ -257,29 +258,29 @@ def copydir(fs1, fs2, create_destination=True, ignore_errors=False, chunk_size=6
def remove_all(fs, path):
"""Remove everything in a directory. Returns True if successful.
:param fs: A filesystem
:param path: Path to a directory
"""
sub_fs = fs.opendir(path)
for sub_path in sub_fs.listdir():
"""
sub_fs = fs.opendir(path)
for sub_path in sub_fs.listdir():
if sub_fs.isdir(sub_path):
sub_fs.removedir(sub_path, force=True)
else:
sub_fs.remove(sub_path)
return fs.isdirempty(path)
def copystructure(src_fs, dst_fs):
"""Copies the directory structure from one filesystem to another, so that
all directories in `src_fs` will have a corresponding directory in `dst_fs`
:param src_fs: Filesystem to copy structure from
:param dst_fs: Filesystem to copy structure to
"""
for path in src_fs.walkdirs():
dst_fs.makedir(path, allow_recreate=True)
......@@ -349,7 +350,7 @@ def find_duplicates(fs,
:param signature_size: The total number of bytes read to generate a signature
For example, the following will list all the duplicate .jpg files in "~/Pictures"::
>>> from fs.utils import find_duplicates
>>> from fs.osfs import OSFS
>>> fs = OSFS('~/Pictures')
......@@ -457,9 +458,9 @@ def print_fs(fs,
This mostly useful as a debugging aid.
Be careful about printing a OSFS, or any other large filesystem.
Without max_levels set, this function will traverse the entire directory tree.
For example, the following will print a tree of the files under the current working directory::
>>> from fs.osfs import *
>>> from fs.utils import *
>>> fs = OSFS('.')
......@@ -487,28 +488,28 @@ def print_fs(fs,
terminal_colors = False
else:
terminal_colors = hasattr(file_out, 'isatty') and file_out.isatty()
def write(line):
file_out.write(line.encode(file_encoding, 'replace')+'\n')
def wrap_prefix(prefix):
if not terminal_colors:
return prefix
return '\x1b[32m%s\x1b[0m' % prefix
return prefix
return '\x1b[32m%s\x1b[0m' % prefix
def wrap_dirname(dirname):
if not terminal_colors:
return dirname
return '\x1b[1;34m%s\x1b[0m' % dirname
def wrap_error(msg):
if not terminal_colors:
return msg
return '\x1b[31m%s\x1b[0m' % msg
def wrap_filename(fname):
if not terminal_colors:
return fname
return fname
# if '.' in fname:
# name, ext = os.path.splitext(fname)
# fname = '%s\x1b[36m%s\x1b[0m' % (name, ext)
......@@ -518,7 +519,7 @@ def print_fs(fs,
return fname
dircount = [0]
filecount = [0]
def print_dir(fs, path, levels=[]):
def print_dir(fs, path, levels=[]):
if file_encoding == 'UTF-8' and terminal_colors:
char_vertline = u'│'
char_newnode = u'├'
......@@ -529,52 +530,52 @@ def print_fs(fs,
char_newnode = '|'
char_line = '--'
char_corner = '`'
try:
try:
dirs = fs.listdir(path, dirs_only=True)
if dirs_only:
files = []
else:
files = fs.listdir(path, files_only=True, wildcard=files_wildcard)
files = fs.listdir(path, files_only=True, wildcard=files_wildcard)
dir_listing = ( [(True, p) for p in dirs] +
[(False, p) for p in files] )
except Exception, e:
prefix = ''.join([(char_vertline + ' ', ' ')[last] for last in levels]) + ' '
write(wrap_prefix(prefix[:-1] + ' ') + wrap_error("unabled to retrieve directory list (%s) ..." % str(e)))
return 0
return 0
if hide_dotfiles:
dir_listing = [(isdir, p) for isdir, p in dir_listing if not p.startswith('.')]
if dirs_first:
dir_listing.sort(key = lambda (isdir, p):(not isdir, p.lower()))
else:
dir_listing.sort(key = lambda (isdir, p):p.lower())
for i, (is_dir, item) in enumerate(dir_listing):
if is_dir:
dircount[0] += 1
else:
filecount[0] += 1
is_last_item = (i == len(dir_listing) - 1)
is_last_item = (i == len(dir_listing) - 1)
prefix = ''.join([(char_vertline + ' ', ' ')[last] for last in levels])
if is_last_item:
prefix += char_corner
else:
prefix += char_newnode
if is_dir:
if is_dir:
write('%s %s' % (wrap_prefix(prefix + char_line), wrap_dirname(item)))
if max_levels is not None and len(levels) + 1 >= max_levels:
pass
#write(wrap_prefix(prefix[:-1] + ' ') + wrap_error('max recursion levels reached'))
else:
print_dir(fs, pathjoin(path, item), levels[:] + [is_last_item])
print_dir(fs, pathjoin(path, item), levels[:] + [is_last_item])
else:
write('%s %s' % (wrap_prefix(prefix + char_line), wrap_filename(item)))
return len(dir_listing)
print_dir(fs, path)
return dircount[0], filecount[0]
......@@ -585,15 +586,14 @@ if __name__ == "__main__":
t1.setcontents("foo", "test")
t1.makedir("bar")
t1.setcontents("bar/baz", "another test")
t1.tree()
t2 = TempFS()
print t2.listdir()
t2 = TempFS()
print t2.listdir()
movedir(t1, t2)
print t2.listdir()
t1.tree()
t2.tree()
......@@ -2,7 +2,7 @@
fs.wrapfs.hidefs
================
Removes resources from a directory listing if they match a given set of wildcards
Removes resources from a directory listing if they match a given set of wildcards
"""
......@@ -12,23 +12,24 @@ from fs.errors import ResourceNotFoundError
import re
import fnmatch
class HideFS(WrapFS):
"""FS wrapper that hides resources if they match a wildcard(s).
For example, to hide all pyc file and subversion directories from a filesystem::
hide_fs = HideFS(my_fs, "*.pyc", ".svn")
"""
def __init__(self, wrapped_fs, *hide_wildcards):
def __init__(self, wrapped_fs, *hide_wildcards):
self._hide_wildcards = [re.compile(fnmatch.translate(wildcard)) for wildcard in hide_wildcards]
super(HideFS, self).__init__(wrapped_fs)
def _should_hide(self, path):
def _should_hide(self, path):
return any(any(wildcard.match(part) for wildcard in self._hide_wildcards)
for part in iteratepath(path))
def _encode(self, path):
if self._should_hide(path):
raise ResourceNotFoundError(path)
......@@ -42,12 +43,12 @@ class HideFS(WrapFS):
return False
return super(HideFS, self).exists(path)
def listdir(self, path="", *args, **kwargs):
entries = super(HideFS, self).listdir(path, *args, **kwargs)
entries = [entry for entry in entries if not self._should_hide(entry)]
def listdir(self, path="", *args, **kwargs):
entries = super(HideFS, self).listdir(path, *args, **kwargs)
entries = [entry for entry in entries if not self._should_hide(entry)]
return entries
if __name__ == "__main__":
from fs.osfs import OSFS
hfs = HideFS(OSFS('~/projects/pyfilesystem'), "*.pyc", ".svn")
hfs.tree()
\ No newline at end of file
hfs.tree()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment