ppf: Add MySQL access tests.
authorMariana Marasoiu <mariana.marasoiu@gmail.com>
Tue, 23 Aug 2011 10:27:35 +0000 (13:27 +0300)
committerMariana Marasoiu <mariana.marasoiu@gmail.com>
Tue, 23 Aug 2011 10:29:22 +0000 (13:29 +0300)
ppf/new/tests/test_storage.py

index 4d7df81..ac4827a 100644 (file)
@@ -2,6 +2,7 @@
 Test suite for storage. Uses unittest module.
 
 2011, Razvan Deaconescu, razvan.deaconescu@cs.pub.ro.
+2011, Mariana Marasoiu, mariana.marasoiu@gmail.com
 """
 
 import unittest
@@ -9,6 +10,8 @@ import os
 import os.path
 import shutil
 import sqlite3
+import MySQLdb
+import sqlparse
 import subprocess
 import sys
 import datetime
@@ -492,5 +495,208 @@ class SQLiteDatabaseAccessTest(unittest.TestCase):
 
         self.assertEqual(verbmsg_count, self.expected_verbmsg_count)
 
+
+class MySQLDatabaseAccessTest(unittest.TestCase):
+    """
+    Test suite for MySQLDatabaseAccess class in storage.py.
+    """
+
+    # Class specific variables. Initialized in setUp, used throughout tests.
+    database = "p2p_test"
+    user = "root"
+    password = "p2p4th3m45535"
+    conn = None
+    cursor = None
+    sql_create_script = "p2p-log-mysql.sql"
+    sql_init_script = "p2p-init-test-mysql.sql"
+    expected_swarm_count = 0
+    expected_session_count = 0
+    expected_statmsg_count = 0
+    expected_pstatmsg_count = 0
+    expected_verbmsg_count = 0
+
+    def get_swarm_count(self):
+        """ Retrieve number of entries in `swarms` table. """
+
+        self.cursor.execute('SELECT COUNT(id) FROM swarms')
+        self.count = self.cursor.fetchone()[0]
+
+        self.cursor.close()
+        self.conn.close()
+
+        return count
+
+    def get_session_count(self):
+        """ Retrieve number of entries in `client_sessions` table. """
+        self.cursor.execute('SELECT COUNT(id) FROM client_sessions')
+        self.count = self.cursor.fetchone()[0]
+
+        self.cursor.close()
+        self.conn.close()
+
+        return count
+
+    def get_statmsg_count(self):
+        """ Retrieve number of entries in `status_messages` table. """
+        self.cursor.execute('SELECT COUNT(id) FROM status_messages')
+        self.count = self.cursor.fetchone()[0]
+
+        self.cursor.close()
+        self.conn.close()
+
+        return count
+
+    def get_pstatmsg_count(self):
+        """ Retrieve number of entries in `peer_status_messages` table. """
+        self.cursor.execute('SELECT COUNT(id) FROM peer_status_messages')
+        self.count = self.cursor.fetchone()[0]
+
+        self.cursor.close()
+        self.conn.close()
+
+        return count
+
+    def get_verbmsg_count(self):
+        """ Retrieve number of entries in `verbose_messages` table. """
+        self.cursor.execute('SELECT COUNT(id) FROM verbose_messages')
+        self.count = self.cursor.fetchone()[0]
+
+        self.cursor.close()
+        self.conn.close()
+
+        return count
+
+    def setUp(self):
+        """ Create database file and instantiate objects. """
+        # Create database file. Create tables and indices.
+        # TODO: Check exceptions.
+
+        sql_create_script_path = os.path.join(os.path.dirname(__file__),
+                                              self.sql_create_script)
+        self.conn = MySQLdb.Connection(user = self.user, passwd = self.password)
+        self.cursor = self.conn.cursor()
+
+        self.cursor.execute("CREATE DATABASE p2p_test")
+        self.cursor.execute("USE p2p_test")
+
+        f = open(sql_create_script_path).read()
+        sql_parts = sqlparse.split(f)
+        for sql_part in sql_parts:
+            if sql_part.strip() == '':
+                continue
+            self.cursor.execute(sql_part)
+
+        # Populate database.
+        # TODO: Check exceptions.
+        sql_init_script_path = os.path.join(os.path.dirname(__file__),
+                self.sql_init_script)
+
+        f = open(sql_init_script_path).read()
+        sql_parts = sqlparse.split(f)
+        for sql_part in sql_parts:
+            if sql_part.strip() == '':
+                continue
+            self.cursor.execute(sql_part)
+
+        # Initialize to number of table entries inserted in the init script.
+        self.expected_swarm_count = 2
+        self.expected_session_count = 2
+        self.expected_statmsg_count = 2
+        self.expected_pstatmsg_count = 2
+        self.expected_verbmsg_count = 2
+
+    def tearDown(self):
+        """ Close connection and remove database file. """
+        self.cursor.execute("DROP DATABASE p2p_test")
+        print "database deleted"
+        self.cursor.close()
+        self.conn.close()
+
+    def test_add_swarm_new_entry_in_table(self):
+        # Add new swarm.
+        s = storage.Swarm(torrent_filename="fedora.torrent",
+                data_size=102400)
+        a = storage.MySQLDatabaseAccess(self.database)
+        a.connect()
+        a.add_swarm(s)
+        a.disconnect()
+        self.expected_swarm_count = self.expected_swarm_count + 1
+
+        # Select number of swarms.
+        swarm_count = self.get_swarm_count()
+
+        self.assertEqual(swarm_count, self.expected_swarm_count)
+
+    def test_add_client_session_new_entry_in_table(self):
+        # Add new client session.
+        cs = storage.ClientSession(swarm_id=1, btclient="Tribler",
+                system_os="Linux", system_os_version="2.6.26",
+                system_ram=2048, system_cpu=3000, public_ip="141.85.224.201",
+                public_port="50500", ds_limit=300, us_limit=200)
+        a = storage.SQLiteDatabaseAccess(self.database)
+        a.connect()
+        a.add_client_session(cs)
+        a.disconnect()
+        self.expected_session_count = self.expected_session_count + 1
+
+        # Select number of swarms.
+        session_count = self.get_session_count()
+
+        self.assertEqual(session_count, self.expected_session_count)
+
+    def test_add_status_message_new_entry_in_table(self):
+        # Add new status message.
+        ts = datetime.datetime.strptime("2010-09-12 08:43:15",
+                "%Y-%m-%d %H:%M:%S")
+        msg = storage.StatusMessage(client_session_id=1, timestamp=ts,
+                num_peers=10, num_dht_peers=3, download_speed=102,
+                upload_speed=99, download_size=10213, upload_size=3301)
+        a = storage.SQLiteDatabaseAccess(self.database)
+        a.connect()
+        a.add_status_message(msg)
+        a.disconnect()
+        self.expected_statmsg_count = self.expected_statmsg_count + 1
+
+        # Select number of status messages.
+        statmsg_count = self.get_statmsg_count()
+
+        self.assertEqual(statmsg_count, self.expected_statmsg_count)
+
+    def test_add_peer_status_message_new_entry_in_table(self):
+        # Add new peer status message.
+        ts = datetime.datetime.strptime("2010-09-12 13:43:25",
+                "%Y-%m-%d %H:%M:%S")
+        msg = storage.PeerStatusMessage(client_session_id=1, timestamp=ts,
+                peer_ip="141.85.224.202", peer_port="12345",
+                download_speed=13, upload_speed=98)
+        a = storage.SQLiteDatabaseAccess(self.database)
+        a.connect()
+        a.add_peer_status_message(msg)
+        a.disconnect()
+        self.expected_pstatmsg_count = self.expected_pstatmsg_count + 1
+
+        # Select number of peer status messages.
+        pstatmsg_count = self.get_pstatmsg_count()
+
+        self.assertEqual(pstatmsg_count, self.expected_pstatmsg_count)
+
+    def test_add_verbose_message_new_entry_in_databas(self):
+        # Add new verbose message.
+        ts = datetime.datetime.strptime("2010-09-12 13:43:24",
+                "%Y-%m-%d %H:%M:%S")
+        msg = storage.VerboseMessage(client_session_id=1, timestamp=ts,
+            transfer_direction="send", peer_ip="141.85.224.202",
+            peer_port="12345", message_type="CHOKE")
+        a = storage.SQLiteDatabaseAccess(self.database)
+        a.connect()
+        a.add_verbose_message(msg)
+        a.disconnect()
+        self.expected_verbmsg_count = self.expected_verbmsg_count + 1
+
+        # Select number of verbose messages.
+        verbmsg_count = self.get_verbmsg_count()
+
+        self.assertEqual(verbmsg_count, self.expected_verbmsg_count)
+
 if __name__ == "__main__":
     unittest.main()