Commit 1583506f by William Tisäter

Merge pull request #53 from appliedsec/lock-fix

Release thread lock on exceptions
parents 0771abc2 f4b195fb
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
### Release 0.3.1 ### Release 0.3.1
* New: MaxMind Netspeed database support * New: MaxMind Netspeed database support
* Fix: Release thread lock on exceptions
Not yet uploaded to PyPi Not yet uploaded to PyPi
......
...@@ -32,16 +32,11 @@ try: ...@@ -32,16 +32,11 @@ try:
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
mmap = None mmap = None
try:
from StringIO import StringIO
range = xrange # Use xrange for Python 2
except ImportError:
from io import StringIO, BytesIO
from pygeoip import util, const from pygeoip import util, const
from pygeoip.const import PY2, PY3 from pygeoip.const import PY2, PY3
from pygeoip.timezone import time_zone_by_country_and_region from pygeoip.timezone import time_zone_by_country_and_region
range = xrange if PY2 else range
STANDARD = const.STANDARD STANDARD = const.STANDARD
MMAP_CACHE = const.MMAP_CACHE MMAP_CACHE = const.MMAP_CACHE
...@@ -72,10 +67,12 @@ class _GeoIPMetaclass(type): ...@@ -72,10 +67,12 @@ class _GeoIPMetaclass(type):
if not kwargs.get('cache', True): if not kwargs.get('cache', True):
return super(_GeoIPMetaclass, cls).__call__(*args, **kwargs) return super(_GeoIPMetaclass, cls).__call__(*args, **kwargs)
cls._instance_lock.acquire() try:
if filename not in cls._instances: cls._instance_lock.acquire()
cls._instances[filename] = super(_GeoIPMetaclass, cls).__call__(*args, **kwargs) if filename not in cls._instances:
cls._instance_lock.release() cls._instances[filename] = super(_GeoIPMetaclass, cls).__call__(*args, **kwargs)
finally:
cls._instance_lock.release()
return cls._instances[filename] return cls._instances[filename]
...@@ -97,6 +94,7 @@ class GeoIP(object): ...@@ -97,6 +94,7 @@ class GeoIP(object):
@param cache: Used in tests to skip instance caching @param cache: Used in tests to skip instance caching
@type cache: bool @type cache: bool
""" """
self._lock = Lock()
self._flags = flags self._flags = flags
self._netmask = None self._netmask = None
...@@ -114,27 +112,18 @@ class GeoIP(object): ...@@ -114,27 +112,18 @@ class GeoIP(object):
elif self._flags & const.MEMORY_CACHE: elif self._flags & const.MEMORY_CACHE:
f = codecs.open(filename, 'rb', ENCODING) f = codecs.open(filename, 'rb', ENCODING)
self._memory = f.read() self._memory = f.read()
self._fp = self._str_to_fp(self._memory) self._fp = util.str2fp(self._memory)
self._type = 'MEMORY_CACHE' self._type = 'MEMORY_CACHE'
f.close() f.close()
else: else:
self._fp = codecs.open(filename, 'rb', ENCODING) self._fp = codecs.open(filename, 'rb', ENCODING)
self._type = 'STANDARD' self._type = 'STANDARD'
self._lock = Lock() try:
self._setup_segments() self._lock.acquire()
self._setup_segments()
@classmethod finally:
def _str_to_fp(cls, data): self._lock.release()
"""
Convert bytes data to file handle object
@param data: string data
@type data: str
@return: file handle object
@rtype: StringIO or BytesIO
"""
return BytesIO(bytearray(data, ENCODING)) if PY3 else StringIO(data)
def _setup_segments(self): def _setup_segments(self):
""" """
...@@ -162,7 +151,6 @@ class GeoIP(object): ...@@ -162,7 +151,6 @@ class GeoIP(object):
self._recordLength = const.STANDARD_RECORD_LENGTH self._recordLength = const.STANDARD_RECORD_LENGTH
self._databaseSegments = const.COUNTRY_BEGIN self._databaseSegments = const.COUNTRY_BEGIN
self._lock.acquire()
filepos = self._fp.tell() filepos = self._fp.tell()
self._fp.seek(-3, os.SEEK_END) self._fp.seek(-3, os.SEEK_END)
...@@ -216,7 +204,6 @@ class GeoIP(object): ...@@ -216,7 +204,6 @@ class GeoIP(object):
self._fp.seek(-4, os.SEEK_CUR) self._fp.seek(-4, os.SEEK_CUR)
self._fp.seek(filepos, os.SEEK_SET) self._fp.seek(filepos, os.SEEK_SET)
self._lock.release()
def _seek_country(self, ipnum): def _seek_country(self, ipnum):
""" """
...@@ -240,10 +227,12 @@ class GeoIP(object): ...@@ -240,10 +227,12 @@ class GeoIP(object):
else: else:
startIndex = 2 * self._recordLength * offset startIndex = 2 * self._recordLength * offset
readLength = 2 * self._recordLength readLength = 2 * self._recordLength
self._lock.acquire() try:
self._fp.seek(startIndex, os.SEEK_SET) self._lock.acquire()
buf = self._fp.read(readLength) self._fp.seek(startIndex, os.SEEK_SET)
self._lock.release() buf = self._fp.read(readLength)
finally:
self._lock.release()
if PY3 and type(buf) is bytes: if PY3 and type(buf) is bytes:
buf = buf.decode(ENCODING) buf = buf.decode(ENCODING)
...@@ -281,10 +270,12 @@ class GeoIP(object): ...@@ -281,10 +270,12 @@ class GeoIP(object):
return None return None
read_length = (2 * self._recordLength - 1) * self._databaseSegments read_length = (2 * self._recordLength - 1) * self._databaseSegments
self._lock.acquire() try:
self._fp.seek(seek_org + read_length, os.SEEK_SET) self._lock.acquire()
buf = self._fp.read(const.MAX_ORG_RECORD_LENGTH) self._fp.seek(seek_org + read_length, os.SEEK_SET)
self._lock.release() buf = self._fp.read(const.MAX_ORG_RECORD_LENGTH)
finally:
self._lock.release()
if PY3 and type(buf) is bytes: if PY3 and type(buf) is bytes:
buf = buf.decode(ENCODING) buf = buf.decode(ENCODING)
...@@ -353,10 +344,12 @@ class GeoIP(object): ...@@ -353,10 +344,12 @@ class GeoIP(object):
return {} return {}
read_length = (2 * self._recordLength - 1) * self._databaseSegments read_length = (2 * self._recordLength - 1) * self._databaseSegments
self._lock.acquire() try:
self._fp.seek(seek_country + read_length, os.SEEK_SET) self._lock.acquire()
buf = self._fp.read(const.FULL_RECORD_LENGTH) self._fp.seek(seek_country + read_length, os.SEEK_SET)
self._lock.release() buf = self._fp.read(const.FULL_RECORD_LENGTH)
finally:
self._lock.release()
if PY3 and type(buf) is bytes: if PY3 and type(buf) is bytes:
buf = buf.decode(ENCODING) buf = buf.decode(ENCODING)
......
...@@ -24,6 +24,13 @@ along with this program. If not, see <http://www.gnu.org/licenses/lgpl.txt>. ...@@ -24,6 +24,13 @@ along with this program. If not, see <http://www.gnu.org/licenses/lgpl.txt>.
import socket import socket
import binascii import binascii
try:
from StringIO import StringIO
except ImportError:
from io import StringIO, BytesIO
from pygeoip import const
def ip2long(ip): def ip2long(ip):
""" """
...@@ -35,3 +42,15 @@ def ip2long(ip): ...@@ -35,3 +42,15 @@ def ip2long(ip):
return int(binascii.hexlify(socket.inet_aton(ip)), 16) return int(binascii.hexlify(socket.inet_aton(ip)), 16)
except socket.error: except socket.error:
return int(binascii.hexlify(socket.inet_pton(socket.AF_INET6, ip)), 16) return int(binascii.hexlify(socket.inet_pton(socket.AF_INET6, ip)), 16)
def str2fp(data):
"""
Convert bytes data to file handle object
@param data: string data
@type data: str
@return: file handle object
@rtype: StringIO or BytesIO
"""
return BytesIO(bytearray(data, const.ENCODING)) if const.PY3 else StringIO(data)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment