generate SELECT statements, tests
This commit is contained in:
parent
aa5c8a0bf4
commit
6640e21df3
|
@ -55,13 +55,23 @@ class DatabaseManager:
|
|||
with self.conn:
|
||||
self.conn.executescript(script)
|
||||
|
||||
def _build_select_query(self, cls, **kwargs) -> (str, tuple):
|
||||
wildcards = ['?=?' * len(kwargs)]
|
||||
query = f'''SELECT {','.join(cls.fields())} FROM {cls.table}
|
||||
WHERE {','.join(wildcards)}'''
|
||||
constraints = []
|
||||
for k, v in kwargs.items():
|
||||
constraints.append(k)
|
||||
constraints.append(v)
|
||||
return query, tuple(constraints)
|
||||
|
||||
def create_guild(self, guild: WatchGuild):
|
||||
'''Insert a new guild into the database'''
|
||||
query = '''INSERT INTO guilds VALUES(?, ?, datetime('now'));'''
|
||||
with self.conn:
|
||||
self.conn.execute(query, (guild.id, guild.name))
|
||||
|
||||
def get_guild(self, id: int = None, name: str = None) -> list[WatchGuild]:
|
||||
def get_guild(self, **kwargs) -> list[WatchGuild]:
|
||||
'''Access guilds stored in the Database.
|
||||
|
||||
Query will be filtered by the passed parameters,
|
||||
|
@ -73,7 +83,11 @@ class DatabaseManager:
|
|||
return cur.fetchall()
|
||||
|
||||
def create_channel(self, channel: WatchChannel):
|
||||
pass
|
||||
'''Insert a new channel into the database'''
|
||||
query = '''INSERT INTO channels VALUES(?, ?, ?, ?);'''
|
||||
with self.conn:
|
||||
self.conn.execute(query, (channel.id, channel.name,
|
||||
channel.register_date, channel.guild.id))
|
||||
|
||||
def get_channel(self, **kwargs):
|
||||
query = '''SELECT channel_id, name,
|
||||
|
|
38
bot/model.py
38
bot/model.py
|
@ -2,8 +2,17 @@ from dataclasses import dataclass
|
|||
from datetime import datetime
|
||||
|
||||
|
||||
class WatchObject:
|
||||
table = ''
|
||||
|
||||
@classmethod
|
||||
def fields(cls):
|
||||
'''Returns the field names in the database'''
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class WatchGuild:
|
||||
class WatchGuild(WatchObject):
|
||||
'''WatchGuild represents a Discord guild as stored in the database.
|
||||
|
||||
Attributes:
|
||||
|
@ -13,10 +22,15 @@ class WatchGuild:
|
|||
id: int
|
||||
name: str
|
||||
join_date: datetime
|
||||
table: str = 'guilds'
|
||||
|
||||
@classmethod
|
||||
def fields(cls):
|
||||
return ['guild_id', 'name', 'join_date']
|
||||
|
||||
|
||||
@dataclass
|
||||
class WatchChannel:
|
||||
class WatchChannel(WatchObject):
|
||||
'''WatchChannel represents a Discord channel being watched for new messages
|
||||
|
||||
Attributes:
|
||||
|
@ -29,10 +43,15 @@ class WatchChannel:
|
|||
name: str
|
||||
register_date: datetime
|
||||
guild: WatchGuild
|
||||
table = 'channels'
|
||||
|
||||
@classmethod
|
||||
def fields(cls):
|
||||
return ['channel_id', 'name', 'register_date', 'guild_id']
|
||||
|
||||
|
||||
@dataclass
|
||||
class WatchUser:
|
||||
class WatchUser(WatchObject):
|
||||
'''WatchUser represents a Discord user who has sent a message in a watched channel.
|
||||
Attributes:
|
||||
id: User ID as given by Discord
|
||||
|
@ -40,10 +59,15 @@ class WatchUser:
|
|||
'''
|
||||
id: int
|
||||
name: str
|
||||
table = 'users'
|
||||
|
||||
@classmethod
|
||||
def fields(cls):
|
||||
return ['user_id', 'name']
|
||||
|
||||
|
||||
@dataclass
|
||||
class WatchMessage:
|
||||
class WatchMessage(WatchObject):
|
||||
'''WatchMessage represents a Discord message sent in a watched channel.
|
||||
|
||||
Attributes:
|
||||
|
@ -60,3 +84,9 @@ class WatchMessage:
|
|||
author: WatchUser
|
||||
channel: WatchChannel
|
||||
guild: WatchGuild
|
||||
table = 'messages'
|
||||
|
||||
@classmethod
|
||||
def fields(cls):
|
||||
return ['message_id', 'contents', 'published_date',
|
||||
'user_id', 'channel_id', 'guild_id']
|
||||
|
|
69
bot/tests.py
69
bot/tests.py
|
@ -10,9 +10,35 @@ class TestMessageParsing(unittest.TestCase):
|
|||
pass
|
||||
|
||||
|
||||
class DbTest(unittest.TestCase):
|
||||
dbname = 'test.db'
|
||||
testguilds = [WatchGuild(1000, 'test1', None),
|
||||
WatchGuild(2000, 'test2', None),
|
||||
WatchGuild(3000, 'test3', None)]
|
||||
testchannels = [WatchChannel(1001, 'channel1', datetime.now(), testguilds[0]),
|
||||
WatchChannel(2001, 'channel2', datetime.now(), testguilds[1]),
|
||||
WatchChannel(3001, 'channel3', datetime.now(), testguilds[2])]
|
||||
testusers = []
|
||||
testmessages = []
|
||||
|
||||
def setUp(self):
|
||||
self.db = database.DatabaseManager(self.dbname)
|
||||
for g in self.testguilds:
|
||||
self.db.create_guild(g)
|
||||
for c in self.testchannels:
|
||||
self.db.create_channel(c)
|
||||
for u in self.testusers:
|
||||
pass
|
||||
for m in self.testmessages:
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
if os.path.exists(self.dbname):
|
||||
os.remove(self.dbname)
|
||||
|
||||
|
||||
class TestGuilds(unittest.TestCase):
|
||||
dbname = 'test.db'
|
||||
get_guilds_stmt = 'SELECT guild_id, name FROM guilds'
|
||||
testguilds = [WatchGuild(1000, 'test1', None),
|
||||
WatchGuild(2000, 'test2', None),
|
||||
WatchGuild(3000, 'test3', None)]
|
||||
|
@ -28,11 +54,13 @@ class TestGuilds(unittest.TestCase):
|
|||
|
||||
def test_get_guild_by_id(self):
|
||||
guilds = self.db.get_guild(id=2000)
|
||||
self.assertTrue(len(guilds) == 1)
|
||||
self.assertTrue(len(guilds) == len(self.testguilds))
|
||||
self.assertTrue(guilds[0].id == 2000)
|
||||
|
||||
def test_get_guild_by_name(self):
|
||||
guilds = self.db.get_guild(name='test1')
|
||||
self.assertTrue(len(guilds) == 1)
|
||||
self.assertTrue(guilds[0].name == 'test1')
|
||||
|
||||
def test_get_multiple_guilds_by_name(self):
|
||||
new_guild = WatchGuild(1001, 'test1', None)
|
||||
|
@ -40,17 +68,13 @@ class TestGuilds(unittest.TestCase):
|
|||
guilds = self.db.get_guild(name='test1')
|
||||
self.assertTrue(len(guilds) == 2)
|
||||
|
||||
def test_get_guild_with_nonsense_params(self):
|
||||
guilds = self.db.get_guild(bleep='bloop')
|
||||
self.assertTrue(len(guilds) == 0)
|
||||
|
||||
def test_new_guild_add(self):
|
||||
with self.db.conn:
|
||||
existing = self.db.conn.execute(self.get_guilds_stmt).fetchall()
|
||||
existing = self.db.get_guild()
|
||||
guild = WatchGuild(4000, 'test4',
|
||||
datetime.now(tz=zoneinfo.ZoneInfo('UTC')))
|
||||
self.db.create_guild(guild)
|
||||
new = self.db.conn.execute(self.get_guilds_stmt).fetchall()
|
||||
new = self.db.get_guild()
|
||||
self.assertTrue(len(new) == len(existing) + 1)
|
||||
|
||||
def tearDown(self):
|
||||
|
@ -58,6 +82,25 @@ class TestGuilds(unittest.TestCase):
|
|||
os.remove(self.dbname)
|
||||
|
||||
|
||||
class TestChannel(DbTest):
|
||||
def test_get_channel(self):
|
||||
channels = self.db.get_channel()
|
||||
self.assertEqual(len(channels), len(self.testchannels), 'not equal')
|
||||
|
||||
def test_get_channel_by_id(self):
|
||||
channels = self.db.get_channel(id=self.testchannels[0].id)
|
||||
self.assertTrue(len(channels) == 1)
|
||||
self.assertTrue(channels[0].id == self.testchannels[0].id)
|
||||
|
||||
|
||||
class TestUser(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
class TestMessage(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
class TestDatabase(unittest.TestCase):
|
||||
dbname = 'test.db'
|
||||
get_guilds_stmt = 'SELECT guild_id, name FROM guilds'
|
||||
|
@ -72,6 +115,16 @@ class TestDatabase(unittest.TestCase):
|
|||
cur = db.conn.execute(self.get_guilds_stmt)
|
||||
self.assertTrue(len(cur.fetchall()) == 0)
|
||||
|
||||
def test_build_select_query(self):
|
||||
expected_result = '''SELECT guild_id, name, join_date
|
||||
FROM guilds
|
||||
WHERE guild_id=1000, name=test1, nonsense=yes'''
|
||||
db = database.DatabaseManager(self.dbname)
|
||||
result, values = db._build_select_query(WatchGuild, guild_id='1000',
|
||||
name='test1', nonsense='yes')
|
||||
self.assertTrue(expected_result, result)
|
||||
self.assertTrue(len(values), 3)
|
||||
|
||||
def tearDown(self):
|
||||
os.remove(self.dbname)
|
||||
|
||||
|
|
Loading…
Reference in New Issue