generate SELECT statements, tests
This commit is contained in:
parent
aa5c8a0bf4
commit
6640e21df3
|
@ -55,13 +55,23 @@ class DatabaseManager:
|
||||||
with self.conn:
|
with self.conn:
|
||||||
self.conn.executescript(script)
|
self.conn.executescript(script)
|
||||||
|
|
||||||
|
def _build_select_query(self, cls, **kwargs) -> (str, tuple):
|
||||||
|
wildcards = ['?=?' * len(kwargs)]
|
||||||
|
query = f'''SELECT {','.join(cls.fields())} FROM {cls.table}
|
||||||
|
WHERE {','.join(wildcards)}'''
|
||||||
|
constraints = []
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
constraints.append(k)
|
||||||
|
constraints.append(v)
|
||||||
|
return query, tuple(constraints)
|
||||||
|
|
||||||
def create_guild(self, guild: WatchGuild):
|
def create_guild(self, guild: WatchGuild):
|
||||||
'''Insert a new guild into the database'''
|
'''Insert a new guild into the database'''
|
||||||
query = '''INSERT INTO guilds VALUES(?, ?, datetime('now'));'''
|
query = '''INSERT INTO guilds VALUES(?, ?, datetime('now'));'''
|
||||||
with self.conn:
|
with self.conn:
|
||||||
self.conn.execute(query, (guild.id, guild.name))
|
self.conn.execute(query, (guild.id, guild.name))
|
||||||
|
|
||||||
def get_guild(self, id: int = None, name: str = None) -> list[WatchGuild]:
|
def get_guild(self, **kwargs) -> list[WatchGuild]:
|
||||||
'''Access guilds stored in the Database.
|
'''Access guilds stored in the Database.
|
||||||
|
|
||||||
Query will be filtered by the passed parameters,
|
Query will be filtered by the passed parameters,
|
||||||
|
@ -73,7 +83,11 @@ class DatabaseManager:
|
||||||
return cur.fetchall()
|
return cur.fetchall()
|
||||||
|
|
||||||
def create_channel(self, channel: WatchChannel):
|
def create_channel(self, channel: WatchChannel):
|
||||||
pass
|
'''Insert a new channel into the database'''
|
||||||
|
query = '''INSERT INTO channels VALUES(?, ?, ?, ?);'''
|
||||||
|
with self.conn:
|
||||||
|
self.conn.execute(query, (channel.id, channel.name,
|
||||||
|
channel.register_date, channel.guild.id))
|
||||||
|
|
||||||
def get_channel(self, **kwargs):
|
def get_channel(self, **kwargs):
|
||||||
query = '''SELECT channel_id, name,
|
query = '''SELECT channel_id, name,
|
||||||
|
|
38
bot/model.py
38
bot/model.py
|
@ -2,8 +2,17 @@ from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class WatchObject:
|
||||||
|
table = ''
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fields(cls):
|
||||||
|
'''Returns the field names in the database'''
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WatchGuild:
|
class WatchGuild(WatchObject):
|
||||||
'''WatchGuild represents a Discord guild as stored in the database.
|
'''WatchGuild represents a Discord guild as stored in the database.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
@ -13,10 +22,15 @@ class WatchGuild:
|
||||||
id: int
|
id: int
|
||||||
name: str
|
name: str
|
||||||
join_date: datetime
|
join_date: datetime
|
||||||
|
table: str = 'guilds'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fields(cls):
|
||||||
|
return ['guild_id', 'name', 'join_date']
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WatchChannel:
|
class WatchChannel(WatchObject):
|
||||||
'''WatchChannel represents a Discord channel being watched for new messages
|
'''WatchChannel represents a Discord channel being watched for new messages
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
@ -29,10 +43,15 @@ class WatchChannel:
|
||||||
name: str
|
name: str
|
||||||
register_date: datetime
|
register_date: datetime
|
||||||
guild: WatchGuild
|
guild: WatchGuild
|
||||||
|
table = 'channels'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fields(cls):
|
||||||
|
return ['channel_id', 'name', 'register_date', 'guild_id']
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WatchUser:
|
class WatchUser(WatchObject):
|
||||||
'''WatchUser represents a Discord user who has sent a message in a watched channel.
|
'''WatchUser represents a Discord user who has sent a message in a watched channel.
|
||||||
Attributes:
|
Attributes:
|
||||||
id: User ID as given by Discord
|
id: User ID as given by Discord
|
||||||
|
@ -40,10 +59,15 @@ class WatchUser:
|
||||||
'''
|
'''
|
||||||
id: int
|
id: int
|
||||||
name: str
|
name: str
|
||||||
|
table = 'users'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fields(cls):
|
||||||
|
return ['user_id', 'name']
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WatchMessage:
|
class WatchMessage(WatchObject):
|
||||||
'''WatchMessage represents a Discord message sent in a watched channel.
|
'''WatchMessage represents a Discord message sent in a watched channel.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
@ -60,3 +84,9 @@ class WatchMessage:
|
||||||
author: WatchUser
|
author: WatchUser
|
||||||
channel: WatchChannel
|
channel: WatchChannel
|
||||||
guild: WatchGuild
|
guild: WatchGuild
|
||||||
|
table = 'messages'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fields(cls):
|
||||||
|
return ['message_id', 'contents', 'published_date',
|
||||||
|
'user_id', 'channel_id', 'guild_id']
|
||||||
|
|
69
bot/tests.py
69
bot/tests.py
|
@ -10,9 +10,35 @@ class TestMessageParsing(unittest.TestCase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DbTest(unittest.TestCase):
|
||||||
|
dbname = 'test.db'
|
||||||
|
testguilds = [WatchGuild(1000, 'test1', None),
|
||||||
|
WatchGuild(2000, 'test2', None),
|
||||||
|
WatchGuild(3000, 'test3', None)]
|
||||||
|
testchannels = [WatchChannel(1001, 'channel1', datetime.now(), testguilds[0]),
|
||||||
|
WatchChannel(2001, 'channel2', datetime.now(), testguilds[1]),
|
||||||
|
WatchChannel(3001, 'channel3', datetime.now(), testguilds[2])]
|
||||||
|
testusers = []
|
||||||
|
testmessages = []
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.db = database.DatabaseManager(self.dbname)
|
||||||
|
for g in self.testguilds:
|
||||||
|
self.db.create_guild(g)
|
||||||
|
for c in self.testchannels:
|
||||||
|
self.db.create_channel(c)
|
||||||
|
for u in self.testusers:
|
||||||
|
pass
|
||||||
|
for m in self.testmessages:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
if os.path.exists(self.dbname):
|
||||||
|
os.remove(self.dbname)
|
||||||
|
|
||||||
|
|
||||||
class TestGuilds(unittest.TestCase):
|
class TestGuilds(unittest.TestCase):
|
||||||
dbname = 'test.db'
|
dbname = 'test.db'
|
||||||
get_guilds_stmt = 'SELECT guild_id, name FROM guilds'
|
|
||||||
testguilds = [WatchGuild(1000, 'test1', None),
|
testguilds = [WatchGuild(1000, 'test1', None),
|
||||||
WatchGuild(2000, 'test2', None),
|
WatchGuild(2000, 'test2', None),
|
||||||
WatchGuild(3000, 'test3', None)]
|
WatchGuild(3000, 'test3', None)]
|
||||||
|
@ -28,11 +54,13 @@ class TestGuilds(unittest.TestCase):
|
||||||
|
|
||||||
def test_get_guild_by_id(self):
|
def test_get_guild_by_id(self):
|
||||||
guilds = self.db.get_guild(id=2000)
|
guilds = self.db.get_guild(id=2000)
|
||||||
self.assertTrue(len(guilds) == 1)
|
self.assertTrue(len(guilds) == len(self.testguilds))
|
||||||
|
self.assertTrue(guilds[0].id == 2000)
|
||||||
|
|
||||||
def test_get_guild_by_name(self):
|
def test_get_guild_by_name(self):
|
||||||
guilds = self.db.get_guild(name='test1')
|
guilds = self.db.get_guild(name='test1')
|
||||||
self.assertTrue(len(guilds) == 1)
|
self.assertTrue(len(guilds) == 1)
|
||||||
|
self.assertTrue(guilds[0].name == 'test1')
|
||||||
|
|
||||||
def test_get_multiple_guilds_by_name(self):
|
def test_get_multiple_guilds_by_name(self):
|
||||||
new_guild = WatchGuild(1001, 'test1', None)
|
new_guild = WatchGuild(1001, 'test1', None)
|
||||||
|
@ -40,17 +68,13 @@ class TestGuilds(unittest.TestCase):
|
||||||
guilds = self.db.get_guild(name='test1')
|
guilds = self.db.get_guild(name='test1')
|
||||||
self.assertTrue(len(guilds) == 2)
|
self.assertTrue(len(guilds) == 2)
|
||||||
|
|
||||||
def test_get_guild_with_nonsense_params(self):
|
|
||||||
guilds = self.db.get_guild(bleep='bloop')
|
|
||||||
self.assertTrue(len(guilds) == 0)
|
|
||||||
|
|
||||||
def test_new_guild_add(self):
|
def test_new_guild_add(self):
|
||||||
with self.db.conn:
|
with self.db.conn:
|
||||||
existing = self.db.conn.execute(self.get_guilds_stmt).fetchall()
|
existing = self.db.get_guild()
|
||||||
guild = WatchGuild(4000, 'test4',
|
guild = WatchGuild(4000, 'test4',
|
||||||
datetime.now(tz=zoneinfo.ZoneInfo('UTC')))
|
datetime.now(tz=zoneinfo.ZoneInfo('UTC')))
|
||||||
self.db.create_guild(guild)
|
self.db.create_guild(guild)
|
||||||
new = self.db.conn.execute(self.get_guilds_stmt).fetchall()
|
new = self.db.get_guild()
|
||||||
self.assertTrue(len(new) == len(existing) + 1)
|
self.assertTrue(len(new) == len(existing) + 1)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
|
@ -58,6 +82,25 @@ class TestGuilds(unittest.TestCase):
|
||||||
os.remove(self.dbname)
|
os.remove(self.dbname)
|
||||||
|
|
||||||
|
|
||||||
|
class TestChannel(DbTest):
|
||||||
|
def test_get_channel(self):
|
||||||
|
channels = self.db.get_channel()
|
||||||
|
self.assertEqual(len(channels), len(self.testchannels), 'not equal')
|
||||||
|
|
||||||
|
def test_get_channel_by_id(self):
|
||||||
|
channels = self.db.get_channel(id=self.testchannels[0].id)
|
||||||
|
self.assertTrue(len(channels) == 1)
|
||||||
|
self.assertTrue(channels[0].id == self.testchannels[0].id)
|
||||||
|
|
||||||
|
|
||||||
|
class TestUser(unittest.TestCase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessage(unittest.TestCase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TestDatabase(unittest.TestCase):
|
class TestDatabase(unittest.TestCase):
|
||||||
dbname = 'test.db'
|
dbname = 'test.db'
|
||||||
get_guilds_stmt = 'SELECT guild_id, name FROM guilds'
|
get_guilds_stmt = 'SELECT guild_id, name FROM guilds'
|
||||||
|
@ -72,6 +115,16 @@ class TestDatabase(unittest.TestCase):
|
||||||
cur = db.conn.execute(self.get_guilds_stmt)
|
cur = db.conn.execute(self.get_guilds_stmt)
|
||||||
self.assertTrue(len(cur.fetchall()) == 0)
|
self.assertTrue(len(cur.fetchall()) == 0)
|
||||||
|
|
||||||
|
def test_build_select_query(self):
|
||||||
|
expected_result = '''SELECT guild_id, name, join_date
|
||||||
|
FROM guilds
|
||||||
|
WHERE guild_id=1000, name=test1, nonsense=yes'''
|
||||||
|
db = database.DatabaseManager(self.dbname)
|
||||||
|
result, values = db._build_select_query(WatchGuild, guild_id='1000',
|
||||||
|
name='test1', nonsense='yes')
|
||||||
|
self.assertTrue(expected_result, result)
|
||||||
|
self.assertTrue(len(values), 3)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
os.remove(self.dbname)
|
os.remove(self.dbname)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue