make module, add test file, skeleton code
This commit is contained in:
parent
08e82b49da
commit
d27f3cad8a
|
@ -1,2 +1,11 @@
|
|||
# announcement-bot
|
||||
|
||||
## Features
|
||||
- Watches specific channels for new messages and new pins
|
||||
- Writes out messages in various formats (HTML, XML, etc)
|
||||
|
||||
## Channel Registration
|
||||
Once the bot is added to a server, register channels to be watched with /register command
|
||||
|
||||
## Export
|
||||
- Preserves Discord formatting
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
import os
|
||||
import discord
|
||||
import database
|
||||
|
||||
|
||||
def read_env(file):
|
||||
for line in file:
|
||||
splitline = line.split('=', maxsplit=1)
|
||||
if len(splitline) == 2:
|
||||
os.environ[splitline[0]] = splitline[1]
|
||||
|
||||
|
||||
class AnnouncementWatcherClient(discord.Client):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.db = database.DatabaseManager('announcement.db')
|
||||
|
||||
async def on_guild_join(self, guild):
|
||||
self.db.create_guild()
|
||||
|
||||
# TODO: on_guild_leave() delete all data?
|
||||
|
||||
async def on_message(self, message: discord.Message):
|
||||
if message.author == self.user:
|
||||
return
|
||||
|
||||
if message.content.startswith('$hello'):
|
||||
await message.channel.send('Hello!')
|
||||
|
||||
async def on_guild_channel_pins_update(self, channel: discord.GuildChannel, last_pin):
|
||||
pass
|
||||
|
||||
def register_channel(self):
|
||||
pass
|
||||
|
||||
def unregister_channel(self):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open('.env') as f:
|
||||
read_env(f)
|
||||
discord_token = os.environ.get('DISCORD_TOKEN')
|
||||
if discord_token is None:
|
||||
print('Could not read token')
|
||||
exit(1)
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
client = AnnouncementWatcherClient(intents=intents)
|
||||
client.run(discord_token)
|
|
@ -0,0 +1,52 @@
|
|||
import sqlite3
|
||||
import discord
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
def __init__(self, name: str):
|
||||
self.conn = sqlite3.connect(name)
|
||||
self._setup_database()
|
||||
|
||||
def _setup_database(self):
|
||||
with self.conn:
|
||||
cur = self.conn.executescript('''
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS guilds(
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT);
|
||||
CREATE TABLE IF NOT EXISTS watched_channels(
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT);
|
||||
COMMIT;
|
||||
''')
|
||||
|
||||
def create_guild(self, guild_id: int, guild_name: str):
|
||||
with self.conn:
|
||||
self.conn.execute('''INSERT INTO guilds
|
||||
VALUES(?, ?);''',
|
||||
(guild_id, guild_name))
|
||||
|
||||
def get_all_guilds(self):
|
||||
with self.conn:
|
||||
cur = self.conn.execute('''SELECT id, name FROM guilds;''')
|
||||
return cur.fetchall()
|
||||
|
||||
def get_all_watched_channels(self):
|
||||
with self.conn:
|
||||
curs = self.conn.execute('''SELECT id, name
|
||||
FROM watched_channels;''')
|
||||
return curs.fetchall()
|
||||
|
||||
def get_guild_watched_channels(self, guild_id: int):
|
||||
with self.conn:
|
||||
curs = self.conn.execute('''SELECT id, name
|
||||
FROM watched_channels
|
||||
WHERE id=?;''',
|
||||
(guild_id,))
|
||||
return curs.fetchall()
|
||||
|
||||
def write_message(self, message: discord.Message):
|
||||
pass
|
||||
|
||||
def get_messages(self, channel_id: int):
|
||||
pass
|
|
@ -0,0 +1,2 @@
|
|||
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
import unittest
|
||||
import os
|
||||
import database
|
||||
|
||||
|
||||
class TestMessageParsing(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
class TestDatabase(unittest.TestCase):
|
||||
dbname = 'test.db'
|
||||
|
||||
def setUp(self):
|
||||
if os.path.exists(self.dbname):
|
||||
os.remove(self.dbname)
|
||||
|
||||
def test_database_init(self):
|
||||
db = database.DatabaseManager('test.db')
|
||||
with db.conn:
|
||||
cur = db.conn.execute('SELECT id, name FROM guilds;')
|
||||
self.assertTrue(len(cur.fetchall()) == 0)
|
||||
|
||||
def tearDown(self):
|
||||
os.remove(self.dbname)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,10 @@
|
|||
aiohttp==3.8.4
|
||||
aiosignal==1.3.1
|
||||
async-timeout==4.0.2
|
||||
attrs==22.2.0
|
||||
charset-normalizer==3.1.0
|
||||
discord.py==2.2.2
|
||||
frozenlist==1.3.3
|
||||
idna==3.4
|
||||
multidict==6.0.4
|
||||
yarl==1.8.2
|
Loading…
Reference in New Issue