Commit 545855cf by William Tisäter

Fix codecs.open() for Python 3.X

parent 2abe41a3
...@@ -103,22 +103,31 @@ class GeoIP(object): ...@@ -103,22 +103,31 @@ class GeoIP(object):
if self._flags & const.MMAP_CACHE: if self._flags & const.MMAP_CACHE:
f = codecs.open(filename, 'rb', ENCODING) f = codecs.open(filename, 'rb', ENCODING)
access = mmap.ACCESS_READ access = mmap.ACCESS_READ
self._filehandle = mmap.mmap(f.fileno(), 0, access=access) self._fp = mmap.mmap(f.fileno(), 0, access=access)
f.close() f.close()
elif self._flags & const.MEMORY_CACHE: elif self._flags & const.MEMORY_CACHE:
f = codecs.open(filename, 'rb', ENCODING) f = codecs.open(filename, 'rb', ENCODING)
self._memoryBuffer = f.read() self._memory = f.read()
iohandle = BytesIO if PY3 else StringIO self._fp = self._str_to_fp(self._memory)
self._filehandle = iohandle(self._memoryBuffer)
f.close() f.close()
else: else:
self._filehandle = codecs.open(filename, 'rb', ENCODING) self._fp = codecs.open(filename, 'rb', ENCODING)
self._lock = Lock() self._lock = Lock()
self._setup_segments() self._setup_segments()
@classmethod
def _str_to_fp(cls, data):
"""
Convert bytes data to file handle object
@param data: string data
@type data: str
@return: file handle object
@rtype: StringIO or BytesIO
"""
return BytesIO(bytearray(data, ENCODING)) if PY3 else StringIO(data)
def _setup_segments(self): def _setup_segments(self):
""" """
Parses the database file to determine what kind of database is Parses the database file to determine what kind of database is
...@@ -145,12 +154,12 @@ class GeoIP(object): ...@@ -145,12 +154,12 @@ class GeoIP(object):
self._databaseSegments = const.COUNTRY_BEGIN self._databaseSegments = const.COUNTRY_BEGIN
self._lock.acquire() self._lock.acquire()
filepos = self._filehandle.tell() filepos = self._fp.tell()
self._filehandle.seek(-3, os.SEEK_END) self._fp.seek(-3, os.SEEK_END)
for i in range(const.STRUCTURE_INFO_MAX_SIZE): for i in range(const.STRUCTURE_INFO_MAX_SIZE):
chars = chr(255) * 3 chars = chr(255) * 3
delim = self._filehandle.read(3) delim = self._fp.read(3)
if PY3 and type(delim) is bytes: if PY3 and type(delim) is bytes:
delim = delim.decode(ENCODING) delim = delim.decode(ENCODING)
...@@ -161,7 +170,7 @@ class GeoIP(object): ...@@ -161,7 +170,7 @@ class GeoIP(object):
delim = delim.decode(ENCODING) delim = delim.decode(ENCODING)
if delim == chars: if delim == chars:
byte = self._filehandle.read(1) byte = self._fp.read(1)
self._databaseType = ord(byte) self._databaseType = ord(byte)
# Compatibility with databases from April 2003 and earlier # Compatibility with databases from April 2003 and earlier
...@@ -182,7 +191,7 @@ class GeoIP(object): ...@@ -182,7 +191,7 @@ class GeoIP(object):
const.ASNUM_EDITION, const.ASNUM_EDITION,
const.ASNUM_EDITION_V6): const.ASNUM_EDITION_V6):
self._databaseSegments = 0 self._databaseSegments = 0
buf = self._filehandle.read(const.SEGMENT_RECORD_LENGTH) buf = self._fp.read(const.SEGMENT_RECORD_LENGTH)
if PY3 and type(buf) is bytes: if PY3 and type(buf) is bytes:
buf = buf.decode(ENCODING) buf = buf.decode(ENCODING)
...@@ -195,9 +204,9 @@ class GeoIP(object): ...@@ -195,9 +204,9 @@ class GeoIP(object):
self._recordLength = const.ORG_RECORD_LENGTH self._recordLength = const.ORG_RECORD_LENGTH
break break
else: else:
self._filehandle.seek(-4, os.SEEK_CUR) self._fp.seek(-4, os.SEEK_CUR)
self._filehandle.seek(filepos, os.SEEK_SET) self._fp.seek(filepos, os.SEEK_SET)
self._lock.release() self._lock.release()
def _seek_country(self, ipnum): def _seek_country(self, ipnum):
...@@ -218,13 +227,13 @@ class GeoIP(object): ...@@ -218,13 +227,13 @@ class GeoIP(object):
if self._flags & const.MEMORY_CACHE: if self._flags & const.MEMORY_CACHE:
startIndex = 2 * self._recordLength * offset startIndex = 2 * self._recordLength * offset
endIndex = startIndex + (2 * self._recordLength) endIndex = startIndex + (2 * self._recordLength)
buf = self._memoryBuffer[startIndex:endIndex] buf = self._memory[startIndex:endIndex]
else: else:
startIndex = 2 * self._recordLength * offset startIndex = 2 * self._recordLength * offset
readLength = 2 * self._recordLength readLength = 2 * self._recordLength
self._lock.acquire() self._lock.acquire()
self._filehandle.seek(startIndex, os.SEEK_SET) self._fp.seek(startIndex, os.SEEK_SET)
buf = self._filehandle.read(readLength) buf = self._fp.read(readLength)
self._lock.release() self._lock.release()
if PY3 and type(buf) is bytes: if PY3 and type(buf) is bytes:
...@@ -262,8 +271,8 @@ class GeoIP(object): ...@@ -262,8 +271,8 @@ class GeoIP(object):
read_length = (2 * self._recordLength - 1) * self._databaseSegments read_length = (2 * self._recordLength - 1) * self._databaseSegments
self._lock.acquire() self._lock.acquire()
self._filehandle.seek(seek_org + read_length, os.SEEK_SET) self._fp.seek(seek_org + read_length, os.SEEK_SET)
buf = self._filehandle.read(const.MAX_ORG_RECORD_LENGTH) buf = self._fp.read(const.MAX_ORG_RECORD_LENGTH)
self._lock.release() self._lock.release()
if PY3 and type(buf) is bytes: if PY3 and type(buf) is bytes:
...@@ -335,8 +344,8 @@ class GeoIP(object): ...@@ -335,8 +344,8 @@ class GeoIP(object):
read_length = (2 * self._recordLength - 1) * self._databaseSegments read_length = (2 * self._recordLength - 1) * self._databaseSegments
self._lock.acquire() self._lock.acquire()
self._filehandle.seek(seek_country + read_length, os.SEEK_SET) self._fp.seek(seek_country + read_length, os.SEEK_SET)
buf = self._filehandle.read(const.FULL_RECORD_LENGTH) buf = self._fp.read(const.FULL_RECORD_LENGTH)
self._lock.release() self._lock.release()
if PY3 and type(buf) is bytes: if PY3 and type(buf) is bytes:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment