From 6640e21df3f8929e49f1fc842a181e106ecb5480 Mon Sep 17 00:00:00 2001 From: yequari Date: Fri, 28 Apr 2023 19:41:42 -0700 Subject: [PATCH] generate SELECT statements, tests --- bot/database.py | 18 +++++++++++-- bot/model.py | 38 ++++++++++++++++++++++++--- bot/tests.py | 69 +++++++++++++++++++++++++++++++++++++++++++------ 3 files changed, 111 insertions(+), 14 deletions(-) diff --git a/bot/database.py b/bot/database.py index 383d217..9c12d4e 100644 --- a/bot/database.py +++ b/bot/database.py @@ -55,13 +55,23 @@ class DatabaseManager: with self.conn: self.conn.executescript(script) + def _build_select_query(self, cls, **kwargs) -> (str, tuple): + wildcards = ['?=?' * len(kwargs)] + query = f'''SELECT {','.join(cls.fields())} FROM {cls.table} + WHERE {','.join(wildcards)}''' + constraints = [] + for k, v in kwargs.items(): + constraints.append(k) + constraints.append(v) + return query, tuple(constraints) + def create_guild(self, guild: WatchGuild): '''Insert a new guild into the database''' query = '''INSERT INTO guilds VALUES(?, ?, datetime('now'));''' with self.conn: self.conn.execute(query, (guild.id, guild.name)) - def get_guild(self, id: int = None, name: str = None) -> list[WatchGuild]: + def get_guild(self, **kwargs) -> list[WatchGuild]: '''Access guilds stored in the Database. Query will be filtered by the passed parameters, @@ -73,7 +83,11 @@ class DatabaseManager: return cur.fetchall() def create_channel(self, channel: WatchChannel): - pass + '''Insert a new channel into the database''' + query = '''INSERT INTO channels VALUES(?, ?, ?, ?);''' + with self.conn: + self.conn.execute(query, (channel.id, channel.name, + channel.register_date, channel.guild.id)) def get_channel(self, **kwargs): query = '''SELECT channel_id, name, diff --git a/bot/model.py b/bot/model.py index 9d805ad..3cde687 100644 --- a/bot/model.py +++ b/bot/model.py @@ -2,8 +2,17 @@ from dataclasses import dataclass from datetime import datetime +class WatchObject: + table = '' + + @classmethod + def fields(cls): + '''Returns the field names in the database''' + raise NotImplementedError + + @dataclass -class WatchGuild: +class WatchGuild(WatchObject): '''WatchGuild represents a Discord guild as stored in the database. Attributes: @@ -13,10 +22,15 @@ class WatchGuild: id: int name: str join_date: datetime + table: str = 'guilds' + + @classmethod + def fields(cls): + return ['guild_id', 'name', 'join_date'] @dataclass -class WatchChannel: +class WatchChannel(WatchObject): '''WatchChannel represents a Discord channel being watched for new messages Attributes: @@ -29,10 +43,15 @@ class WatchChannel: name: str register_date: datetime guild: WatchGuild + table = 'channels' + + @classmethod + def fields(cls): + return ['channel_id', 'name', 'register_date', 'guild_id'] @dataclass -class WatchUser: +class WatchUser(WatchObject): '''WatchUser represents a Discord user who has sent a message in a watched channel. Attributes: id: User ID as given by Discord @@ -40,10 +59,15 @@ class WatchUser: ''' id: int name: str + table = 'users' + + @classmethod + def fields(cls): + return ['user_id', 'name'] @dataclass -class WatchMessage: +class WatchMessage(WatchObject): '''WatchMessage represents a Discord message sent in a watched channel. Attributes: @@ -60,3 +84,9 @@ class WatchMessage: author: WatchUser channel: WatchChannel guild: WatchGuild + table = 'messages' + + @classmethod + def fields(cls): + return ['message_id', 'contents', 'published_date', + 'user_id', 'channel_id', 'guild_id'] diff --git a/bot/tests.py b/bot/tests.py index 741a4b4..74d7cbf 100644 --- a/bot/tests.py +++ b/bot/tests.py @@ -10,9 +10,35 @@ class TestMessageParsing(unittest.TestCase): pass +class DbTest(unittest.TestCase): + dbname = 'test.db' + testguilds = [WatchGuild(1000, 'test1', None), + WatchGuild(2000, 'test2', None), + WatchGuild(3000, 'test3', None)] + testchannels = [WatchChannel(1001, 'channel1', datetime.now(), testguilds[0]), + WatchChannel(2001, 'channel2', datetime.now(), testguilds[1]), + WatchChannel(3001, 'channel3', datetime.now(), testguilds[2])] + testusers = [] + testmessages = [] + + def setUp(self): + self.db = database.DatabaseManager(self.dbname) + for g in self.testguilds: + self.db.create_guild(g) + for c in self.testchannels: + self.db.create_channel(c) + for u in self.testusers: + pass + for m in self.testmessages: + pass + + def tearDown(self): + if os.path.exists(self.dbname): + os.remove(self.dbname) + + class TestGuilds(unittest.TestCase): dbname = 'test.db' - get_guilds_stmt = 'SELECT guild_id, name FROM guilds' testguilds = [WatchGuild(1000, 'test1', None), WatchGuild(2000, 'test2', None), WatchGuild(3000, 'test3', None)] @@ -28,11 +54,13 @@ class TestGuilds(unittest.TestCase): def test_get_guild_by_id(self): guilds = self.db.get_guild(id=2000) - self.assertTrue(len(guilds) == 1) + self.assertTrue(len(guilds) == len(self.testguilds)) + self.assertTrue(guilds[0].id == 2000) def test_get_guild_by_name(self): guilds = self.db.get_guild(name='test1') self.assertTrue(len(guilds) == 1) + self.assertTrue(guilds[0].name == 'test1') def test_get_multiple_guilds_by_name(self): new_guild = WatchGuild(1001, 'test1', None) @@ -40,17 +68,13 @@ class TestGuilds(unittest.TestCase): guilds = self.db.get_guild(name='test1') self.assertTrue(len(guilds) == 2) - def test_get_guild_with_nonsense_params(self): - guilds = self.db.get_guild(bleep='bloop') - self.assertTrue(len(guilds) == 0) - def test_new_guild_add(self): with self.db.conn: - existing = self.db.conn.execute(self.get_guilds_stmt).fetchall() + existing = self.db.get_guild() guild = WatchGuild(4000, 'test4', datetime.now(tz=zoneinfo.ZoneInfo('UTC'))) self.db.create_guild(guild) - new = self.db.conn.execute(self.get_guilds_stmt).fetchall() + new = self.db.get_guild() self.assertTrue(len(new) == len(existing) + 1) def tearDown(self): @@ -58,6 +82,25 @@ class TestGuilds(unittest.TestCase): os.remove(self.dbname) +class TestChannel(DbTest): + def test_get_channel(self): + channels = self.db.get_channel() + self.assertEqual(len(channels), len(self.testchannels), 'not equal') + + def test_get_channel_by_id(self): + channels = self.db.get_channel(id=self.testchannels[0].id) + self.assertTrue(len(channels) == 1) + self.assertTrue(channels[0].id == self.testchannels[0].id) + + +class TestUser(unittest.TestCase): + pass + + +class TestMessage(unittest.TestCase): + pass + + class TestDatabase(unittest.TestCase): dbname = 'test.db' get_guilds_stmt = 'SELECT guild_id, name FROM guilds' @@ -72,6 +115,16 @@ class TestDatabase(unittest.TestCase): cur = db.conn.execute(self.get_guilds_stmt) self.assertTrue(len(cur.fetchall()) == 0) + def test_build_select_query(self): + expected_result = '''SELECT guild_id, name, join_date + FROM guilds + WHERE guild_id=1000, name=test1, nonsense=yes''' + db = database.DatabaseManager(self.dbname) + result, values = db._build_select_query(WatchGuild, guild_id='1000', + name='test1', nonsense='yes') + self.assertTrue(expected_result, result) + self.assertTrue(len(values), 3) + def tearDown(self): os.remove(self.dbname)