make module, add test file, skeleton code
This commit is contained in:
parent
08e82b49da
commit
d27f3cad8a
|
@ -1,2 +1,11 @@
|
||||||
# announcement-bot
|
# announcement-bot
|
||||||
|
|
||||||
|
## Features
|
||||||
|
- Watches specific channels for new messages and new pins
|
||||||
|
- Writes out messages in various formats (HTML, XML, etc)
|
||||||
|
|
||||||
|
## Channel Registration
|
||||||
|
Once the bot is added to a server, register channels to be watched with /register command
|
||||||
|
|
||||||
|
## Export
|
||||||
|
- Preserves Discord formatting
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
import os
|
||||||
|
import discord
|
||||||
|
import database
|
||||||
|
|
||||||
|
|
||||||
|
def read_env(file):
|
||||||
|
for line in file:
|
||||||
|
splitline = line.split('=', maxsplit=1)
|
||||||
|
if len(splitline) == 2:
|
||||||
|
os.environ[splitline[0]] = splitline[1]
|
||||||
|
|
||||||
|
|
||||||
|
class AnnouncementWatcherClient(discord.Client):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.db = database.DatabaseManager('announcement.db')
|
||||||
|
|
||||||
|
async def on_guild_join(self, guild):
|
||||||
|
self.db.create_guild()
|
||||||
|
|
||||||
|
# TODO: on_guild_leave() delete all data?
|
||||||
|
|
||||||
|
async def on_message(self, message: discord.Message):
|
||||||
|
if message.author == self.user:
|
||||||
|
return
|
||||||
|
|
||||||
|
if message.content.startswith('$hello'):
|
||||||
|
await message.channel.send('Hello!')
|
||||||
|
|
||||||
|
async def on_guild_channel_pins_update(self, channel: discord.GuildChannel, last_pin):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def register_channel(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def unregister_channel(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
with open('.env') as f:
|
||||||
|
read_env(f)
|
||||||
|
discord_token = os.environ.get('DISCORD_TOKEN')
|
||||||
|
if discord_token is None:
|
||||||
|
print('Could not read token')
|
||||||
|
exit(1)
|
||||||
|
intents = discord.Intents.default()
|
||||||
|
intents.message_content = True
|
||||||
|
client = AnnouncementWatcherClient(intents=intents)
|
||||||
|
client.run(discord_token)
|
|
@ -0,0 +1,52 @@
|
||||||
|
import sqlite3
|
||||||
|
import discord
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseManager:
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self.conn = sqlite3.connect(name)
|
||||||
|
self._setup_database()
|
||||||
|
|
||||||
|
def _setup_database(self):
|
||||||
|
with self.conn:
|
||||||
|
cur = self.conn.executescript('''
|
||||||
|
BEGIN;
|
||||||
|
CREATE TABLE IF NOT EXISTS guilds(
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
name TEXT);
|
||||||
|
CREATE TABLE IF NOT EXISTS watched_channels(
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
name TEXT);
|
||||||
|
COMMIT;
|
||||||
|
''')
|
||||||
|
|
||||||
|
def create_guild(self, guild_id: int, guild_name: str):
|
||||||
|
with self.conn:
|
||||||
|
self.conn.execute('''INSERT INTO guilds
|
||||||
|
VALUES(?, ?);''',
|
||||||
|
(guild_id, guild_name))
|
||||||
|
|
||||||
|
def get_all_guilds(self):
|
||||||
|
with self.conn:
|
||||||
|
cur = self.conn.execute('''SELECT id, name FROM guilds;''')
|
||||||
|
return cur.fetchall()
|
||||||
|
|
||||||
|
def get_all_watched_channels(self):
|
||||||
|
with self.conn:
|
||||||
|
curs = self.conn.execute('''SELECT id, name
|
||||||
|
FROM watched_channels;''')
|
||||||
|
return curs.fetchall()
|
||||||
|
|
||||||
|
def get_guild_watched_channels(self, guild_id: int):
|
||||||
|
with self.conn:
|
||||||
|
curs = self.conn.execute('''SELECT id, name
|
||||||
|
FROM watched_channels
|
||||||
|
WHERE id=?;''',
|
||||||
|
(guild_id,))
|
||||||
|
return curs.fetchall()
|
||||||
|
|
||||||
|
def write_message(self, message: discord.Message):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_messages(self, channel_id: int):
|
||||||
|
pass
|
|
@ -0,0 +1,2 @@
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
import unittest
|
||||||
|
import os
|
||||||
|
import database
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageParsing(unittest.TestCase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatabase(unittest.TestCase):
|
||||||
|
dbname = 'test.db'
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
if os.path.exists(self.dbname):
|
||||||
|
os.remove(self.dbname)
|
||||||
|
|
||||||
|
def test_database_init(self):
|
||||||
|
db = database.DatabaseManager('test.db')
|
||||||
|
with db.conn:
|
||||||
|
cur = db.conn.execute('SELECT id, name FROM guilds;')
|
||||||
|
self.assertTrue(len(cur.fetchall()) == 0)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
os.remove(self.dbname)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
|
@ -0,0 +1,10 @@
|
||||||
|
aiohttp==3.8.4
|
||||||
|
aiosignal==1.3.1
|
||||||
|
async-timeout==4.0.2
|
||||||
|
attrs==22.2.0
|
||||||
|
charset-normalizer==3.1.0
|
||||||
|
discord.py==2.2.2
|
||||||
|
frozenlist==1.3.3
|
||||||
|
idna==3.4
|
||||||
|
multidict==6.0.4
|
||||||
|
yarl==1.8.2
|
Loading…
Reference in New Issue