Commit ee47991d by willmcgugan@gmail.com

Added closing context manager to files returned by open

parent 0dc13500
...@@ -8,6 +8,7 @@ A FS object that represents the contents of a Zip file ...@@ -8,6 +8,7 @@ A FS object that represents the contents of a Zip file
import datetime import datetime
import os.path import os.path
from contextlib import closing
from fs.base import * from fs.base import *
from fs.path import * from fs.path import *
...@@ -32,7 +33,6 @@ class ZipNotFoundError(CreateFailedError): ...@@ -32,7 +33,6 @@ class ZipNotFoundError(CreateFailedError):
class _TempWriteFile(object): class _TempWriteFile(object):
"""Proxies a file object and calls a callback when the file is closed.""" """Proxies a file object and calls a callback when the file is closed."""
def __init__(self, fs, filename, close_callback): def __init__(self, fs, filename, close_callback):
...@@ -50,13 +50,18 @@ class _TempWriteFile(object): ...@@ -50,13 +50,18 @@ class _TempWriteFile(object):
def close(self): def close(self):
self._file.close() self._file.close()
self.close_callback(self.filename) self.close_callback(self.filename)
def flush(self): def flush(self):
self._file.flush() self._file.flush()
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()
class _ExceptionProxy(object):
class _ExceptionProxy(object):
"""A placeholder for an object that may no longer be used.""" """A placeholder for an object that may no longer be used."""
def __getattr__(self, name): def __getattr__(self, name):
...@@ -70,9 +75,8 @@ class _ExceptionProxy(object): ...@@ -70,9 +75,8 @@ class _ExceptionProxy(object):
class ZipFS(FS): class ZipFS(FS):
"""A FileSystem that represents a zip file.""" """A FileSystem that represents a zip file."""
_meta = { 'thread_safe' : True, _meta = { 'thread_safe' : True,
'virtual' : False, 'virtual' : False,
'read_only' : False, 'read_only' : False,
...@@ -107,27 +111,27 @@ class ZipFS(FS): ...@@ -107,27 +111,27 @@ class ZipFS(FS):
raise ValueError("mode must be 'r', 'w' or 'a'") raise ValueError("mode must be 'r', 'w' or 'a'")
self.zip_mode = mode self.zip_mode = mode
self.encoding = encoding self.encoding = encoding
if isinstance(zip_file, basestring): if isinstance(zip_file, basestring):
zip_file = os.path.expanduser(os.path.expandvars(zip_file)) zip_file = os.path.expanduser(os.path.expandvars(zip_file))
zip_file = os.path.normpath(os.path.abspath(zip_file)) zip_file = os.path.normpath(os.path.abspath(zip_file))
self._zip_file_string = True self._zip_file_string = True
else: else:
self._zip_file_string = False self._zip_file_string = False
try: try:
self.zf = ZipFile(zip_file, mode, compression_type, allow_zip_64) self.zf = ZipFile(zip_file, mode, compression_type, allow_zip_64)
except BadZipfile, bzf: except BadZipfile, bzf:
raise ZipOpenError("Not a zip file or corrupt (%s)" % str(zip_file), raise ZipOpenError("Not a zip file or corrupt (%s)" % str(zip_file),
details=bzf) details=bzf)
except IOError, ioe: except IOError, ioe:
if str(ioe).startswith('[Errno 22] Invalid argument'): if str(ioe).startswith('[Errno 22] Invalid argument'):
raise ZipOpenError("Not a zip file or corrupt (%s)" % str(zip_file), raise ZipOpenError("Not a zip file or corrupt (%s)" % str(zip_file),
details=ioe) details=ioe)
raise ZipNotFoundError("Zip file not found (%s)" % str(zip_file), raise ZipNotFoundError("Zip file not found (%s)" % str(zip_file),
details=ioe) details=ioe)
self.zip_path = str(zip_file) self.zip_path = str(zip_file)
self.temp_fs = None self.temp_fs = None
if mode in 'wa': if mode in 'wa':
...@@ -136,8 +140,8 @@ class ZipFS(FS): ...@@ -136,8 +140,8 @@ class ZipFS(FS):
self._path_fs = MemoryFS() self._path_fs = MemoryFS()
if mode in 'ra': if mode in 'ra':
self._parse_resource_list() self._parse_resource_list()
self.read_only = mode == 'r' self.read_only = mode == 'r'
def __str__(self): def __str__(self):
return "<ZipFS: %s>" % self.zip_path return "<ZipFS: %s>" % self.zip_path
...@@ -149,7 +153,7 @@ class ZipFS(FS): ...@@ -149,7 +153,7 @@ class ZipFS(FS):
if PY3: if PY3:
return path return path
return path.decode(self.encoding) return path.decode(self.encoding)
def _encode_path(self, path): def _encode_path(self, path):
if PY3: if PY3:
return path return path
...@@ -172,11 +176,10 @@ class ZipFS(FS): ...@@ -172,11 +176,10 @@ class ZipFS(FS):
f = self._path_fs.open(path, 'w') f = self._path_fs.open(path, 'w')
f.close() f.close()
def getmeta(self, meta_name, default=NoDefaultMeta): def getmeta(self, meta_name, default=NoDefaultMeta):
if meta_name == 'read_only': if meta_name == 'read_only':
return self.read_only return self.read_only
return super(ZipFS, self).getmeta(meta_name, default) return super(ZipFS, self).getmeta(meta_name, default)
def close(self): def close(self):
"""Finalizes the zip file so that it can be read. """Finalizes the zip file so that it can be read.
...@@ -187,8 +190,8 @@ class ZipFS(FS): ...@@ -187,8 +190,8 @@ class ZipFS(FS):
self.zf = _ExceptionProxy() self.zf = _ExceptionProxy()
@synchronize @synchronize
def open(self, path, mode="r", **kwargs): def open(self, path, mode="r", **kwargs):
path = normpath(relpath(path)) path = normpath(relpath(path))
if 'r' in mode: if 'r' in mode:
if self.zip_mode not in 'ra': if self.zip_mode not in 'ra':
...@@ -202,10 +205,10 @@ class ZipFS(FS): ...@@ -202,10 +205,10 @@ class ZipFS(FS):
contents = self.zf.read(self._encode_path(path)) contents = self.zf.read(self._encode_path(path))
except KeyError: except KeyError:
raise ResourceNotFoundError(path) raise ResourceNotFoundError(path)
return StringIO(contents) return closing(StringIO(contents))
if 'w' in mode: if 'w' in mode:
if self.zip_mode not in 'wa': if self.zip_mode not in 'wa':
raise OperationFailedError("open file", raise OperationFailedError("open file",
path=path, path=path,
msg="2 Zip file must be opened for writing ('w') or appending ('a')") msg="2 Zip file must be opened for writing ('w') or appending ('a')")
...@@ -215,7 +218,6 @@ class ZipFS(FS): ...@@ -215,7 +218,6 @@ class ZipFS(FS):
self._add_resource(path) self._add_resource(path)
f = _TempWriteFile(self.temp_fs, path, self._on_write_close) f = _TempWriteFile(self.temp_fs, path, self._on_write_close)
return f return f
raise ValueError("Mode must contain be 'r' or 'w'") raise ValueError("Mode must contain be 'r' or 'w'")
...@@ -238,8 +240,8 @@ class ZipFS(FS): ...@@ -238,8 +240,8 @@ class ZipFS(FS):
sys_path = self.temp_fs.getsyspath(filename) sys_path = self.temp_fs.getsyspath(filename)
self.zf.write(sys_path, self._encode_path(filename)) self.zf.write(sys_path, self._encode_path(filename))
def desc(self, path): def desc(self, path):
return "%s in zip file %s" % (path, self.zip_path) return "%s in zip file %s" % (path, self.zip_path)
def isdir(self, path): def isdir(self, path):
return self._path_fs.isdir(path) return self._path_fs.isdir(path)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment