Commit 619b26c3 by willmcgugan

Fixed tests for ftp

parent a5ffbc28
...@@ -50,6 +50,12 @@ class DummyLock(object): ...@@ -50,6 +50,12 @@ class DummyLock(object):
def release(self): def release(self):
"""Releasing a DummyLock always succeeds.""" """Releasing a DummyLock always succeeds."""
pass pass
def __enter__(self):
pass
def __exit__(self, *args):
pass
def silence_fserrors(f, *args, **kwargs): def silence_fserrors(f, *args, **kwargs):
......
...@@ -12,6 +12,7 @@ import fs ...@@ -12,6 +12,7 @@ import fs
from fs.base import * from fs.base import *
from fs.errors import * from fs.errors import *
from fs.path import pathsplit, abspath, dirname, recursepath, normpath from fs.path import pathsplit, abspath, dirname, recursepath, normpath
from fs.remote import RemoteFileBuffer
from ftplib import FTP, error_perm, error_temp, error_proto, error_reply from ftplib import FTP, error_perm, error_temp, error_proto, error_reply
...@@ -29,6 +30,11 @@ import re ...@@ -29,6 +30,11 @@ import re
from socket import error as socket_error from socket import error as socket_error
from fs.local_functools import wraps from fs.local_functools import wraps
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO
import time import time
import sys import sys
...@@ -528,7 +534,7 @@ class _FTPFile(object): ...@@ -528,7 +534,7 @@ class _FTPFile(object):
def __init__(self, ftpfs, ftp, path, mode): def __init__(self, ftpfs, ftp, path, mode):
if not hasattr(self, '_lock'): if not hasattr(self, '_lock'):
self._lock = threading.RLock() self._lock = threading.RLock()
self.ftpfs = ftpfs self.ftpfs = ftpfs
self.ftp = ftp self.ftp = ftp
self.path = path self.path = path
...@@ -536,18 +542,23 @@ class _FTPFile(object): ...@@ -536,18 +542,23 @@ class _FTPFile(object):
self.read_pos = 0 self.read_pos = 0
self.write_pos = 0 self.write_pos = 0
self.closed = False self.closed = False
self.file_size = None
if 'r' in mode or 'a' in mode: if 'r' in mode or 'a' in mode:
self.file_size = ftpfs.getsize(path) self.file_size = ftpfs.getsize(path)
self.conn = None self.conn = None
path = _encode(path) path = _encode(path)
#self._lock = ftpfs._lock #self._lock = ftpfs._lock
self._start_file(mode, path)
def _start_file(self, mode, path):
self.read_pos = 0
self.write_pos = 0
if 'r' in mode: if 'r' in mode:
self.ftp.voidcmd('TYPE I') self.ftp.voidcmd('TYPE I')
self.conn = ftp.transfercmd('RETR '+path, None) self.conn = self.ftp.transfercmd('RETR '+path, None)
elif 'w' in mode or 'a' in mode: else:#if 'w' in mode or 'a' in mode:
self.ftp.voidcmd('TYPE I') self.ftp.voidcmd('TYPE I')
if 'a' in mode: if 'a' in mode:
self.write_pos = self.file_size self.write_pos = self.file_size
...@@ -607,15 +618,16 @@ class _FTPFile(object): ...@@ -607,15 +618,16 @@ class _FTPFile(object):
def __exit__(self,exc_type,exc_value,traceback): def __exit__(self,exc_type,exc_value,traceback):
self.close() self.close()
@synchronize #@synchronize
def flush(self): def flush(self):
return return
@synchronize
def seek(self, pos, where=fs.SEEK_SET): def seek(self, pos, where=fs.SEEK_SET):
# Ftp doesn't support a real seek, so we close the transfer and resume # Ftp doesn't support a real seek, so we close the transfer and resume
# it at the new position with the REST command # it at the new position with the REST command
# I'm not sure how reliable this method is! # I'm not sure how reliable this method is!
if not self.file_size: if self.file_size is None:
raise ValueError("Seek only works with files open for read") raise ValueError("Seek only works with files open for read")
self._lock.acquire() self._lock.acquire()
...@@ -659,6 +671,33 @@ class _FTPFile(object): ...@@ -659,6 +671,33 @@ class _FTPFile(object):
return self.write_pos return self.write_pos
@synchronize @synchronize
def truncate(self, size=None):
# Inefficient, but I don't know how else to implement this
if size is None:
size = self.tell()
if self.conn is not None:
self.conn.close()
self.close()
read_f = None
try:
read_f = self.ftpfs.open(self.path, 'rb')
data = read_f.read(size)
finally:
if read_f is not None:
read_f.close()
self.ftp = self.ftpfs._open_ftp()
self.mode = 'w'
self.__init__(self.ftpfs, self.ftp, _encode(self.path), self.mode)
#self._start_file(self.mode, self.path)
self.write(data)
if len(data) < size:
self.write('\0' * (size - len(data)))
@synchronize
def close(self): def close(self):
if self.conn is not None: if self.conn is not None:
self.conn.close() self.conn.close()
...@@ -706,7 +745,7 @@ class _FTPFile(object): ...@@ -706,7 +745,7 @@ class _FTPFile(object):
def ftperrors(f): def ftperrors(f):
@wraps(f) @wraps(f)
def deco(self, *args, **kwargs): def deco(self, *args, **kwargs):
self._lock.acquire() self._lock.acquire()
...@@ -747,6 +786,7 @@ class FTPFS(FS): ...@@ -747,6 +786,7 @@ class FTPFS(FS):
'atomic.makedir' : True, 'atomic.makedir' : True,
'atomic.rename' : True, 'atomic.rename' : True,
'atomic.setcontents' : False, 'atomic.setcontents' : False,
'file.read_and_write' : False,
} }
def __init__(self, host='', user='', passwd='', acct='', timeout=_GLOBAL_DEFAULT_TIMEOUT, def __init__(self, host='', user='', passwd='', acct='', timeout=_GLOBAL_DEFAULT_TIMEOUT,
...@@ -967,8 +1007,24 @@ class FTPFS(FS): ...@@ -967,8 +1007,24 @@ class FTPFS(FS):
if 'w' in mode or 'a' in mode: if 'w' in mode or 'a' in mode:
self.clear_dircache(dirname(path)) self.clear_dircache(dirname(path))
ftp = self._open_ftp() ftp = self._open_ftp()
f = _FTPFile(self, ftp, path, mode) f = _FTPFile(self, ftp, path, mode)
return f return f
#remote_f = RemoteFileBuffer(self, path, mode, rfile = f)
#return remote_f
@ftperrors
def setcontents(self, path, data, chunk_size=8192):
if isinstance(data, basestring):
data = StringIO(data)
self.ftp.storbinary('STOR %s' % _encode(normpath(path)), data, blocksize=chunk_size)
@ftperrors
def getcontents(self, path, chunk_size=8192):
if not self.exists(path):
raise ResourceNotFoundError(path=path)
contents = StringIO()
self.ftp.retrbinary('RETR %s' % _encode(normpath(path)), contents.write, blocksize=chunk_size)
return contents.getvalue()
@ftperrors @ftperrors
def exists(self, path): def exists(self, path):
......
...@@ -49,7 +49,7 @@ class OpenerRegistry(object): ...@@ -49,7 +49,7 @@ class OpenerRegistry(object):
) )
(?: (?:
\+(.*?)$ \!(.*?)$
)*$ )*$
''', re.VERBOSE) ''', re.VERBOSE)
......
...@@ -62,7 +62,7 @@ class SFTPFS(FS): ...@@ -62,7 +62,7 @@ class SFTPFS(FS):
} }
def __init__(self, connection, root_path="/", encoding=None, **credentials): def __init__(self, connection, root_path="/", encoding=None, username='', password=None, pkey=None):
"""SFTPFS constructor. """SFTPFS constructor.
The only required argument is 'connection', which must be something The only required argument is 'connection', which must be something
...@@ -75,14 +75,21 @@ class SFTPFS(FS): ...@@ -75,14 +75,21 @@ class SFTPFS(FS):
* a paramiko.Channel instance in "sftp" mode * a paramiko.Channel instance in "sftp" mode
The kwd argument 'root_path' specifies the root directory on the remote The kwd argument 'root_path' specifies the root directory on the remote
machine - access to files outsite this root wil be prevented. Any machine - access to files outsite this root wil be prevented.
other keyword arguments are assumed to be credentials to be used when
connecting the transport.
:param connection: a connection string :param connection: a connection string
:param root_path: The root path to open :param root_path: The root path to open
:param encoding: String encoding of paths (defaults to UTF-8)
:param username: Name of SFTP user
:param password: Password for SFTP user
:param pkey: Public key
""" """
credentials = dict(username=username,
password=password,
pkey=pkey)
if encoding is None: if encoding is None:
encoding = "utf8" encoding = "utf8"
self.encoding = encoding self.encoding = encoding
...@@ -93,15 +100,6 @@ class SFTPFS(FS): ...@@ -93,15 +100,6 @@ class SFTPFS(FS):
self._transport = None self._transport = None
self._client = None self._client = None
hostname = None
if isinstance(connection, basestring):
hostname = connection
else:
try:
hostname, port = connection
except ValueError:
pass
super(SFTPFS, self).__init__() super(SFTPFS, self).__init__()
self.root_path = abspath(normpath(root_path)) self.root_path = abspath(normpath(root_path))
...@@ -112,10 +110,32 @@ class SFTPFS(FS): ...@@ -112,10 +110,32 @@ class SFTPFS(FS):
else: else:
if not isinstance(connection,paramiko.Transport): if not isinstance(connection,paramiko.Transport):
connection = paramiko.Transport(connection) connection = paramiko.Transport(connection)
connection.daemon = True
self._owns_transport = True self._owns_transport = True
if not connection.is_authenticated():
connection.connect(**credentials) if not connection.is_authenticated():
self._transport = connection
try:
connection.start_client()
if pkey:
connection.auth_publickey(username, pkey)
if not connection.is_authenticated() and password:
connection.auth_password(username, password)
if not connection.is_authenticated():
self._agent_auth(connection, username)
if not connection.is_authenticated():
connection.close()
raise RemoteConnectionError('no auth')
except paramiko.SSHException, e:
connection.close()
raise RemoteConnectionError('SSH exception (%s)' % str(e), details=e)
self._transport = connection
@classmethod @classmethod
...@@ -210,6 +230,8 @@ class SFTPFS(FS): ...@@ -210,6 +230,8 @@ class SFTPFS(FS):
@convert_os_errors @convert_os_errors
def exists(self,path): def exists(self,path):
if path in ('', '/'):
return True
npath = self._normpath(path) npath = self._normpath(path)
try: try:
self.client.stat(npath) self.client.stat(npath)
...@@ -221,7 +243,7 @@ class SFTPFS(FS): ...@@ -221,7 +243,7 @@ class SFTPFS(FS):
@convert_os_errors @convert_os_errors
def isdir(self,path): def isdir(self,path):
if path == '/': if path in ('', '/'):
return True return True
npath = self._normpath(path) npath = self._normpath(path)
try: try:
...@@ -246,12 +268,15 @@ class SFTPFS(FS): ...@@ -246,12 +268,15 @@ class SFTPFS(FS):
@convert_os_errors @convert_os_errors
def listdir(self,path="./",wildcard=None,full=False,absolute=False,dirs_only=False,files_only=False): def listdir(self,path="./",wildcard=None,full=False,absolute=False,dirs_only=False,files_only=False):
npath = self._normpath(path) npath = self._normpath(path)
try: try:
paths = self.client.listdir(npath) attrs_map = None
if dirs_only or files_only: if dirs_only or files_only:
path_attrs = self.client.listdir_attr(npath) attrs = self.client.listdir_attr(npath)
attrs_map = dict((a.filename, a) for a in attrs)
paths = attrs_map.keys()
else: else:
path_attrs = None paths = self.client.listdir(npath)
except IOError, e: except IOError, e:
if getattr(e,"errno",None) == 2: if getattr(e,"errno",None) == 2:
if self.isfile(path): if self.isfile(path):
...@@ -261,16 +286,16 @@ class SFTPFS(FS): ...@@ -261,16 +286,16 @@ class SFTPFS(FS):
raise ResourceInvalidError(path,msg="Can't list directory contents of a file: %(path)s") raise ResourceInvalidError(path,msg="Can't list directory contents of a file: %(path)s")
raise raise
if path_attrs is not None: if attrs_map:
if dirs_only: if dirs_only:
filter_paths = [] filter_paths = []
for path, attr in zip(paths, path_attrs): for path, attr in attrs_map.iteritems():
if isdir(self, path, attr.__dict__): if isdir(self, path, attr.__dict__):
filter_paths.append(path) filter_paths.append(path)
paths = filter_paths paths = filter_paths
elif files_only: elif files_only:
filter_paths = [] filter_paths = []
for path, attr in zip(paths, path_attrs): for path, attr in attrs_map.iteritems():
if isfile(self, path, attr.__dict__): if isfile(self, path, attr.__dict__):
filter_paths.append(path) filter_paths.append(path)
paths = filter_paths paths = filter_paths
...@@ -284,10 +309,10 @@ class SFTPFS(FS): ...@@ -284,10 +309,10 @@ class SFTPFS(FS):
@convert_os_errors @convert_os_errors
def listdirinfo(self,path="./",wildcard=None,full=False,absolute=False,dirs_only=False,files_only=False): def listdirinfo(self,path="./",wildcard=None,full=False,absolute=False,dirs_only=False,files_only=False):
npath = self._normpath(path) npath = self._normpath(path)
try: try:
paths = self.client.listdir(npath)
attrs = self.client.listdir_attr(npath) attrs = self.client.listdir_attr(npath)
attrs_map = dict(zip(paths, attrs)) attrs_map = dict((a.filename, a) for a in attrs)
paths = attrs.keys()
except IOError, e: except IOError, e:
if getattr(e,"errno",None) == 2: if getattr(e,"errno",None) == 2:
if self.isfile(path): if self.isfile(path):
...@@ -296,17 +321,16 @@ class SFTPFS(FS): ...@@ -296,17 +321,16 @@ class SFTPFS(FS):
elif self.isfile(path): elif self.isfile(path):
raise ResourceInvalidError(path,msg="Can't list directory contents of a file: %(path)s") raise ResourceInvalidError(path,msg="Can't list directory contents of a file: %(path)s")
raise raise
if dirs_only: if dirs_only:
filter_paths = [] filter_paths = []
for path, attr in zip(paths, attrs): for path, attr in attrs_map.iteritems():
if isdir(self, path, attr.__dict__): if isdir(self, path, attr.__dict__):
filter_paths.append(path) filter_paths.append(path)
paths = filter_paths paths = filter_paths
elif files_only: elif files_only:
filter_paths = [] filter_paths = []
for path, attr in zip(paths, attrs): for path, attr in attrs_map.iteritems():
if isfile(self, path, attr.__dict__): if isfile(self, path, attr.__dict__):
filter_paths.append(path) filter_paths.append(path)
paths = filter_paths paths = filter_paths
......
...@@ -645,15 +645,20 @@ class FSTestCases(object): ...@@ -645,15 +645,20 @@ class FSTestCases(object):
checkcontents("hello","12345") checkcontents("hello","12345")
def test_truncate_to_larger_size(self): def test_truncate_to_larger_size(self):
with self.fs.open("hello","w") as f: with self.fs.open("hello","w") as f:
f.truncate(30) f.truncate(30)
self.assertEquals(self.fs.getsize("hello"),30) self.assertEquals(self.fs.getsize("hello"), 30)
with self.fs.open("hello","r+") as f:
f.seek(25) # Some file systems (FTPFS) don't support both reading and writing
f.write("123456") if self.fs.getmeta('file.read_and_write', True):
with self.fs.open("hello","r") as f: with self.fs.open("hello","r+") as f:
f.seek(25) f.seek(25)
self.assertEquals(f.read(),"123456") f.write("123456")
with self.fs.open("hello","r") as f:
f.seek(25)
self.assertEquals(f.read(),"123456")
def test_with_statement(self): def test_with_statement(self):
# This is a little tricky since 'with' is actually new syntax. # This is a little tricky since 'with' is actually new syntax.
......
...@@ -32,7 +32,7 @@ class TestFTPFS(unittest.TestCase, FSTestCases, ThreadingTestCases): ...@@ -32,7 +32,7 @@ class TestFTPFS(unittest.TestCase, FSTestCases, ThreadingTestCases):
self.ftp_server = subprocess.Popen([sys.executable, abspath(__file__), self.temp_dir, str(use_port)]) self.ftp_server = subprocess.Popen([sys.executable, abspath(__file__), self.temp_dir, str(use_port)])
# Need to sleep to allow ftp server to start # Need to sleep to allow ftp server to start
time.sleep(.2) time.sleep(.1)
self.fs = ftpfs.FTPFS('127.0.0.1', 'user', '12345', port=use_port, timeout=5.0) self.fs = ftpfs.FTPFS('127.0.0.1', 'user', '12345', port=use_port, timeout=5.0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment