Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 120 additions & 36 deletions cogs/counting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,47 @@ def __init__(self, bot):
self.bot = bot
# Cache for counting channels: guild_id -> channel_id
self.counting_channels = {}
# Protect against occasional duplicate MESSAGE_CREATE dispatches or accidental double-processing.
# Key: message_id, Value: monotonic timestamp
self._recent_message_ids: dict[int, float] = {}
# Throttle reaction API calls to avoid Discord rate limits in fast counting channels.
self._reaction_queue: asyncio.Queue[tuple[discord.Message, str]] = asyncio.Queue()
self._pending_reactions: set[tuple[int, str]] = set()
self._reaction_worker_task: Optional[asyncio.Task[None]] = None

async def cog_unload(self) -> None:
if self._reaction_worker_task and not self._reaction_worker_task.done():
self._reaction_worker_task.cancel()

async def _reaction_worker(self) -> None:
# A small delay between reaction requests keeps us under the common reaction route limits.
# Reactions may appear slightly delayed, but they will still be added.
while True:
message, emoji = await self._reaction_queue.get()
try:
try:
await message.add_reaction(emoji)
except Exception:
pass
await asyncio.sleep(0.35)
finally:
self._pending_reactions.discard((message.id, emoji))
self._reaction_queue.task_done()

def _enqueue_reaction(self, message: discord.Message, emoji: str) -> None:
key = (message.id, emoji)
if key in self._pending_reactions:
return
self._pending_reactions.add(key)
try:
self._reaction_queue.put_nowait((message, emoji))
except Exception:
self._pending_reactions.discard(key)

async def cog_load(self):
"""Load counting channels into memory on startup"""
if self._reaction_worker_task is None or self._reaction_worker_task.done():
self._reaction_worker_task = asyncio.create_task(self._reaction_worker())
try:
async with aiosqlite.connect(DB_PATH, timeout=30.0) as db:
# Ensure auxiliary tables exist
Expand Down Expand Up @@ -140,27 +178,23 @@ async def _mark_highscore_message(
new_count: int,
previous_high_score: int,
) -> None:
"""Add ✅+🏆 to the message and keep only one active highscore marker per guild."""
"""Add ✅+🏆 to the message.

Note: Reactions, once added by the bot, should never be removed.
"""
if not message.guild or not isinstance(message.channel, discord.TextChannel):
return

guild_id = message.guild.id
channel = message.channel

previous_marker = await self._get_active_highscore_message_id(guild_id)
if previous_marker and previous_marker != message.id:
await self._remove_bot_reactions(channel, previous_marker)

# Ensure reactions exist (✅ may already be there)
try:
await message.add_reaction("✅")
except Exception:
pass
try:
await message.add_reaction("🏆")
except Exception:
pass
# Only add the trophy here.
# The ✅ reaction is added for all valid counts in the main handler;
# adding it again here causes extra API calls and rate limits.
self._enqueue_reaction(message, "🏆")

# Track the latest highscore/tie message ID for bookkeeping.
# (We no longer remove reactions from older messages.)
await self._set_active_highscore_message_id(guild_id, message.id)

# Record history only if it is a NEW record
Expand All @@ -178,24 +212,48 @@ async def _mark_highscore_message(
async def _clear_highscore_marker_if_any(self, guild_id: int, channel: discord.TextChannel) -> None:
marker_id = await self._get_active_highscore_message_id(guild_id)
if marker_id:
await self._remove_bot_reactions(channel, marker_id)
await self._set_active_highscore_message_id(guild_id, None)

@app_commands.command(name="setcountingchannel", description="Set the channel for the counting game")
@app_commands.checks.has_permissions(administrator=True)
async def setcountingchannel(self, interaction: discord.Interaction, channel: discord.TextChannel):
async with aiosqlite.connect(DB_PATH, timeout=30.0) as db:
await db.execute("""
INSERT INTO counting_config (guild_id, channel_id)
VALUES (?, ?)
ON CONFLICT(guild_id) DO UPDATE SET channel_id = excluded.channel_id
""", (interaction.guild_id, channel.id))
await db.commit()

# Slash command interactions must be acknowledged quickly.
# DB operations can take >3s (locks, slow disks), so defer immediately.
if interaction.response.is_done():
# Extremely defensive; normally false here.
pass
else:
await interaction.response.defer(ephemeral=True)

if interaction.guild_id is None:
await interaction.followup.send("This command can only be used in a server.", ephemeral=True)
return

retries = 3
while retries > 0:
try:
async with aiosqlite.connect(DB_PATH, timeout=30.0) as db:
await db.execute(
"""
INSERT INTO counting_config (guild_id, channel_id)
VALUES (?, ?)
ON CONFLICT(guild_id) DO UPDATE SET channel_id = excluded.channel_id
""",
(interaction.guild_id, channel.id),
)
await db.commit()
break
except aiosqlite.OperationalError as e:
if "locked" in str(e).lower():
retries -= 1
await asyncio.sleep(0.5)
continue
raise

# Update cache
self.counting_channels[interaction.guild_id] = channel.id
await interaction.response.send_message(f"Counting channel set to {channel.mention}", ephemeral=True)

await interaction.followup.send(f"Counting channel set to {channel.mention}", ephemeral=True)

def safe_eval(self, expr):
operators = {
Expand Down Expand Up @@ -293,6 +351,17 @@ async def on_message(self, message):
if message.channel.id != self.counting_channels[message.guild.id]:
return

# Deduplicate processing of the same message ID within this process.
# This prevents duplicate warnings/messages if Discord or the bot dispatches the event twice.
now = time.monotonic()
last_seen = self._recent_message_ids.get(message.id)
if last_seen is not None and (now - last_seen) < 30:
return
self._recent_message_ids[message.id] = now
if len(self._recent_message_ids) > 5000:
cutoff = now - 120
self._recent_message_ids = {mid: ts for mid, ts in self._recent_message_ids.items() if ts >= cutoff}

# 2. Process the message logic
# Wrap DB operations in retry loop for robustness
retries = 3
Expand Down Expand Up @@ -327,19 +396,29 @@ async def on_message(self, message):

if message.author.id == last_user_id:
# Warn instead of instant ruin. 3 warnings ruins the count.
warnings = await self._get_warning_count(message.guild.id, message.author.id)
warnings += 1
await self._set_warning_count(message.guild.id, message.author.id, warnings)

try:
await message.add_reaction("⚠️")
except Exception:
pass
# Use an atomic increment in the SAME connection to avoid races and DB-lock retries.
await db.execute(
"""
INSERT INTO counting_warnings (guild_id, user_id, warnings)
VALUES (?, ?, 1)
ON CONFLICT(guild_id, user_id) DO UPDATE SET warnings = warnings + 1
""",
(message.guild.id, message.author.id),
)
async with db.execute(
"SELECT warnings FROM counting_warnings WHERE guild_id = ? AND user_id = ?",
(message.guild.id, message.author.id),
) as cursor:
row = await cursor.fetchone()
warnings = int(row[0]) if row else 1
await db.commit()

if warnings >= 3:
await self.fail_count(message, current_count, "Too many warnings (counted twice in a row 3 times)!")
return

self._enqueue_reaction(message, "⚠️")

await message.channel.send(
f"You can't count twice in a row, {message.author.mention}. "
f"You have **{warnings}/3** warnings.",
Expand All @@ -348,7 +427,6 @@ async def on_message(self, message):
return

# Valid count - Update DB
await message.add_reaction("✅")
new_high_score = max(high_score, next_count)

# Update configuration tables
Expand All @@ -364,11 +442,17 @@ async def on_message(self, message):
VALUES (?, ?, 1, 0)
ON CONFLICT(user_id, guild_id) DO UPDATE SET total_counts = total_counts + 1
""", (message.author.id, message.guild.id))

# Reset warnings for this user on a valid count (in the same transaction).
await db.execute(
"DELETE FROM counting_warnings WHERE guild_id = ? AND user_id = ?",
(message.guild.id, message.author.id),
)

await db.commit()

# Reset warnings for this user on a valid count
await self._set_warning_count(message.guild.id, message.author.id, 0)
# Side effects after commit to avoid duplicate reactions on retries.
self._enqueue_reaction(message, "✅")

# Highscore marker: react ✅+🏆 when reaching/topping the record
if next_count >= high_score:
Expand Down
44 changes: 43 additions & 1 deletion utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import os
from typing import Any, Optional

from pydantic import Field, field_validator, model_validator
Expand Down Expand Up @@ -31,6 +32,42 @@ class Config(BaseSettings):
extra='ignore' # Ignore extra fields from .env
)

@field_validator('guild_id', mode='before')
@classmethod
def _parse_guild_id(cls, v: Any) -> Optional[int]:
"""Parse legacy single guild id.

Some deployments mistakenly set `GUILD_ID` as a CSV / JSON list.
To stay backwards compatible and avoid startup failures, we accept:
- int
- numeric string
- CSV: "1,2,3" (uses the first value)
- JSON list: "[1,2]" (uses the first value)
"""
if v is None:
return None
if isinstance(v, int):
return v
if isinstance(v, str):
s = v.strip()
if not s:
return None
if s.startswith('[') and s.endswith(']'):
try:
import json

parsed = json.loads(s)
if isinstance(parsed, list) and parsed:
return int(parsed[0])
except Exception:
# Fall back to CSV/int parsing
pass
if ',' in s:
first = s.split(',', 1)[0].strip()
return int(first) if first else None
return int(s)
return int(v)

@field_validator('guild_ids', mode='before')
@classmethod
def _parse_guild_ids(cls, v: Any) -> list[int]:
Expand All @@ -42,7 +79,12 @@ def _parse_guild_ids(cls, v: Any) -> list[int]:
- CSV: "1,2,3"
"""
if v is None:
return []
# If GUILD_IDS isn't set, allow legacy GUILD_ID to behave like a list.
legacy = os.getenv('GUILD_ID')
if legacy and str(legacy).strip():
v = legacy
else:
return []
if isinstance(v, list):
return [int(x) for x in v if str(x).strip()]
if isinstance(v, (int, str)):
Expand Down
Loading