249 lines
8.9 KiB
Python
249 lines
8.9 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any, Dict, Iterator, Optional
|
|
|
|
from .state import GameState, GameRecord, Constraints, AttributeStats, GameStatus
|
|
from .api import APIClient
|
|
from .policy import load_policy, PolicyLoadError, call_policy, default_policy_path
|
|
from .config import VENUE_CAP, REJECT_CAP
|
|
|
|
|
|
class BerghainClient:
|
|
def __init__(self) -> None:
|
|
self.gs = GameState.load()
|
|
self.api = APIClient()
|
|
|
|
# ----------------------------
|
|
# Player / game lifecycle
|
|
# ----------------------------
|
|
|
|
def set_player(self, uuid: str) -> None:
|
|
"""Sets the player ID and persists state."""
|
|
if not uuid:
|
|
raise ValueError("Player ID cannot be empty")
|
|
self.gs.player_id = uuid
|
|
self.gs.save()
|
|
|
|
def new_game(self, scenario: int) -> GameRecord:
|
|
"""
|
|
Starts a new game (policy is optional and may be attached later).
|
|
Returns the created GameRecord.
|
|
"""
|
|
prev = self.gs.current_game()
|
|
scenario = prev.scenario if scenario == -1 else scenario
|
|
|
|
resp = self.api.new_game(self.gs.player_id, scenario)
|
|
|
|
inherit_path = getattr(prev, "policy_path", None) if prev else None
|
|
policy_path = inherit_path or default_policy_path()
|
|
|
|
rec = GameRecord(
|
|
game_id=resp["gameId"],
|
|
initial_raw=resp,
|
|
scenario=scenario,
|
|
status=GameStatus.RUNNING,
|
|
constraints=Constraints(
|
|
mins={c["attribute"]: int(c["minCount"]) for c in resp.get("constraints", [])}
|
|
),
|
|
stats=AttributeStats(
|
|
relative_frequencies=resp.get("attributeStatistics", {}).get("relativeFrequencies", {}) or {},
|
|
correlations=resp.get("attributeStatistics", {}).get("correlations", {}) or {},
|
|
),
|
|
admitted_count=0,
|
|
rejected_count=0,
|
|
next_person=None,
|
|
tallies={},
|
|
policy_path=policy_path,
|
|
)
|
|
self.gs.add_game(rec)
|
|
self.gs.save()
|
|
return rec
|
|
|
|
def set_policy_for_current_game(self, policy_path: str) -> GameRecord:
|
|
"""
|
|
Attach/replace a policy for the current game (validates it first).
|
|
"""
|
|
rec = self.gs.current_game()
|
|
if not rec:
|
|
raise ValueError("No current game to set a policy for.")
|
|
load_policy(policy_path) # validate before saving
|
|
rec.policy_path = policy_path
|
|
self.gs.save()
|
|
return rec
|
|
|
|
# ----------------------------
|
|
# Internals
|
|
# ----------------------------
|
|
|
|
def _get_decide_callable(self, rec: GameRecord):
|
|
"""
|
|
Load and return the user's decide() callable for the given game record.
|
|
Raises PolicyLoadError if missing/invalid.
|
|
"""
|
|
if not rec.policy_path:
|
|
raise PolicyLoadError("No policy set for this game. Use `berghain set-policy <file>`.")
|
|
return load_policy(rec.policy_path)
|
|
|
|
def _update_from_decide_response(self, rec: GameRecord, data: Dict[str, Any]) -> None:
|
|
"""
|
|
Update local record from /decide-and-next API response.
|
|
"""
|
|
status = (data.get("status") or "").lower()
|
|
if status in {GameStatus.RUNNING, GameStatus.COMPLETED, GameStatus.FAILED}:
|
|
rec.status = status
|
|
rec.admitted_count = int(data.get("admittedCount", rec.admitted_count))
|
|
rec.rejected_count = int(data.get("rejectedCount", rec.rejected_count))
|
|
rec.next_person = data.get("nextPerson")
|
|
self.gs.save()
|
|
|
|
def fetch_first_person(self, rec: GameRecord) -> None:
|
|
"""
|
|
Fetch the first person (personIndex=0). The API allows 'accept' to be omitted for index 0.
|
|
"""
|
|
data = self.api.decide_and_next(rec.game_id, person_index=0, accept=None)
|
|
self._update_from_decide_response(rec, data)
|
|
|
|
def _remaining_need(self, rec: GameRecord) -> Dict[str, int]:
|
|
"""
|
|
Compute remaining shortfall per constrained attribute based on local tallies.
|
|
"""
|
|
return {
|
|
a: max(0, rec.constraints.mins.get(a, 0) - rec.tallies.get(a, 0))
|
|
for a in rec.constraints.mins
|
|
}
|
|
|
|
# ----------------------------
|
|
# Step / autoplay
|
|
# ----------------------------
|
|
|
|
def step(self, rec: GameRecord) -> Dict[str, Any]:
|
|
"""
|
|
Execute ONE decision step using the current policy and return a transcript entry:
|
|
|
|
{
|
|
"personIndex": int | None,
|
|
"accept": bool,
|
|
"attributes": {attr: bool},
|
|
"remaining_need_before": {attr: int},
|
|
"admitted_count": int, # after API call
|
|
"rejected_count": int, # after API call
|
|
"status": "running"|"completed"|"failed",
|
|
"reason": str # human-readable reason (from policy)
|
|
}
|
|
|
|
If no move is possible (e.g., game not running), returns a no-op entry with current counters.
|
|
Raises PolicyLoadError if no policy is set/invalid.
|
|
"""
|
|
# If the game isn't running, return a no-op entry.
|
|
if rec.status != GameStatus.RUNNING:
|
|
return {
|
|
"personIndex": None,
|
|
"accept": False,
|
|
"attributes": {},
|
|
"remaining_need_before": self._remaining_need(rec),
|
|
"admitted_count": rec.admitted_count,
|
|
"rejected_count": rec.rejected_count,
|
|
"status": rec.status,
|
|
"reason": "Game not running",
|
|
}
|
|
|
|
# Ensure we have a person ready.
|
|
if not rec.next_person:
|
|
self.fetch_first_person(rec)
|
|
if not rec.next_person:
|
|
return {
|
|
"personIndex": None,
|
|
"accept": False,
|
|
"attributes": {},
|
|
"remaining_need_before": self._remaining_need(rec),
|
|
"admitted_count": rec.admitted_count,
|
|
"rejected_count": rec.rejected_count,
|
|
"status": rec.status,
|
|
"reason": "No next person",
|
|
}
|
|
|
|
# Load user policy decide() and compute decision.
|
|
decide = self._get_decide_callable(rec)
|
|
|
|
idx = int(rec.next_person.get("personIndex", 0))
|
|
attrs: Dict[str, bool] = rec.next_person.get("attributes", {}) or {}
|
|
remaining_need_before = self._remaining_need(rec)
|
|
|
|
accept = call_policy(
|
|
decide,
|
|
attributes=attrs,
|
|
tallies=rec.tallies,
|
|
mins=rec.constraints.mins,
|
|
admitted_count=rec.admitted_count,
|
|
venue_cap=VENUE_CAP,
|
|
# New optional context for policies that opt in:
|
|
p=rec.stats.relative_frequencies,
|
|
corr=rec.stats.correlations,
|
|
stats=rec.stats, # optional: full object if the policy wants it
|
|
)
|
|
|
|
# Ask the policy module for its last explanation
|
|
reason = None
|
|
if hasattr(decide, "__globals__"):
|
|
pol = decide.__globals__
|
|
if "policy_reason" in pol and callable(pol["policy_reason"]):
|
|
reason = pol["policy_reason"]()
|
|
|
|
# Update local tallies pre-API so they reflect our chosen accept.
|
|
if accept:
|
|
rec.bump_tallies(attrs)
|
|
|
|
# Send decision, update state.
|
|
data = self.api.decide_and_next(rec.game_id, person_index=idx, accept=bool(accept))
|
|
self._update_from_decide_response(rec, data)
|
|
|
|
# Build the entry dict (with reason included).
|
|
entry = {
|
|
"personIndex": idx,
|
|
"accept": bool(accept),
|
|
"attributes": attrs,
|
|
"remaining_need_before": remaining_need_before,
|
|
"admitted_count": rec.admitted_count,
|
|
"rejected_count": rec.rejected_count,
|
|
"status": rec.status,
|
|
"reason": reason or ("Admit" if accept else "Reject"),
|
|
}
|
|
|
|
return entry
|
|
|
|
|
|
def autoplay_stream(
|
|
self,
|
|
rec: GameRecord,
|
|
max_steps: Optional[int] = None,
|
|
) -> Iterator[Dict[str, Any]]:
|
|
"""
|
|
Yield transcript entries (same shape as step()) until completion/failure.
|
|
Stops early if max_steps is provided.
|
|
"""
|
|
# Ensure policy exists before starting (raises PolicyLoadError if missing/invalid).
|
|
self._get_decide_callable(rec)
|
|
|
|
# Ensure we have the first person available to step on.
|
|
if rec.next_person is None and rec.status == GameStatus.RUNNING:
|
|
self.fetch_first_person(rec)
|
|
|
|
steps = 0
|
|
while rec.status == GameStatus.RUNNING:
|
|
if rec.admitted_count >= VENUE_CAP or rec.rejected_count >= REJECT_CAP:
|
|
break
|
|
entry = self.step(rec)
|
|
yield entry
|
|
steps += 1
|
|
if max_steps is not None and steps >= max_steps:
|
|
break
|
|
|
|
def autoplay(self, rec: GameRecord) -> GameRecord:
|
|
"""
|
|
Run until completion/failure using the current policy (drains the stream).
|
|
"""
|
|
for _ in self.autoplay_stream(rec):
|
|
pass
|
|
return rec
|
|
|