Commit 9ace9f0d by rfkelly0

SFTPFS: better thread-safety using a per-thread SFTPClient instance

parent 12ca3e2c
...@@ -11,6 +11,22 @@ import paramiko ...@@ -11,6 +11,22 @@ import paramiko
from fs.base import * from fs.base import *
# SFTPClient appears to not be thread-safe, so we use an instance per thread
if hasattr(threading,"local"):
thread_local = threading.local
else:
class thread_local(object):
def __init__(self):
self._map = {}
def __getattr__(self,attr):
try:
return self._map[(threading.currentThread().ident,attr)]
except KeyError:
raise AttributeError, attr
def __setattr__(self,attr,value):
self._map[(threading.currentThread().ident,attr)] = value
if not hasattr(paramiko.SFTPFile,"__enter__"): if not hasattr(paramiko.SFTPFile,"__enter__"):
paramiko.SFTPFile.__enter__ = lambda self: self paramiko.SFTPFile.__enter__ = lambda self: self
...@@ -41,17 +57,20 @@ class SFTPFS(FS): ...@@ -41,17 +57,20 @@ class SFTPFS(FS):
other keyword arguments are assumed to be credentials to be used when other keyword arguments are assumed to be credentials to be used when
connecting the transport. connecting the transport.
""" """
self.closed = False
self._owns_transport = False self._owns_transport = False
self._credentials = credentials self._credentials = credentials
self._tlocal = thread_local()
if isinstance(connection,paramiko.Channel): if isinstance(connection,paramiko.Channel):
self.client = paramiko.SFTPClient(connection) self._transport = None
self._client = paramiko.SFTPClient(connection)
else: else:
if not isinstance(connection,paramiko.Transport): if not isinstance(connection,paramiko.Transport):
connection = paramiko.Transport(connection) connection = paramiko.Transport(connection)
self._owns_transport = True self._owns_transport = True
if not connection.is_authenticated(): if not connection.is_authenticated():
connection.connect(**credentials) connection.connect(**credentials)
self.client = paramiko.SFTPClient.from_transport(connection) self._transport = connection
self.root = abspath(normpath(root)) self.root = abspath(normpath(root))
def __del__(self): def __del__(self):
...@@ -59,28 +78,37 @@ class SFTPFS(FS): ...@@ -59,28 +78,37 @@ class SFTPFS(FS):
def __getstate__(self): def __getstate__(self):
state = super(SFTPFS,self).__getstate__() state = super(SFTPFS,self).__getstate__()
del state["_tlocal"]
if self._owns_transport: if self._owns_transport:
state['client'] = self.client.get_channel().get_transport().getpeername() state['_transport'] = self._transport.getpeername()
return state return state
def __setstate__(self,state): def __setstate__(self,state):
for (k,v) in state.iteritems(): for (k,v) in state.iteritems():
self.__dict__[k] = v self.__dict__[k] = v
self._tlocal = thread_local()
if self._owns_transport: if self._owns_transport:
t = paramiko.Transport(self.client) self._transport = paramiko.Transport(self._transport)
t.connect(**self._credentials) self._transport.connect(**self._credentials)
self.client = paramiko.SFTPClient.from_transport(t)
@property
def client(self):
try:
return self._tlocal.client
except AttributeError:
if self._transport is None:
return self._client
client = paramiko.SFTPClient.from_transport(self._transport)
self._tlocal.client = client
return client
def close(self): def close(self):
"""Close the connection to the remote server.""" """Close the connection to the remote server."""
if getattr(self,"client",None): if not self.closed:
if self._owns_transport: if self.client:
t = self.client.get_channel().get_transport()
self.client.close()
t.close()
else:
self.client.close() self.client.close()
self.client = None if self._owns_transport and self._transport:
self._transport.close()
def _normpath(self,path): def _normpath(self,path):
npath = pathjoin(self.root,relpath(normpath(path))) npath = pathjoin(self.root,relpath(normpath(path)))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment