2023-04-15 20:54:16 +00:00
|
|
|
import unittest
|
2023-04-26 05:14:02 +00:00
|
|
|
import zoneinfo
|
|
|
|
from datetime import datetime
|
2023-04-15 20:54:16 +00:00
|
|
|
import os
|
|
|
|
import database
|
2023-04-26 05:14:02 +00:00
|
|
|
from model import WatchGuild, WatchChannel, WatchUser, WatchMessage
|
2023-04-15 20:54:16 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TestMessageParsing(unittest.TestCase):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2023-04-29 02:41:42 +00:00
|
|
|
class DbTest(unittest.TestCase):
|
|
|
|
dbname = 'test.db'
|
|
|
|
testguilds = [WatchGuild(1000, 'test1', None),
|
|
|
|
WatchGuild(2000, 'test2', None),
|
|
|
|
WatchGuild(3000, 'test3', None)]
|
|
|
|
testchannels = [WatchChannel(1001, 'channel1', datetime.now(), testguilds[0]),
|
|
|
|
WatchChannel(2001, 'channel2', datetime.now(), testguilds[1]),
|
|
|
|
WatchChannel(3001, 'channel3', datetime.now(), testguilds[2])]
|
|
|
|
testusers = []
|
|
|
|
testmessages = []
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
self.db = database.DatabaseManager(self.dbname)
|
|
|
|
for g in self.testguilds:
|
|
|
|
self.db.create_guild(g)
|
|
|
|
for c in self.testchannels:
|
|
|
|
self.db.create_channel(c)
|
|
|
|
for u in self.testusers:
|
|
|
|
pass
|
|
|
|
for m in self.testmessages:
|
|
|
|
pass
|
|
|
|
|
|
|
|
def tearDown(self):
|
|
|
|
if os.path.exists(self.dbname):
|
|
|
|
os.remove(self.dbname)
|
|
|
|
|
|
|
|
|
2023-04-26 05:14:02 +00:00
|
|
|
class TestGuilds(unittest.TestCase):
|
|
|
|
dbname = 'test.db'
|
|
|
|
testguilds = [WatchGuild(1000, 'test1', None),
|
|
|
|
WatchGuild(2000, 'test2', None),
|
|
|
|
WatchGuild(3000, 'test3', None)]
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
self.db = database.DatabaseManager(self.dbname)
|
|
|
|
for g in self.testguilds:
|
|
|
|
self.db.create_guild(g)
|
|
|
|
|
|
|
|
def test_get_all_guilds(self):
|
|
|
|
guilds = self.db.get_guild()
|
|
|
|
self.assertTrue(len(guilds) == len(self.testguilds))
|
|
|
|
|
|
|
|
def test_get_guild_by_id(self):
|
|
|
|
guilds = self.db.get_guild(id=2000)
|
2023-04-29 02:41:42 +00:00
|
|
|
self.assertTrue(len(guilds) == len(self.testguilds))
|
|
|
|
self.assertTrue(guilds[0].id == 2000)
|
2023-04-26 05:14:02 +00:00
|
|
|
|
|
|
|
def test_get_guild_by_name(self):
|
|
|
|
guilds = self.db.get_guild(name='test1')
|
|
|
|
self.assertTrue(len(guilds) == 1)
|
2023-04-29 02:41:42 +00:00
|
|
|
self.assertTrue(guilds[0].name == 'test1')
|
2023-04-26 05:14:02 +00:00
|
|
|
|
|
|
|
def test_get_multiple_guilds_by_name(self):
|
|
|
|
new_guild = WatchGuild(1001, 'test1', None)
|
|
|
|
self.db.create_guild(new_guild)
|
|
|
|
guilds = self.db.get_guild(name='test1')
|
|
|
|
self.assertTrue(len(guilds) == 2)
|
|
|
|
|
|
|
|
def test_new_guild_add(self):
|
|
|
|
with self.db.conn:
|
2023-04-29 02:41:42 +00:00
|
|
|
existing = self.db.get_guild()
|
2023-04-26 05:14:02 +00:00
|
|
|
guild = WatchGuild(4000, 'test4',
|
|
|
|
datetime.now(tz=zoneinfo.ZoneInfo('UTC')))
|
|
|
|
self.db.create_guild(guild)
|
2023-04-29 02:41:42 +00:00
|
|
|
new = self.db.get_guild()
|
2023-04-26 05:14:02 +00:00
|
|
|
self.assertTrue(len(new) == len(existing) + 1)
|
|
|
|
|
|
|
|
def tearDown(self):
|
|
|
|
if os.path.exists(self.dbname):
|
|
|
|
os.remove(self.dbname)
|
|
|
|
|
|
|
|
|
2023-04-29 02:41:42 +00:00
|
|
|
class TestChannel(DbTest):
|
|
|
|
def test_get_channel(self):
|
|
|
|
channels = self.db.get_channel()
|
|
|
|
self.assertEqual(len(channels), len(self.testchannels), 'not equal')
|
|
|
|
|
|
|
|
def test_get_channel_by_id(self):
|
|
|
|
channels = self.db.get_channel(id=self.testchannels[0].id)
|
|
|
|
self.assertTrue(len(channels) == 1)
|
|
|
|
self.assertTrue(channels[0].id == self.testchannels[0].id)
|
|
|
|
|
|
|
|
|
|
|
|
class TestUser(unittest.TestCase):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class TestMessage(unittest.TestCase):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2023-04-15 20:54:16 +00:00
|
|
|
class TestDatabase(unittest.TestCase):
|
|
|
|
dbname = 'test.db'
|
2023-04-26 05:14:02 +00:00
|
|
|
get_guilds_stmt = 'SELECT guild_id, name FROM guilds'
|
2023-04-15 20:54:16 +00:00
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
if os.path.exists(self.dbname):
|
|
|
|
os.remove(self.dbname)
|
|
|
|
|
|
|
|
def test_database_init(self):
|
2023-04-26 05:14:02 +00:00
|
|
|
db = database.DatabaseManager(self.dbname)
|
2023-04-15 20:54:16 +00:00
|
|
|
with db.conn:
|
2023-04-26 05:14:02 +00:00
|
|
|
cur = db.conn.execute(self.get_guilds_stmt)
|
2023-04-15 20:54:16 +00:00
|
|
|
self.assertTrue(len(cur.fetchall()) == 0)
|
|
|
|
|
2023-04-29 02:41:42 +00:00
|
|
|
def test_build_select_query(self):
|
|
|
|
expected_result = '''SELECT guild_id, name, join_date
|
|
|
|
FROM guilds
|
|
|
|
WHERE guild_id=1000, name=test1, nonsense=yes'''
|
|
|
|
db = database.DatabaseManager(self.dbname)
|
|
|
|
result, values = db._build_select_query(WatchGuild, guild_id='1000',
|
|
|
|
name='test1', nonsense='yes')
|
|
|
|
self.assertTrue(expected_result, result)
|
|
|
|
self.assertTrue(len(values), 3)
|
|
|
|
|
2023-04-15 20:54:16 +00:00
|
|
|
def tearDown(self):
|
|
|
|
os.remove(self.dbname)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
unittest.main()
|