instrumentation: add next-share/
[cs-p2p-next.git] / instrumentation / next-share / BaseLib / Core / Video / LiveSourceAuth.py
1 # written by Arno Bakker
2 # see LICENSE.txt for license information
3
4 import sys
5 from traceback import print_exc
6 from cStringIO import StringIO
7 import struct
8 import time
9 import array
10
11 from BaseLib.Core.Utilities.Crypto import sha,RSA_pub_key_from_der
12 from BaseLib.Core.osutils import *
13 from M2Crypto import EC
14 from BaseLib.Core.osutils import *
15 from types import StringType
16
17 DEBUG = False
18
19 class Authenticator:
20     
21     def __init__(self,piecelen,npieces):
22         self.piecelen = piecelen
23         self.npieces = npieces
24         self.seqnum = 0L
25     
26     def get_piece_length(self):
27         return self.piecelen
28     
29     def get_npieces(self):
30         return self.npieces
31     
32     def get_content_blocksize(self):
33         pass
34     
35     def sign(self,content):
36         pass
37     
38     def verify(self,piece):
39         pass
40     
41     def get_content(self,piece):
42         pass
43
44     def get_source_seqnum(self):
45         return self.seqnum
46
47     def set_source_seqnum(self,seqnum):
48         self.seqnum = seqnum
49
50
51 class NullAuthenticator(Authenticator):
52     
53     def __init__(self,piecelen,npieces):
54         Authenticator.__init__(self,piecelen,npieces)
55         self.contentblocksize = piecelen
56     
57     def get_content_blocksize(self):
58         return self.contentblocksize
59     
60     def sign(self,content):
61         return [content]
62     
63     def verify(self,piece):
64         return True
65     
66     def get_content(self,piece):
67         return piece
68
69
70 class ECDSAAuthenticator(Authenticator):
71     """ Authenticator who places a ECDSA signature in the last part of a
72     piece. In particular, the sig consists of:
73     - an 8 byte sequence number
74     - an 8 byte real-time timestamp
75     - a 1 byte length field followed by
76     - a variable-length ECDSA signature in ASN.1, (max 64 bytes)  
77     - optionally 0x00 padding bytes, if the ECDSA sig is less than 64 bytes,
78     to give a total of 81 bytes.
79     """
80     
81     SEQNUM_SIZE = 8
82     RTSTAMP_SIZE = 8
83     LENGTH_SIZE = 1
84     MAX_ECDSA_ASN1_SIGSIZE = 64
85     EXTRA_SIZE = SEQNUM_SIZE + RTSTAMP_SIZE
86     # = seqnum + rtstamp + 1 byte length + MAX_ECDSA, padded
87     # put seqnum + rtstamp directly after content, so we calc the sig directly 
88     # from the received buffer.
89     OUR_SIGSIZE = EXTRA_SIZE+LENGTH_SIZE+MAX_ECDSA_ASN1_SIGSIZE 
90     
91     def __init__(self,piecelen,npieces,keypair=None,pubkeypem=None):
92         
93         print >>sys.stderr,"ECDSAAuth: npieces",npieces
94         
95         Authenticator.__init__(self,piecelen,npieces)
96         self.contentblocksize = piecelen-self.OUR_SIGSIZE
97         self.keypair = keypair
98         if pubkeypem is not None:
99             #print >>sys.stderr,"ECDSAAuth: pubkeypem",`pubkeypem`
100             self.pubkey = EC.pub_key_from_der(pubkeypem)
101         else:
102             self.pubkey = None
103         self.startts = None
104
105     def get_content_blocksize(self):
106         return self.contentblocksize
107     
108     def sign(self,content):
109         rtstamp = time.time()
110         #print >>sys.stderr,"ECDSAAuth: sign: ts %.5f s" % rtstamp
111         
112         extra = struct.pack('>Qd', self.seqnum,rtstamp)
113         self.seqnum += 1L
114
115         sig = ecdsa_sign_data(content,extra,self.keypair)
116         # The sig returned is either 64 or 63 bytes long (62 also possible I 
117         # guess). Therefore we transmit size as 1 bytes and fill to 64 bytes.
118         lensig = chr(len(sig))
119         if len(sig) != self.MAX_ECDSA_ASN1_SIGSIZE:
120             # Note: this is not official ASN.1 padding. Also need to modify
121             # the header length for that I assume.
122             diff = self.MAX_ECDSA_ASN1_SIGSIZE-len(sig)
123             padding = '\x00' * diff 
124             return [content,extra,lensig,sig,padding]
125         else:
126             return [content,extra,lensig,sig]
127         
128     def verify(self,piece,index):
129         """ A piece is valid if:
130         - the signature is correct,
131         - the seqnum % npieces == piecenr.
132         - the seqnum is no older than self.seqnum - npieces
133         @param piece The piece data as received from peer
134         @param index The piece number as received from peer
135         @return Boolean
136         """
137         try:
138             # Can we do this without memcpy?
139             #print >>sys.stderr,"ECDSAAuth: verify",len(piece)
140             extra = piece[-self.OUR_SIGSIZE:-self.OUR_SIGSIZE+self.EXTRA_SIZE]
141             lensig = ord(piece[-self.OUR_SIGSIZE+self.EXTRA_SIZE])
142             if lensig > self.MAX_ECDSA_ASN1_SIGSIZE:
143                 print >>sys.stderr,"ECDSAAuth: @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ failed piece",index,"lensig wrong",lensig
144                 return False
145             #print >>sys.stderr,"ECDSAAuth: verify lensig",lensig
146             diff = lensig-self.MAX_ECDSA_ASN1_SIGSIZE
147             if diff == 0:
148                 sig = piece[-self.OUR_SIGSIZE+self.EXTRA_SIZE+self.LENGTH_SIZE:]
149             else:
150                 sig = piece[-self.OUR_SIGSIZE+self.EXTRA_SIZE+self.LENGTH_SIZE:diff]
151             content = piece[:-self.OUR_SIGSIZE]
152             if DEBUG:
153                 print >>sys.stderr,"ECDSAAuth: verify piece",index,"sig",`sig`
154                 print >>sys.stderr,"ECDSAAuth: verify dig",sha(content).hexdigest()
155         
156             ret = ecdsa_verify_data_pubkeyobj(content,extra,self.pubkey,sig)
157             if ret:
158                 (seqnum, rtstamp) = self._decode_extra(piece)
159                 
160                 if DEBUG:
161                     print >>sys.stderr,"ECDSAAuth: verify piece",index,"seq",seqnum,"ts %.5f s" % rtstamp,"ls",lensig
162                 
163                 mod = seqnum % self.get_npieces()
164                 thres = self.seqnum - self.get_npieces()/2
165                 if seqnum <= thres:
166                     print >>sys.stderr,"ECDSAAuth: @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ failed piece",index,"old seqnum",seqnum,"<<",self.seqnum
167                     return False
168                 elif mod != index:
169                     print >>sys.stderr,"ECDSAAuth: @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ failed piece",index,"expected",mod
170                     return False 
171                 elif self.startts is not None and rtstamp < self.startts:
172                     print >>sys.stderr,"ECDSAAuth: @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ failed piece",index,"older than oldest known ts",rtstamp,self.startts
173                     return False
174                 else:
175                     self.seqnum = max(self.seqnum,seqnum)
176                     if self.startts is None:
177                         self.startts = rtstamp-300.0 # minus 5 min in case we read piece N+1 before piece N
178                         print >>sys.stderr,"ECDSAAuth: @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@: startts",self.startts
179             else:
180                 print >>sys.stderr,"ECDSAAuth: @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ piece",index,"failed sig"
181             
182             return ret
183         except:
184             print_exc()
185             return False 
186
187     def get_content(self,piece):
188         return piece[:-self.OUR_SIGSIZE]
189
190     # Extra fields
191     def get_seqnum(self,piece):
192         (seqnum, rtstamp) = self._decode_extra(piece)
193         return seqnum
194
195     def get_rtstamp(self,piece):
196         (seqnum, rtstamp) = self._decode_extra(piece)
197         return rtstamp
198         
199     def _decode_extra(self,piece):
200         extra = piece[-self.OUR_SIGSIZE:-self.OUR_SIGSIZE+self.EXTRA_SIZE]
201         if type(extra) == array.array:
202             extra = extra.tostring()
203         return struct.unpack('>Qd',extra)
204
205     
206 def ecdsa_sign_data(plaintext,extra,ec_keypair):
207     digester = sha(plaintext)
208     digester.update(extra)
209     digest = digester.digest()
210     return ec_keypair.sign_dsa_asn1(digest)
211     
212 def ecdsa_verify_data_pubkeyobj(plaintext,extra,pubkey,blob):
213     digester = sha(plaintext)
214     digester.update(extra)
215     digest = digester.digest()
216     return pubkey.verify_dsa_asn1(digest,blob)
217     
218
219
220
221 class RSAAuthenticator(Authenticator):
222     """ Authenticator who places a RSA signature in the last part of a piece. 
223     In particular, the sig consists of:
224     - an 8 byte sequence number
225     - an 8 byte real-time timestamp
226     - a variable-length RSA signature, length equivalent to the keysize in bytes  
227     to give a total of 16+(keysize/8) bytes.
228     """
229     
230     SEQNUM_SIZE = 8
231     RTSTAMP_SIZE = 8
232     EXTRA_SIZE = SEQNUM_SIZE + RTSTAMP_SIZE
233     # put seqnum + rtstamp directly after content, so we calc the sig directly 
234     # from the received buffer.
235     def our_sigsize(self):
236         return self.EXTRA_SIZE+self.rsa_sigsize() 
237     
238     def rsa_sigsize(self):
239         return len(self.pubkey)/8
240     
241     def __init__(self,piecelen,npieces,keypair=None,pubkeypem=None):
242         Authenticator.__init__(self,piecelen,npieces)
243         self.keypair = keypair
244         if pubkeypem is not None:
245             #print >>sys.stderr,"ECDSAAuth: pubkeypem",`pubkeypem`
246             self.pubkey = RSA_pub_key_from_der(pubkeypem)
247         else:
248             self.pubkey = self.keypair
249         self.contentblocksize = piecelen-self.our_sigsize()
250         self.startts = None
251
252     def get_content_blocksize(self):
253         return self.contentblocksize
254     
255     def sign(self,content):
256         rtstamp = time.time()
257         #print >>sys.stderr,"ECDSAAuth: sign: ts %.5f s" % rtstamp
258         
259         extra = struct.pack('>Qd', self.seqnum,rtstamp)
260         self.seqnum += 1L
261
262         sig = rsa_sign_data(content,extra,self.keypair)
263         return [content,extra,sig]
264         
265     def verify(self,piece,index):
266         """ A piece is valid if:
267         - the signature is correct,
268         - the seqnum % npieces == piecenr.
269         - the seqnum is no older than self.seqnum - npieces
270         @param piece The piece data as received from peer
271         @param index The piece number as received from peer
272         @return Boolean
273         """
274         try:
275             # Can we do this without memcpy?
276             #print >>sys.stderr,"ECDSAAuth: verify",len(piece)
277             extra = piece[-self.our_sigsize():-self.our_sigsize()+self.EXTRA_SIZE]
278             sig = piece[-self.our_sigsize()+self.EXTRA_SIZE:]
279             content = piece[:-self.our_sigsize()]
280             #if DEBUG:
281             #    print >>sys.stderr,"RSAAuth: verify piece",index,"sig",`sig`
282             #    print >>sys.stderr,"RSAAuth: verify dig",sha(content).hexdigest()
283         
284             ret = rsa_verify_data_pubkeyobj(content,extra,self.pubkey,sig)
285             if ret:
286                 (seqnum, rtstamp) = self._decode_extra(piece)
287                 
288                 if DEBUG:
289                     print >>sys.stderr,"RSAAuth: verify piece",index,"seq",seqnum,"ts %.5f s" % rtstamp
290                 
291                 mod = seqnum % self.get_npieces()
292                 thres = self.seqnum - self.get_npieces()/2
293                 if seqnum <= thres:
294                     print >>sys.stderr,"RSAAuth: @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ failed piece",index,"old seqnum",seqnum,"<<",self.seqnum
295                     return False
296                 elif mod != index:
297                     print >>sys.stderr,"RSAAuth: @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ failed piece",index,"expected",mod
298                     return False
299                 elif self.startts is not None and rtstamp < self.startts:
300                     print >>sys.stderr,"RSAAuth: @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ failed piece",index,"older than oldest known ts",rtstamp,self.startts
301                     return False
302                 else:
303                     self.seqnum = max(self.seqnum,seqnum)
304                     if self.startts is None:
305                         self.startts = rtstamp-300.0 # minus 5 min in case we read piece N+1 before piece N
306                         print >>sys.stderr,"RSAAuth: @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@: startts",self.startts
307             else:
308                 print >>sys.stderr,"RSAAuth: @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ piece",index,"failed sig"
309             
310             return ret
311         except:
312             print_exc()
313             return False 
314
315     def get_content(self,piece):
316         return piece[:-self.our_sigsize()]
317
318     # Extra fields
319     def get_seqnum(self,piece):
320         (seqnum, rtstamp) = self._decode_extra(piece)
321         return seqnum
322
323     def get_rtstamp(self,piece):
324         (seqnum, rtstamp) = self._decode_extra(piece)
325         return rtstamp
326         
327     def _decode_extra(self,piece):
328         extra = piece[-self.our_sigsize():-self.our_sigsize()+self.EXTRA_SIZE]
329         if type(extra) == array.array:
330             extra = extra.tostring()
331         return struct.unpack('>Qd',extra)
332
333
334 def rsa_sign_data(plaintext,extra,rsa_keypair):
335     digester = sha(plaintext)
336     digester.update(extra)
337     digest = digester.digest()
338     return rsa_keypair.sign(digest)
339     
340 def rsa_verify_data_pubkeyobj(plaintext,extra,pubkey,sig):
341     digester = sha(plaintext)
342     digester.update(extra)
343     digest = digester.digest()
344     
345     # The type of sig is array.array() at this point (why?), M2Crypto RSA verify
346     # will complain if it is not a string or Unicode object. Check if this is a
347     # memcpy. 
348     s = sig.tostring()
349     return pubkey.verify(digest,s)
350
351
352
353
354
355
356     
357 class AuthStreamWrapper:
358     """ Wrapper around the stream returned by VideoOnDemand/MovieOnDemandTransporter
359     that strips of the signature info
360     """
361     
362     def __init__(self,inputstream,authenticator):
363         self.inputstream = inputstream
364         self.buffer = StringIO()
365         self.authenticator = authenticator
366         self.piecelen = authenticator.get_piece_length()
367         self.last_rtstamp = None
368
369     def read(self,numbytes=None):
370         rawdata = self._readn(self.piecelen)
371         if len(rawdata) == 0:
372             # EOF
373             return rawdata
374         content = self.authenticator.get_content(rawdata)
375         self.last_rtstamp = self.authenticator.get_rtstamp(rawdata)
376         if numbytes is None or numbytes < 0:
377             raise ValueError('Stream has unlimited size, read all not supported.')
378         elif numbytes < len(content):
379             # TODO: buffer unread data for next read
380             raise ValueError('reading less than piecesize not supported yet')
381         else:
382             return content
383
384     def get_generation_time(self):
385         """ Returns the time at which the last read piece was generated at the source. """
386         return self.last_rtstamp
387     
388     def seek(self,pos,whence=os.SEEK_SET):
389         if pos == 0 and whence == os.SEEK_SET:
390             print >>sys.stderr,"authstream: seek: Ignoring seek 0 in live"
391         else:
392             raise ValueError("authstream does not support seek")
393
394     def close(self):
395         self.inputstream.close()
396
397     def available(self):
398         return self.inputstream.available()
399
400
401     # Internal method
402     def _readn(self,n):
403         """ read exactly n bytes from inputstream, block if unavail """
404         nwant = n
405         while True:
406             data = self.inputstream.read(nwant)
407             if len(data) == 0:
408                 return data
409             nwant -= len(data)
410             self.buffer.write(data)
411             if nwant == 0:
412                 break
413         self.buffer.seek(0)
414         data = self.buffer.read(n)
415         self.buffer.seek(0)
416         return data
417         
418
419
420 class VariableReadAuthStreamWrapper:
421     """ Wrapper around AuthStreamWrapper that allows reading of variable
422     number of bytes. TODO: optimize whole stack of AuthWrapper, 
423     MovieTransportWrapper, MovieOnDemandTransporter
424     """
425     
426     def __init__(self,inputstream,piecelen):
427         self.inputstream = inputstream
428         self.buffer = ''
429         self.piecelen = piecelen
430
431     def read(self,numbytes=None):
432         if numbytes is None or numbytes < 0:
433             raise ValueError('Stream has unlimited size, read all not supported.')
434         return self._readn(numbytes)
435
436     def get_generation_time(self):
437         """ Returns the time at which the last read piece was generated at the source. """
438         return self.inputstream.get_generation_time()
439     
440     def seek(self,pos,whence=os.SEEK_SET):
441         return self.inputstream.seek(pos,whence=whence)
442         
443     def close(self):
444         self.inputstream.close()
445
446     def available(self):
447         return self.inputstream.available()
448
449     # Internal method
450     def _readn(self,nwant):
451         """ read *at most* nwant bytes from inputstream """
452         
453         if len(self.buffer) == 0:
454             # Must read fixed size blocks from authwrapper
455             data = self.inputstream.read(self.piecelen)
456             #print >>sys.stderr,"varread: Got",len(data),"want",nwant
457             if len(data) == 0:
458                 return data
459             self.buffer = data
460
461         lenb = len(self.buffer)
462         tosend = min(nwant,lenb)
463             
464         if tosend == lenb:
465             #print >>sys.stderr,"varread: zero copy 2 lenb",lenb
466             pre = self.buffer
467             post = ''
468         else:
469             #print >>sys.stderr,"varread: copy",tosend,"lenb",lenb
470             pre = self.buffer[0:tosend]
471             post = self.buffer[tosend:]
472             
473         self.buffer = post
474         #print >>sys.stderr,"varread: Returning",len(pre)
475         return pre
476     
477