Commit 671d3ab9 by William Tisäter

Remove dead code and make all tests pass

parent b6d171c6
...@@ -62,6 +62,9 @@ class _GeoIPMetaclass(type): ...@@ -62,6 +62,9 @@ class _GeoIPMetaclass(type):
def __call__(cls, *args, **kwargs): def __call__(cls, *args, **kwargs):
""" Singleton method to gets an instance without reparsing """ Singleton method to gets an instance without reparsing
the database, the filename is being used as cache key. the database, the filename is being used as cache key.
@param cache: Used in tests for skipping instance caching
@type cache: bool
""" """
if len(args) > 0: if len(args) > 0:
filename = args[0] filename = args[0]
...@@ -70,6 +73,10 @@ class _GeoIPMetaclass(type): ...@@ -70,6 +73,10 @@ class _GeoIPMetaclass(type):
else: else:
return None return None
if not kwargs.get('cache', True):
del kwargs['cache']
return super(_GeoIPMetaclass, cls).__call__(*args, **kwargs)
cls._instance_lock.acquire() cls._instance_lock.acquire()
if filename not in cls._instances: if filename not in cls._instances:
cls._instances[filename] = super(_GeoIPMetaclass, cls).__call__(*args, **kwargs) cls._instances[filename] = super(_GeoIPMetaclass, cls).__call__(*args, **kwargs)
...@@ -104,14 +111,17 @@ class GeoIP(object): ...@@ -104,14 +111,17 @@ class GeoIP(object):
f = codecs.open(filename, 'rb', ENCODING) f = codecs.open(filename, 'rb', ENCODING)
access = mmap.ACCESS_READ access = mmap.ACCESS_READ
self._fp = mmap.mmap(f.fileno(), 0, access=access) self._fp = mmap.mmap(f.fileno(), 0, access=access)
self._type = 'MMAP_CACHE'
f.close() f.close()
elif self._flags & const.MEMORY_CACHE: elif self._flags & const.MEMORY_CACHE:
f = codecs.open(filename, 'rb', ENCODING) f = codecs.open(filename, 'rb', ENCODING)
self._memory = f.read() self._memory = f.read()
self._fp = self._str_to_fp(self._memory) self._fp = self._str_to_fp(self._memory)
self._type = 'MEMORY_CACHE'
f.close() f.close()
else: else:
self._fp = codecs.open(filename, 'rb', ENCODING) self._fp = codecs.open(filename, 'rb', ENCODING)
self._type = 'STANDARD'
self._lock = Lock() self._lock = Lock()
self._setup_segments() self._setup_segments()
...@@ -430,24 +440,25 @@ class GeoIP(object): ...@@ -430,24 +440,25 @@ class GeoIP(object):
def id_by_addr(self, addr): def id_by_addr(self, addr):
""" """
Get the country index. Looks up the index for the country which is the key for the
Looks up the index for the country which is the key for code and name.
the code and name.
@param addr: The IP address @param addr: IPv4 or IPv6 address
@type addr: str @type addr: str
@return: network byte order 32-bit integer @return: network byte order 32-bit integer
@rtype: int @rtype: int
""" """
ipnum = util.ip2long(addr)
if not ipnum:
raise ValueError("Invalid IP address: %s" % addr)
COUNTY_EDITIONS = (const.COUNTRY_EDITION, const.COUNTRY_EDITION_V6) COUNTY_EDITIONS = (const.COUNTRY_EDITION, const.COUNTRY_EDITION_V6)
if self._databaseType not in COUNTY_EDITIONS: if self._databaseType not in COUNTY_EDITIONS:
message = 'Invalid database type, expected Country' raise GeoIPError('Invalid database type, expected Country')
raise GeoIPError(message)
ipv = 6 if addr.find(':') >= 0 else 4
if ipv == 4 and self._databaseType != const.COUNTRY_EDITION:
raise GeoIPError('Invalid database type; expected IPv6 address')
if ipv == 6 and self._databaseType != const.COUNTRY_EDITION_V6:
raise GeoIPError('Invalid database type; expected IPv4 address')
ipnum = util.ip2long(addr)
return self._seek_country(ipnum) - const.COUNTRY_BEGIN return self._seek_country(ipnum) - const.COUNTRY_BEGIN
def country_code_by_addr(self, addr): def country_code_by_addr(self, addr):
...@@ -460,27 +471,14 @@ class GeoIP(object): ...@@ -460,27 +471,14 @@ class GeoIP(object):
@return: 2-letter country code @return: 2-letter country code
@rtype: str @rtype: str
""" """
try: VALID_EDITIONS = (const.COUNTRY_EDITION, const.COUNTRY_EDITION_V6)
VALID_EDITIONS = (const.COUNTRY_EDITION, const.COUNTRY_EDITION_V6) if self._databaseType in VALID_EDITIONS:
if self._databaseType in VALID_EDITIONS: country_id = self.id_by_addr(addr)
ipv = 6 if addr.find(':') >= 0 else 4 return const.COUNTRY_CODES[country_id]
elif self._databaseType in const.REGION_CITY_EDITIONS:
if ipv == 4 and self._databaseType != const.COUNTRY_EDITION: return self.region_by_addr(addr).get('country_code')
message = 'Invalid database type; expected IPv6 address'
raise ValueError(message) raise GeoIPError('Invalid database type, expected Country, City or Region')
if ipv == 6 and self._databaseType != const.COUNTRY_EDITION_V6:
message = 'Invalid database type; expected IPv4 address'
raise ValueError(message)
country_id = self.id_by_addr(addr)
return const.COUNTRY_CODES[country_id]
elif self._databaseType in const.REGION_CITY_EDITIONS:
return self.region_by_addr(addr).get('country_code')
message = 'Invalid database type, expected Country, City or Region'
raise GeoIPError(message)
except ValueError:
raise GeoIPError('Failed to lookup address %s' % addr)
def country_code_by_name(self, hostname): def country_code_by_name(self, hostname):
""" """
...@@ -505,18 +503,15 @@ class GeoIP(object): ...@@ -505,18 +503,15 @@ class GeoIP(object):
@return: country name @return: country name
@rtype: str @rtype: str
""" """
try: VALID_EDITIONS = (const.COUNTRY_EDITION, const.COUNTRY_EDITION_V6)
VALID_EDITIONS = (const.COUNTRY_EDITION, const.COUNTRY_EDITION_V6) if self._databaseType in VALID_EDITIONS:
if self._databaseType in VALID_EDITIONS: country_id = self.id_by_addr(addr)
country_id = self.id_by_addr(addr) return const.COUNTRY_NAMES[country_id]
return const.COUNTRY_NAMES[country_id] elif self._databaseType in const.CITY_EDITIONS:
elif self._databaseType in const.CITY_EDITIONS: return self.record_by_addr(addr).get('country_name')
return self.record_by_addr(addr).get('country_name') else:
else: message = 'Invalid database type, expected Country or City'
message = 'Invalid database type, expected Country or City' raise GeoIPError(message)
raise GeoIPError(message)
except ValueError:
raise GeoIPError('Failed to lookup address %s' % addr)
def country_name_by_name(self, hostname): def country_name_by_name(self, hostname):
""" """
...@@ -541,19 +536,13 @@ class GeoIP(object): ...@@ -541,19 +536,13 @@ class GeoIP(object):
@return: organization or ISP name @return: organization or ISP name
@rtype: str @rtype: str
""" """
try: valid = (const.ORG_EDITION, const.ISP_EDITION, const.ASNUM_EDITION, const.ASNUM_EDITION_V6)
ipnum = util.ip2long(addr) if self._databaseType not in valid:
if not ipnum: message = 'Invalid database type, expected Org, ISP or ASNum'
raise ValueError('Invalid IP address') raise GeoIPError(message)
valid = (const.ORG_EDITION, const.ISP_EDITION, const.ASNUM_EDITION, const.ASNUM_EDITION_V6)
if self._databaseType not in valid:
message = 'Invalid database type, expected Org, ISP or ASNum'
raise GeoIPError(message)
return self._get_org(ipnum) ipnum = util.ip2long(addr)
except ValueError: return self._get_org(ipnum)
raise GeoIPError('Failed to lookup address %s' % addr)
def org_by_name(self, hostname): def org_by_name(self, hostname):
""" """
...@@ -580,22 +569,16 @@ class GeoIP(object): ...@@ -580,22 +569,16 @@ class GeoIP(object):
metro_code, area_code, region_name, time_zone metro_code, area_code, region_name, time_zone
@rtype: dict @rtype: dict
""" """
try: if self._databaseType not in const.CITY_EDITIONS:
ipnum = util.ip2long(addr) message = 'Invalid database type, expected City'
if not ipnum: raise GeoIPError(message)
raise ValueError('Invalid IP address')
if self._databaseType not in const.CITY_EDITIONS:
message = 'Invalid database type, expected City'
raise GeoIPError(message)
rec = self._get_record(ipnum) ipnum = util.ip2long(addr)
if not rec: rec = self._get_record(ipnum)
return None if not rec:
return None
return rec return rec
except ValueError:
raise GeoIPError('Failed to lookup address %s' % addr)
def record_by_name(self, hostname): def record_by_name(self, hostname):
""" """
...@@ -622,18 +605,12 @@ class GeoIP(object): ...@@ -622,18 +605,12 @@ class GeoIP(object):
@return: Dictionary containing country_code, region and region_name @return: Dictionary containing country_code, region and region_name
@rtype: dict @rtype: dict
""" """
try: if self._databaseType not in const.REGION_CITY_EDITIONS:
ipnum = util.ip2long(addr) message = 'Invalid database type, expected Region or City'
if not ipnum: raise GeoIPError(message)
raise ValueError('Invalid IP address')
if self._databaseType not in const.REGION_CITY_EDITIONS:
message = 'Invalid database type, expected Region or City'
raise GeoIPError(message)
return self._get_region(ipnum) ipnum = util.ip2long(addr)
except ValueError: return self._get_region(ipnum)
raise GeoIPError('Failed to lookup address %s' % addr)
def region_by_name(self, hostname): def region_by_name(self, hostname):
""" """
...@@ -658,18 +635,12 @@ class GeoIP(object): ...@@ -658,18 +635,12 @@ class GeoIP(object):
@return: Time zone @return: Time zone
@rtype: str @rtype: str
""" """
try: if self._databaseType not in const.CITY_EDITIONS:
ipnum = util.ip2long(addr) message = 'Invalid database type, expected City'
if not ipnum: raise GeoIPError(message)
raise ValueError('Invalid IP address')
if self._databaseType not in const.CITY_EDITIONS:
message = 'Invalid database type, expected City'
raise GeoIPError(message)
return self._get_record(ipnum).get('time_zone') ipnum = util.ip2long(addr)
except ValueError: return self._get_record(ipnum).get('time_zone')
raise GeoIPError('Failed to lookup address %s' % addr)
def time_zone_by_name(self, hostname): def time_zone_by_name(self, hostname):
""" """
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import unittest import unittest
from nose.tools import raises
import pygeoip import pygeoip
from tests.config import COUNTRY_DB_PATH, COUNTRY_V6_DB_PATH from tests.config import COUNTRY_DB_PATH, COUNTRY_V6_DB_PATH
...@@ -64,3 +65,13 @@ class TestGeoIPCountryFunctions(unittest.TestCase): ...@@ -64,3 +65,13 @@ class TestGeoIPCountryFunctions(unittest.TestCase):
self.assertEqual(us_name, self.us_name) self.assertEqual(us_name, self.us_name)
self.assertEqual(gb_name, self.gb_name) self.assertEqual(gb_name, self.gb_name)
self.assertEqual(ie6_name, self.ie_name) self.assertEqual(ie6_name, self.ie_name)
@raises(pygeoip.GeoIPError)
def testOpen4With6(self):
data = self.gi.country_code_by_addr(self.ie6_ip)
raise ValueError(data)
@raises(pygeoip.GeoIPError)
def testOpen6With4(self):
data = self.gi6.country_code_by_addr(self.gb_ip)
raise ValueError(data)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment