Commit 4d6725c4 by William Tisäter

Release thread lock on exceptions

parent 2df4b7b2
......@@ -32,16 +32,11 @@ try:
except ImportError: # pragma: no cover
mmap = None
try:
from StringIO import StringIO
range = xrange # Use xrange for Python 2
except ImportError:
from io import StringIO, BytesIO
from pygeoip import util, const
from pygeoip.const import PY2, PY3
from pygeoip.timezone import time_zone_by_country_and_region
range = xrange if PY2 else range
STANDARD = const.STANDARD
MMAP_CACHE = const.MMAP_CACHE
......@@ -72,10 +67,12 @@ class _GeoIPMetaclass(type):
if not kwargs.get('cache', True):
return super(_GeoIPMetaclass, cls).__call__(*args, **kwargs)
cls._instance_lock.acquire()
if filename not in cls._instances:
cls._instances[filename] = super(_GeoIPMetaclass, cls).__call__(*args, **kwargs)
cls._instance_lock.release()
try:
cls._instance_lock.acquire()
if filename not in cls._instances:
cls._instances[filename] = super(_GeoIPMetaclass, cls).__call__(*args, **kwargs)
finally:
cls._instance_lock.release()
return cls._instances[filename]
......@@ -97,6 +94,7 @@ class GeoIP(object):
@param cache: Used in tests to skip instance caching
@type cache: bool
"""
self._lock = Lock()
self._flags = flags
self._netmask = None
......@@ -114,27 +112,18 @@ class GeoIP(object):
elif self._flags & const.MEMORY_CACHE:
f = codecs.open(filename, 'rb', ENCODING)
self._memory = f.read()
self._fp = self._str_to_fp(self._memory)
self._fp = util.str2fp(self._memory)
self._type = 'MEMORY_CACHE'
f.close()
else:
self._fp = codecs.open(filename, 'rb', ENCODING)
self._type = 'STANDARD'
self._lock = Lock()
self._setup_segments()
@classmethod
def _str_to_fp(cls, data):
"""
Convert bytes data to file handle object
@param data: string data
@type data: str
@return: file handle object
@rtype: StringIO or BytesIO
"""
return BytesIO(bytearray(data, ENCODING)) if PY3 else StringIO(data)
try:
self._lock.acquire()
self._setup_segments()
finally:
self._lock.release()
def _setup_segments(self):
"""
......@@ -161,7 +150,6 @@ class GeoIP(object):
self._recordLength = const.STANDARD_RECORD_LENGTH
self._databaseSegments = const.COUNTRY_BEGIN
self._lock.acquire()
filepos = self._fp.tell()
self._fp.seek(-3, os.SEEK_END)
......@@ -215,7 +203,6 @@ class GeoIP(object):
self._fp.seek(-4, os.SEEK_CUR)
self._fp.seek(filepos, os.SEEK_SET)
self._lock.release()
def _seek_country(self, ipnum):
"""
......@@ -239,10 +226,12 @@ class GeoIP(object):
else:
startIndex = 2 * self._recordLength * offset
readLength = 2 * self._recordLength
self._lock.acquire()
self._fp.seek(startIndex, os.SEEK_SET)
buf = self._fp.read(readLength)
self._lock.release()
try:
self._lock.acquire()
self._fp.seek(startIndex, os.SEEK_SET)
buf = self._fp.read(readLength)
finally:
self._lock.release()
if PY3 and type(buf) is bytes:
buf = buf.decode(ENCODING)
......@@ -280,10 +269,12 @@ class GeoIP(object):
return None
read_length = (2 * self._recordLength - 1) * self._databaseSegments
self._lock.acquire()
self._fp.seek(seek_org + read_length, os.SEEK_SET)
buf = self._fp.read(const.MAX_ORG_RECORD_LENGTH)
self._lock.release()
try:
self._lock.acquire()
self._fp.seek(seek_org + read_length, os.SEEK_SET)
buf = self._fp.read(const.MAX_ORG_RECORD_LENGTH)
finally:
self._lock.release()
if PY3 and type(buf) is bytes:
buf = buf.decode(ENCODING)
......@@ -352,10 +343,12 @@ class GeoIP(object):
return {}
read_length = (2 * self._recordLength - 1) * self._databaseSegments
self._lock.acquire()
self._fp.seek(seek_country + read_length, os.SEEK_SET)
buf = self._fp.read(const.FULL_RECORD_LENGTH)
self._lock.release()
try:
self._lock.acquire()
self._fp.seek(seek_country + read_length, os.SEEK_SET)
buf = self._fp.read(const.FULL_RECORD_LENGTH)
finally:
self._lock.release()
if PY3 and type(buf) is bytes:
buf = buf.decode(ENCODING)
......
......@@ -24,6 +24,13 @@ along with this program. If not, see <http://www.gnu.org/licenses/lgpl.txt>.
import socket
import binascii
try:
from StringIO import StringIO
except ImportError:
from io import StringIO, BytesIO
from pygeoip import const
def ip2long(ip):
"""
......@@ -35,3 +42,15 @@ def ip2long(ip):
return int(binascii.hexlify(socket.inet_aton(ip)), 16)
except socket.error:
return int(binascii.hexlify(socket.inet_pton(socket.AF_INET6, ip)), 16)
def str2fp(data):
"""
Convert bytes data to file handle object
@param data: string data
@type data: str
@return: file handle object
@rtype: StringIO or BytesIO
"""
return BytesIO(bytearray(data, const.ENCODING)) if const.PY3 else StringIO(data)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment