tests for guild create/retrieve
This commit is contained in:
parent
ae1a56342c
commit
aa5c8a0bf4
56
bot/tests.py
56
bot/tests.py
|
@ -1,23 +1,75 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
import zoneinfo
|
||||||
|
from datetime import datetime
|
||||||
import os
|
import os
|
||||||
import database
|
import database
|
||||||
|
from model import WatchGuild, WatchChannel, WatchUser, WatchMessage
|
||||||
|
|
||||||
|
|
||||||
class TestMessageParsing(unittest.TestCase):
|
class TestMessageParsing(unittest.TestCase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestGuilds(unittest.TestCase):
|
||||||
|
dbname = 'test.db'
|
||||||
|
get_guilds_stmt = 'SELECT guild_id, name FROM guilds'
|
||||||
|
testguilds = [WatchGuild(1000, 'test1', None),
|
||||||
|
WatchGuild(2000, 'test2', None),
|
||||||
|
WatchGuild(3000, 'test3', None)]
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.db = database.DatabaseManager(self.dbname)
|
||||||
|
for g in self.testguilds:
|
||||||
|
self.db.create_guild(g)
|
||||||
|
|
||||||
|
def test_get_all_guilds(self):
|
||||||
|
guilds = self.db.get_guild()
|
||||||
|
self.assertTrue(len(guilds) == len(self.testguilds))
|
||||||
|
|
||||||
|
def test_get_guild_by_id(self):
|
||||||
|
guilds = self.db.get_guild(id=2000)
|
||||||
|
self.assertTrue(len(guilds) == 1)
|
||||||
|
|
||||||
|
def test_get_guild_by_name(self):
|
||||||
|
guilds = self.db.get_guild(name='test1')
|
||||||
|
self.assertTrue(len(guilds) == 1)
|
||||||
|
|
||||||
|
def test_get_multiple_guilds_by_name(self):
|
||||||
|
new_guild = WatchGuild(1001, 'test1', None)
|
||||||
|
self.db.create_guild(new_guild)
|
||||||
|
guilds = self.db.get_guild(name='test1')
|
||||||
|
self.assertTrue(len(guilds) == 2)
|
||||||
|
|
||||||
|
def test_get_guild_with_nonsense_params(self):
|
||||||
|
guilds = self.db.get_guild(bleep='bloop')
|
||||||
|
self.assertTrue(len(guilds) == 0)
|
||||||
|
|
||||||
|
def test_new_guild_add(self):
|
||||||
|
with self.db.conn:
|
||||||
|
existing = self.db.conn.execute(self.get_guilds_stmt).fetchall()
|
||||||
|
guild = WatchGuild(4000, 'test4',
|
||||||
|
datetime.now(tz=zoneinfo.ZoneInfo('UTC')))
|
||||||
|
self.db.create_guild(guild)
|
||||||
|
new = self.db.conn.execute(self.get_guilds_stmt).fetchall()
|
||||||
|
self.assertTrue(len(new) == len(existing) + 1)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
if os.path.exists(self.dbname):
|
||||||
|
os.remove(self.dbname)
|
||||||
|
|
||||||
|
|
||||||
class TestDatabase(unittest.TestCase):
|
class TestDatabase(unittest.TestCase):
|
||||||
dbname = 'test.db'
|
dbname = 'test.db'
|
||||||
|
get_guilds_stmt = 'SELECT guild_id, name FROM guilds'
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
if os.path.exists(self.dbname):
|
if os.path.exists(self.dbname):
|
||||||
os.remove(self.dbname)
|
os.remove(self.dbname)
|
||||||
|
|
||||||
def test_database_init(self):
|
def test_database_init(self):
|
||||||
db = database.DatabaseManager('test.db')
|
db = database.DatabaseManager(self.dbname)
|
||||||
with db.conn:
|
with db.conn:
|
||||||
cur = db.conn.execute('SELECT id, name FROM guilds;')
|
cur = db.conn.execute(self.get_guilds_stmt)
|
||||||
self.assertTrue(len(cur.fetchall()) == 0)
|
self.assertTrue(len(cur.fetchall()) == 0)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
|
|
Loading…
Reference in New Issue