instrumentation: add next-share/
[cs-p2p-next.git] / instrumentation / next-share / BaseLib / Test / test_merkle.py
1 # Written by Arno Bakker
2 # see LICENSE.txt for license information
3
4 import unittest
5
6 from tempfile import mkstemp
7 import os
8 from types import StringType, DictType
9 from math import ceil
10
11 from BaseLib.Core.API import *
12 from BaseLib.Core.Merkle.merkle import *
13 from BaseLib.Core.BitTornado.bencode import bdecode
14
15 from traceback import print_exc
16
17 DEBUG=False
18
19 class TestMerkleHashes(unittest.TestCase):
20     """ 
21     Testing Simple Merkle Hashes extension version 0, in particular:
22     * The algorithmic part
23     * The .torrent file part
24     See test_merkle_msg.py for protocol testing.
25     """
26     
27     def setUp(self):
28         pass
29         
30     def tearDown(self):
31         pass
32
33     def test_get_hashes_for_piece(self):
34         """ 
35             test MerkleTree.get_hashes_for_piece() method 
36         """
37         self._test_123pieces_tree_get_hashes()
38         self._test_8piece_tree_uncle_calc()
39
40     def _test_123pieces_tree_get_hashes(self):
41         for n in range(1,64):
42             piece_size = 2 ** n
43             self._test_1piece_tree_get_hashes(piece_size,piece_size)
44             for add in [1,piece_size-1]:
45                 self._test_1piece_tree_get_hashes(piece_size,add)
46                 self._test_2piece_tree_get_hashes(piece_size,add)
47                 self._test_3piece_tree_get_hashes(piece_size,add)
48
49     def _test_1piece_tree_get_hashes(self,piece_size,length_add):
50         """ testing get_hashes_for_piece on tree with 1 piece """
51         msg = "1piece_get_hashes("+str(piece_size)+","+str(length_add)+") failed"
52         npieces = 1
53         total_length = length_add
54
55         piece_hashes = ['\x01\x02\x03\x04\x05\x06\x07\x08\x07\x06\x05\x04\x03\x02\x01\x00\x01\x02\x03\x04' ] * npieces
56         tree = MerkleTree(piece_size,total_length,None,piece_hashes)
57         for p in range(npieces):
58             ohlist = tree.get_hashes_for_piece(p)
59             self.assert_(len(ohlist)==1,msg)
60             self.assert_(ohlist[0][0] == 0,msg)
61             self.assertEquals(ohlist[0][1],piece_hashes[0],msg)
62
63     def _test_2piece_tree_get_hashes(self,piece_size,length_add):
64         """testing get_hashes_for_piece on tree with 2 pieces """
65         msg = "2piece_get_hashes("+str(piece_size)+","+str(length_add)+") failed"
66         npieces = 2
67         total_length = piece_size+length_add
68
69         piece_hashes = ['\x01\x02\x03\x04\x05\x06\x07\x08\x07\x06\x05\x04\x03\x02\x01\x00\x01\x02\x03\x04' ] * npieces
70         tree = MerkleTree(piece_size,total_length,None,piece_hashes)
71         for p in range(npieces):
72             ohlist = tree.get_hashes_for_piece(p)
73             self.assert_(len(ohlist)==3)
74             ohlist.sort()
75             self.assert_(ohlist[0][0] == 0,msg)
76             self.assert_(ohlist[1][0] == 1,msg)
77             self.assert_(ohlist[2][0] == 2,msg)
78             self.assertDigestEquals(ohlist[1][1]+ohlist[2][1], ohlist[0][1],msg)
79
80     def _test_3piece_tree_get_hashes(self,piece_size,length_add):
81         """ testing get_hashes_for_piece on tree with 3 pieces """
82         msg = "3piece_get_hashes("+str(piece_size)+","+str(length_add)+") failed"
83         npieces = 3
84         total_length = 2*piece_size+length_add
85
86         piece_hashes = ['\x01\x02\x03\x04\x05\x06\x07\x08\x07\x06\x05\x04\x03\x02\x01\x00\x01\x02\x03\x04' ] * npieces
87         tree = MerkleTree(piece_size,total_length,None,piece_hashes)
88         for p in range(npieces):
89             ohlist = tree.get_hashes_for_piece(p)
90             self.assert_(len(ohlist)==4,msg)
91             ohlist.sort()
92             if p == 0 or p == 1:
93                 self.assert_(ohlist[0][0] == 0,msg)
94                 self.assert_(ohlist[1][0] == 2,msg)
95                 self.assert_(ohlist[2][0] == 3,msg)
96                 self.assert_(ohlist[3][0] == 4,msg)
97                 digest34 = self.calc_digest(ohlist[2][1]+ohlist[3][1])
98                 self.assertDigestEquals(digest34+ohlist[1][1],ohlist[0][1],msg)
99             else:
100                 self.assert_(ohlist[0][0] == 0,msg)
101                 self.assert_(ohlist[1][0] == 1,msg)
102                 self.assert_(ohlist[2][0] == 5,msg)
103                 self.assert_(ohlist[3][0] == 6,msg)
104                 digest56 = self.calc_digest(ohlist[2][1]+ohlist[3][1])
105                 self.assertDigestEquals(ohlist[1][1]+digest56,ohlist[0][1],msg)
106
107     def assertDigestEquals(self,data,digest,msg=None):
108         self.assertEquals(self.calc_digest(data),digest,msg)
109
110     def calc_digest(self,data):
111         digester = sha()
112         digester.update(data)
113         return digester.digest()
114
115     def _test_8piece_tree_uncle_calc(self):
116         npieces = 8
117         hashlist = self.get_indices_for_piece(0,npieces)
118         assert hashlist == [7, 8, 4, 2, 0]
119
120         hashlist = self.get_indices_for_piece(1,npieces)
121         assert hashlist == [8, 7, 4, 2, 0]
122
123         hashlist = self.get_indices_for_piece(2,npieces)
124         assert hashlist == [9, 10, 3, 2, 0]
125
126         hashlist = self.get_indices_for_piece(3,npieces)
127         assert hashlist == [10, 9, 3, 2, 0]
128
129         hashlist = self.get_indices_for_piece(4,npieces)
130         assert hashlist == [11, 12, 6, 1, 0]
131
132         hashlist = self.get_indices_for_piece(5,npieces)
133         assert hashlist == [12, 11, 6, 1, 0]
134
135         hashlist = self.get_indices_for_piece(6,npieces)
136         assert hashlist == [13, 14, 5, 1, 0]
137
138         hashlist = self.get_indices_for_piece(7,npieces)
139         assert hashlist == [14, 13, 5, 1, 0]
140
141     def get_indices_for_piece(self,index,npieces):
142         height = get_tree_height(npieces)
143         tree = create_tree(height)
144         ohlist = get_hashes_for_piece(tree,height,index)
145         list = []
146         for oh in ohlist:
147             list.append(oh[0])
148         return list
149
150     def test_check_hashes_update_hash_admin(self):
151         """ 
152             test MerkleTree.check_hashes() and update_hash_admin() methods
153         """
154         for n in range(1,64):
155             piece_size = 2 ** n
156             for add in [1,piece_size-1]:
157                 self._test_3piece_tree_check_hashes_update_hash_admin(piece_size, add)
158
159     def _test_3piece_tree_check_hashes_update_hash_admin(self,piece_size,length_add):
160         """ testing check_hashes and update_hash_admin tree with 3 pieces """
161         msg = "3piece_check_hashes("+str(piece_size)+","+str(length_add)+") failed"
162         npieces = 3
163         total_length = 2*piece_size+length_add
164
165         piece_hashes = ['\x01\x02\x03\x04\x05\x06\x07\x08\x07\x06\x05\x04\x03\x02\x01\x00\x01\x02\x03\x04' ] * npieces
166         fulltree = MerkleTree(piece_size,total_length,None,piece_hashes)
167         root_hash = fulltree.get_root_hash()
168         emptytree = MerkleTree(piece_size,total_length,root_hash,None)
169         empty_piece_hashes = [0] * npieces
170
171         for p in range(npieces):
172             ohlist = fulltree.get_hashes_for_piece(p)
173             self.assert_(emptytree.check_hashes(ohlist),msg)
174
175         for p in range(npieces):
176             ohlist = fulltree.get_hashes_for_piece(p)
177             self.assert_(emptytree.check_hashes(ohlist),msg)
178             emptytree.update_hash_admin(ohlist,empty_piece_hashes)
179
180         for p in range(npieces):
181             self.assert_(piece_hashes[p] == empty_piece_hashes[p],msg)
182
183     def test_merkle_torrent(self):
184         """
185             test the creation of Merkle torrent files via TorrentMaker/btmakemetafile.py
186         """
187         piece_size = 2 ** 18
188         for file_size in [1,piece_size-1,piece_size,piece_size+1,2*piece_size,(2*piece_size)+1]:
189             self.create_merkle_torrent(file_size,piece_size)
190
191     def create_merkle_torrent(self,file_size,piece_size):
192         try:
193             # 1. create file
194             [handle,datafilename]= mkstemp()
195             os.close(handle)
196             block = "".zfill(file_size)
197             fp = open(datafilename,"wb")
198             fp.write(block)
199             fp.close()
200             torrentfilename = datafilename+'.tribe'
201
202             # 2. Set torrent args
203             tdef = TorrentDef()
204             tdef.set_tracker("http://localhost:6969/announce")
205             tdef.set_create_merkle_torrent(True)
206             tdef.set_piece_length(int(log(piece_size,2)))
207
208             # 3. create Merkle torrent
209             #make_meta_file(datafilename,url,params,flag,dummy_progress,1,dummy_filecallback)
210             tdef.add_content(datafilename)
211             tdef.finalize()
212             tdef.save(torrentfilename)
213
214             # 4. read Merkle torrent
215             fp = open(torrentfilename,"rb")
216             data = fp.read(10000)
217             fp.close()
218
219             # 5. test Merkle torrent
220             # basic tests
221             dict = bdecode(data)
222             self.assert_(type(dict) == DictType)
223             self.assert_(dict.has_key('info'))
224             info = dict['info']
225             self.assert_(type(info) == DictType)
226             self.assert_(not info.has_key('pieces'))
227             self.assert_(info.has_key('root hash'))
228             roothash = info['root hash']
229             self.assert_(type(roothash) == StringType)
230             self.assert_(len(roothash)== 20)
231
232             # create hash tree
233             hashes = self.read_and_calc_hashes(datafilename,piece_size)
234             npieces = len(hashes)
235             if DEBUG:
236                 print "npieces is",npieces
237             height = log(npieces,2)+1
238             if height > int(height):
239                 height += 1
240             height = int(height)
241             if DEBUG:
242                 print "height is",height
243
244             starto = (2 ** (height-1))-1
245
246             if DEBUG:
247                 print "starto is",starto
248             tree = [0] * ((2 ** (height))-1)
249             if DEBUG:
250                 print "len tree is",len(tree)
251             # put hashes in tree
252             for i in range(len(hashes)):
253                 o = starto + i
254                 tree[o] = hashes[i]
255
256             # fill unused
257             nplaces = (2 ** height)-(2 ** (height-1))        
258             xso = starto+npieces
259             xeo = starto+nplaces
260             for o in range(xso,xeo):
261                 tree[o] = '\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
262
263             # calc higher level ones
264             if height > 1:
265                 for o in range(len(tree)-starto-2,-1,-1):
266                     co = self.get_child_offset(o,height)
267                     if DEBUG:
268                         print "offset is",o,"co is",co
269                     data = tree[co]+tree[co+1]
270                     digest = self.calc_digest(data)
271                     tree[o] = digest
272             self.assert_(tree[0],roothash)
273         except Exception,e:
274             print_exc()
275         #finally:
276         #    os.remove(datafilename)
277         #    os.remove(torrentfilename)
278
279     def read_and_calc_hashes(self,filename,piece_size):
280         hashes = []
281         fp = open(filename,"rb")
282         while True:
283             block = fp.read(piece_size)
284             if len(block) == 0:
285                 break
286             digest = self.calc_digest(block)
287             hashes.append(digest)
288             if len(block) != piece_size:
289                 break
290         fp.close()
291         return hashes
292
293     def get_child_offset(self,offset,height):
294         if DEBUG:
295             print "get_child(",offset,",",height,")"
296         if offset == 0:
297             level = 1
298         else:
299             level = log(offset,2)
300             if level == int(level):
301                 level += 1
302             else:
303                 level = ceil(level)
304             level = int(level)
305         starto = (2 ** (level-1))-1
306         diffo = offset-starto
307         diffo *= 2
308         cstarto = (2 ** level)-1
309         return cstarto+diffo
310
311 def test_suite():
312     suite = unittest.TestSuite()
313     suite.addTest(unittest.makeSuite(TestMerkleHashes))
314     
315     return suite
316
317 if __name__ == "__main__":
318     unittest.main()