generate SELECT statements, tests

This commit is contained in:
yequari 2023-04-28 19:41:42 -07:00
parent aa5c8a0bf4
commit 6640e21df3
3 changed files with 111 additions and 14 deletions

View File

@ -55,13 +55,23 @@ class DatabaseManager:
with self.conn: with self.conn:
self.conn.executescript(script) self.conn.executescript(script)
def _build_select_query(self, cls, **kwargs) -> (str, tuple):
wildcards = ['?=?' * len(kwargs)]
query = f'''SELECT {','.join(cls.fields())} FROM {cls.table}
WHERE {','.join(wildcards)}'''
constraints = []
for k, v in kwargs.items():
constraints.append(k)
constraints.append(v)
return query, tuple(constraints)
def create_guild(self, guild: WatchGuild): def create_guild(self, guild: WatchGuild):
'''Insert a new guild into the database''' '''Insert a new guild into the database'''
query = '''INSERT INTO guilds VALUES(?, ?, datetime('now'));''' query = '''INSERT INTO guilds VALUES(?, ?, datetime('now'));'''
with self.conn: with self.conn:
self.conn.execute(query, (guild.id, guild.name)) self.conn.execute(query, (guild.id, guild.name))
def get_guild(self, id: int = None, name: str = None) -> list[WatchGuild]: def get_guild(self, **kwargs) -> list[WatchGuild]:
'''Access guilds stored in the Database. '''Access guilds stored in the Database.
Query will be filtered by the passed parameters, Query will be filtered by the passed parameters,
@ -73,7 +83,11 @@ class DatabaseManager:
return cur.fetchall() return cur.fetchall()
def create_channel(self, channel: WatchChannel): def create_channel(self, channel: WatchChannel):
pass '''Insert a new channel into the database'''
query = '''INSERT INTO channels VALUES(?, ?, ?, ?);'''
with self.conn:
self.conn.execute(query, (channel.id, channel.name,
channel.register_date, channel.guild.id))
def get_channel(self, **kwargs): def get_channel(self, **kwargs):
query = '''SELECT channel_id, name, query = '''SELECT channel_id, name,

View File

@ -2,8 +2,17 @@ from dataclasses import dataclass
from datetime import datetime from datetime import datetime
class WatchObject:
table = ''
@classmethod
def fields(cls):
'''Returns the field names in the database'''
raise NotImplementedError
@dataclass @dataclass
class WatchGuild: class WatchGuild(WatchObject):
'''WatchGuild represents a Discord guild as stored in the database. '''WatchGuild represents a Discord guild as stored in the database.
Attributes: Attributes:
@ -13,10 +22,15 @@ class WatchGuild:
id: int id: int
name: str name: str
join_date: datetime join_date: datetime
table: str = 'guilds'
@classmethod
def fields(cls):
return ['guild_id', 'name', 'join_date']
@dataclass @dataclass
class WatchChannel: class WatchChannel(WatchObject):
'''WatchChannel represents a Discord channel being watched for new messages '''WatchChannel represents a Discord channel being watched for new messages
Attributes: Attributes:
@ -29,10 +43,15 @@ class WatchChannel:
name: str name: str
register_date: datetime register_date: datetime
guild: WatchGuild guild: WatchGuild
table = 'channels'
@classmethod
def fields(cls):
return ['channel_id', 'name', 'register_date', 'guild_id']
@dataclass @dataclass
class WatchUser: class WatchUser(WatchObject):
'''WatchUser represents a Discord user who has sent a message in a watched channel. '''WatchUser represents a Discord user who has sent a message in a watched channel.
Attributes: Attributes:
id: User ID as given by Discord id: User ID as given by Discord
@ -40,10 +59,15 @@ class WatchUser:
''' '''
id: int id: int
name: str name: str
table = 'users'
@classmethod
def fields(cls):
return ['user_id', 'name']
@dataclass @dataclass
class WatchMessage: class WatchMessage(WatchObject):
'''WatchMessage represents a Discord message sent in a watched channel. '''WatchMessage represents a Discord message sent in a watched channel.
Attributes: Attributes:
@ -60,3 +84,9 @@ class WatchMessage:
author: WatchUser author: WatchUser
channel: WatchChannel channel: WatchChannel
guild: WatchGuild guild: WatchGuild
table = 'messages'
@classmethod
def fields(cls):
return ['message_id', 'contents', 'published_date',
'user_id', 'channel_id', 'guild_id']

View File

@ -10,9 +10,35 @@ class TestMessageParsing(unittest.TestCase):
pass pass
class DbTest(unittest.TestCase):
dbname = 'test.db'
testguilds = [WatchGuild(1000, 'test1', None),
WatchGuild(2000, 'test2', None),
WatchGuild(3000, 'test3', None)]
testchannels = [WatchChannel(1001, 'channel1', datetime.now(), testguilds[0]),
WatchChannel(2001, 'channel2', datetime.now(), testguilds[1]),
WatchChannel(3001, 'channel3', datetime.now(), testguilds[2])]
testusers = []
testmessages = []
def setUp(self):
self.db = database.DatabaseManager(self.dbname)
for g in self.testguilds:
self.db.create_guild(g)
for c in self.testchannels:
self.db.create_channel(c)
for u in self.testusers:
pass
for m in self.testmessages:
pass
def tearDown(self):
if os.path.exists(self.dbname):
os.remove(self.dbname)
class TestGuilds(unittest.TestCase): class TestGuilds(unittest.TestCase):
dbname = 'test.db' dbname = 'test.db'
get_guilds_stmt = 'SELECT guild_id, name FROM guilds'
testguilds = [WatchGuild(1000, 'test1', None), testguilds = [WatchGuild(1000, 'test1', None),
WatchGuild(2000, 'test2', None), WatchGuild(2000, 'test2', None),
WatchGuild(3000, 'test3', None)] WatchGuild(3000, 'test3', None)]
@ -28,11 +54,13 @@ class TestGuilds(unittest.TestCase):
def test_get_guild_by_id(self): def test_get_guild_by_id(self):
guilds = self.db.get_guild(id=2000) guilds = self.db.get_guild(id=2000)
self.assertTrue(len(guilds) == 1) self.assertTrue(len(guilds) == len(self.testguilds))
self.assertTrue(guilds[0].id == 2000)
def test_get_guild_by_name(self): def test_get_guild_by_name(self):
guilds = self.db.get_guild(name='test1') guilds = self.db.get_guild(name='test1')
self.assertTrue(len(guilds) == 1) self.assertTrue(len(guilds) == 1)
self.assertTrue(guilds[0].name == 'test1')
def test_get_multiple_guilds_by_name(self): def test_get_multiple_guilds_by_name(self):
new_guild = WatchGuild(1001, 'test1', None) new_guild = WatchGuild(1001, 'test1', None)
@ -40,17 +68,13 @@ class TestGuilds(unittest.TestCase):
guilds = self.db.get_guild(name='test1') guilds = self.db.get_guild(name='test1')
self.assertTrue(len(guilds) == 2) self.assertTrue(len(guilds) == 2)
def test_get_guild_with_nonsense_params(self):
guilds = self.db.get_guild(bleep='bloop')
self.assertTrue(len(guilds) == 0)
def test_new_guild_add(self): def test_new_guild_add(self):
with self.db.conn: with self.db.conn:
existing = self.db.conn.execute(self.get_guilds_stmt).fetchall() existing = self.db.get_guild()
guild = WatchGuild(4000, 'test4', guild = WatchGuild(4000, 'test4',
datetime.now(tz=zoneinfo.ZoneInfo('UTC'))) datetime.now(tz=zoneinfo.ZoneInfo('UTC')))
self.db.create_guild(guild) self.db.create_guild(guild)
new = self.db.conn.execute(self.get_guilds_stmt).fetchall() new = self.db.get_guild()
self.assertTrue(len(new) == len(existing) + 1) self.assertTrue(len(new) == len(existing) + 1)
def tearDown(self): def tearDown(self):
@ -58,6 +82,25 @@ class TestGuilds(unittest.TestCase):
os.remove(self.dbname) os.remove(self.dbname)
class TestChannel(DbTest):
def test_get_channel(self):
channels = self.db.get_channel()
self.assertEqual(len(channels), len(self.testchannels), 'not equal')
def test_get_channel_by_id(self):
channels = self.db.get_channel(id=self.testchannels[0].id)
self.assertTrue(len(channels) == 1)
self.assertTrue(channels[0].id == self.testchannels[0].id)
class TestUser(unittest.TestCase):
pass
class TestMessage(unittest.TestCase):
pass
class TestDatabase(unittest.TestCase): class TestDatabase(unittest.TestCase):
dbname = 'test.db' dbname = 'test.db'
get_guilds_stmt = 'SELECT guild_id, name FROM guilds' get_guilds_stmt = 'SELECT guild_id, name FROM guilds'
@ -72,6 +115,16 @@ class TestDatabase(unittest.TestCase):
cur = db.conn.execute(self.get_guilds_stmt) cur = db.conn.execute(self.get_guilds_stmt)
self.assertTrue(len(cur.fetchall()) == 0) self.assertTrue(len(cur.fetchall()) == 0)
def test_build_select_query(self):
expected_result = '''SELECT guild_id, name, join_date
FROM guilds
WHERE guild_id=1000, name=test1, nonsense=yes'''
db = database.DatabaseManager(self.dbname)
result, values = db._build_select_query(WatchGuild, guild_id='1000',
name='test1', nonsense='yes')
self.assertTrue(expected_result, result)
self.assertTrue(len(values), 3)
def tearDown(self): def tearDown(self):
os.remove(self.dbname) os.remove(self.dbname)