From d27f3cad8aa52aedf6fd476adedf6abb2aac29e5 Mon Sep 17 00:00:00 2001 From: yequari Date: Sat, 15 Apr 2023 13:54:16 -0700 Subject: [PATCH] make module, add test file, skeleton code --- README.md | 9 +++++++++ bot/__init__.py | 0 bot/bot.py | 50 ++++++++++++++++++++++++++++++++++++++++++++++ bot/database.py | 52 ++++++++++++++++++++++++++++++++++++++++++++++++ bot/messages.py | 2 ++ bot/tests.py | 28 ++++++++++++++++++++++++++ requirements.txt | 10 ++++++++++ 7 files changed, 151 insertions(+) create mode 100644 bot/__init__.py create mode 100644 bot/bot.py create mode 100644 bot/database.py create mode 100644 bot/messages.py create mode 100644 bot/tests.py create mode 100644 requirements.txt diff --git a/README.md b/README.md index 0eab738..493dc82 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,11 @@ # announcement-bot +## Features +- Watches specific channels for new messages and new pins +- Writes out messages in various formats (HTML, XML, etc) + +## Channel Registration +Once the bot is added to a server, register channels to be watched with /register command + +## Export +- Preserves Discord formatting diff --git a/bot/__init__.py b/bot/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bot/bot.py b/bot/bot.py new file mode 100644 index 0000000..b512060 --- /dev/null +++ b/bot/bot.py @@ -0,0 +1,50 @@ +import os +import discord +import database + + +def read_env(file): + for line in file: + splitline = line.split('=', maxsplit=1) + if len(splitline) == 2: + os.environ[splitline[0]] = splitline[1] + + +class AnnouncementWatcherClient(discord.Client): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.db = database.DatabaseManager('announcement.db') + + async def on_guild_join(self, guild): + self.db.create_guild() + + # TODO: on_guild_leave() delete all data? + + async def on_message(self, message: discord.Message): + if message.author == self.user: + return + + if message.content.startswith('$hello'): + await message.channel.send('Hello!') + + async def on_guild_channel_pins_update(self, channel: discord.GuildChannel, last_pin): + pass + + def register_channel(self): + pass + + def unregister_channel(self): + pass + + +if __name__ == "__main__": + with open('.env') as f: + read_env(f) + discord_token = os.environ.get('DISCORD_TOKEN') + if discord_token is None: + print('Could not read token') + exit(1) + intents = discord.Intents.default() + intents.message_content = True + client = AnnouncementWatcherClient(intents=intents) + client.run(discord_token) diff --git a/bot/database.py b/bot/database.py new file mode 100644 index 0000000..75f683a --- /dev/null +++ b/bot/database.py @@ -0,0 +1,52 @@ +import sqlite3 +import discord + + +class DatabaseManager: + def __init__(self, name: str): + self.conn = sqlite3.connect(name) + self._setup_database() + + def _setup_database(self): + with self.conn: + cur = self.conn.executescript(''' + BEGIN; + CREATE TABLE IF NOT EXISTS guilds( + id INTEGER PRIMARY KEY, + name TEXT); + CREATE TABLE IF NOT EXISTS watched_channels( + id INTEGER PRIMARY KEY, + name TEXT); + COMMIT; + ''') + + def create_guild(self, guild_id: int, guild_name: str): + with self.conn: + self.conn.execute('''INSERT INTO guilds + VALUES(?, ?);''', + (guild_id, guild_name)) + + def get_all_guilds(self): + with self.conn: + cur = self.conn.execute('''SELECT id, name FROM guilds;''') + return cur.fetchall() + + def get_all_watched_channels(self): + with self.conn: + curs = self.conn.execute('''SELECT id, name + FROM watched_channels;''') + return curs.fetchall() + + def get_guild_watched_channels(self, guild_id: int): + with self.conn: + curs = self.conn.execute('''SELECT id, name + FROM watched_channels + WHERE id=?;''', + (guild_id,)) + return curs.fetchall() + + def write_message(self, message: discord.Message): + pass + + def get_messages(self, channel_id: int): + pass diff --git a/bot/messages.py b/bot/messages.py new file mode 100644 index 0000000..139597f --- /dev/null +++ b/bot/messages.py @@ -0,0 +1,2 @@ + + diff --git a/bot/tests.py b/bot/tests.py new file mode 100644 index 0000000..5d99068 --- /dev/null +++ b/bot/tests.py @@ -0,0 +1,28 @@ +import unittest +import os +import database + + +class TestMessageParsing(unittest.TestCase): + pass + + +class TestDatabase(unittest.TestCase): + dbname = 'test.db' + + def setUp(self): + if os.path.exists(self.dbname): + os.remove(self.dbname) + + def test_database_init(self): + db = database.DatabaseManager('test.db') + with db.conn: + cur = db.conn.execute('SELECT id, name FROM guilds;') + self.assertTrue(len(cur.fetchall()) == 0) + + def tearDown(self): + os.remove(self.dbname) + + +if __name__ == "__main__": + unittest.main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9a9a238 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +aiohttp==3.8.4 +aiosignal==1.3.1 +async-timeout==4.0.2 +attrs==22.2.0 +charset-normalizer==3.1.0 +discord.py==2.2.2 +frozenlist==1.3.3 +idna==3.4 +multidict==6.0.4 +yarl==1.8.2