Commit 9ace9f0d by rfkelly0

SFTPFS: better thread-safety using a per-thread SFTPClient instance

parent 12ca3e2c
......@@ -11,6 +11,22 @@ import paramiko
from fs.base import *
# SFTPClient appears to not be thread-safe, so we use an instance per thread
if hasattr(threading,"local"):
thread_local = threading.local
else:
class thread_local(object):
def __init__(self):
self._map = {}
def __getattr__(self,attr):
try:
return self._map[(threading.currentThread().ident,attr)]
except KeyError:
raise AttributeError, attr
def __setattr__(self,attr,value):
self._map[(threading.currentThread().ident,attr)] = value
if not hasattr(paramiko.SFTPFile,"__enter__"):
paramiko.SFTPFile.__enter__ = lambda self: self
......@@ -41,17 +57,20 @@ class SFTPFS(FS):
other keyword arguments are assumed to be credentials to be used when
connecting the transport.
"""
self.closed = False
self._owns_transport = False
self._credentials = credentials
self._tlocal = thread_local()
if isinstance(connection,paramiko.Channel):
self.client = paramiko.SFTPClient(connection)
self._transport = None
self._client = paramiko.SFTPClient(connection)
else:
if not isinstance(connection,paramiko.Transport):
connection = paramiko.Transport(connection)
self._owns_transport = True
if not connection.is_authenticated():
connection.connect(**credentials)
self.client = paramiko.SFTPClient.from_transport(connection)
self._transport = connection
self.root = abspath(normpath(root))
def __del__(self):
......@@ -59,28 +78,37 @@ class SFTPFS(FS):
def __getstate__(self):
state = super(SFTPFS,self).__getstate__()
del state["_tlocal"]
if self._owns_transport:
state['client'] = self.client.get_channel().get_transport().getpeername()
state['_transport'] = self._transport.getpeername()
return state
def __setstate__(self,state):
for (k,v) in state.iteritems():
self.__dict__[k] = v
self._tlocal = thread_local()
if self._owns_transport:
t = paramiko.Transport(self.client)
t.connect(**self._credentials)
self.client = paramiko.SFTPClient.from_transport(t)
self._transport = paramiko.Transport(self._transport)
self._transport.connect(**self._credentials)
@property
def client(self):
try:
return self._tlocal.client
except AttributeError:
if self._transport is None:
return self._client
client = paramiko.SFTPClient.from_transport(self._transport)
self._tlocal.client = client
return client
def close(self):
"""Close the connection to the remote server."""
if getattr(self,"client",None):
if self._owns_transport:
t = self.client.get_channel().get_transport()
self.client.close()
t.close()
else:
if not self.closed:
if self.client:
self.client.close()
self.client = None
if self._owns_transport and self._transport:
self._transport.close()
def _normpath(self,path):
npath = pathjoin(self.root,relpath(normpath(path)))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment