2025-09-03 03:34:30 -07:00

249 lines
8.9 KiB
Python

from __future__ import annotations
from typing import Any, Dict, Iterator, Optional
from .state import GameState, GameRecord, Constraints, AttributeStats, GameStatus
from .api import APIClient
from .policy import load_policy, PolicyLoadError, call_policy, default_policy_path
from .config import VENUE_CAP, REJECT_CAP
class BerghainClient:
def __init__(self) -> None:
self.gs = GameState.load()
self.api = APIClient()
# ----------------------------
# Player / game lifecycle
# ----------------------------
def set_player(self, uuid: str) -> None:
"""Sets the player ID and persists state."""
if not uuid:
raise ValueError("Player ID cannot be empty")
self.gs.player_id = uuid
self.gs.save()
def new_game(self, scenario: int) -> GameRecord:
"""
Starts a new game (policy is optional and may be attached later).
Returns the created GameRecord.
"""
prev = self.gs.current_game()
scenario = prev.scenario if scenario == -1 else scenario
resp = self.api.new_game(self.gs.player_id, scenario)
inherit_path = getattr(prev, "policy_path", None) if prev else None
policy_path = inherit_path or default_policy_path()
rec = GameRecord(
game_id=resp["gameId"],
initial_raw=resp,
scenario=scenario,
status=GameStatus.RUNNING,
constraints=Constraints(
mins={c["attribute"]: int(c["minCount"]) for c in resp.get("constraints", [])}
),
stats=AttributeStats(
relative_frequencies=resp.get("attributeStatistics", {}).get("relativeFrequencies", {}) or {},
correlations=resp.get("attributeStatistics", {}).get("correlations", {}) or {},
),
admitted_count=0,
rejected_count=0,
next_person=None,
tallies={},
policy_path=policy_path,
)
self.gs.add_game(rec)
self.gs.save()
return rec
def set_policy_for_current_game(self, policy_path: str) -> GameRecord:
"""
Attach/replace a policy for the current game (validates it first).
"""
rec = self.gs.current_game()
if not rec:
raise ValueError("No current game to set a policy for.")
load_policy(policy_path) # validate before saving
rec.policy_path = policy_path
self.gs.save()
return rec
# ----------------------------
# Internals
# ----------------------------
def _get_decide_callable(self, rec: GameRecord):
"""
Load and return the user's decide() callable for the given game record.
Raises PolicyLoadError if missing/invalid.
"""
if not rec.policy_path:
raise PolicyLoadError("No policy set for this game. Use `berghain set-policy <file>`.")
return load_policy(rec.policy_path)
def _update_from_decide_response(self, rec: GameRecord, data: Dict[str, Any]) -> None:
"""
Update local record from /decide-and-next API response.
"""
status = (data.get("status") or "").lower()
if status in {GameStatus.RUNNING, GameStatus.COMPLETED, GameStatus.FAILED}:
rec.status = status
rec.admitted_count = int(data.get("admittedCount", rec.admitted_count))
rec.rejected_count = int(data.get("rejectedCount", rec.rejected_count))
rec.next_person = data.get("nextPerson")
self.gs.save()
def fetch_first_person(self, rec: GameRecord) -> None:
"""
Fetch the first person (personIndex=0). The API allows 'accept' to be omitted for index 0.
"""
data = self.api.decide_and_next(rec.game_id, person_index=0, accept=None)
self._update_from_decide_response(rec, data)
def _remaining_need(self, rec: GameRecord) -> Dict[str, int]:
"""
Compute remaining shortfall per constrained attribute based on local tallies.
"""
return {
a: max(0, rec.constraints.mins.get(a, 0) - rec.tallies.get(a, 0))
for a in rec.constraints.mins
}
# ----------------------------
# Step / autoplay
# ----------------------------
def step(self, rec: GameRecord) -> Dict[str, Any]:
"""
Execute ONE decision step using the current policy and return a transcript entry:
{
"personIndex": int | None,
"accept": bool,
"attributes": {attr: bool},
"remaining_need_before": {attr: int},
"admitted_count": int, # after API call
"rejected_count": int, # after API call
"status": "running"|"completed"|"failed",
"reason": str # human-readable reason (from policy)
}
If no move is possible (e.g., game not running), returns a no-op entry with current counters.
Raises PolicyLoadError if no policy is set/invalid.
"""
# If the game isn't running, return a no-op entry.
if rec.status != GameStatus.RUNNING:
return {
"personIndex": None,
"accept": False,
"attributes": {},
"remaining_need_before": self._remaining_need(rec),
"admitted_count": rec.admitted_count,
"rejected_count": rec.rejected_count,
"status": rec.status,
"reason": "Game not running",
}
# Ensure we have a person ready.
if not rec.next_person:
self.fetch_first_person(rec)
if not rec.next_person:
return {
"personIndex": None,
"accept": False,
"attributes": {},
"remaining_need_before": self._remaining_need(rec),
"admitted_count": rec.admitted_count,
"rejected_count": rec.rejected_count,
"status": rec.status,
"reason": "No next person",
}
# Load user policy decide() and compute decision.
decide = self._get_decide_callable(rec)
idx = int(rec.next_person.get("personIndex", 0))
attrs: Dict[str, bool] = rec.next_person.get("attributes", {}) or {}
remaining_need_before = self._remaining_need(rec)
accept = call_policy(
decide,
attributes=attrs,
tallies=rec.tallies,
mins=rec.constraints.mins,
admitted_count=rec.admitted_count,
venue_cap=VENUE_CAP,
# New optional context for policies that opt in:
p=rec.stats.relative_frequencies,
corr=rec.stats.correlations,
stats=rec.stats, # optional: full object if the policy wants it
)
# Ask the policy module for its last explanation
reason = None
if hasattr(decide, "__globals__"):
pol = decide.__globals__
if "policy_reason" in pol and callable(pol["policy_reason"]):
reason = pol["policy_reason"]()
# Update local tallies pre-API so they reflect our chosen accept.
if accept:
rec.bump_tallies(attrs)
# Send decision, update state.
data = self.api.decide_and_next(rec.game_id, person_index=idx, accept=bool(accept))
self._update_from_decide_response(rec, data)
# Build the entry dict (with reason included).
entry = {
"personIndex": idx,
"accept": bool(accept),
"attributes": attrs,
"remaining_need_before": remaining_need_before,
"admitted_count": rec.admitted_count,
"rejected_count": rec.rejected_count,
"status": rec.status,
"reason": reason or ("Admit" if accept else "Reject"),
}
return entry
def autoplay_stream(
self,
rec: GameRecord,
max_steps: Optional[int] = None,
) -> Iterator[Dict[str, Any]]:
"""
Yield transcript entries (same shape as step()) until completion/failure.
Stops early if max_steps is provided.
"""
# Ensure policy exists before starting (raises PolicyLoadError if missing/invalid).
self._get_decide_callable(rec)
# Ensure we have the first person available to step on.
if rec.next_person is None and rec.status == GameStatus.RUNNING:
self.fetch_first_person(rec)
steps = 0
while rec.status == GameStatus.RUNNING:
if rec.admitted_count >= VENUE_CAP or rec.rejected_count >= REJECT_CAP:
break
entry = self.step(rec)
yield entry
steps += 1
if max_steps is not None and steps >= max_steps:
break
def autoplay(self, rec: GameRecord) -> GameRecord:
"""
Run until completion/failure using the current policy (drains the stream).
"""
for _ in self.autoplay_stream(rec):
pass
return rec