Commit 5e4af47e by William Tisäter

Optimize and refactor internal seek functions

parent e0a63175
...@@ -89,12 +89,12 @@ class GeoIP(GeoIPBase): ...@@ -89,12 +89,12 @@ class GeoIP(GeoIPBase):
""" """
Initialize the class. Initialize the class.
@param filename: path to a geoip database. If MEMORY_CACHE is used, @param filename: Path to a geoip database. If MEMORY_CACHE is used,
the file can be gzipped. the file can be gzipped.
@type filename: str @type filename: str
@param flags: flags that affect how the database is processed. @param flags: Flags that affect how the database is processed.
Currently the only supported flags are STANDARD (the default), Currently supported flags are STANDARD (the default),
MEMORY_CACHE (preload the whole file into memory), and MEMORY_CACHE (preload the whole file into memory) and
MMAP_CACHE (access the file via mmap). MMAP_CACHE (access the file via mmap).
@type flags: int @type flags: int
""" """
...@@ -103,7 +103,8 @@ class GeoIP(GeoIPBase): ...@@ -103,7 +103,8 @@ class GeoIP(GeoIPBase):
if self._flags & const.MMAP_CACHE: if self._flags & const.MMAP_CACHE:
with open(filename, 'rb') as f: with open(filename, 'rb') as f:
self._filehandle = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) access = mmap.ACCESS_READ
self._filehandle = mmap.mmap(f.fileno(), 0, access=access)
elif self._flags & const.MEMORY_CACHE: elif self._flags & const.MEMORY_CACHE:
if filename.endswith('.gz'): if filename.endswith('.gz'):
...@@ -115,37 +116,48 @@ class GeoIP(GeoIPBase): ...@@ -115,37 +116,48 @@ class GeoIP(GeoIPBase):
self._memoryBuffer = f.read() self._memoryBuffer = f.read()
self._filehandle = StringIO(self._memoryBuffer) self._filehandle = StringIO(self._memoryBuffer)
else: else:
self._filehandle = codecs.open(filename, 'rb','latin_1') self._filehandle = codecs.open(filename, 'rb', 'latin_1')
self._setup_segments() self._setup_segments()
def _setup_segments(self): def _setup_segments(self):
""" """
Parses the database file to determine what kind of database is being used and setup Parses the database file to determine what kind of database is
segment sizes and start points that will be used by the seek*() methods later. being used and setup segment sizes and start points that will
be used by the seek*() methods later.
Supported databases:
COUNTRY_EDITION
REGION_EDITION_REV0
REGION_EDITION_REV1
CITY_EDITION_REV0
CITY_EDITION_REV1
ORG_EDITION
ISP_EDITION
ASNUM_EDITION
""" """
self._databaseType = const.COUNTRY_EDITION self._databaseType = const.COUNTRY_EDITION
self._recordLength = const.STANDARD_RECORD_LENGTH self._recordLength = const.STANDARD_RECORD_LENGTH
self._databaseSegments = const.COUNTRY_BEGIN
filepos = self._filehandle.tell() filepos = self._filehandle.tell()
self._filehandle.seek(-3, os.SEEK_END) self._filehandle.seek(-3, os.SEEK_END)
for i in range(const.STRUCTURE_INFO_MAX_SIZE): for i in range(const.STRUCTURE_INFO_MAX_SIZE):
delim = self._filehandle.read(3) if self._filehandle.read(3) == six.u(chr(255) * 3):
if delim == six.u(chr(255) * 3):
self._databaseType = ord(self._filehandle.read(1)) self._databaseType = ord(self._filehandle.read(1))
# Backwards compatibility with databases from
# April 2003 and earlier
if (self._databaseType >= 106): if (self._databaseType >= 106):
# backwards compatibility with databases from April 2003 and earlier
self._databaseType -= 105 self._databaseType -= 105
if self._databaseType == const.REGION_EDITION_REV0: if self._databaseType == const.REGION_EDITION_REV0:
self._databaseSegments = const.STATE_BEGIN_REV0 self._databaseSegments = const.STATE_BEGIN_REV0
elif self._databaseType == const.REGION_EDITION_REV1: elif self._databaseType == const.REGION_EDITION_REV1:
self._databaseSegments = const.STATE_BEGIN_REV1 self._databaseSegments = const.STATE_BEGIN_REV1
elif self._databaseType in (const.CITY_EDITION_REV0, elif self._databaseType in (const.CITY_EDITION_REV0,
const.CITY_EDITION_REV1, const.CITY_EDITION_REV1,
const.ORG_EDITION, const.ORG_EDITION,
...@@ -153,20 +165,15 @@ class GeoIP(GeoIPBase): ...@@ -153,20 +165,15 @@ class GeoIP(GeoIPBase):
const.ASNUM_EDITION): const.ASNUM_EDITION):
self._databaseSegments = 0 self._databaseSegments = 0
buf = self._filehandle.read(const.SEGMENT_RECORD_LENGTH) buf = self._filehandle.read(const.SEGMENT_RECORD_LENGTH)
for j in range(const.SEGMENT_RECORD_LENGTH): for j in range(const.SEGMENT_RECORD_LENGTH):
self._databaseSegments += (ord(buf[j]) << (j * 8)) self._databaseSegments += (ord(buf[j]) << (j * 8))
if self._databaseType in (const.ORG_EDITION, const.ISP_EDITION): LONG_RECORDS = (const.ORG_EDITION, const.ISP_EDITION)
if self._databaseType in LONG_RECORDS:
self._recordLength = const.ORG_RECORD_LENGTH self._recordLength = const.ORG_RECORD_LENGTH
break break
else: else:
self._filehandle.seek(-4, os.SEEK_CUR) self._filehandle.seek(-4, os.SEEK_CUR)
if self._databaseType == const.COUNTRY_EDITION:
self._databaseSegments = const.COUNTRY_BEGIN
self._filehandle.seek(filepos, os.SEEK_SET) self._filehandle.seek(filepos, os.SEEK_SET)
def _seek_country(self, ipnum): def _seek_country(self, ipnum):
...@@ -180,105 +187,92 @@ class GeoIP(GeoIPBase): ...@@ -180,105 +187,92 @@ class GeoIP(GeoIPBase):
@rtype: int @rtype: int
""" """
offset = 0 offset = 0
for depth in range(31, -1, -1): for depth in range(31, -1, -1):
if self._flags & const.MEMORY_CACHE: if self._flags & const.MEMORY_CACHE:
startIndex = 2 * self._recordLength * offset startIndex = 2 * self._recordLength * offset
length = 2 * self._recordLength endIndex = startIndex + (2 * self._recordLength)
endIndex = startIndex + length
buf = self._memoryBuffer[startIndex:endIndex] buf = self._memoryBuffer[startIndex:endIndex]
else: else:
self._filehandle.seek(2 * self._recordLength * offset, os.SEEK_SET) startIndex = 2 * self._recordLength * offset
buf = self._filehandle.read(2 * self._recordLength) readLength = 2 * self._recordLength
self._filehandle.seek(startIndex, os.SEEK_SET)
x = [0,0] buf = self._filehandle.read(readLength)
x = [0, 0]
for i in range(2): for i in range(2):
for j in range(self._recordLength): for j in range(self._recordLength):
x[i] += ord(buf[self._recordLength * i + j]) << (j * 8) x[i] += ord(buf[self._recordLength * i + j]) << (j * 8)
if ipnum & (1 << depth): if ipnum & (1 << depth):
if x[1] >= self._databaseSegments: if x[1] >= self._databaseSegments:
return x[1] return x[1]
offset = x[1] offset = x[1]
else: else:
if x[0] >= self._databaseSegments: if x[0] >= self._databaseSegments:
return x[0] return x[0]
offset = x[0] offset = x[0]
raise Exception('Error traversing database - perhaps it is corrupt?') raise Exception('Error traversing database - perhaps it is corrupt?')
def _get_org(self, ipnum): def _get_org(self, ipnum):
""" """
Seek and return organization (or ISP) name for converted IP addr. Seek and return organization or ISP name for ipnum.
@param ipnum: Converted IP address @param ipnum: Converted IP address
@type ipnum: int @type ipnum: int
@return: org/isp name @return: org/isp name
@rtype: str @rtype: str
""" """
seek_org = self._seek_country(ipnum) seek_org = self._seek_country(ipnum)
if seek_org == self._databaseSegments: if seek_org == self._databaseSegments:
return None return None
record_pointer = seek_org + (2 * self._recordLength - 1) * self._databaseSegments read_length = (2 * self._recordLength - 1) * self._databaseSegments
self._filehandle.seek(seek_org + read_length, os.SEEK_SET)
self._filehandle.seek(record_pointer, os.SEEK_SET)
org_buf = self._filehandle.read(const.MAX_ORG_RECORD_LENGTH) org_buf = self._filehandle.read(const.MAX_ORG_RECORD_LENGTH)
return org_buf[:org_buf.index(chr(0))] return org_buf[:org_buf.index(chr(0))]
def _get_region(self, ipnum): def _get_region(self, ipnum):
""" """
Seek and return the region info (dict containing country_code and region_name). Seek and return the region info (dict containing country_code
and region_name).
@param ipnum: converted IP address @param ipnum: Converted IP address
@type ipnum: int @type ipnum: int
@return: dict containing country_code and region_name @return: dict containing country_code and region_name
@rtype: dict @rtype: dict
""" """
country_code = ''
region = '' region = ''
country_code = ''
seek_country = self._seek_country(ipnum)
def get_region_name(offset):
region1 = chr(offset // 26 + 65)
region2 = chr(offset % 26 + 65)
return ''.join([region1, region2])
if self._databaseType == const.REGION_EDITION_REV0: if self._databaseType == const.REGION_EDITION_REV0:
seek_country = self._seek_country(ipnum)
seek_region = seek_country - const.STATE_BEGIN_REV0 seek_region = seek_country - const.STATE_BEGIN_REV0
if seek_region >= 1000: if seek_region >= 1000:
country_code = 'US' country_code = 'US'
region = ''.join([chr((seek_region - 1000) // 26 + 65), chr((seek_region - 1000) % 26 + 65)]) region = get_region_name(seek_region - 1000)
else: else:
country_code = const.COUNTRY_CODES[seek_region] country_code = const.COUNTRY_CODES[seek_region]
region = ''
elif self._databaseType == const.REGION_EDITION_REV1: elif self._databaseType == const.REGION_EDITION_REV1:
seek_country = self._seek_country(ipnum)
seek_region = seek_country - const.STATE_BEGIN_REV1 seek_region = seek_country - const.STATE_BEGIN_REV1
if seek_region < const.US_OFFSET: if seek_region < const.US_OFFSET:
country_code = ''; pass
region = ''
elif seek_region < const.CANADA_OFFSET: elif seek_region < const.CANADA_OFFSET:
country_code = 'US' country_code = 'US'
region = ''.join([chr((seek_region - const.US_OFFSET) // 26 + 65), chr((seek_region - const.US_OFFSET) % 26 + 65)]) region = get_region_name(seek_region - const.US_OFFSET)
elif seek_region < const.WORLD_OFFSET: elif seek_region < const.WORLD_OFFSET:
country_code = 'CA' country_code = 'CA'
region = ''.join([chr((seek_region - const.CANADA_OFFSET) // 26 + 65), chr((seek_region - const.CANADA_OFFSET) % 26 + 65)]) region = get_region_name(seek_region - const.CANADA_OFFSET)
else: else:
i = (seek_region - const.WORLD_OFFSET) // const.FIPS_RANGE index = (seek_region - const.WORLD_OFFSET) // const.FIPS_RANGE
if i < len(const.COUNTRY_CODES): if index in const.COUNTRY_CODES:
#country_code = const.COUNTRY_CODES[(seek_region - const.WORLD_OFFSET) // const.FIPS_RANGE] country_code = const.COUNTRY_CODES[index]
country_code = const.COUNTRY_CODES[i] elif self._databaseType in const.CITY_EDITIONS:
else:
country_code = ''
region = ''
elif self._databaseType in (const.CITY_EDITION_REV0, const.CITY_EDITION_REV1):
rec = self._get_record(ipnum) rec = self._get_record(ipnum)
country_code = rec['country_code'] if 'country_code' in rec else '' country_code = rec['country_code'] if 'country_code' in rec else ''
region = rec['region_name'] if 'region_name' in rec else '' region = rec['region_name'] if 'region_name' in rec else ''
...@@ -289,7 +283,7 @@ class GeoIP(GeoIPBase): ...@@ -289,7 +283,7 @@ class GeoIP(GeoIPBase):
""" """
Populate location dict for converted IP. Populate location dict for converted IP.
@param ipnum: converted IP address @param ipnum: Converted IP address
@type ipnum: int @type ipnum: int
@return: dict with country_code, country_code3, country_name, @return: dict with country_code, country_code3, country_name,
region, city, postal_code, latitude, longitude, region, city, postal_code, latitude, longitude,
...@@ -300,102 +294,73 @@ class GeoIP(GeoIPBase): ...@@ -300,102 +294,73 @@ class GeoIP(GeoIPBase):
if seek_country == self._databaseSegments: if seek_country == self._databaseSegments:
return None return None
record_pointer = seek_country + (2 * self._recordLength - 1) * self._databaseSegments
self._filehandle.seek(record_pointer, os.SEEK_SET) read_length = (2 * self._recordLength - 1) * self._databaseSegments
self._filehandle.seek(seek_country + read_length, os.SEEK_SET)
record_buf = self._filehandle.read(const.FULL_RECORD_LENGTH) record_buf = self._filehandle.read(const.FULL_RECORD_LENGTH)
record = {} record = {
'dma_code': 0,
'area_code': 0,
'metro_code': '',
'postal_code': ''
}
latitude = 0
longitude = 0
record_buf_pos = 0 record_buf_pos = 0
# Get country
char = ord(record_buf[record_buf_pos]) char = ord(record_buf[record_buf_pos])
#char = record_buf[record_buf_pos] if six.PY3 else ord(record_buf[record_buf_pos])
record['country_code'] = const.COUNTRY_CODES[char] record['country_code'] = const.COUNTRY_CODES[char]
record['country_code3'] = const.COUNTRY_CODES3[char] record['country_code3'] = const.COUNTRY_CODES3[char]
record['country_name'] = const.COUNTRY_NAMES[char] record['country_name'] = const.COUNTRY_NAMES[char]
record_buf_pos += 1 record_buf_pos += 1
str_length = 0
# get region
char = ord(record_buf[record_buf_pos+str_length])
while (char != 0):
str_length += 1
char = ord(record_buf[record_buf_pos+str_length])
if str_length > 0:
record['region_name'] = record_buf[record_buf_pos:record_buf_pos+str_length]
record_buf_pos += str_length + 1 def get_data(record_buf, record_buf_pos):
str_length = 0 offset = record_buf_pos
char = ord(record_buf[offset])
# get city
char = ord(record_buf[record_buf_pos+str_length])
while (char != 0):
str_length += 1
char = ord(record_buf[record_buf_pos+str_length])
if str_length > 0:
record['city'] = record_buf[record_buf_pos:record_buf_pos+str_length]
else:
record['city'] = ''
record_buf_pos += str_length + 1
str_length = 0
# get the postal code
char = ord(record_buf[record_buf_pos+str_length])
while (char != 0): while (char != 0):
str_length += 1 offset += 1
char = ord(record_buf[record_buf_pos+str_length]) char = ord(record_buf[offset])
if offset > record_buf_pos:
if str_length > 0: return (offset, record_buf[record_buf_pos:offset])
record['postal_code'] = record_buf[record_buf_pos:record_buf_pos+str_length] return (offset, '')
else:
record['postal_code'] = None
record_buf_pos += str_length + 1 offset, record['region_name'] = get_data(record_buf, record_buf_pos)
str_length = 0 offset, record['city'] = get_data(record_buf, offset + 1)
offset, record['postal_code'] = get_data(record_buf, offset + 1)
record_buf_pos = offset + 1
latitude = 0
longitude = 0
for j in range(3): for j in range(3):
char = ord(record_buf[record_buf_pos]) char = ord(record_buf[record_buf_pos])
record_buf_pos += 1 record_buf_pos += 1
latitude += (char << (j * 8)) latitude += (char << (j * 8))
record['latitude'] = (latitude/10000.0) - 180.0
for j in range(3): for j in range(3):
char = ord(record_buf[record_buf_pos]) char = ord(record_buf[record_buf_pos])
record_buf_pos += 1 record_buf_pos += 1
longitude += (char << (j * 8)) longitude += (char << (j * 8))
record['longitude'] = (longitude/10000.0) - 180.0 record['latitude'] = (latitude / 10000.0) - 180.0
record['longitude'] = (longitude / 10000.0) - 180.0
if self._databaseType == const.CITY_EDITION_REV1: if self._databaseType == const.CITY_EDITION_REV1:
dmaarea_combo = 0 dmaarea_combo = 0
if record['country_code'] == 'US': if record['country_code'] == 'US':
for j in range(3): for j in range(3):
char = ord(record_buf[record_buf_pos]) char = ord(record_buf[record_buf_pos])
record_buf_pos += 1
dmaarea_combo += (char << (j*8)) dmaarea_combo += (char << (j*8))
record_buf_pos += 1
record['dma_code'] = int(math.floor(dmaarea_combo/1000)) record['dma_code'] = int(math.floor(dmaarea_combo / 1000))
record['area_code'] = dmaarea_combo%1000 record['area_code'] = dmaarea_combo % 1000
else:
record['dma_code'] = 0
record['area_code'] = 0
if 'dma_code' in record and record['dma_code'] in const.DMA_MAP: if record['dma_code'] in const.DMA_MAP:
record['metro_code'] = const.DMA_MAP[record['dma_code']] record['metro_code'] = const.DMA_MAP[record['dma_code']]
else:
record['metro_code'] = ''
if 'country_code' in record: params = (record['country_code'], record['region_name'])
record['time_zone'] = time_zone_by_country_and_region( record['time_zone'] = time_zone_by_country_and_region(*params)
record['country_code'], record.get('region_name')) or ''
else:
record['time_zone'] = ''
return record return record
......
...@@ -362,11 +362,16 @@ CITY_EDITION_REV0 = 6 ...@@ -362,11 +362,16 @@ CITY_EDITION_REV0 = 6
CITY_EDITION_REV1 = 2 CITY_EDITION_REV1 = 2
ORG_EDITION = 5 ORG_EDITION = 5
ISP_EDITION = 4 ISP_EDITION = 4
PROXY_EDITION = 8
ASNUM_EDITION = 9 ASNUM_EDITION = 9
# Not yet supported databases
PROXY_EDITION = 8
NETSPEED_EDITION = 11 NETSPEED_EDITION = 11
COUNTRY_EDITION_V6 = 12 COUNTRY_EDITION_V6 = 12
# Collection of databases
CITY_EDITIONS = (CITY_EDITION_REV0, CITY_EDITION_REV1)
REGION_EDITIONS = (REGION_EDITION_REV0, REGION_EDITION_REV1)
SEGMENT_RECORD_LENGTH = 3 SEGMENT_RECORD_LENGTH = 3
STANDARD_RECORD_LENGTH = 3 STANDARD_RECORD_LENGTH = 3
ORG_RECORD_LENGTH = 4 ORG_RECORD_LENGTH = 4
......
...@@ -699,17 +699,17 @@ _country["ZW"] = "Africa/Harare" ...@@ -699,17 +699,17 @@ _country["ZW"] = "Africa/Harare"
def time_zone_by_country_and_region(country_code, region_name=None): def time_zone_by_country_and_region(country_code, region_name=None):
if country_code not in _country: if country_code not in _country:
return None return ''
if not region_name or region_name == '00': if not region_name or region_name == '00':
region_name = None region_name = None
timezones = _country[country_code] timezones = _country[country_code]
if isinstance(timezones, str): if isinstance(timezones, str):
return timezones return timezones
if region_name: if not region_name:
return ''
return timezones.get(region_name) return timezones.get(region_name)
...@@ -40,7 +40,8 @@ class TestGeoIPCityFunctions(unittest.TestCase): ...@@ -40,7 +40,8 @@ class TestGeoIPCityFunctions(unittest.TestCase):
'longitude': -0.23339999999998895, 'longitude': -0.23339999999998895,
'country_code3': 'GBR', 'country_code3': 'GBR',
'latitude': 51.283299999999997, 'latitude': 51.283299999999997,
'postal_code': None, 'dma_code': 0, 'postal_code': '',
'dma_code': 0,
'country_code': 'GB', 'country_code': 'GB',
'country_name': 'United Kingdom', 'country_name': 'United Kingdom',
'time_zone': 'Europe/London' 'time_zone': 'Europe/London'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment