Commit 5e4af47e by William Tisäter

Optimize and refactor internal seek functions

parent e0a63175
......@@ -89,12 +89,12 @@ class GeoIP(GeoIPBase):
"""
Initialize the class.
@param filename: path to a geoip database. If MEMORY_CACHE is used,
@param filename: Path to a geoip database. If MEMORY_CACHE is used,
the file can be gzipped.
@type filename: str
@param flags: flags that affect how the database is processed.
Currently the only supported flags are STANDARD (the default),
MEMORY_CACHE (preload the whole file into memory), and
@param flags: Flags that affect how the database is processed.
Currently supported flags are STANDARD (the default),
MEMORY_CACHE (preload the whole file into memory) and
MMAP_CACHE (access the file via mmap).
@type flags: int
"""
......@@ -103,7 +103,8 @@ class GeoIP(GeoIPBase):
if self._flags & const.MMAP_CACHE:
with open(filename, 'rb') as f:
self._filehandle = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
access = mmap.ACCESS_READ
self._filehandle = mmap.mmap(f.fileno(), 0, access=access)
elif self._flags & const.MEMORY_CACHE:
if filename.endswith('.gz'):
......@@ -115,37 +116,48 @@ class GeoIP(GeoIPBase):
self._memoryBuffer = f.read()
self._filehandle = StringIO(self._memoryBuffer)
else:
self._filehandle = codecs.open(filename, 'rb','latin_1')
self._filehandle = codecs.open(filename, 'rb', 'latin_1')
self._setup_segments()
def _setup_segments(self):
"""
Parses the database file to determine what kind of database is being used and setup
segment sizes and start points that will be used by the seek*() methods later.
Parses the database file to determine what kind of database is
being used and setup segment sizes and start points that will
be used by the seek*() methods later.
Supported databases:
COUNTRY_EDITION
REGION_EDITION_REV0
REGION_EDITION_REV1
CITY_EDITION_REV0
CITY_EDITION_REV1
ORG_EDITION
ISP_EDITION
ASNUM_EDITION
"""
self._databaseType = const.COUNTRY_EDITION
self._recordLength = const.STANDARD_RECORD_LENGTH
self._databaseSegments = const.COUNTRY_BEGIN
filepos = self._filehandle.tell()
self._filehandle.seek(-3, os.SEEK_END)
for i in range(const.STRUCTURE_INFO_MAX_SIZE):
delim = self._filehandle.read(3)
if delim == six.u(chr(255) * 3):
if self._filehandle.read(3) == six.u(chr(255) * 3):
self._databaseType = ord(self._filehandle.read(1))
# Backwards compatibility with databases from
# April 2003 and earlier
if (self._databaseType >= 106):
# backwards compatibility with databases from April 2003 and earlier
self._databaseType -= 105
if self._databaseType == const.REGION_EDITION_REV0:
self._databaseSegments = const.STATE_BEGIN_REV0
elif self._databaseType == const.REGION_EDITION_REV1:
self._databaseSegments = const.STATE_BEGIN_REV1
elif self._databaseType in (const.CITY_EDITION_REV0,
const.CITY_EDITION_REV1,
const.ORG_EDITION,
......@@ -153,20 +165,15 @@ class GeoIP(GeoIPBase):
const.ASNUM_EDITION):
self._databaseSegments = 0
buf = self._filehandle.read(const.SEGMENT_RECORD_LENGTH)
for j in range(const.SEGMENT_RECORD_LENGTH):
self._databaseSegments += (ord(buf[j]) << (j * 8))
if self._databaseType in (const.ORG_EDITION, const.ISP_EDITION):
LONG_RECORDS = (const.ORG_EDITION, const.ISP_EDITION)
if self._databaseType in LONG_RECORDS:
self._recordLength = const.ORG_RECORD_LENGTH
break
else:
self._filehandle.seek(-4, os.SEEK_CUR)
if self._databaseType == const.COUNTRY_EDITION:
self._databaseSegments = const.COUNTRY_BEGIN
self._filehandle.seek(filepos, os.SEEK_SET)
def _seek_country(self, ipnum):
......@@ -180,105 +187,92 @@ class GeoIP(GeoIPBase):
@rtype: int
"""
offset = 0
for depth in range(31, -1, -1):
if self._flags & const.MEMORY_CACHE:
startIndex = 2 * self._recordLength * offset
length = 2 * self._recordLength
endIndex = startIndex + length
endIndex = startIndex + (2 * self._recordLength)
buf = self._memoryBuffer[startIndex:endIndex]
else:
self._filehandle.seek(2 * self._recordLength * offset, os.SEEK_SET)
buf = self._filehandle.read(2 * self._recordLength)
x = [0,0]
startIndex = 2 * self._recordLength * offset
readLength = 2 * self._recordLength
self._filehandle.seek(startIndex, os.SEEK_SET)
buf = self._filehandle.read(readLength)
x = [0, 0]
for i in range(2):
for j in range(self._recordLength):
x[i] += ord(buf[self._recordLength * i + j]) << (j * 8)
if ipnum & (1 << depth):
if x[1] >= self._databaseSegments:
return x[1]
offset = x[1]
else:
if x[0] >= self._databaseSegments:
return x[0]
offset = x[0]
raise Exception('Error traversing database - perhaps it is corrupt?')
def _get_org(self, ipnum):
"""
Seek and return organization (or ISP) name for converted IP addr.
Seek and return organization or ISP name for ipnum.
@param ipnum: Converted IP address
@type ipnum: int
@return: org/isp name
@rtype: str
"""
seek_org = self._seek_country(ipnum)
if seek_org == self._databaseSegments:
return None
record_pointer = seek_org + (2 * self._recordLength - 1) * self._databaseSegments
self._filehandle.seek(record_pointer, os.SEEK_SET)
read_length = (2 * self._recordLength - 1) * self._databaseSegments
self._filehandle.seek(seek_org + read_length, os.SEEK_SET)
org_buf = self._filehandle.read(const.MAX_ORG_RECORD_LENGTH)
return org_buf[:org_buf.index(chr(0))]
def _get_region(self, ipnum):
"""
Seek and return the region info (dict containing country_code and region_name).
Seek and return the region info (dict containing country_code
and region_name).
@param ipnum: converted IP address
@param ipnum: Converted IP address
@type ipnum: int
@return: dict containing country_code and region_name
@rtype: dict
"""
country_code = ''
region = ''
country_code = ''
seek_country = self._seek_country(ipnum)
def get_region_name(offset):
region1 = chr(offset // 26 + 65)
region2 = chr(offset % 26 + 65)
return ''.join([region1, region2])
if self._databaseType == const.REGION_EDITION_REV0:
seek_country = self._seek_country(ipnum)
seek_region = seek_country - const.STATE_BEGIN_REV0
if seek_region >= 1000:
country_code = 'US'
region = ''.join([chr((seek_region - 1000) // 26 + 65), chr((seek_region - 1000) % 26 + 65)])
region = get_region_name(seek_region - 1000)
else:
country_code = const.COUNTRY_CODES[seek_region]
region = ''
elif self._databaseType == const.REGION_EDITION_REV1:
seek_country = self._seek_country(ipnum)
seek_region = seek_country - const.STATE_BEGIN_REV1
if seek_region < const.US_OFFSET:
country_code = '';
region = ''
pass
elif seek_region < const.CANADA_OFFSET:
country_code = 'US'
region = ''.join([chr((seek_region - const.US_OFFSET) // 26 + 65), chr((seek_region - const.US_OFFSET) % 26 + 65)])
region = get_region_name(seek_region - const.US_OFFSET)
elif seek_region < const.WORLD_OFFSET:
country_code = 'CA'
region = ''.join([chr((seek_region - const.CANADA_OFFSET) // 26 + 65), chr((seek_region - const.CANADA_OFFSET) % 26 + 65)])
region = get_region_name(seek_region - const.CANADA_OFFSET)
else:
i = (seek_region - const.WORLD_OFFSET) // const.FIPS_RANGE
if i < len(const.COUNTRY_CODES):
#country_code = const.COUNTRY_CODES[(seek_region - const.WORLD_OFFSET) // const.FIPS_RANGE]
country_code = const.COUNTRY_CODES[i]
else:
country_code = ''
region = ''
elif self._databaseType in (const.CITY_EDITION_REV0, const.CITY_EDITION_REV1):
index = (seek_region - const.WORLD_OFFSET) // const.FIPS_RANGE
if index in const.COUNTRY_CODES:
country_code = const.COUNTRY_CODES[index]
elif self._databaseType in const.CITY_EDITIONS:
rec = self._get_record(ipnum)
country_code = rec['country_code'] if 'country_code' in rec else ''
region = rec['region_name'] if 'region_name' in rec else ''
......@@ -289,7 +283,7 @@ class GeoIP(GeoIPBase):
"""
Populate location dict for converted IP.
@param ipnum: converted IP address
@param ipnum: Converted IP address
@type ipnum: int
@return: dict with country_code, country_code3, country_name,
region, city, postal_code, latitude, longitude,
......@@ -300,102 +294,73 @@ class GeoIP(GeoIPBase):
if seek_country == self._databaseSegments:
return None
record_pointer = seek_country + (2 * self._recordLength - 1) * self._databaseSegments
self._filehandle.seek(record_pointer, os.SEEK_SET)
read_length = (2 * self._recordLength - 1) * self._databaseSegments
self._filehandle.seek(seek_country + read_length, os.SEEK_SET)
record_buf = self._filehandle.read(const.FULL_RECORD_LENGTH)
record = {}
record = {
'dma_code': 0,
'area_code': 0,
'metro_code': '',
'postal_code': ''
}
latitude = 0
longitude = 0
record_buf_pos = 0
# Get country
char = ord(record_buf[record_buf_pos])
#char = record_buf[record_buf_pos] if six.PY3 else ord(record_buf[record_buf_pos])
record['country_code'] = const.COUNTRY_CODES[char]
record['country_code3'] = const.COUNTRY_CODES3[char]
record['country_name'] = const.COUNTRY_NAMES[char]
record_buf_pos += 1
str_length = 0
# get region
char = ord(record_buf[record_buf_pos+str_length])
while (char != 0):
str_length += 1
char = ord(record_buf[record_buf_pos+str_length])
if str_length > 0:
record['region_name'] = record_buf[record_buf_pos:record_buf_pos+str_length]
record_buf_pos += str_length + 1
str_length = 0
# get city
char = ord(record_buf[record_buf_pos+str_length])
while (char != 0):
str_length += 1
char = ord(record_buf[record_buf_pos+str_length])
if str_length > 0:
record['city'] = record_buf[record_buf_pos:record_buf_pos+str_length]
else:
record['city'] = ''
record_buf_pos += str_length + 1
str_length = 0
def get_data(record_buf, record_buf_pos):
offset = record_buf_pos
char = ord(record_buf[offset])
while (char != 0):
offset += 1
char = ord(record_buf[offset])
if offset > record_buf_pos:
return (offset, record_buf[record_buf_pos:offset])
return (offset, '')
offset, record['region_name'] = get_data(record_buf, record_buf_pos)
offset, record['city'] = get_data(record_buf, offset + 1)
offset, record['postal_code'] = get_data(record_buf, offset + 1)
record_buf_pos = offset + 1
# get the postal code
char = ord(record_buf[record_buf_pos+str_length])
while (char != 0):
str_length += 1
char = ord(record_buf[record_buf_pos+str_length])
if str_length > 0:
record['postal_code'] = record_buf[record_buf_pos:record_buf_pos+str_length]
else:
record['postal_code'] = None
record_buf_pos += str_length + 1
str_length = 0
latitude = 0
longitude = 0
for j in range(3):
char = ord(record_buf[record_buf_pos])
record_buf_pos += 1
latitude += (char << (j * 8))
record['latitude'] = (latitude/10000.0) - 180.0
for j in range(3):
char = ord(record_buf[record_buf_pos])
record_buf_pos += 1
longitude += (char << (j * 8))
record['longitude'] = (longitude/10000.0) - 180.0
record['latitude'] = (latitude / 10000.0) - 180.0
record['longitude'] = (longitude / 10000.0) - 180.0
if self._databaseType == const.CITY_EDITION_REV1:
dmaarea_combo = 0
if record['country_code'] == 'US':
for j in range(3):
char = ord(record_buf[record_buf_pos])
record_buf_pos += 1
dmaarea_combo += (char << (j*8))
record_buf_pos += 1
record['dma_code'] = int(math.floor(dmaarea_combo/1000))
record['area_code'] = dmaarea_combo%1000
else:
record['dma_code'] = 0
record['area_code'] = 0
record['dma_code'] = int(math.floor(dmaarea_combo / 1000))
record['area_code'] = dmaarea_combo % 1000
if 'dma_code' in record and record['dma_code'] in const.DMA_MAP:
if record['dma_code'] in const.DMA_MAP:
record['metro_code'] = const.DMA_MAP[record['dma_code']]
else:
record['metro_code'] = ''
if 'country_code' in record:
record['time_zone'] = time_zone_by_country_and_region(
record['country_code'], record.get('region_name')) or ''
else:
record['time_zone'] = ''
params = (record['country_code'], record['region_name'])
record['time_zone'] = time_zone_by_country_and_region(*params)
return record
......
......@@ -362,11 +362,16 @@ CITY_EDITION_REV0 = 6
CITY_EDITION_REV1 = 2
ORG_EDITION = 5
ISP_EDITION = 4
PROXY_EDITION = 8
ASNUM_EDITION = 9
# Not yet supported databases
PROXY_EDITION = 8
NETSPEED_EDITION = 11
COUNTRY_EDITION_V6 = 12
# Collection of databases
CITY_EDITIONS = (CITY_EDITION_REV0, CITY_EDITION_REV1)
REGION_EDITIONS = (REGION_EDITION_REV0, REGION_EDITION_REV1)
SEGMENT_RECORD_LENGTH = 3
STANDARD_RECORD_LENGTH = 3
ORG_RECORD_LENGTH = 4
......
......@@ -699,17 +699,17 @@ _country["ZW"] = "Africa/Harare"
def time_zone_by_country_and_region(country_code, region_name=None):
if country_code not in _country:
return None
return ''
if not region_name or region_name == '00':
region_name = None
timezones = _country[country_code]
if isinstance(timezones, str):
return timezones
if region_name:
return timezones.get(region_name)
if not region_name:
return ''
return timezones.get(region_name)
......@@ -40,7 +40,8 @@ class TestGeoIPCityFunctions(unittest.TestCase):
'longitude': -0.23339999999998895,
'country_code3': 'GBR',
'latitude': 51.283299999999997,
'postal_code': None, 'dma_code': 0,
'postal_code': '',
'dma_code': 0,
'country_code': 'GB',
'country_name': 'United Kingdom',
'time_zone': 'Europe/London'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment