diff --git a/bot/tests.py b/bot/tests.py index 5d99068..741a4b4 100644 --- a/bot/tests.py +++ b/bot/tests.py @@ -1,23 +1,75 @@ import unittest +import zoneinfo +from datetime import datetime import os import database +from model import WatchGuild, WatchChannel, WatchUser, WatchMessage class TestMessageParsing(unittest.TestCase): pass +class TestGuilds(unittest.TestCase): + dbname = 'test.db' + get_guilds_stmt = 'SELECT guild_id, name FROM guilds' + testguilds = [WatchGuild(1000, 'test1', None), + WatchGuild(2000, 'test2', None), + WatchGuild(3000, 'test3', None)] + + def setUp(self): + self.db = database.DatabaseManager(self.dbname) + for g in self.testguilds: + self.db.create_guild(g) + + def test_get_all_guilds(self): + guilds = self.db.get_guild() + self.assertTrue(len(guilds) == len(self.testguilds)) + + def test_get_guild_by_id(self): + guilds = self.db.get_guild(id=2000) + self.assertTrue(len(guilds) == 1) + + def test_get_guild_by_name(self): + guilds = self.db.get_guild(name='test1') + self.assertTrue(len(guilds) == 1) + + def test_get_multiple_guilds_by_name(self): + new_guild = WatchGuild(1001, 'test1', None) + self.db.create_guild(new_guild) + guilds = self.db.get_guild(name='test1') + self.assertTrue(len(guilds) == 2) + + def test_get_guild_with_nonsense_params(self): + guilds = self.db.get_guild(bleep='bloop') + self.assertTrue(len(guilds) == 0) + + def test_new_guild_add(self): + with self.db.conn: + existing = self.db.conn.execute(self.get_guilds_stmt).fetchall() + guild = WatchGuild(4000, 'test4', + datetime.now(tz=zoneinfo.ZoneInfo('UTC'))) + self.db.create_guild(guild) + new = self.db.conn.execute(self.get_guilds_stmt).fetchall() + self.assertTrue(len(new) == len(existing) + 1) + + def tearDown(self): + if os.path.exists(self.dbname): + os.remove(self.dbname) + + class TestDatabase(unittest.TestCase): dbname = 'test.db' + get_guilds_stmt = 'SELECT guild_id, name FROM guilds' def setUp(self): if os.path.exists(self.dbname): os.remove(self.dbname) def test_database_init(self): - db = database.DatabaseManager('test.db') + db = database.DatabaseManager(self.dbname) with db.conn: - cur = db.conn.execute('SELECT id, name FROM guilds;') + cur = db.conn.execute(self.get_guilds_stmt) self.assertTrue(len(cur.fetchall()) == 0) def tearDown(self):