import unittest import zoneinfo from datetime import datetime import os import database from model import WatchGuild, WatchChannel, WatchUser, WatchMessage class TestMessageParsing(unittest.TestCase): pass class DbTest(unittest.TestCase): dbname = 'test.db' testguilds = [WatchGuild(1000, 'test1', None), WatchGuild(2000, 'test2', None), WatchGuild(3000, 'test3', None)] testchannels = [WatchChannel(1001, 'channel1', datetime.now(), testguilds[0]), WatchChannel(2001, 'channel2', datetime.now(), testguilds[1]), WatchChannel(3001, 'channel3', datetime.now(), testguilds[2])] testusers = [] testmessages = [] def setUp(self): self.db = database.DatabaseManager(self.dbname) for g in self.testguilds: self.db.create_guild(g) for c in self.testchannels: self.db.create_channel(c) for u in self.testusers: pass for m in self.testmessages: pass def tearDown(self): if os.path.exists(self.dbname): os.remove(self.dbname) class TestGuilds(unittest.TestCase): dbname = 'test.db' testguilds = [WatchGuild(1000, 'test1', None), WatchGuild(2000, 'test2', None), WatchGuild(3000, 'test3', None)] def setUp(self): self.db = database.DatabaseManager(self.dbname) for g in self.testguilds: self.db.create_guild(g) def test_get_all_guilds(self): guilds = self.db.get_guild() self.assertTrue(len(guilds) == len(self.testguilds)) def test_get_guild_by_id(self): guilds = self.db.get_guild(id=2000) self.assertTrue(len(guilds) == len(self.testguilds)) self.assertTrue(guilds[0].id == 2000) def test_get_guild_by_name(self): guilds = self.db.get_guild(name='test1') self.assertTrue(len(guilds) == 1) self.assertTrue(guilds[0].name == 'test1') def test_get_multiple_guilds_by_name(self): new_guild = WatchGuild(1001, 'test1', None) self.db.create_guild(new_guild) guilds = self.db.get_guild(name='test1') self.assertTrue(len(guilds) == 2) def test_new_guild_add(self): with self.db.conn: existing = self.db.get_guild() guild = WatchGuild(4000, 'test4', datetime.now(tz=zoneinfo.ZoneInfo('UTC'))) self.db.create_guild(guild) new = self.db.get_guild() self.assertTrue(len(new) == len(existing) + 1) def tearDown(self): if os.path.exists(self.dbname): os.remove(self.dbname) class TestChannel(DbTest): def test_get_channel(self): channels = self.db.get_channel() self.assertEqual(len(channels), len(self.testchannels), 'not equal') def test_get_channel_by_id(self): channels = self.db.get_channel(id=self.testchannels[0].id) self.assertTrue(len(channels) == 1) self.assertTrue(channels[0].id == self.testchannels[0].id) class TestUser(unittest.TestCase): pass class TestMessage(unittest.TestCase): pass class TestDatabase(unittest.TestCase): dbname = 'test.db' get_guilds_stmt = 'SELECT guild_id, name FROM guilds' def setUp(self): if os.path.exists(self.dbname): os.remove(self.dbname) def test_database_init(self): db = database.DatabaseManager(self.dbname) with db.conn: cur = db.conn.execute(self.get_guilds_stmt) self.assertTrue(len(cur.fetchall()) == 0) def test_build_select_query(self): expected_result = '''SELECT guild_id, name, join_date FROM guilds WHERE guild_id=1000, name=test1, nonsense=yes''' db = database.DatabaseManager(self.dbname) result, values = db._build_select_query(WatchGuild, guild_id='1000', name='test1', nonsense='yes') self.assertTrue(expected_result, result) self.assertTrue(len(values), 3) def tearDown(self): os.remove(self.dbname) if __name__ == "__main__": unittest.main()