Commit b17b976a by Wojciech Malinowski

Used a lock to prevent concurrency problems between calls to file.seek()

and file.read()
parent 368e3dce
......@@ -35,6 +35,7 @@ import mmap
import gzip
import codecs
from StringIO import StringIO
from threading import Lock
from . import const
from .util import ip2long
......@@ -109,6 +110,7 @@ class GeoIP(GeoIPBase):
else:
self._filehandle = codecs.open(filename, 'rb','latin_1')
self._lock = Lock()
self._setup_segments()
def _setup_segments(self):
......@@ -119,6 +121,7 @@ class GeoIP(GeoIPBase):
self._databaseType = const.COUNTRY_EDITION
self._recordLength = const.STANDARD_RECORD_LENGTH
self._lock.acquire()
filepos = self._filehandle.tell()
self._filehandle.seek(-3, os.SEEK_END)
......@@ -160,6 +163,7 @@ class GeoIP(GeoIPBase):
self._databaseSegments = const.COUNTRY_BEGIN
self._filehandle.seek(filepos, os.SEEK_SET)
self._lock.release()
def _lookup_country_id(self, addr):
"""
......@@ -206,8 +210,10 @@ class GeoIP(GeoIPBase):
endIndex = startIndex + length
buf = self._memoryBuffer[startIndex:endIndex]
else:
self._lock.acquire()
self._filehandle.seek(2 * self._recordLength * offset, os.SEEK_SET)
buf = self._filehandle.read(2 * self._recordLength)
self._lock.release()
x = [0,0]
......@@ -247,9 +253,11 @@ class GeoIP(GeoIPBase):
record_pointer = seek_org + (2 * self._recordLength - 1) * self._databaseSegments
self._lock.acquire()
self._filehandle.seek(record_pointer, os.SEEK_SET)
org_buf = self._filehandle.read(const.MAX_ORG_RECORD_LENGTH)
self._lock.release()
return org_buf[:org_buf.index(chr(0))]
......@@ -319,9 +327,11 @@ class GeoIP(GeoIPBase):
record_pointer = seek_country + (2 * self._recordLength - 1) * self._databaseSegments
self._lock.acquire()
self._filehandle.seek(record_pointer, os.SEEK_SET)
record_buf = self._filehandle.read(const.FULL_RECORD_LENGTH)
self._lock.release()
record = {}
record_buf_pos = 0
......
from __future__ import with_statement, absolute_import
import threading
import unittest
import pygeoip
from .config import (CITY_DB_PATH, COUNTRY_DB_PATH, ISP_DB_PATH, ORG_DB_PATH, REGION_DB_PATH, _data_dir)
from config import (CITY_DB_PATH, COUNTRY_DB_PATH, ISP_DB_PATH, ORG_DB_PATH, REGION_DB_PATH, _data_dir)
class BaseGeoIPTestCase(unittest.TestCase):
def setUp(self):
......@@ -197,5 +198,41 @@ class TestGeoIPRegionFunctions(BaseGeoIPTestCase):
self.assertEqual(self.gir.region_by_addr(self.yahoo_ip), self.yahoo_region_data)
class TestThread(threading.Thread):
def __init__(self, name, geoip, ip, country_code, assertEqual):
threading.Thread.__init__(self, name=name)
self.geoip = geoip
self.ip = ip
self.country_code = country_code
self.assertEqual = assertEqual
def run(self):
for i in range(1000):
self.assertEqual(self.geoip.country_code_by_addr(self.ip), self.country_code)
class TestThreadedFunctions(BaseGeoIPTestCase):
def setUp(self):
super(TestThreadedFunctions, self).setUp()
self.gi = pygeoip.GeoIP(COUNTRY_DB_PATH)
self.gic = pygeoip.GeoIP(CITY_DB_PATH)
def testCountryDB(self):
us_thread = TestThread('country-us', self.gi, self.us_ip, self.us_code, self.assertEqual)
gb_thread = TestThread('country-gb', self.gi, self.gb_ip, self.gb_code, self.assertEqual)
us_thread.start()
gb_thread.start()
us_thread.join()
gb_thread.join()
def testCityDB(self):
us_thread = TestThread('city-us', self.gic, self.us_ip, self.us_code, self.assertEqual)
gb_thread = TestThread('city-gb', self.gic, self.gb_ip, self.gb_code, self.assertEqual)
us_thread.start()
gb_thread.start()
us_thread.join()
gb_thread.join()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment