Commit f5725375 by William Tisäter

Merge branch 'master' of https://github.com/wmalinowski/pygeoip

Conflicts:
	pygeoip/__init__.py
	setup.py
	tests/config.py
	tests/test_pygeoip.py
parents 91d3634a b37c3437
...@@ -36,6 +36,7 @@ import socket ...@@ -36,6 +36,7 @@ import socket
import mmap import mmap
import gzip import gzip
import codecs import codecs
from threading import Lock
try: try:
from StringIO import StringIO from StringIO import StringIO
...@@ -118,6 +119,7 @@ class GeoIP(GeoIPBase): ...@@ -118,6 +119,7 @@ class GeoIP(GeoIPBase):
else: else:
self._filehandle = codecs.open(filename, 'rb', 'latin_1') self._filehandle = codecs.open(filename, 'rb', 'latin_1')
self._lock = Lock()
self._setup_segments() self._setup_segments()
def _setup_segments(self): def _setup_segments(self):
...@@ -142,6 +144,7 @@ class GeoIP(GeoIPBase): ...@@ -142,6 +144,7 @@ class GeoIP(GeoIPBase):
self._recordLength = const.STANDARD_RECORD_LENGTH self._recordLength = const.STANDARD_RECORD_LENGTH
self._databaseSegments = const.COUNTRY_BEGIN self._databaseSegments = const.COUNTRY_BEGIN
self._lock.acquire()
filepos = self._filehandle.tell() filepos = self._filehandle.tell()
self._filehandle.seek(-3, os.SEEK_END) self._filehandle.seek(-3, os.SEEK_END)
...@@ -178,6 +181,7 @@ class GeoIP(GeoIPBase): ...@@ -178,6 +181,7 @@ class GeoIP(GeoIPBase):
else: else:
self._filehandle.seek(-4, os.SEEK_CUR) self._filehandle.seek(-4, os.SEEK_CUR)
self._filehandle.seek(filepos, os.SEEK_SET) self._filehandle.seek(filepos, os.SEEK_SET)
self._lock.release()
def _seek_country(self, ipnum): def _seek_country(self, ipnum):
""" """
...@@ -200,8 +204,10 @@ class GeoIP(GeoIPBase): ...@@ -200,8 +204,10 @@ class GeoIP(GeoIPBase):
else: else:
startIndex = 2 * self._recordLength * offset startIndex = 2 * self._recordLength * offset
readLength = 2 * self._recordLength readLength = 2 * self._recordLength
self._lock.acquire()
self._filehandle.seek(startIndex, os.SEEK_SET) self._filehandle.seek(startIndex, os.SEEK_SET)
buf = self._filehandle.read(readLength) buf = self._filehandle.read(readLength)
self._lock.release()
x = [0, 0] x = [0, 0]
for i in range(2): for i in range(2):
...@@ -231,8 +237,10 @@ class GeoIP(GeoIPBase): ...@@ -231,8 +237,10 @@ class GeoIP(GeoIPBase):
return None return None
read_length = (2 * self._recordLength - 1) * self._databaseSegments read_length = (2 * self._recordLength - 1) * self._databaseSegments
self._lock.acquire()
self._filehandle.seek(seek_org + read_length, os.SEEK_SET) self._filehandle.seek(seek_org + read_length, os.SEEK_SET)
org_buf = self._filehandle.read(const.MAX_ORG_RECORD_LENGTH) org_buf = self._filehandle.read(const.MAX_ORG_RECORD_LENGTH)
self._lock.release()
return org_buf[:org_buf.index(chr(0))] return org_buf[:org_buf.index(chr(0))]
...@@ -298,10 +306,11 @@ class GeoIP(GeoIPBase): ...@@ -298,10 +306,11 @@ class GeoIP(GeoIPBase):
if seek_country == self._databaseSegments: if seek_country == self._databaseSegments:
return None return None
read_length = (2 * self._recordLength - 1) * self._databaseSegments read_length = (2 * self._recordLength - 1) * self._databaseSegments
self._lock.acquire()
self._filehandle.seek(seek_country + read_length, os.SEEK_SET) self._filehandle.seek(seek_country + read_length, os.SEEK_SET)
record_buf = self._filehandle.read(const.FULL_RECORD_LENGTH) record_buf = self._filehandle.read(const.FULL_RECORD_LENGTH)
self._lock.release()
record = { record = {
'dma_code': 0, 'dma_code': 0,
......
# -*- coding: utf-8 -*-
import unittest
import threading
import pygeoip
from tests.config import COUNTRY_DB_PATH
class TestGeoIPThreading(unittest.TestCase):
def setUp(self):
self.us_ip = '64.233.161.99'
self.gb_ip = '212.58.253.68'
self.us_code = 'US'
self.gb_code = 'GB'
self.gi = pygeoip.GeoIP(COUNTRY_DB_PATH)
def testCountryDatabase(self):
us_thread = TestThread('us', self.gi, self.us_ip, self.us_code, self.assertEqual)
gb_thread = TestThread('gb', self.gi, self.us_ip, self.us_code, self.assertEqual)
us_thread.start()
gb_thread.start()
us_thread.join()
gb_thread.join()
class TestThread(threading.Thread):
def __init__(self, name, gi, ip, code, assertEqual):
threading.Thread.__init__(self, name=name)
self.ip = ip
self.gi = gi
self.code = code
self.assertEqual = assertEqual
def run(self):
for i in range(1000):
code = self.gi.country_code_by_addr(self.ip)
self.assertEqual(code, self.code)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment