instrumentation: add next-share/
[cs-p2p-next.git] / instrumentation / next-share / BaseLib / Test / test_crawler.py
1 # Written by Boudewijn Schoon
2 # see LICENSE.txt for license information
3
4 import socket
5 import unittest
6 import os
7 import sys
8 import time
9 from BaseLib.Core.Utilities.Crypto import sha
10 from M2Crypto import Rand
11 import cPickle
12
13 from BaseLib.Test.test_as_server import TestAsServer
14 from olconn import OLConnection
15 from BaseLib.Core.BitTornado.BT1.MessageID import CRAWLER_REQUEST, CRAWLER_REPLY, CRAWLER_DATABASE_QUERY, getMessageName
16
17 from BaseLib.Core.CacheDB.SqliteCacheDBHandler import CrawlerDBHandler
18
19 DEBUG=True
20
21 class TestCrawler(TestAsServer):
22     """ 
23     Testing the user side of the crawler
24     """
25     
26     def setUp(self):
27         """ override TestAsServer """
28         TestAsServer.setUp(self)
29         Rand.load_file('randpool.dat', -1)
30
31     def setUpPreSession(self):
32         """ override TestAsServer """
33         TestAsServer.setUpPreSession(self)
34
35         # Enable buddycast and crawler handling
36         self.config.set_buddycast(True)
37         self.config.set_crawler(True)
38
39     def setUpPostSession(self):
40         """ override TestAsServer """
41         TestAsServer.setUpPostSession(self)
42
43         self.my_permid = str(self.my_keypair.pub().get_der())
44         self.my_hash = sha(self.my_permid).digest()
45         self.his_permid = str(self.his_keypair.pub().get_der())        
46
47         # Start our server side, to with Tribler will try to connect
48         self.listen_port = 4123
49         self.listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
50         self.listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
51         self.listen_socket.bind(("", self.listen_port))
52         self.listen_socket.listen(10)
53         self.listen_socket.settimeout(10)
54         
55     def tearDown(self):
56         """ override TestAsServer """
57         TestAsServer.tearDown(self)
58         try:
59             os.remove('randpool.dat')
60         except:
61             pass
62
63     def test_all(self):
64         """
65         I want to start a Tribler client once and then connect to it
66         many times. So there must be only one test method to prevent
67         setUp() from creating a new client every time.
68
69         The code is constructed so unittest will show the name of the
70         (sub)test where the error occured in the traceback it prints.
71         """
72         self.subtest_invalid_permid()
73         self.subtest_invalid_messageid()
74         self.subtest_invalid_sql_query()
75         self.subtest_invalid_frequency()
76         self.subtest_invalid_tablename()
77         self.subtest_valid_messageid()
78         self.subtest_dialback()
79
80     def subtest_invalid_permid(self):
81         """
82         Send crawler messages from a non-crawler peer
83         """
84         print >>sys.stderr, "-"*80, "\ntest: invalid_permid"
85
86         # make sure that the OLConnection is NOT in the crawler_db
87         crawler_db = CrawlerDBHandler.getInstance()
88         assert not self.my_permid in crawler_db.getCrawlers()
89
90         # We are not a registered crawler, any request from us should
91         # be denied
92         messages = [CRAWLER_REQUEST,
93                     CRAWLER_REQUEST + CRAWLER_DATABASE_QUERY,
94                     CRAWLER_REQUEST + CRAWLER_DATABASE_QUERY,
95                     CRAWLER_REQUEST + chr(0)]
96         for msg in messages:
97             s = OLConnection(self.my_keypair, "localhost", self.hisport)
98             s.send(msg)
99             response  = s.recv()
100             assert response == "", "response type is %s" % getMessageName(response[0])
101
102         time.sleep(1)
103         s.close()
104
105     def subtest_invalid_messageid(self):
106         """
107         Send an invalid message-id from a registered crawler peer
108         """
109         print >>sys.stderr, "-"*80, "\ntest: invalid_messageid"
110
111         # make sure that the OLConnection IS in the crawler_db
112         crawler_db = CrawlerDBHandler.getInstance()
113         crawler_db.temporarilyAddCrawler(self.my_permid)
114
115         # We are a registered crawler, start sending invalid messages
116         messages = [CRAWLER_REQUEST,
117                     CRAWLER_REQUEST + chr(0),
118                     CRAWLER_REPLY,
119                     CRAWLER_REPLY + chr(0)]
120         for msg in messages:
121             s = OLConnection(self.my_keypair, "localhost", self.hisport)
122             s.send(msg)
123             response  = s.recv()
124             assert response == "", "response type is %s" % getMessageName(response[0])
125
126         time.sleep(1)
127         s.close()
128
129     def subtest_invalid_sql_query(self):
130         """
131         Send an invalid sql query from a registered crawler peer
132         """
133         print >>sys.stderr, "-"*80, "\ntest: invalid_sql_query"
134
135         # make sure that the OLConnection IS in the crawler_db
136         crawler_db = CrawlerDBHandler.getInstance()
137         crawler_db.temporarilyAddCrawler(self.my_permid)
138
139         s = OLConnection(self.my_keypair, "localhost", self.hisport)
140
141         queries = ["FOO BAR"]
142         for query in queries:
143             self.send_crawler_request(s, CRAWLER_DATABASE_QUERY, 0, 0, query)
144
145             error, payload = self.receive_crawler_reply(s, CRAWLER_DATABASE_QUERY, 0)
146             
147             
148             
149             assert error == 1
150             if DEBUG:
151                 print >>sys.stderr, payload
152
153         time.sleep(1)
154         s.close()
155
156     def subtest_invalid_frequency(self):
157         """
158         Send two valid requests shortly after each other. However,
159         indicate that the frequency should be large. This should
160         result in a frequency error
161         """
162         print >>sys.stderr, "-"*80, "\ntest: invalid_invalid_frequency"
163
164         # make sure that the OLConnection IS in the crawler_db
165         crawler_db = CrawlerDBHandler.getInstance()
166         crawler_db.temporarilyAddCrawler(self.my_permid)
167
168         s = OLConnection(self.my_keypair, "localhost", self.hisport)
169         self.send_crawler_request(s, CRAWLER_DATABASE_QUERY, 42, 0, "SELECT * FROM peer")
170         error, payload = self.receive_crawler_reply(s, CRAWLER_DATABASE_QUERY, 42)
171         assert error == 0
172
173         # try on the same connection
174         self.send_crawler_request(s, CRAWLER_DATABASE_QUERY, 42, 1000, "SELECT * FROM peer")
175         error, payload = self.receive_crawler_reply(s, CRAWLER_DATABASE_QUERY, 42)
176         assert error == 254 # should give a frequency erro
177         s.close()
178         
179         # try on a new connection
180         s = OLConnection(self.my_keypair, "localhost", self.hisport)
181         self.send_crawler_request(s, CRAWLER_DATABASE_QUERY, 42, 1000, "SELECT * FROM peer")
182         error, payload = self.receive_crawler_reply(s, CRAWLER_DATABASE_QUERY, 42)
183         assert error == 254 # should give a frequency error
184  
185         time.sleep(1)
186         s.close()
187         
188
189     def subtest_invalid_tablename(self):
190         """
191         Send an invalid query and check that we get the actual sql
192         exception back
193         """
194         print >>sys.stderr, "-"*80, "\ntest: invalid_tablename"
195
196         # make sure that the OLConnection IS in the crawler_db
197         crawler_db = CrawlerDBHandler.getInstance()
198         crawler_db.temporarilyAddCrawler(self.my_permid)
199
200         s = OLConnection(self.my_keypair, "localhost", self.hisport)
201         self.send_crawler_request(s, CRAWLER_DATABASE_QUERY, 42, 0, "SELECT * FROM nofoobar")
202         error, payload = self.receive_crawler_reply(s, CRAWLER_DATABASE_QUERY, 42)
203         assert error != 0
204         assert payload == "SQLError: no such table: nofoobar", payload
205
206     def subtest_valid_messageid(self):
207         """
208         Send a valid message-id from a registered crawler peer
209         """
210         print >>sys.stderr, "-"*80, "\ntest: valid_messageid"
211
212         # make sure that the OLConnection IS in the crawler_db
213         crawler_db = CrawlerDBHandler.getInstance()
214         crawler_db.temporarilyAddCrawler(self.my_permid)
215
216         s = OLConnection(self.my_keypair, "localhost", self.hisport)
217
218         queries = ["SELECT name FROM category", "SELECT * FROM peer", "SELECT * FROM torrent"]
219         for query in queries:
220             self.send_crawler_request(s, CRAWLER_DATABASE_QUERY, 0, 0, query)
221
222             error, payload = self.receive_crawler_reply(s, CRAWLER_DATABASE_QUERY, 0)
223             assert error == 0
224             if DEBUG:
225                 print >>sys.stderr, cPickle.loads(payload)
226
227         time.sleep(1)
228         s.close()
229
230     def subtest_dialback(self):
231         """
232         Send a valid request, disconnect, and wait for an incomming
233         connection with the reply
234         """
235         print >>sys.stderr, "-"*80, "\ntest: dialback"
236         
237         # make sure that the OLConnection IS in the crawler_db
238         crawler_db = CrawlerDBHandler.getInstance()
239         crawler_db.temporarilyAddCrawler(self.my_permid)
240
241         s = OLConnection(self.my_keypair, "localhost", self.hisport, mylistenport=self.listen_port)
242         self.send_crawler_request(s, CRAWLER_DATABASE_QUERY, 42, 0, "SELECT * FROM peer")
243         s.close()
244
245         # wait for reply
246         try:
247             conn, addr = self.listen_socket.accept()
248         except socket.timeout:
249             if DEBUG: print >> sys.stderr,"test_crawler: timeout, bad, peer didn't connect to send the crawler reply"
250             assert False, "test_crawler: timeout, bad, peer didn't connect to send the crawler reply"
251         s = OLConnection(self.my_keypair, "", 0, conn, mylistenport=self.listen_port)
252
253         # read reply
254         error, payload = self.receive_crawler_reply(s, CRAWLER_DATABASE_QUERY, 42)
255         assert error == 0
256         if DEBUG: print >>sys.stderr, cPickle.loads(payload)
257
258         time.sleep(1)
259
260     def send_crawler_request(self, sock, message_id, channel_id, frequency, payload):
261         # Sending a request from a Crawler to a Tribler peer
262         #     SIZE    INDEX
263         #     1 byte: 0      CRAWLER_REQUEST (from BaseLib.Core.BitTornado.BT1.MessageID)
264         #     1 byte: 1      --MESSAGE-SPECIFIC-ID--
265         #     1 byte: 2      Channel id
266         #     2 byte: 3+4    Frequency
267         #     n byte: 5...   Request payload
268         sock.send("".join((CRAWLER_REQUEST,
269                            message_id,
270                            chr(channel_id & 0xFF),
271                            chr((frequency >> 8) & 0xFF) + chr(frequency & 0xFF),
272                            payload)))
273
274     def receive_crawler_reply(self, sock, message_id, channel_id):
275         # Sending a reply from a Tribler peer to a Crawler
276         #     SIZE    INDEX
277         #     1 byte: 0      CRAWLER_REPLY (from BaseLib.Core.BitTornado.BT1.MessageID)
278         #     1 byte: 1      --MESSAGE-SPECIFIC-ID--
279         #     1 byte: 2      Channel id
280         #     1 byte: 3      Parts left
281         #     1 byte: 4      Indicating success (0) or failure (non 0)
282         #     n byte: 5...   Reply payload
283
284         if DEBUG:
285             print >>sys.stderr, "test_crawler: receive_crawler_reply: waiting for channel",channel_id
286
287         parts = []
288         while True:
289             response  = sock.recv()
290             if response:
291                 if response[0] == CRAWLER_REPLY and response[1] == message_id and ord(response[2]) == channel_id:
292                     parts.append(response[5:])
293                     if DEBUG:
294                         print >>sys.stderr, "test_crawler: received", getMessageName(response[0:2]), "channel", channel_id, "length", sum([len(part) for part in parts]), "parts left", ord(response[3])
295
296                     if ord(response[3]):
297                         # there are parts left
298                         continue
299
300                     return ord(response[4]), "".join(parts)
301
302             return -1, ""
303
304 def test_suite():
305     suite = unittest.TestSuite()
306     suite.addTest(unittest.makeSuite(TestCrawler))
307     
308     return suite
309
310 if __name__ == "__main__":
311     unittest.main()
312