diff --git a/.github/workflows/docker-ci.yml b/.github/workflows/docker-ci.yml index f028aaa8..65270e9d 100644 --- a/.github/workflows/docker-ci.yml +++ b/.github/workflows/docker-ci.yml @@ -111,6 +111,7 @@ jobs: docker cp pyproject.toml test_container:/root/python-sc2/ docker cp uv.lock test_container:/root/python-sc2/ docker cp sc2 test_container:/root/python-sc2/sc2 + docker cp s2clientprotocol test_container:/root/python-sc2/s2clientprotocol docker cp test test_container:/root/python-sc2/test docker cp examples test_container:/root/python-sc2/examples docker exec -i test_container bash -c "pip install uv \ diff --git a/dockerfiles/test_docker_image.sh b/dockerfiles/test_docker_image.sh index 4b203c2e..7c10be5f 100644 --- a/dockerfiles/test_docker_image.sh +++ b/dockerfiles/test_docker_image.sh @@ -46,6 +46,7 @@ docker cp uv.lock test_container:/root/python-sc2/ docker exec -i test_container bash -c "pip install uv && cd python-sc2 && uv sync --no-cache --no-install-project" docker cp sc2 test_container:/root/python-sc2/sc2 +docker cp s2clientprotocol test_container:/root/python-sc2/s2clientprotocol docker cp test test_container:/root/python-sc2/test # Run various test bots diff --git a/examples/arcade_bot.py b/examples/arcade_bot.py index 32bbf22c..811ee944 100644 --- a/examples/arcade_bot.py +++ b/examples/arcade_bot.py @@ -42,7 +42,7 @@ async def on_start(self): await self.chat_send("Edit this message for automatic chat commands.") self.client.game_step = 2 - async def on_step(self, iteration): + async def on_step(self, iteration: int): # do marine micro vs zerglings for unit in self.units(UnitTypeId.MARINE): if self.enemy_units: diff --git a/examples/competitive/bot.py b/examples/competitive/bot.py index 5170635a..253337b6 100644 --- a/examples/competitive/bot.py +++ b/examples/competitive/bot.py @@ -7,7 +7,7 @@ async def on_start(self): print("Game started") # Do things here before the game starts - async def on_step(self, iteration): + async def on_step(self, iteration: int): # Populate this function with whatever your bot should do! pass diff --git a/examples/distributed_workers.py b/examples/distributed_workers.py index 95d3d4af..9e7940e5 100644 --- a/examples/distributed_workers.py +++ b/examples/distributed_workers.py @@ -8,7 +8,7 @@ class TerranBot(BotAI): - async def on_step(self, iteration): + async def on_step(self, iteration: int): await self.distribute_workers() await self.build_supply() await self.build_workers() diff --git a/examples/fastreload.py b/examples/fastreload.py index 3cacde4a..4fc5439a 100644 --- a/examples/fastreload.py +++ b/examples/fastreload.py @@ -4,11 +4,14 @@ from sc2 import maps from sc2.data import Difficulty, Race from sc2.main import _host_game_iter -from sc2.player import Bot, Computer +from sc2.player import AbstractPlayer, Bot, Computer def main(): - player_config = [Bot(Race.Zerg, zerg_rush.ZergRushBot()), Computer(Race.Terran, Difficulty.Medium)] + player_config: list[AbstractPlayer] = [ + Bot(Race.Zerg, zerg_rush.ZergRushBot()), + Computer(Race.Terran, Difficulty.Medium), + ] gen = _host_game_iter(maps.get("Abyssal Reef LE"), player_config, realtime=False) diff --git a/examples/host_external_norestart.py b/examples/host_external_norestart.py index eb2558a9..c5626ac0 100644 --- a/examples/host_external_norestart.py +++ b/examples/host_external_norestart.py @@ -1,13 +1,13 @@ -import sc2 from examples.zerg.zerg_rush import ZergRushBot from sc2 import maps from sc2.data import Race from sc2.main import _host_game_iter from sc2.player import Bot +from sc2.portconfig import Portconfig def main(): - portconfig = sc2.portconfig.Portconfig() + portconfig: Portconfig = Portconfig() print(portconfig.as_json) player_config = [Bot(Race.Zerg, ZergRushBot()), Bot(Race.Zerg, None)] diff --git a/examples/protoss/cannon_rush.py b/examples/protoss/cannon_rush.py index 2d287202..abe691ac 100644 --- a/examples/protoss/cannon_rush.py +++ b/examples/protoss/cannon_rush.py @@ -9,7 +9,7 @@ class CannonRushBot(BotAI): - async def on_step(self, iteration): + async def on_step(self, iteration: int): if iteration == 0: await self.chat_send("(probe)(pylon)(cannon)(cannon)(gg)") diff --git a/examples/protoss/find_adept_shades.py b/examples/protoss/find_adept_shades.py index 8b136cfc..d10896d3 100644 --- a/examples/protoss/find_adept_shades.py +++ b/examples/protoss/find_adept_shades.py @@ -13,7 +13,7 @@ class FindAdeptShadesBot(BotAI): def __init__(self): self.shaded = False - self.shades_mapping = {} + self.shades_mapping: dict[int, int] = {} async def on_start(self): self.client.game_step = 2 diff --git a/examples/protoss/threebase_voidray.py b/examples/protoss/threebase_voidray.py index 2030a28f..314f6696 100644 --- a/examples/protoss/threebase_voidray.py +++ b/examples/protoss/threebase_voidray.py @@ -9,7 +9,7 @@ class ThreebaseVoidrayBot(BotAI): - async def on_step(self, iteration): + async def on_step(self, iteration: int): target_base_count = 3 target_stargate_count = 3 diff --git a/examples/protoss/warpgate_push.py b/examples/protoss/warpgate_push.py index 7058ea43..f1b36079 100644 --- a/examples/protoss/warpgate_push.py +++ b/examples/protoss/warpgate_push.py @@ -9,6 +9,7 @@ from sc2.ids.upgrade_id import UpgradeId from sc2.main import run_game from sc2.player import Bot, Computer +from sc2.unit import Unit class WarpGateBot(BotAI): @@ -16,11 +17,11 @@ def __init__(self): # Initialize inherited class self.proxy_built = False - async def warp_new_units(self, proxy): + async def warp_new_units(self, proxy: Unit): for warpgate in self.structures(UnitTypeId.WARPGATE).ready: - abilities = await self.get_available_abilities(warpgate) + abilities = await self.get_available_abilities([warpgate]) # all the units have the same cooldown anyway so let's just look at ZEALOT - if AbilityId.WARPGATETRAIN_STALKER in abilities: + if AbilityId.WARPGATETRAIN_STALKER in abilities[0]: pos = proxy.position.to2.random_on_distance(4) placement = await self.find_placement(AbilityId.WARPGATETRAIN_STALKER, pos, placement_step=1) if placement is None: @@ -29,7 +30,7 @@ async def warp_new_units(self, proxy): return warpgate.warp_in(UnitTypeId.STALKER, placement) - async def on_step(self, iteration): + async def on_step(self, iteration: int): await self.distribute_workers() if not self.townhalls.ready: diff --git a/examples/simulate_fight_scenario.py b/examples/simulate_fight_scenario.py index b3d66511..2bee4390 100644 --- a/examples/simulate_fight_scenario.py +++ b/examples/simulate_fight_scenario.py @@ -15,7 +15,7 @@ class FightBot(BotAI): def __init__(self): super().__init__() - self.enemy_location: Point2 = None + self.enemy_location: Point2 | None = None self.fight_started = False async def on_start(self): @@ -23,7 +23,7 @@ async def on_start(self): await self.client.debug_show_map() await self.client.debug_control_enemy() - async def on_step(self, iteration): + async def on_step(self, iteration: int): # Wait till control retrieved, destroy all starting units, recreate the world if iteration > 0 and self.enemy_units and not self.enemy_location: await self.reset_arena() diff --git a/examples/terran/cyclone_push.py b/examples/terran/cyclone_push.py index cf04c91a..f66a18bb 100644 --- a/examples/terran/cyclone_push.py +++ b/examples/terran/cyclone_push.py @@ -28,7 +28,7 @@ def select_target(self) -> Point2: # Pick a random mineral field on the map return self.mineral_field.random.position - async def on_step(self, iteration): + async def on_step(self, iteration: int): CCs: Units = self.townhalls(UnitTypeId.COMMANDCENTER) # If no command center exists, attack-move with all workers and cyclones if not CCs: @@ -87,7 +87,7 @@ async def on_step(self, iteration): if self.gas_buildings.filter(lambda unit: unit.distance_to(vg) < 1): continue # Select a worker closest to the vespene geysir - worker: Unit = self.select_build_worker(vg) + worker: Unit | None = self.select_build_worker(vg) # Worker can be none in cases where all workers are dead # or 'select_build_worker' function only selects from workers which carry no minerals if worker is None: @@ -112,9 +112,9 @@ async def on_step(self, iteration): # Saturate gas for refinery in self.gas_buildings: if refinery.assigned_harvesters < refinery.ideal_harvesters: - worker: Units = self.workers.closer_than(10, refinery) - if worker: - worker.random.gather(refinery) + workers: Units = self.workers.closer_than(10, refinery) + if workers: + workers.random.gather(refinery) for scv in self.workers.idle: scv.gather(self.mineral_field.closest_to(cc)) diff --git a/examples/terran/mass_reaper.py b/examples/terran/mass_reaper.py index 01aba5dd..4ca3f1aa 100644 --- a/examples/terran/mass_reaper.py +++ b/examples/terran/mass_reaper.py @@ -25,7 +25,7 @@ def __init__(self): # Select distance calculation method 0, which is the pure python distance calculation without caching or indexing, using math.hypot(), for more info see bot_ai_internal.py _distances_override_functions() function self.distance_calculation_method = 3 - async def on_step(self, iteration): + async def on_step(self, iteration: int): # Benchmark and print duration time of the on_step method based on "self.distance_calculation_method" value # logger.info(self.time_formatted, self.supply_used, self.step_time[1]) """ @@ -45,7 +45,9 @@ async def on_step(self, iteration): # If workers were found if workers: worker: Unit = workers.furthest_to(workers.center) - location: Point2 = await self.find_placement(UnitTypeId.SUPPLYDEPOT, worker.position, placement_step=3) + location: Point2 | None = await self.find_placement( + UnitTypeId.SUPPLYDEPOT, worker.position, placement_step=3 + ) # If a placement location was found if location: # Order worker to build exactly on that location @@ -72,13 +74,13 @@ async def on_step(self, iteration): and self.can_afford(UnitTypeId.COMMANDCENTER) ): # get_next_expansion returns the position of the next possible expansion location where you can place a command center - location: Point2 = await self.get_next_expansion() + location: Point2 | None = await self.get_next_expansion() if location: # Now we "select" (or choose) the nearest worker to that found location - worker: Unit = self.select_build_worker(location) - if worker and self.can_afford(UnitTypeId.COMMANDCENTER): + worker2: Unit | None = self.select_build_worker(location) + if worker2 and self.can_afford(UnitTypeId.COMMANDCENTER): # The worker will be commanded to build the command center - worker.build(UnitTypeId.COMMANDCENTER, location) + worker2.build(UnitTypeId.COMMANDCENTER, location) # Build up to 4 barracks if we can afford them # Check if we have a supply depot (tech requirement) before trying to make barracks @@ -97,7 +99,7 @@ async def on_step(self, iteration): ): # need to check if townhalls.amount > 0 because placement is based on townhall location worker: Unit = workers.furthest_to(workers.center) # I chose placement_step 4 here so there will be gaps between barracks hopefully - location: Point2 = await self.find_placement( + location: Point2 | None = await self.find_placement( UnitTypeId.BARRACKS, self.townhalls.random.position, placement_step=4 ) if location: @@ -168,7 +170,7 @@ async def on_step(self, iteration): retreat_points: set[Point2] = {x for x in retreat_points if self.in_pathing_grid(x)} if retreat_points: closest_enemy: Unit = enemy_threats_close.closest_to(r) - retreat_point: Unit = closest_enemy.position.furthest(retreat_points) + retreat_point: Point2 = closest_enemy.position.furthest(retreat_points) r.move(retreat_point) continue # Continue for loop, dont execute any of the following @@ -259,13 +261,13 @@ async def on_step(self, iteration): # Stolen and modified from position.py @staticmethod - def neighbors4(position, distance=1) -> set[Point2]: + def neighbors4(position: Point2, distance: float = 1) -> set[Point2]: p = position d = distance return {Point2((p.x - d, p.y)), Point2((p.x + d, p.y)), Point2((p.x, p.y - d)), Point2((p.x, p.y + d))} # Stolen and modified from position.py - def neighbors8(self, position, distance=1) -> set[Point2]: + def neighbors8(self, position: Point2, distance: float = 1) -> set[Point2]: p = position d = distance return self.neighbors4(position, distance) | { diff --git a/examples/terran/onebase_battlecruiser.py b/examples/terran/onebase_battlecruiser.py index 1173af5e..47cd5f62 100644 --- a/examples/terran/onebase_battlecruiser.py +++ b/examples/terran/onebase_battlecruiser.py @@ -29,7 +29,7 @@ def select_target(self) -> tuple[Point2, bool]: return self.mineral_field.random.position, False - async def on_step(self, iteration): + async def on_step(self, iteration: int): ccs: Units = self.townhalls # If we no longer have townhalls, attack with all workers if not ccs: @@ -85,7 +85,7 @@ async def on_step(self, iteration): if self.gas_buildings.filter(lambda unit: unit.distance_to(vg) < 1): break - worker: Unit = self.select_build_worker(vg.position) + worker: Unit | None = self.select_build_worker(vg.position) if worker is None: break @@ -172,9 +172,9 @@ def starport_land_positions(sp_position: Point2) -> list[Point2]: # Saturate refineries for refinery in self.gas_buildings: if refinery.assigned_harvesters < refinery.ideal_harvesters: - worker: Units = self.workers.closer_than(10, refinery) - if worker: - worker.random.gather(refinery) + workers: Units = self.workers.closer_than(10, refinery) + if workers: + workers.random.gather(refinery) # Send workers back to mine if they are idle for scv in self.workers.idle: diff --git a/examples/terran/proxy_rax.py b/examples/terran/proxy_rax.py index 0ce8d789..5912e461 100644 --- a/examples/terran/proxy_rax.py +++ b/examples/terran/proxy_rax.py @@ -13,7 +13,7 @@ class ProxyRaxBot(BotAI): async def on_start(self): self.client.game_step = 2 - async def on_step(self, iteration): + async def on_step(self, iteration: int): # If we don't have a townhall anymore, send all units to attack ccs: Units = self.townhalls(UnitTypeId.COMMANDCENTER) if not ccs: diff --git a/examples/terran/ramp_wall.py b/examples/terran/ramp_wall.py index 244bce99..0aa12fea 100644 --- a/examples/terran/ramp_wall.py +++ b/examples/terran/ramp_wall.py @@ -19,7 +19,7 @@ class RampWallBot(BotAI): def __init__(self): self.unit_command_uses_self_do = False - async def on_step(self, iteration): + async def on_step(self, iteration: int): ccs: Units = self.townhalls(UnitTypeId.COMMANDCENTER) if not ccs: return @@ -70,11 +70,11 @@ async def on_step(self, iteration): # Draw if two selected units are facing each other - green if this guy is facing the other, red if he is not self.draw_facing_units() - depot_placement_positions: frozenset[Point2] = self.main_base_ramp.corner_depots + depot_placement_positions: set[Point2] = self.main_base_ramp.corner_depots # Uncomment the following if you want to build 3 supply depots in the wall instead of a barracks in the middle + 2 depots in the corner # depot_placement_positions = self.main_base_ramp.corner_depots | {self.main_base_ramp.depot_in_middle} - barracks_placement_position: Point2 = self.main_base_ramp.barracks_correct_placement + barracks_placement_position: Point2 | None = self.main_base_ramp.barracks_correct_placement # If you prefer to have the barracks in the middle without room for addons, use the following instead # barracks_placement_position = self.main_base_ramp.barracks_in_middle diff --git a/examples/too_slow_bot.py b/examples/too_slow_bot.py index 28d32c0e..6e8b9baf 100644 --- a/examples/too_slow_bot.py +++ b/examples/too_slow_bot.py @@ -9,7 +9,7 @@ class SlowBot(ProxyRaxBot): - async def on_step(self, iteration): + async def on_step(self, iteration: int): await asyncio.sleep(random.random()) await super().on_step(iteration) diff --git a/examples/worker_rush.py b/examples/worker_rush.py index 686c7256..3f57e7d2 100644 --- a/examples/worker_rush.py +++ b/examples/worker_rush.py @@ -6,7 +6,7 @@ class WorkerRushBot(BotAI): - async def on_step(self, iteration): + async def on_step(self, iteration: int): if iteration == 0: for worker in self.workers: worker.attack(self.enemy_start_locations[0]) diff --git a/examples/zerg/banes_banes_banes.py b/examples/zerg/banes_banes_banes.py index 85a00c70..c5216977 100644 --- a/examples/zerg/banes_banes_banes.py +++ b/examples/zerg/banes_banes_banes.py @@ -23,7 +23,7 @@ def select_target(self) -> Point2: return random.choice(self.enemy_structures).position return self.enemy_start_locations[0] - async def on_step(self, iteration): + async def on_step(self, iteration: int): larvae: Units = self.larva lings: Units = self.units(UnitTypeId.ZERGLING) # Send all idle banes to enemy diff --git a/examples/zerg/expand_everywhere.py b/examples/zerg/expand_everywhere.py index 552ae0f9..87186ac2 100644 --- a/examples/zerg/expand_everywhere.py +++ b/examples/zerg/expand_everywhere.py @@ -16,7 +16,7 @@ async def on_start(self): self.client.game_step = 50 await self.client.debug_show_map() - async def on_step(self, iteration): + async def on_step(self, iteration: int): # Build overlords if about to be supply blocked if ( self.supply_left < 2 diff --git a/examples/zerg/hydralisk_push.py b/examples/zerg/hydralisk_push.py index 6e6d17e2..34f80003 100644 --- a/examples/zerg/hydralisk_push.py +++ b/examples/zerg/hydralisk_push.py @@ -19,7 +19,7 @@ def select_target(self) -> Point2: return random.choice(self.enemy_structures).position return self.enemy_start_locations[0] - async def on_step(self, iteration): + async def on_step(self, iteration: int): larvae: Units = self.larva forces: Units = self.units.of_type({UnitTypeId.ZERGLING, UnitTypeId.HYDRALISK}) diff --git a/examples/zerg/onebase_broodlord.py b/examples/zerg/onebase_broodlord.py index 72d75bca..5db57a51 100644 --- a/examples/zerg/onebase_broodlord.py +++ b/examples/zerg/onebase_broodlord.py @@ -19,7 +19,7 @@ def select_target(self) -> Point2: return random.choice(self.enemy_structures).position return self.enemy_start_locations[0] - async def on_step(self, iteration): + async def on_step(self, iteration: int): larvae: Units = self.larva forces: Units = self.units.of_type({UnitTypeId.ZERGLING, UnitTypeId.CORRUPTOR, UnitTypeId.BROODLORD}) diff --git a/examples/zerg/worker_split.py b/examples/zerg/worker_split.py index 3edec5bb..689bdd90 100644 --- a/examples/zerg/worker_split.py +++ b/examples/zerg/worker_split.py @@ -30,7 +30,7 @@ async def on_before_start(self): async def on_start(self): """This function is run after the expansion locations and ramps are calculated.""" - async def on_step(self, iteration): + async def on_step(self, iteration: int): if iteration % 10 == 0: await asyncio.sleep(3) # In realtime=False, this should print "8*x" and "x" if diff --git a/examples/zerg/zerg_rush.py b/examples/zerg/zerg_rush.py index 93139434..15d0df50 100644 --- a/examples/zerg/zerg_rush.py +++ b/examples/zerg/zerg_rush.py @@ -22,7 +22,7 @@ def __init__(self): async def on_start(self): self.client.game_step = 2 - async def on_step(self, iteration): + async def on_step(self, iteration: int): if iteration == 0: await self.chat_send("(glhf)") @@ -38,11 +38,11 @@ async def on_step(self, iteration): hatch: Unit = self.townhalls[0] # Pick a target location - target: Point2 = self.enemy_structures.not_flying.random_or(self.enemy_start_locations[0]).position + target_pos: Point2 = self.enemy_structures.not_flying.random_or(self.enemy_start_locations[0]).position # Give all zerglings an attack command for zergling in self.units(UnitTypeId.ZERGLING): - zergling.attack(target) + zergling.attack(target=target_pos) # Inject hatchery if queen has more than 25 energy for queen in self.units(UnitTypeId.QUEEN): diff --git a/pyproject.toml b/pyproject.toml index 22aac81d..f634fb32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,10 @@ dev = [ [tool.setuptools] license-files = [] -package-dir = { sc2 = "sc2" } +package-dir = { sc2 = "sc2", s2clientprotocol = "s2clientprotocol" } + +[tool.setuptools.package-data] +sc2 = ["py.typed", "*.pyi"] [build-system] # https://packaging.python.org/en/latest/tutorials/packaging-projects/#choosing-a-build-backend diff --git a/s2clientprotocol/__init__.pyi b/s2clientprotocol/__init__.pyi new file mode 100644 index 00000000..e69de29b diff --git a/s2clientprotocol/common_pb2.pyi b/s2clientprotocol/common_pb2.pyi new file mode 100644 index 00000000..e586cfef --- /dev/null +++ b/s2clientprotocol/common_pb2.pyi @@ -0,0 +1,48 @@ +# https://github.com/Blizzard/s2client-proto/blob/bff45dae1fc685e6acbaae084670afb7d1c0832c/s2clientprotocol/common.proto +from enum import Enum + +from google.protobuf.message import Message + +class AvailableAbility(Message): + ability_id: int + requires_point: bool + def __init__(self, ability_id: int = ..., requires_point: bool = ...) -> None: ... + +class ImageData(Message): + bits_per_pixel: int + size: Size2DI + data: bytes + def __init__(self, bits_per_pixel: int = ..., size: Size2DI = ..., data: bytes = ...) -> None: ... + +class PointI(Message): + x: int + y: int + def __init__(self, x: int = ..., y: int = ...) -> None: ... + +class RectangleI(Message): + p0: PointI + p1: PointI + def __init__(self, p0: PointI = ..., p1: PointI = ...) -> None: ... + +class Point2D(Message): + x: float + y: float + def __init__(self, x: float = ..., y: float = ...) -> None: ... + +class Point(Message): + x: float + y: float + z: float + def __init__(self, x: float = ..., y: float = ..., z: float = ...) -> None: ... + +class Size2DI(Message): + x: int + y: int + def __init__(self, x: int = ..., y: int = ...) -> None: ... + +class Race(Enum): + NoRace: int + Terran: int + Zerg: int + Protoss: int + Random: int diff --git a/s2clientprotocol/data_pb2.pyi b/s2clientprotocol/data_pb2.pyi new file mode 100644 index 00000000..8b839cf5 --- /dev/null +++ b/s2clientprotocol/data_pb2.pyi @@ -0,0 +1,170 @@ +from collections.abc import Iterable +from enum import Enum + +from google.protobuf.message import Message + +class Target(Enum): + # NONE: int + Point: int + Unit: int + PointOrUnit: int + PointOrNone: int + +class AbilityData(Message): + ability_id: int + link_name: str + link_index: int + button_name: str + friendly_name: str + hotkey: str + remaps_to_ability_id: int + available: bool + target: int + allow_minimap: bool + allow_autocast: bool + is_building: bool + footprint_radius: float + is_instant_placement: bool + cast_range: float + def __init__( + self, + ability_id: int = ..., + link_name: str = ..., + link_index: int = ..., + button_name: str = ..., + friendly_name: str = ..., + hotkey: str = ..., + remaps_to_ability_id: int = ..., + available: bool = ..., + target: int = ..., + allow_minimap: bool = ..., + allow_autocast: bool = ..., + is_building: bool = ..., + footprint_radius: float = ..., + is_instant_placement: bool = ..., + cast_range: float = ..., + ) -> None: ... + +class Attribute(Enum): + Light: int + Armored: int + Biological: int + Mechanical: int + Robotic: int + Psionic: int + Massive: int + Structure: int + Hover: int + Heroic: int + Summoned: int + +class DamageBonus(Message): + attribute: int + bonus: float + def __init__(self, attribute: int = ..., bonus: float = ...) -> None: ... + +class TargetType(Enum): + Ground: int + Air: int + Any: int + +class Weapon(Message): + type: int + damage: float + damage_bonus: Iterable[DamageBonus] + attacks: int + range: float + speed: float + def __init__( + self, + type: int = ..., + damage: float = ..., + damage_bonus: Iterable[DamageBonus] = ..., + attacks: int = ..., + range: float = ..., + speed: float = ..., + ) -> None: ... + +class UnitTypeData(Message): + unit_id: int + name: str + available: bool + cargo_size: int + mineral_cost: int + vespene_cost: int + food_required: float + food_provided: float + ability_id: int + race: int + build_time: float + has_vespene: bool + has_minerals: bool + sight_range: float + tech_alias: Iterable[int] + unit_alias: int + tech_requirement: int + require_attached: bool + attributes: Iterable[int] + movement_speed: float + armor: float + weapons: Iterable[Weapon] + def __init__( + self, + unit_id: int = ..., + name: str = ..., + available: bool = ..., + cargo_size: int = ..., + mineral_cost: int = ..., + vespene_cost: int = ..., + food_required: float = ..., + food_provided: float = ..., + ability_id: int = ..., + race: int = ..., + build_time: float = ..., + has_vespene: bool = ..., + has_minerals: bool = ..., + sight_range: float = ..., + tech_alias: Iterable[int] = ..., + unit_alias: int = ..., + tech_requirement: int = ..., + require_attached: bool = ..., + attributes: Iterable[int] = ..., + movement_speed: float = ..., + armor: float = ..., + weapons: Iterable[Weapon] = ..., + ) -> None: ... + +class UpgradeData(Message): + upgrade_id: int + name: str + mineral_cost: int + vespene_cost: int + research_time: float + ability_id: int + def __init__( + self, + upgrade_id: int = ..., + name: str = ..., + mineral_cost: int = ..., + vespene_cost: int = ..., + research_time: float = ..., + ability_id: int = ..., + ) -> None: ... + +class BuffData(Message): + buff_id: int + name: str + def __init__(self, buff_id: int = ..., name: str = ...) -> None: ... + +class EffectData(Message): + effect_id: int + name: str + friendly_name: str + radius: float + def __init__( + self, + effect_id: int = ..., + name: str = ..., + friendly_name: str = ..., + radius: float = ..., + ) -> None: ... diff --git a/s2clientprotocol/debug_pb2.pyi b/s2clientprotocol/debug_pb2.pyi new file mode 100644 index 00000000..edc956ee --- /dev/null +++ b/s2clientprotocol/debug_pb2.pyi @@ -0,0 +1,152 @@ +from collections.abc import Iterable +from enum import Enum + +from google.protobuf.message import Message + +from .common_pb2 import Point, Point2D + +class DebugCommand(Message): + draw: DebugDraw + game_state: int + create_unit: DebugCreateUnit + kill_unit: DebugKillUnit + test_process: DebugTestProcess + score: DebugSetScore + end_game: DebugEndGame + unit_value: DebugSetUnitValue + def __init__( + self, + draw: DebugDraw = ..., + game_state: int = ..., + create_unit: DebugCreateUnit = ..., + kill_unit: DebugKillUnit = ..., + test_process: DebugTestProcess = ..., + score: DebugSetScore = ..., + end_game: DebugEndGame = ..., + unit_value: DebugSetUnitValue = ..., + ) -> None: ... + +class DebugDraw(Message): + text: Iterable[DebugText] + lines: Iterable[DebugLine] + boxes: Iterable[DebugBox] + spheres: Iterable[DebugSphere] + def __init__( + self, + text: Iterable[DebugText] = ..., + lines: Iterable[DebugLine] = ..., + boxes: Iterable[DebugBox] = ..., + spheres: Iterable[DebugSphere] = ..., + ) -> None: ... + +class Line(Message): + p0: Point + p1: Point + def __init__(self, p0: Point = ..., p1: Point = ...) -> None: ... + +class Color(Message): + r: int + g: int + b: int + def __init__(self, r: int = ..., g: int = ..., b: int = ...) -> None: ... + +class DebugText(Message): + color: Color + text: str + virtual_pos: Point + world_pos: Point + size: int + def __init__( + self, + color: Color = ..., + text: str = ..., + virtual_pos: Point = ..., + world_pos: Point = ..., + size: int = ..., + ) -> None: ... + +class DebugLine(Message): + color: Color + line: Line + def __init__(self, color: Color = ..., line: Line = ...) -> None: ... + +class DebugBox(Message): + color: Color + min: Point + max: Point + def __init__(self, color: Color = ..., min: Point = ..., max: Point = ...) -> None: ... + +class DebugSphere(Message): + color: Color + p: Point + r: float + def __init__(self, color: Color = ..., p: Point = ..., r: float = ...) -> None: ... + +class DebugGameState(Enum): + show_map: int + control_enemy: int + food: int + free: int + all_resources: int + god: int + minerals: int + gas: int + cooldown: int + tech_tree: int + upgrade: int + fast_build: int + +class DebugCreateUnit(Message): + unit_type: int + owner: int + pos: Point2D + quantity: int + def __init__( + self, + unit_type: int = ..., + owner: int = ..., + pos: Point2D = ..., + quantity: int = ..., + ) -> None: ... + +class DebugKillUnit(Message): + tag: Iterable[int] + def __init__(self, tag: Iterable[int] = ...) -> None: ... + +class Test(Enum): + hang: int + crash: int + exit: int + +class DebugTestProcess(Message): + test: int + delay_ms: int + def __init__(self, test: int = ..., delay_ms: int = ...) -> None: ... + +class DebugSetScore(Message): + score: float + def __init__(self, score: float = ...) -> None: ... + +class EndResult(Enum): + Surrender: int + DeclareVictory: int + +class DebugEndGame(Message): + end_result: int + def __init__(self, end_result: int = ...) -> None: ... + +class UnitValue(Enum): + Energy: int + Life: int + Shields: int + +class DebugSetUnitValue(Message): + unit_value: int + value: float + unit_tag: int + def __init__( + self, + unit_value: int = ..., + value: float = ..., + unit_tag: int = ..., + ) -> None: ... diff --git a/s2clientprotocol/error_pb2.pyi b/s2clientprotocol/error_pb2.pyi new file mode 100644 index 00000000..017262d9 --- /dev/null +++ b/s2clientprotocol/error_pb2.pyi @@ -0,0 +1,217 @@ +from enum import Enum + +class ActionResult(Enum): + Success: int + NotSupported: int + Error: int + CantQueueThatOrder: int + Retry: int + Cooldown: int + QueueIsFull: int + RallyQueueIsFull: int + NotEnoughMinerals: int + NotEnoughVespene: int + NotEnoughTerrazine: int + NotEnoughCustom: int + NotEnoughFood: int + FoodUsageImpossible: int + NotEnoughLife: int + NotEnoughShields: int + NotEnoughEnergy: int + LifeSuppressed: int + ShieldsSuppressed: int + EnergySuppressed: int + NotEnoughCharges: int + CantAddMoreCharges: int + TooMuchMinerals: int + TooMuchVespene: int + TooMuchTerrazine: int + TooMuchCustom: int + TooMuchFood: int + TooMuchLife: int + TooMuchShields: int + TooMuchEnergy: int + MustTargetUnitWithLife: int + MustTargetUnitWithShields: int + MustTargetUnitWithEnergy: int + CantTrade: int + CantSpend: int + CantTargetThatUnit: int + CouldntAllocateUnit: int + UnitCantMove: int + TransportIsHoldingPosition: int + BuildTechRequirementsNotMet: int + CantFindPlacementLocation: int + CantBuildOnThat: int + CantBuildTooCloseToDropOff: int + CantBuildLocationInvalid: int + CantSeeBuildLocation: int + CantBuildTooCloseToCreepSource: int + CantBuildTooCloseToResources: int + CantBuildTooFarFromWater: int + CantBuildTooFarFromCreepSource: int + CantBuildTooFarFromBuildPowerSource: int + CantBuildOnDenseTerrain: int + CantTrainTooFarFromTrainPowerSource: int + CantLandLocationInvalid: int + CantSeeLandLocation: int + CantLandTooCloseToCreepSource: int + CantLandTooCloseToResources: int + CantLandTooFarFromWater: int + CantLandTooFarFromCreepSource: int + CantLandTooFarFromBuildPowerSource: int + CantLandTooFarFromTrainPowerSource: int + CantLandOnDenseTerrain: int + AddOnTooFarFromBuilding: int + MustBuildRefineryFirst: int + BuildingIsUnderConstruction: int + CantFindDropOff: int + CantLoadOtherPlayersUnits: int + NotEnoughRoomToLoadUnit: int + CantUnloadUnitsThere: int + CantWarpInUnitsThere: int + CantLoadImmobileUnits: int + CantRechargeImmobileUnits: int + CantRechargeUnderConstructionUnits: int + CantLoadThatUnit: int + NoCargoToUnload: int + LoadAllNoTargetsFound: int + NotWhileOccupied: int + CantAttackWithoutAmmo: int + CantHoldAnyMoreAmmo: int + TechRequirementsNotMet: int + MustLockdownUnitFirst: int + MustTargetUnit: int + MustTargetInventory: int + MustTargetVisibleUnit: int + MustTargetVisibleLocation: int + MustTargetWalkableLocation: int + MustTargetPawnableUnit: int + YouCantControlThatUnit: int + YouCantIssueCommandsToThatUnit: int + MustTargetResources: int + RequiresHealTarget: int + RequiresRepairTarget: int + NoItemsToDrop: int + CantHoldAnyMoreItems: int + CantHoldThat: int + TargetHasNoInventory: int + CantDropThisItem: int + CantMoveThisItem: int + CantPawnThisUnit: int + MustTargetCaster: int + CantTargetCaster: int + MustTargetOuter: int + CantTargetOuter: int + MustTargetYourOwnUnits: int + CantTargetYourOwnUnits: int + MustTargetFriendlyUnits: int + CantTargetFriendlyUnits: int + MustTargetNeutralUnits: int + CantTargetNeutralUnits: int + MustTargetEnemyUnits: int + CantTargetEnemyUnits: int + MustTargetAirUnits: int + CantTargetAirUnits: int + MustTargetGroundUnits: int + CantTargetGroundUnits: int + MustTargetStructures: int + CantTargetStructures: int + MustTargetLightUnits: int + CantTargetLightUnits: int + MustTargetArmoredUnits: int + CantTargetArmoredUnits: int + MustTargetBiologicalUnits: int + CantTargetBiologicalUnits: int + MustTargetHeroicUnits: int + CantTargetHeroicUnits: int + MustTargetRoboticUnits: int + CantTargetRoboticUnits: int + MustTargetMechanicalUnits: int + CantTargetMechanicalUnits: int + MustTargetPsionicUnits: int + CantTargetPsionicUnits: int + MustTargetMassiveUnits: int + CantTargetMassiveUnits: int + MustTargetMissile: int + CantTargetMissile: int + MustTargetWorkerUnits: int + CantTargetWorkerUnits: int + MustTargetEnergyCapableUnits: int + CantTargetEnergyCapableUnits: int + MustTargetShieldCapableUnits: int + CantTargetShieldCapableUnits: int + MustTargetFlyers: int + CantTargetFlyers: int + MustTargetBuriedUnits: int + CantTargetBuriedUnits: int + MustTargetCloakedUnits: int + CantTargetCloakedUnits: int + MustTargetUnitsInAStasisField: int + CantTargetUnitsInAStasisField: int + MustTargetUnderConstructionUnits: int + CantTargetUnderConstructionUnits: int + MustTargetDeadUnits: int + CantTargetDeadUnits: int + MustTargetRevivableUnits: int + CantTargetRevivableUnits: int + MustTargetHiddenUnits: int + CantTargetHiddenUnits: int + CantRechargeOtherPlayersUnits: int + MustTargetHallucinations: int + CantTargetHallucinations: int + MustTargetInvulnerableUnits: int + CantTargetInvulnerableUnits: int + MustTargetDetectedUnits: int + CantTargetDetectedUnits: int + CantTargetUnitWithEnergy: int + CantTargetUnitWithShields: int + MustTargetUncommandableUnits: int + CantTargetUncommandableUnits: int + MustTargetPreventDefeatUnits: int + CantTargetPreventDefeatUnits: int + MustTargetPreventRevealUnits: int + CantTargetPreventRevealUnits: int + MustTargetPassiveUnits: int + CantTargetPassiveUnits: int + MustTargetStunnedUnits: int + CantTargetStunnedUnits: int + MustTargetSummonedUnits: int + CantTargetSummonedUnits: int + MustTargetUser1: int + CantTargetUser1: int + MustTargetUnstoppableUnits: int + CantTargetUnstoppableUnits: int + MustTargetResistantUnits: int + CantTargetResistantUnits: int + MustTargetDazedUnits: int + CantTargetDazedUnits: int + CantLockdown: int + CantMindControl: int + MustTargetDestructibles: int + CantTargetDestructibles: int + MustTargetItems: int + CantTargetItems: int + NoCalldownAvailable: int + WaypointListFull: int + MustTargetRace: int + CantTargetRace: int + MustTargetSimilarUnits: int + CantTargetSimilarUnits: int + CantFindEnoughTargets: int + AlreadySpawningLarva: int + CantTargetExhaustedResources: int + CantUseMinimap: int + CantUseInfoPanel: int + OrderQueueIsFull: int + CantHarvestThatResource: int + HarvestersNotRequired: int + AlreadyTargeted: int + CantAttackWeaponsDisabled: int + CouldntReachTarget: int + TargetIsOutOfRange: int + TargetIsTooClose: int + TargetIsOutOfArc: int + CantFindTeleportLocation: int + InvalidItemClass: int + CantFindCancelOrder: int diff --git a/s2clientprotocol/query_pb2.pyi b/s2clientprotocol/query_pb2.pyi new file mode 100644 index 00000000..746d86d9 --- /dev/null +++ b/s2clientprotocol/query_pb2.pyi @@ -0,0 +1,74 @@ +from collections.abc import Iterable + +from google.protobuf.message import Message + +from .common_pb2 import AvailableAbility, Point2D + +class RequestQuery(Message): + pathing: Iterable[RequestQueryPathing] + abilities: Iterable[RequestQueryAvailableAbilities] + placements: Iterable[RequestQueryBuildingPlacement] + ignore_resource_requirements: bool + def __init__( + self, + pathing: Iterable[RequestQueryPathing] = ..., + abilities: Iterable[RequestQueryAvailableAbilities] = ..., + placements: Iterable[RequestQueryBuildingPlacement] = ..., + ignore_resource_requirements: bool = ..., + ) -> None: ... + +class ResponseQuery(Message): + pathing: Iterable[ResponseQueryPathing] + abilities: Iterable[ResponseQueryAvailableAbilities] + placements: Iterable[ResponseQueryBuildingPlacement] + def __init__( + self, + pathing: Iterable[ResponseQueryPathing] = ..., + abilities: Iterable[ResponseQueryAvailableAbilities] = ..., + placements: Iterable[ResponseQueryBuildingPlacement] = ..., + ) -> None: ... + +class RequestQueryPathing(Message): + start_pos: Point2D + unit_tag: int + end_pos: Point2D + def __init__( + self, + start_pos: Point2D = ..., + unit_tag: int = ..., + end_pos: Point2D = ..., + ) -> None: ... + +class ResponseQueryPathing(Message): + distance: float + def __init__(self, distance: float = ...) -> None: ... + +class RequestQueryAvailableAbilities(Message): + unit_tag: int + def __init__(self, unit_tag: int = ...) -> None: ... + +class ResponseQueryAvailableAbilities(Message): + abilities: Iterable[AvailableAbility] + unit_tag: int + unit_type_id: int + def __init__( + self, + abilities: Iterable[AvailableAbility] = ..., + unit_tag: int = ..., + unit_type_id: int = ..., + ) -> None: ... + +class RequestQueryBuildingPlacement(Message): + ability_id: int + target_pos: Point2D + placing_unit_tag: int + def __init__( + self, + ability_id: int = ..., + target_pos: Point2D = ..., + placing_unit_tag: int = ..., + ) -> None: ... + +class ResponseQueryBuildingPlacement(Message): + result: int + def __init__(self, result: int = ...) -> None: ... diff --git a/s2clientprotocol/raw_pb2.pyi b/s2clientprotocol/raw_pb2.pyi new file mode 100644 index 00000000..34d89c6d --- /dev/null +++ b/s2clientprotocol/raw_pb2.pyi @@ -0,0 +1,272 @@ +from collections.abc import Iterable +from enum import Enum + +from google.protobuf.message import Message + +from .common_pb2 import ImageData, Point, Point2D, RectangleI, Size2DI + +class StartRaw(Message): + map_size: Size2DI + pathing_grid: ImageData + terrain_height: ImageData + placement_grid: ImageData + playable_area: RectangleI + start_locations: Iterable[Point2D] + def __init__( + self, + map_size: Size2DI = ..., + pathing_grid: ImageData = ..., + terrain_height: ImageData = ..., + placement_grid: ImageData = ..., + playable_area: RectangleI = ..., + start_locations: Iterable[Point2D] = ..., + ) -> None: ... + +class ObservationRaw(Message): + player: PlayerRaw + units: Iterable[Unit] + map_state: MapState + event: Event + effects: Iterable[Effect] + radar: Iterable[RadarRing] + def __init__( + self, + player: PlayerRaw = ..., + units: Iterable[Unit] = ..., + map_state: MapState = ..., + event: Event = ..., + effects: Iterable[Effect] = ..., + radar: Iterable[RadarRing] = ..., + ) -> None: ... + +class RadarRing(Message): + pos: Point + radius: float + def __init__(self, pos: Point = ..., radius: float = ...) -> None: ... + +class PowerSource(Message): + pos: Point + radius: float + tag: int + def __init__(self, pos: Point = ..., radius: float = ..., tag: int = ...) -> None: ... + +class PlayerRaw(Message): + power_sources: Iterable[PowerSource] + camera: Point + upgrade_ids: Iterable[int] + def __init__( + self, + power_sources: Iterable[PowerSource] = ..., + camera: Point = ..., + upgrade_ids: Iterable[int] = ..., + ) -> None: ... + +class UnitOrder(Message): + ability_id: int + target_world_space_pos: Point + target_unit_tag: int + progress: float + def __init__( + self, + ability_id: int = ..., + target_world_space_pos: Point = ..., + target_unit_tag: int = ..., + progress: float = ..., + ) -> None: ... + +class DisplayType(Enum): + Visible: int + Snapshot: int + Hidden: int + Placeholder: int + +class Alliance(Enum): + Self: int + Ally: int + Neutral: int + Enemy: int + +class CloakState(Enum): + CloakedUnknown: int + Cloaked: int + CloakedDetected: int + NotCloaked: int + CloakedAllied: int + +class PassengerUnit(Message): + tag: int + health: float + health_max: float + shield: float + shield_max: float + energy: float + energy_max: float + unit_type: int + def __init__( + self, + tag: int = ..., + health: float = ..., + health_max: float = ..., + shield: float = ..., + shield_max: float = ..., + energy: float = ..., + energy_max: float = ..., + unit_type: int = ..., + ) -> None: ... + +class RallyTarget(Message): + point: Point + tag: int + def __init__(self, point: Point = ..., tag: int = ...) -> None: ... + +class Unit(Message): + display_type: int + alliance: int + tag: int + unit_type: int + owner: int + pos: Point + facing: float + radius: float + build_progress: float + cloak: int + buff_ids: Iterable[int] + detect_range: float + radar_range: float + is_selected: bool + is_on_screen: bool + is_blip: bool + is_powered: bool + is_active: bool + attack_upgrade_level: int + armor_upgrade_level: int + shield_upgrade_level: int + health: float + health_max: float + shield: float + shield_max: float + energy: float + energy_max: float + mineral_contents: int + vespene_contents: int + is_flying: bool + is_burrowed: bool + is_hallucination: bool + orders: Iterable[UnitOrder] + add_on_tag: int + passengers: Iterable[PassengerUnit] + cargo_space_taken: int + cargo_space_max: int + assigned_harvesters: int + ideal_harvesters: int + weapon_cooldown: float + engaged_target_tag: int + buff_duration_remain: int + buff_duration_max: int + rally_targets: Iterable[RallyTarget] + def __init__( + self, + display_type: int = ..., + alliance: int = ..., + tag: int = ..., + unit_type: int = ..., + owner: int = ..., + pos: Point = ..., + facing: float = ..., + radius: float = ..., + build_progress: float = ..., + cloak: int = ..., + buff_ids: Iterable[int] = ..., + detect_range: float = ..., + radar_range: float = ..., + is_selected: bool = ..., + is_on_screen: bool = ..., + is_blip: bool = ..., + is_powered: bool = ..., + is_active: bool = ..., + attack_upgrade_level: int = ..., + armor_upgrade_level: int = ..., + shield_upgrade_level: int = ..., + health: float = ..., + health_max: float = ..., + shield: float = ..., + shield_max: float = ..., + energy: float = ..., + energy_max: float = ..., + mineral_contents: int = ..., + vespene_contents: int = ..., + is_flying: bool = ..., + is_burrowed: bool = ..., + is_hallucination: bool = ..., + orders: Iterable[UnitOrder] = ..., + add_on_tag: int = ..., + passengers: Iterable[PassengerUnit] = ..., + cargo_space_taken: int = ..., + cargo_space_max: int = ..., + assigned_harvesters: int = ..., + ideal_harvesters: int = ..., + weapon_cooldown: float = ..., + engaged_target_tag: int = ..., + buff_duration_remain: int = ..., + buff_duration_max: int = ..., + rally_targets: Iterable[RallyTarget] = ..., + ) -> None: ... + +class MapState(Message): + visibility: ImageData + creep: ImageData + def __init__(self, visibility: ImageData = ..., creep: ImageData = ...) -> None: ... + +class Event(Message): + dead_units: Iterable[int] + def __init__(self, dead_units: Iterable[int] = ...) -> None: ... + +class Effect(Message): + effect_id: int + pos: Iterable[Point2D] + alliance: int + owner: int + radius: float + def __init__( + self, + effect_id: int = ..., + pos: Iterable[Point2D] = ..., + alliance: int = ..., + owner: int = ..., + radius: float = ..., + ) -> None: ... + +class ActionRaw(Message): + unit_command: ActionRawUnitCommand + camera_move: ActionRawCameraMove + toggle_autocast: ActionRawToggleAutocast + def __init__( + self, + unit_command: ActionRawUnitCommand = ..., + camera_move: ActionRawCameraMove = ..., + toggle_autocast: ActionRawToggleAutocast = ..., + ) -> None: ... + +class ActionRawUnitCommand(Message): + ability_id: int + target_world_space_pos: Point2D + target_unit_tag: int + unit_tags: Iterable[int] + queue_command: bool + def __init__( + self, + ability_id: int = ..., + target_world_space_pos: Point2D = ..., + target_unit_tag: int = ..., + unit_tags: Iterable[int] = ..., + queue_command: bool = ..., + ) -> None: ... + +class ActionRawCameraMove(Message): + center_world_space: Point + def __init__(self, center_world_space: Point = ...) -> None: ... + +class ActionRawToggleAutocast(Message): + ability_id: int + unit_tags: Iterable[int] + def __init__(self, ability_id: int = ..., unit_tags: Iterable[int] = ...) -> None: ... diff --git a/s2clientprotocol/sc2api_pb2.pyi b/s2clientprotocol/sc2api_pb2.pyi new file mode 100644 index 00000000..67574e00 --- /dev/null +++ b/s2clientprotocol/sc2api_pb2.pyi @@ -0,0 +1,748 @@ +from __future__ import annotations + +from collections.abc import Iterable +from enum import Enum + +from google.protobuf.message import Message + +from s2clientprotocol.spatial_pb2 import ActionSpatial, ObservationFeatureLayer, ObservationRender + +from .common_pb2 import AvailableAbility, Point2D, Size2DI +from .data_pb2 import AbilityData, BuffData, EffectData, UnitTypeData, UpgradeData +from .debug_pb2 import DebugCommand +from .query_pb2 import RequestQuery, ResponseQuery +from .raw_pb2 import ActionRaw, ObservationRaw, StartRaw +from .score_pb2 import Score +from .ui_pb2 import ActionUI, ObservationUI + +class Request(Message): + create_game: RequestCreateGame + join_game: RequestJoinGame + restart_game: RequestRestartGame + start_replay: RequestStartReplay + leave_game: RequestLeaveGame + quick_save: RequestQuickSave + quick_load: RequestQuickLoad + quit: RequestQuit + game_info: RequestGameInfo + observation: RequestObservation + action: RequestAction + obs_action: RequestObserverAction + step: RequestStep + data: RequestData + query: RequestQuery + save_replay: RequestSaveReplay + map_command: RequestMapCommand + replay_info: RequestReplayInfo + available_maps: RequestAvailableMaps + save_map: RequestSaveMap + ping: RequestPing + debug: RequestDebug + id: int + def __init__( + self, + create_game: RequestCreateGame = ..., + join_game: RequestJoinGame = ..., + restart_game: RequestRestartGame = ..., + start_replay: RequestStartReplay = ..., + leave_game: RequestLeaveGame = ..., + quick_save: RequestQuickSave = ..., + quick_load: RequestQuickLoad = ..., + quit: RequestQuit = ..., + game_info: RequestGameInfo = ..., + observation: RequestObservation = ..., + action: RequestAction = ..., + obs_action: RequestObserverAction = ..., + step: RequestStep = ..., + data: RequestData = ..., + query: RequestQuery = ..., + save_replay: RequestSaveReplay = ..., + map_command: RequestMapCommand = ..., + replay_info: RequestReplayInfo = ..., + available_maps: RequestAvailableMaps = ..., + save_map: RequestSaveMap = ..., + ping: RequestPing = ..., + debug: RequestDebug = ..., + id: int = ..., + ) -> None: ... + +class Response(Message): + create_game: ResponseCreateGame + join_game: ResponseJoinGame + restart_game: ResponseRestartGame + start_replay: ResponseStartReplay + leave_game: ResponseLeaveGame + quick_save: ResponseQuickSave + quick_load: ResponseQuickLoad + quit: ResponseQuit + game_info: ResponseGameInfo + observation: ResponseObservation + action: ResponseAction + obs_action: ResponseObserverAction + step: ResponseStep + data: ResponseData + query: ResponseQuery + save_replay: ResponseSaveReplay + replay_info: ResponseReplayInfo + available_maps: ResponseAvailableMaps + save_map: ResponseSaveMap + map_command: ResponseMapCommand + ping: ResponsePing + debug: ResponseDebug + id: int + error: Iterable[str] + status: int + def __init__( + self, + create_game: ResponseCreateGame = ..., + join_game: ResponseJoinGame = ..., + restart_game: ResponseRestartGame = ..., + start_replay: ResponseStartReplay = ..., + leave_game: ResponseLeaveGame = ..., + quick_save: ResponseQuickSave = ..., + quick_load: ResponseQuickLoad = ..., + quit: ResponseQuit = ..., + game_info: ResponseGameInfo = ..., + observation: ResponseObservation = ..., + action: ResponseAction = ..., + obs_action: ResponseObserverAction = ..., + step: ResponseStep = ..., + data: ResponseData = ..., + query: ResponseQuery = ..., + save_replay: ResponseSaveReplay = ..., + replay_info: ResponseReplayInfo = ..., + available_maps: ResponseAvailableMaps = ..., + save_map: ResponseSaveMap = ..., + map_command: ResponseMapCommand = ..., + ping: ResponsePing = ..., + debug: ResponseDebug = ..., + id: int = ..., + error: Iterable[str] = ..., + status: int = ..., + ) -> None: ... + +class Status(Enum): + launched: int + init_game: int + in_game: int + in_replay: int + ended: int + quit: int + unknown: int + +class RequestCreateGame(Message): + local_map: LocalMap + battlenet_map_name: str + player_setup: Iterable[PlayerSetup] + disable_fog: bool + random_seed: int + realtime: bool + def __init__( + self, + local_map: LocalMap = ..., + battlenet_map_name: str = ..., + player_setup: Iterable[PlayerSetup] = ..., + disable_fog: bool = ..., + random_seed: int = ..., + realtime: bool = ..., + ) -> None: ... + +class LocalMap(Message): + map_path: str + map_data: bytes + def __init__(self, map_path: str = ..., map_data: bytes = ...) -> None: ... + +class ResponseCreateGame(Message): + class Error(Enum): + MissingMap: int + InvalidMapPath: int + InvalidMapData: int + InvalidMapName: int + InvalidMapHandle: int + MissingPlayerSetup: int + InvalidPlayerSetup: int + MultiplayerUnsupported: int + + error: int + error_details: str + def __init__(self, error: int = ..., error_details: str = ...) -> None: ... + +class RequestJoinGame(Message): + race: int + observed_player_id: int + options: InterfaceOptions + server_ports: PortSet + client_ports: Iterable[PortSet] + shared_port: int + player_name: str + host_ip: str + def __init__( + self, + race: int = ..., + observed_player_id: int = ..., + options: InterfaceOptions = ..., + server_ports: PortSet = ..., + client_ports: Iterable[PortSet] = ..., + shared_port: int = ..., + player_name: str = ..., + host_ip: str = ..., + ) -> None: ... + +class PortSet(Message): + game_port: int + base_port: int + def __init__(self, game_port: int = ..., base_port: int = ...) -> None: ... + +class ResponseJoinGame(Message): + class Error(Enum): + MissingParticipation: int + InvalidObservedPlayerId: int + MissingOptions: int + MissingPorts: int + GameFull: int + LaunchError: int + FeatureUnsupported: int + NoSpaceForUser: int + MapDoesNotExist: int + CannotOpenMap: int + ChecksumError: int + NetworkError: int + OtherError: int + + player_id: int + error: int + error_details: str + def __init__(self, player_id: int = ..., error: int = ..., error_details: str = ...) -> None: ... + +class RequestRestartGame(Message): + def __init__(self) -> None: ... + +class ResponseRestartGame(Message): + class Error(Enum): + LaunchError: int + + error: int + error_details: str + need_hard_reset: bool + def __init__(self, error: int = ..., error_details: str = ..., need_hard_reset: bool = ...) -> None: ... + +class RequestStartReplay(Message): + replay_path: str + replay_data: bytes + map_data: bytes + observed_player_id: int + options: InterfaceOptions + disable_fog: bool + realtime: bool + record_replay: bool + def __init__( + self, + replay_path: str = ..., + replay_data: bytes = ..., + map_data: bytes = ..., + observed_player_id: int = ..., + options: InterfaceOptions = ..., + disable_fog: bool = ..., + realtime: bool = ..., + record_replay: bool = ..., + ) -> None: ... + +class ResponseStartReplay(Message): + class Error(Enum): + MissingReplay: int + InvalidReplayPath: int + InvalidReplayData: int + InvalidMapData: int + InvalidObservedPlayerId: int + MissingOptions: int + LaunchError: int + + error: int + error_details: str + def __init__(self, error: int = ..., error_details: str = ...) -> None: ... + +class RequestMapCommand(Message): + trigger_cmd: str + def __init__(self, trigger_cmd: str = ...) -> None: ... + +class ResponseMapCommand(Message): + class Error(Enum): + NoTriggerError: int + + error: int + error_details: str + def __init__(self, error: int = ..., error_details: str = ...) -> None: ... + +class RequestLeaveGame(Message): + def __init__(self) -> None: ... + +class ResponseLeaveGame(Message): + def __init__(self) -> None: ... + +class RequestQuickSave(Message): + def __init__(self) -> None: ... + +class ResponseQuickSave(Message): + def __init__(self) -> None: ... + +class RequestQuickLoad(Message): + def __init__(self) -> None: ... + +class ResponseQuickLoad(Message): + def __init__(self) -> None: ... + +class RequestQuit(Message): + def __init__(self) -> None: ... + +class ResponseQuit(Message): + def __init__(self) -> None: ... + +class RequestGameInfo(Message): + def __init__(self) -> None: ... + +class ResponseGameInfo(Message): + map_name: str + mod_names: Iterable[str] + local_map_path: str + player_info: Iterable[PlayerInfo] + start_raw: StartRaw + options: InterfaceOptions + def __init__( + self, + map_name: str = ..., + mod_names: Iterable[str] = ..., + local_map_path: str = ..., + player_info: Iterable[PlayerInfo] = ..., + start_raw: StartRaw = ..., + options: InterfaceOptions = ..., + ) -> None: ... + +class RequestObservation(Message): + disable_fog: bool + game_loop: int + def __init__(self, disable_fog: bool = ..., game_loop: int = ...) -> None: ... + +class ResponseObservation(Message): + actions: Iterable[Action] + action_errors: Iterable[ActionError] + observation: Observation + player_result: Iterable[PlayerResult] + chat: Iterable[ChatReceived] + def __init__( + self, + actions: Iterable[Action] = ..., + action_errors: Iterable[ActionError] = ..., + observation: Observation = ..., + player_result: Iterable[PlayerResult] = ..., + chat: Iterable[ChatReceived] = ..., + ) -> None: ... + +class ChatReceived(Message): + player_id: int + message: str + def __init__(self, player_id: int = ..., message: str = ...) -> None: ... + +class RequestAction(Message): + actions: Iterable[Action] + def __init__(self, actions: Iterable[Action] = ...) -> None: ... + +class ResponseAction(Message): + result: Iterable[int] + def __init__(self, result: Iterable[int] = ...) -> None: ... + +class RequestObserverAction(Message): + actions: Iterable[ObserverAction] + def __init__(self, actions: Iterable[ObserverAction] = ...) -> None: ... + +class ResponseObserverAction(Message): + def __init__(self) -> None: ... + +class RequestStep(Message): + count: int + def __init__(self, count: int = ...) -> None: ... + +class ResponseStep(Message): + simulation_loop: int + def __init__(self, simulation_loop: int = ...) -> None: ... + +class RequestData(Message): + ability_id: bool + unit_type_id: bool + upgrade_id: bool + buff_id: bool + effect_id: bool + def __init__( + self, + ability_id: bool = ..., + unit_type_id: bool = ..., + upgrade_id: bool = ..., + buff_id: bool = ..., + effect_id: bool = ..., + ) -> None: ... + +class ResponseData(Message): + abilities: Iterable[AbilityData] + units: Iterable[UnitTypeData] + upgrades: Iterable[UpgradeData] + buffs: Iterable[BuffData] + effects: Iterable[EffectData] + def __init__( + self, + abilities: Iterable[AbilityData] = ..., + units: Iterable[UnitTypeData] = ..., + upgrades: Iterable[UpgradeData] = ..., + buffs: Iterable[BuffData] = ..., + effects: Iterable[EffectData] = ..., + ) -> None: ... + +class RequestSaveReplay(Message): + def __init__(self) -> None: ... + +class ResponseSaveReplay(Message): + data: bytes + def __init__(self, data: bytes = ...) -> None: ... + +class RequestReplayInfo(Message): + replay_path: str + replay_data: bytes + download_data: bool + def __init__( + self, + replay_path: str = ..., + replay_data: bytes = ..., + download_data: bool = ..., + ) -> None: ... + +class PlayerInfoExtra(Message): + player_info: PlayerInfo + player_result: PlayerResult + player_mmr: int + player_apm: int + def __init__( + self, + player_info: PlayerInfo = ..., + player_result: PlayerResult = ..., + player_mmr: int = ..., + player_apm: int = ..., + ) -> None: ... + +class ResponseReplayInfo(Message): + class Error(Enum): + MissingReplay: int + InvalidReplayPath: int + InvalidReplayData: int + ParsingError: int + DownloadError: int + + map_name: str + local_map_path: str + player_info: Iterable[PlayerInfoExtra] + game_duration_loops: int + game_duration_seconds: float + game_version: str + data_version: str + data_build: int + base_build: int + error: int + error_details: str + def __init__( + self, + map_name: str = ..., + local_map_path: str = ..., + player_info: Iterable[PlayerInfoExtra] = ..., + game_duration_loops: int = ..., + game_duration_seconds: float = ..., + game_version: str = ..., + data_version: str = ..., + data_build: int = ..., + base_build: int = ..., + error: int = ..., + error_details: str = ..., + ) -> None: ... + +class RequestAvailableMaps(Message): + def __init__(self) -> None: ... + +class ResponseAvailableMaps(Message): + local_map_paths: Iterable[str] + battlenet_map_names: Iterable[str] + def __init__(self, local_map_paths: Iterable[str] = ..., battlenet_map_names: Iterable[str] = ...) -> None: ... + +class RequestSaveMap(Message): + map_path: str + map_data: bytes + def __init__(self, map_path: str = ..., map_data: bytes = ...) -> None: ... + +class ResponseSaveMap(Message): + class Error(Enum): + InvalidMapData: int + + error: int + def __init__(self, error: int = ...) -> None: ... + +class RequestPing(Message): + def __init__(self) -> None: ... + +class ResponsePing(Message): + game_version: str + data_version: str + data_build: int + base_build: int + def __init__( + self, + game_version: str = ..., + data_version: str = ..., + data_build: int = ..., + base_build: int = ..., + ) -> None: ... + +class RequestDebug(Message): + debug: Iterable[DebugCommand] + def __init__(self, debug: Iterable[DebugCommand] = ...) -> None: ... + +class ResponseDebug(Message): + def __init__(self) -> None: ... + +class Difficulty(Enum): + VeryEasy: int + Easy: int + Medium: int + MediumHard: int + Hard: int + Harder: int + VeryHard: int + CheatVision: int + CheatMoney: int + CheatInsane: int + +class PlayerType(Enum): + Participant: int + Computer: int + Observer: int + +class AIBuild(Enum): + RandomBuild: int + Rush: int + Timing: int + Power: int + Macro: int + Air: int + +class PlayerSetup(Message): + type: int + race: int + difficulty: int + player_name: str + ai_build: int + def __init__( + self, + type: int = ..., + race: int = ..., + difficulty: int = ..., + player_name: str = ..., + ai_build: int = ..., + ) -> None: ... + +class SpatialCameraSetup(Message): + resolution: Size2DI + minimap_resolution: Size2DI + width: float + crop_to_playable_area: bool + allow_cheating_layers: bool + def __init__( + self, + resolution: Size2DI = ..., + minimap_resolution: Size2DI = ..., + width: float = ..., + crop_to_playable_area: bool = ..., + allow_cheating_layers: bool = ..., + ) -> None: ... + +class InterfaceOptions(Message): + raw: bool + score: bool + feature_layer: SpatialCameraSetup + render: SpatialCameraSetup + show_cloaked: bool + show_burrowed_shadows: bool + show_placeholders: bool + raw_affects_selection: bool + raw_crop_to_playable_area: bool + def __init__( + self, + raw: bool = ..., + score: bool = ..., + feature_layer: SpatialCameraSetup = ..., + render: SpatialCameraSetup = ..., + show_cloaked: bool = ..., + show_burrowed_shadows: bool = ..., + show_placeholders: bool = ..., + raw_affects_selection: bool = ..., + raw_crop_to_playable_area: bool = ..., + ) -> None: ... + +class PlayerInfo(Message): + player_id: int + type: int + race_requested: int + race_actual: int + difficulty: int + ai_build: int + player_name: str + def __init__( + self, + player_id: int = ..., + type: int = ..., + race_requested: int = ..., + race_actual: int = ..., + difficulty: int = ..., + ai_build: int = ..., + player_name: str = ..., + ) -> None: ... + +class PlayerCommon(Message): + player_id: int + minerals: int + vespene: int + food_cap: int + food_used: int + food_army: int + food_workers: int + idle_worker_count: int + army_count: int + warp_gate_count: int + larva_count: int + def __init__( + self, + player_id: int = ..., + minerals: int = ..., + vespene: int = ..., + food_cap: int = ..., + food_used: int = ..., + food_army: int = ..., + food_workers: int = ..., + idle_worker_count: int = ..., + army_count: int = ..., + warp_gate_count: int = ..., + larva_count: int = ..., + ) -> None: ... + +class Observation(Message): + game_loop: int + player_common: PlayerCommon + alerts: Iterable[int] + abilities: Iterable[AvailableAbility] + score: Score + raw_data: ObservationRaw + feature_layer_data: ObservationFeatureLayer + render_data: ObservationRender + ui_data: ObservationUI + def __init__( + self, + game_loop: int = ..., + player_common: PlayerCommon = ..., + alerts: Iterable[int] = ..., + abilities: Iterable[AvailableAbility] = ..., + score: Score = ..., + raw_data: ObservationRaw = ..., + feature_layer_data: ObservationFeatureLayer = ..., + render_data: ObservationRender = ..., + ui_data: ObservationUI = ..., + ) -> None: ... + +class Action(Message): + action_raw: ActionRaw + action_feature_layer: ActionSpatial + action_render: ActionSpatial + action_ui: ActionUI + action_chat: ActionChat + game_loop: int + def __init__( + self, + action_raw: ActionRaw = ..., + action_feature_layer: ActionSpatial = ..., + action_render: ActionSpatial = ..., + action_ui: ActionUI = ..., + action_chat: ActionChat = ..., + game_loop: int = ..., + ) -> None: ... + +class Channel(Enum): + Broadcast: int + Team: int + +class ActionChat(Message): + channel: int + message: str + def __init__(self, channel: int = ..., message: str = ...) -> None: ... + +class ActionError(Message): + unit_tag: int + ability_id: int + result: int + def __init__(self, unit_tag: int = ..., ability_id: int = ..., result: int = ...) -> None: ... + +class ObserverAction(Message): + player_perspective: ActionObserverPlayerPerspective + camera_move: ActionObserverCameraMove + camera_follow_player: ActionObserverCameraFollowPlayer + camera_follow_units: ActionObserverCameraFollowUnits + def __init__( + self, + player_perspective: ActionObserverPlayerPerspective = ..., + camera_move: ActionObserverCameraMove = ..., + camera_follow_player: ActionObserverCameraFollowPlayer = ..., + camera_follow_units: ActionObserverCameraFollowUnits = ..., + ) -> None: ... + +class ActionObserverPlayerPerspective(Message): + player_id: int + def __init__(self, player_id: int = ...) -> None: ... + +class ActionObserverCameraMove(Message): + world_pos: Point2D + distance: float + def __init__(self, world_pos: Point2D = ..., distance: float = ...) -> None: ... + +class ActionObserverCameraFollowPlayer(Message): + player_id: int + def __init__(self, player_id: int = ...) -> None: ... + +class ActionObserverCameraFollowUnits(Message): + unit_tags: Iterable[int] + def __init__(self, unit_tags: Iterable[int] = ...) -> None: ... + +class Alert(Enum): + AlertError: int + AddOnComplete: int + BuildingComplete: int + BuildingUnderAttack: int + LarvaHatched: int + MergeComplete: int + MineralsExhausted: int + MorphComplete: int + MothershipComplete: int + MULEExpired: int + NuclearLaunchDetected: int + NukeComplete: int + NydusWormDetected: int + ResearchComplete: int + TrainError: int + TrainUnitComplete: int + TrainWorkerComplete: int + TransformationComplete: int + UnitUnderAttack: int + UpgradeComplete: int + VespeneExhausted: int + WarpInComplete: int + +class Result(Enum): + Victory: int + Defeat: int + Tie: int + Undecided: int + +class PlayerResult(Message): + player_id: int + result: int + def __init__(self, player_id: int = ..., result: int = ...) -> None: ... diff --git a/s2clientprotocol/score_pb2.pyi b/s2clientprotocol/score_pb2.pyi new file mode 100644 index 00000000..88c47391 --- /dev/null +++ b/s2clientprotocol/score_pb2.pyi @@ -0,0 +1,107 @@ +from __future__ import annotations + +from enum import Enum + +from google.protobuf.message import Message + +class ScoreType(Enum): + Curriculum: int + Melee: int + +class Score(Message): + score_type: int + score: int + score_details: ScoreDetails + def __init__( + self, + score_type: int = ..., + score: int = ..., + score_details: ScoreDetails = ..., + ) -> None: ... + +class CategoryScoreDetails(Message): + none: float + army: float + economy: float + technology: float + upgrade: float + def __init__( + self, + none: float = ..., + army: float = ..., + economy: float = ..., + technology: float = ..., + upgrade: float = ..., + ) -> None: ... + +class VitalScoreDetails(Message): + life: float + shields: float + energy: float + def __init__( + self, + life: float = ..., + shields: float = ..., + energy: float = ..., + ) -> None: ... + +class ScoreDetails(Message): + idle_production_time: float + idle_worker_time: float + total_value_units: float + total_value_structures: float + killed_value_units: float + killed_value_structures: float + collected_minerals: float + collected_vespene: float + collection_rate_minerals: float + collection_rate_vespene: float + spent_minerals: float + spent_vespene: float + food_used: CategoryScoreDetails + killed_minerals: CategoryScoreDetails + killed_vespene: CategoryScoreDetails + lost_minerals: CategoryScoreDetails + lost_vespene: CategoryScoreDetails + friendly_fire_minerals: CategoryScoreDetails + friendly_fire_vespene: CategoryScoreDetails + used_minerals: CategoryScoreDetails + used_vespene: CategoryScoreDetails + total_used_minerals: CategoryScoreDetails + total_used_vespene: CategoryScoreDetails + total_damage_dealt: VitalScoreDetails + total_damage_taken: VitalScoreDetails + total_healed: VitalScoreDetails + current_apm: float + current_effective_apm: float + def __init__( + self, + idle_production_time: float = ..., + idle_worker_time: float = ..., + total_value_units: float = ..., + total_value_structures: float = ..., + killed_value_units: float = ..., + killed_value_structures: float = ..., + collected_minerals: float = ..., + collected_vespene: float = ..., + collection_rate_minerals: float = ..., + collection_rate_vespene: float = ..., + spent_minerals: float = ..., + spent_vespene: float = ..., + food_used: CategoryScoreDetails = ..., + killed_minerals: CategoryScoreDetails = ..., + killed_vespene: CategoryScoreDetails = ..., + lost_minerals: CategoryScoreDetails = ..., + lost_vespene: CategoryScoreDetails = ..., + friendly_fire_minerals: CategoryScoreDetails = ..., + friendly_fire_vespene: CategoryScoreDetails = ..., + used_minerals: CategoryScoreDetails = ..., + used_vespene: CategoryScoreDetails = ..., + total_used_minerals: CategoryScoreDetails = ..., + total_used_vespene: CategoryScoreDetails = ..., + total_damage_dealt: VitalScoreDetails = ..., + total_damage_taken: VitalScoreDetails = ..., + total_healed: VitalScoreDetails = ..., + current_apm: float = ..., + current_effective_apm: float = ..., + ) -> None: ... diff --git a/s2clientprotocol/spatial_pb2.pyi b/s2clientprotocol/spatial_pb2.pyi new file mode 100644 index 00000000..a1b72a29 --- /dev/null +++ b/s2clientprotocol/spatial_pb2.pyi @@ -0,0 +1,154 @@ +from __future__ import annotations + +from collections.abc import Iterable +from enum import Enum + +from google.protobuf.message import Message + +from .common_pb2 import ImageData, PointI, RectangleI + +class ObservationFeatureLayer(Message): + renders: FeatureLayers + minimap_renders: FeatureLayersMinimap + def __init__( + self, + renders: FeatureLayers = ..., + minimap_renders: FeatureLayersMinimap = ..., + ) -> None: ... + +class FeatureLayers(Message): + height_map: ImageData + visibility_map: ImageData + creep: ImageData + power: ImageData + player_id: ImageData + unit_type: ImageData + selected: ImageData + unit_hit_points: ImageData + unit_hit_points_ratio: ImageData + unit_energy: ImageData + unit_energy_ratio: ImageData + unit_shields: ImageData + unit_shields_ratio: ImageData + player_relative: ImageData + unit_density_aa: ImageData + unit_density: ImageData + effects: ImageData + hallucinations: ImageData + cloaked: ImageData + blip: ImageData + buffs: ImageData + buff_duration: ImageData + active: ImageData + build_progress: ImageData + buildable: ImageData + pathable: ImageData + placeholder: ImageData + def __init__( + self, + height_map: ImageData = ..., + visibility_map: ImageData = ..., + creep: ImageData = ..., + power: ImageData = ..., + player_id: ImageData = ..., + unit_type: ImageData = ..., + selected: ImageData = ..., + unit_hit_points: ImageData = ..., + unit_hit_points_ratio: ImageData = ..., + unit_energy: ImageData = ..., + unit_energy_ratio: ImageData = ..., + unit_shields: ImageData = ..., + unit_shields_ratio: ImageData = ..., + player_relative: ImageData = ..., + unit_density_aa: ImageData = ..., + unit_density: ImageData = ..., + effects: ImageData = ..., + hallucinations: ImageData = ..., + cloaked: ImageData = ..., + blip: ImageData = ..., + buffs: ImageData = ..., + buff_duration: ImageData = ..., + active: ImageData = ..., + build_progress: ImageData = ..., + buildable: ImageData = ..., + pathable: ImageData = ..., + placeholder: ImageData = ..., + ) -> None: ... + +class FeatureLayersMinimap(Message): + height_map: ImageData + visibility_map: ImageData + creep: ImageData + camera: ImageData + player_id: ImageData + player_relative: ImageData + selected: ImageData + alerts: ImageData + buildable: ImageData + pathable: ImageData + unit_type: ImageData + def __init__( + self, + height_map: ImageData = ..., + visibility_map: ImageData = ..., + creep: ImageData = ..., + camera: ImageData = ..., + player_id: ImageData = ..., + player_relative: ImageData = ..., + selected: ImageData = ..., + alerts: ImageData = ..., + buildable: ImageData = ..., + pathable: ImageData = ..., + unit_type: ImageData = ..., + ) -> None: ... + +class ObservationRender(Message): + map: ImageData + minimap: ImageData + def __init__(self, map: ImageData = ..., minimap: ImageData = ...) -> None: ... + +class ActionSpatial(Message): + unit_command: ActionSpatialUnitCommand + camera_move: ActionSpatialCameraMove + unit_selection_point: ActionSpatialUnitSelectionPoint + unit_selection_rect: ActionSpatialUnitSelectionRect + def __init__( + self, + unit_command: ActionSpatialUnitCommand = ..., + camera_move: ActionSpatialCameraMove = ..., + unit_selection_point: ActionSpatialUnitSelectionPoint = ..., + unit_selection_rect: ActionSpatialUnitSelectionRect = ..., + ) -> None: ... + +class ActionSpatialUnitCommand(Message): + ability_id: int + target_screen_coord: PointI + target_minimap_coord: PointI + queue_command: bool + def __init__( + self, + ability_id: int = ..., + target_screen_coord: PointI = ..., + target_minimap_coord: PointI = ..., + queue_command: bool = ..., + ) -> None: ... + +class ActionSpatialCameraMove(Message): + center_minimap: PointI + def __init__(self, center_minimap: PointI = ...) -> None: ... + +class Type(Enum): + Select: int + Toggle: int + AllType: int + AddAllType: int + +class ActionSpatialUnitSelectionPoint(Message): + selection_screen_coord: PointI + type: int + def __init__(self, selection_screen_coord: PointI = ..., type: int = ...) -> None: ... + +class ActionSpatialUnitSelectionRect(Message): + selection_screen_coord: Iterable[RectangleI] + selection_add: bool + def __init__(self, selection_screen_coord: Iterable[RectangleI] = ..., selection_add: bool = ...) -> None: ... diff --git a/s2clientprotocol/ui_pb2.pyi b/s2clientprotocol/ui_pb2.pyi new file mode 100644 index 00000000..dbf39f3b --- /dev/null +++ b/s2clientprotocol/ui_pb2.pyi @@ -0,0 +1,184 @@ +from __future__ import annotations + +from collections.abc import Iterable +from enum import Enum + +from google.protobuf.message import Message + +class ObservationUI(Message): + groups: Iterable[ControlGroup] + single: SinglePanel + multi: MultiPanel + cargo: CargoPanel + production: ProductionPanel + def __init__( + self, + groups: Iterable[ControlGroup] = ..., + single: SinglePanel = ..., + multi: MultiPanel = ..., + cargo: CargoPanel = ..., + production: ProductionPanel = ..., + ) -> None: ... + +class ControlGroup(Message): + control_group_index: int + leader_unit_type: int + count: int + def __init__( + self, + control_group_index: int = ..., + leader_unit_type: int = ..., + count: int = ..., + ) -> None: ... + +class UnitInfo(Message): + unit_type: int + player_relative: int + health: int + shields: int + energy: int + transport_slots_taken: int + build_progress: float + add_on: UnitInfo + max_health: int + max_shields: int + max_energy: int + def __init__( + self, + unit_type: int = ..., + player_relative: int = ..., + health: int = ..., + shields: int = ..., + energy: int = ..., + transport_slots_taken: int = ..., + build_progress: float = ..., + add_on: UnitInfo = ..., + max_health: int = ..., + max_shields: int = ..., + max_energy: int = ..., + ) -> None: ... + +class SinglePanel(Message): + unit: UnitInfo + attack_upgrade_level: int + armor_upgrade_level: int + shield_upgrade_level: int + buffs: Iterable[int] + def __init__( + self, + unit: UnitInfo = ..., + attack_upgrade_level: int = ..., + armor_upgrade_level: int = ..., + shield_upgrade_level: int = ..., + buffs: Iterable[int] = ..., + ) -> None: ... + +class MultiPanel(Message): + units: Iterable[UnitInfo] + def __init__(self, units: Iterable[UnitInfo] = ...) -> None: ... + +class CargoPanel(Message): + unit: UnitInfo + passengers: Iterable[UnitInfo] + slots_available: int + def __init__( + self, + unit: UnitInfo = ..., + passengers: Iterable[UnitInfo] = ..., + slots_available: int = ..., + ) -> None: ... + +class BuildItem(Message): + ability_id: int + build_progress: float + def __init__(self, ability_id: int = ..., build_progress: float = ...) -> None: ... + +class ProductionPanel(Message): + unit: UnitInfo + build_queue: Iterable[UnitInfo] + production_queue: Iterable[BuildItem] + def __init__( + self, + unit: UnitInfo = ..., + build_queue: Iterable[UnitInfo] = ..., + production_queue: Iterable[BuildItem] = ..., + ) -> None: ... + +class ActionUI(Message): + control_group: ActionControlGroup + select_army: ActionSelectArmy + select_warp_gates: ActionSelectWarpGates + select_larva: ActionSelectLarva + select_idle_worker: ActionSelectIdleWorker + multi_panel: ActionMultiPanel + cargo_panel: ActionCargoPanelUnload + production_panel: ActionProductionPanelRemoveFromQueue + toggle_autocast: ActionToggleAutocast + def __init__( + self, + control_group: ActionControlGroup = ..., + select_army: ActionSelectArmy = ..., + select_warp_gates: ActionSelectWarpGates = ..., + select_larva: ActionSelectLarva = ..., + select_idle_worker: ActionSelectIdleWorker = ..., + multi_panel: ActionMultiPanel = ..., + cargo_panel: ActionCargoPanelUnload = ..., + production_panel: ActionProductionPanelRemoveFromQueue = ..., + toggle_autocast: ActionToggleAutocast = ..., + ) -> None: ... + +class ControlGroupAction(Enum): + Recall: int + Set: int + Append: int + SetAndSteal: int + AppendAndSteal: int + +class ActionControlGroup(Message): + action: int + control_group_index: int + def __init__(self, action: int = ..., control_group_index: int = ...) -> None: ... + +class ActionSelectArmy(Message): + selection_add: bool + def __init__(self, selection_add: bool = ...) -> None: ... + +class ActionSelectWarpGates(Message): + selection_add: bool + def __init__(self, selection_add: bool = ...) -> None: ... + +class ActionSelectLarva(Message): + def __init__(self) -> None: ... + +class ActionSelectIdleWorker(Message): + class Type(Enum): + Set: int + Add: int + All: int + AddAll: int + + type: int + def __init__(self, type: int = ...) -> None: ... + +class ActionMultiPanel(Message): + class Type(Enum): + SingleSelect: int + DeselectUnit: int + SelectAllOfType: int + DeselectAllOfType: int + + type: int + unit_index: int + def __init__(self, type: int = ..., unit_index: int = ...) -> None: ... + +class ActionCargoPanelUnload(Message): + unit_index: int + def __init__(self, unit_index: int = ...) -> None: ... + +class ActionProductionPanelRemoveFromQueue(Message): + unit_index: int + def __init__(self, unit_index: int = ...) -> None: ... + +class ActionToggleAutocast(Message): + ability_id: int + def __init__(self, ability_id: int = ...) -> None: ... diff --git a/sc2/action.py b/sc2/action.py index 0500309e..b43124ae 100644 --- a/sc2/action.py +++ b/sc2/action.py @@ -3,9 +3,7 @@ from itertools import groupby from typing import TYPE_CHECKING -# pyre-ignore[21] from s2clientprotocol import raw_pb2 as raw_pb - from sc2.position import Point2 from sc2.unit import Unit @@ -14,7 +12,7 @@ from sc2.unit_command import UnitCommand -def combine_actions(action_iter): +def combine_actions(action_iter: list[UnitCommand]): """ Example input: [ @@ -57,7 +55,6 @@ def combine_actions(action_iter): I imagine the same thing would happen to certain other abilities: Battlecruiser yamato on same target, queen transfuse on same target, ghost snipe on same target, all build commands with the same unit type and also all morphs (zergling to banelings) However, other abilities can and should be grouped, see constants.py 'COMBINEABLE_ABILITIES' """ - u: UnitCommand if target is None: for u in items: cmd = raw_pb.ActionRawUnitCommand( @@ -73,7 +70,6 @@ def combine_actions(action_iter): target_world_space_pos=target.as_Point2D, ) yield raw_pb.ActionRaw(unit_command=cmd) - elif isinstance(target, Unit): for u in items: cmd = raw_pb.ActionRawUnitCommand( diff --git a/sc2/bot_ai.py b/sc2/bot_ai.py index a4bcfb86..d98e72a0 100644 --- a/sc2/bot_ai.py +++ b/sc2/bot_ai.py @@ -1187,7 +1187,7 @@ async def chat_send(self, message: str, team_only: bool = False) -> None: assert isinstance(message, str), f"{message} is not a string" await self.client.chat_send(message, team_only) - def in_map_bounds(self, pos: Point2 | tuple | list) -> bool: + def in_map_bounds(self, pos: Point2 | tuple[float, float] | list[float]) -> bool: """Tests if a 2 dimensional point is within the map boundaries of the pixelmaps. :param pos:""" diff --git a/sc2/bot_ai_internal.py b/sc2/bot_ai_internal.py index bb5738c2..d2bde3f7 100644 --- a/sc2/bot_ai_internal.py +++ b/sc2/bot_ai_internal.py @@ -14,9 +14,7 @@ import numpy as np from loguru import logger -# pyre-ignore[21] from s2clientprotocol import sc2api_pb2 as sc_pb - from sc2.cache import property_cache_once_per_frame from sc2.constants import ( ALL_GAS, @@ -35,14 +33,13 @@ from sc2.ids.unit_typeid import UnitTypeId from sc2.ids.upgrade_id import UpgradeId from sc2.pixel_map import PixelMap -from sc2.position import Point2 +from sc2.position import Point2, _PointLike from sc2.unit import Unit from sc2.unit_command import UnitCommand from sc2.units import Units with warnings.catch_warnings(): warnings.simplefilter("ignore") - # pyre-ignore[21] from scipy.spatial.distance import cdist, pdist if TYPE_CHECKING: @@ -582,7 +579,7 @@ async def _do_actions(self, actions: list[UnitCommand], prevent_double: bool = T @final @staticmethod - def prevent_double_actions(action) -> bool: + def prevent_double_actions(action: UnitCommand) -> bool: """ :param action: """ @@ -609,7 +606,13 @@ def prevent_double_actions(action) -> bool: @final def _prepare_start( - self, client, player_id: int, game_info, game_data, realtime: bool = False, base_build: int = -1 + self, + client: Client, + player_id: int, + game_info: GameInfo, + game_data: GameData, + realtime: bool = False, + base_build: int = -1, ) -> None: """ Ran until game start to set game and player data. @@ -646,13 +649,13 @@ def _prepare_first_step(self) -> None: self._time_before_step: float = time.perf_counter() @final - def _prepare_step(self, state, proto_game_info) -> None: + def _prepare_step(self, state: GameState, proto_game_info: sc_pb.Response) -> None: """ :param state: :param proto_game_info: """ # Set attributes from new state before on_step.""" - self.state: GameState = state # See game_state.py + self.state = state # See game_state.py # update pathing grid, which unfortunately is in GameInfo instead of GameState self.game_info.pathing_grid = PixelMap(proto_game_info.game_info.start_raw.pathing_grid, in_bits=True) # Required for events, needs to be before self.units are initialized so the old units are stored @@ -1013,16 +1016,16 @@ def convert_tuple_to_numpy_array(pos: tuple[float, float]) -> np.ndarray: @final @staticmethod def distance_math_hypot( - p1: tuple[float, float] | Point2, - p2: tuple[float, float] | Point2, + p1: _PointLike, + p2: _PointLike, ) -> float: return math.hypot(p1[0] - p2[0], p1[1] - p2[1]) @final @staticmethod def distance_math_hypot_squared( - p1: tuple[float, float] | Point2, - p2: tuple[float, float] | Point2, + p1: _PointLike, + p2: _PointLike, ) -> float: return pow(p1[0] - p2[0], 2) + pow(p1[1] - p2[1], 2) diff --git a/sc2/cache.py b/sc2/cache.py index d3e9090d..b1682fc5 100644 --- a/sc2/cache.py +++ b/sc2/cache.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Callable, Hashable -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar if TYPE_CHECKING: from sc2.bot_ai import BotAI @@ -9,7 +9,7 @@ T = TypeVar("T") -class CacheDict(dict): +class CacheDict(dict[Hashable, Any]): def retrieve_and_set(self, key: Hashable, func: Callable[[], T]) -> T: """Either return the value at a certain key, or set the return value of a function to that key, then return that value.""" @@ -29,7 +29,7 @@ class property_cache_once_per_frame(property): # noqa: N801 Copied and modified from https://tedboy.github.io/flask/_modules/werkzeug/utils.html#cached_property #""" - def __init__(self, func: Callable[[BotAI], T], name=None) -> None: + def __init__(self, func: Callable[[BotAI], T], name: str | None = None) -> None: self.__name__ = name or func.__name__ self.__frame__ = f"__frame__{self.__name__}" self.func = func diff --git a/sc2/client.py b/sc2/client.py index 5888b190..8971b86f 100644 --- a/sc2/client.py +++ b/sc2/client.py @@ -3,31 +3,33 @@ from collections.abc import Iterable from pathlib import Path +from typing import Any +from aiohttp import ClientWebSocketResponse from loguru import logger -# pyre-ignore[21] from s2clientprotocol import debug_pb2 as debug_pb from s2clientprotocol import query_pb2 as query_pb from s2clientprotocol import raw_pb2 as raw_pb from s2clientprotocol import sc2api_pb2 as sc_pb from s2clientprotocol import spatial_pb2 as spatial_pb - from sc2.action import combine_actions from sc2.data import ActionResult, ChatChannel, Race, Result, Status from sc2.game_data import AbilityData, GameData from sc2.game_info import GameInfo from sc2.ids.ability_id import AbilityId from sc2.ids.unit_typeid import UnitTypeId +from sc2.portconfig import Portconfig from sc2.position import Point2, Point3 from sc2.protocol import ConnectionAlreadyClosedError, Protocol, ProtocolError from sc2.renderer import Renderer from sc2.unit import Unit +from sc2.unit_command import UnitCommand from sc2.units import Units class Client(Protocol): - def __init__(self, ws, save_replay_path: str = None) -> None: + def __init__(self, ws: ClientWebSocketResponse, save_replay_path: str | None = None) -> None: """ :param ws: """ @@ -52,7 +54,14 @@ def __init__(self, ws, save_replay_path: str = None) -> None: def in_game(self) -> bool: return self._status in {Status.in_game, Status.in_replay} - async def join_game(self, name=None, race=None, observed_player_id=None, portconfig=None, rgb_render_config=None): + async def join_game( + self, + name: str | None = None, + race: Race | None = None, + observed_player_id: int | None = None, + portconfig: Portconfig | None = None, + rgb_render_config: dict[str, Any] | None = None, + ): ifopts = sc_pb.InterfaceOptions( raw=True, score=True, @@ -121,14 +130,14 @@ async def leave(self) -> None: if is_resign: raise - async def save_replay(self, path) -> None: + async def save_replay(self, path: str) -> None: logger.debug("Requesting replay from server") result = await self._execute(save_replay=sc_pb.RequestSaveReplay()) with Path(path).open("wb") as f: f.write(result.save_replay.data) logger.info(f"Saved replay to {path}") - async def observation(self, game_loop: int = None): + async def observation(self, game_loop: int | None = None): if game_loop is not None: result = await self._execute(observation=sc_pb.RequestObservation(game_loop=game_loop)) else: @@ -152,13 +161,13 @@ async def observation(self, game_loop: int = None): return result - async def step(self, step_size: int = None): + async def step(self, step_size: int | None = None): """EXPERIMENTAL: Change self._client.game_step during the step function to increase or decrease steps per second""" step_size = step_size or self.game_step return await self._execute(step=sc_pb.RequestStep(count=step_size)) async def get_game_data(self) -> GameData: - result = await self._execute( + result: sc_pb.Response = await self._execute( data=sc_pb.RequestData(ability_id=True, unit_type_id=True, upgrade_id=True, buff_id=True, effect_id=True) ) return GameData(result.data) @@ -194,9 +203,9 @@ async def get_game_info(self) -> GameInfo: result = await self._execute(game_info=sc_pb.RequestGameInfo()) return GameInfo(result.game_info) - async def actions(self, actions, return_successes: bool = False): + async def actions(self, actions: list[UnitCommand], return_successes: bool = False) -> list[ActionResult]: if not actions: - return None + return [] if not isinstance(actions, list): actions = [actions] @@ -470,8 +479,8 @@ def debug_text_simple(self, text: str) -> None: def debug_text_screen( self, text: str, - pos: Point2 | Point3 | tuple | list, - color: tuple | list | Point3 = None, + pos: Point2 | Point3 | tuple[float, float] | list[float], + color: tuple[float, float, float] | list[float] | Point3 | None = None, size: int = 8, ) -> None: """ @@ -491,14 +500,18 @@ def debug_text_screen( def debug_text_2d( self, text: str, - pos: Point2 | Point3 | tuple | list, - color: tuple | list | Point3 = None, + pos: Point2 | Point3 | tuple[float, float] | list[float], + color: tuple[float, float, float] | list[float] | Point3 | None = None, size: int = 8, ): return self.debug_text_screen(text, pos, color, size) def debug_text_world( - self, text: str, pos: Unit | Point3, color: tuple | list | Point3 = None, size: int = 8 + self, + text: str, + pos: Unit | Point3, + color: tuple[float, float, float] | list[float] | Point3 | None = None, + size: int = 8, ) -> None: """ Draws a text at Point3 position in the game world. @@ -514,10 +527,21 @@ def debug_text_world( assert isinstance(pos, Point3) self._debug_texts.append(DrawItemWorldText(text=text, color=color, start_point=pos, font_size=size)) - def debug_text_3d(self, text: str, pos: Unit | Point3, color: tuple | list | Point3 = None, size: int = 8): + def debug_text_3d( + self, + text: str, + pos: Unit | Point3, + color: tuple[float, float, float] | list[float] | Point3 | None = None, + size: int = 8, + ): return self.debug_text_world(text, pos, color, size) - def debug_line_out(self, p0: Unit | Point3, p1: Unit | Point3, color: tuple | list | Point3 = None) -> None: + def debug_line_out( + self, + p0: Unit | Point3, + p1: Unit | Point3, + color: tuple[float, float, float] | list[float] | Point3 | None = None, + ) -> None: """ Draws a line from p0 to p1. @@ -537,7 +561,7 @@ def debug_box_out( self, p_min: Unit | Point3, p_max: Unit | Point3, - color: tuple | list | Point3 = None, + color: tuple[float, float, float] | list[float] | Point3 | None = None, ) -> None: """ Draws a box with p_min and p_max as corners of the box. @@ -558,7 +582,7 @@ def debug_box2_out( self, pos: Unit | Point3, half_vertex_length: float = 0.25, - color: tuple | list | Point3 = None, + color: tuple[float, float, float] | list[float] | Point3 | None = None, ) -> None: """ Draws a box center at a position 'pos', with box side lengths (vertices) of two times 'half_vertex_length'. @@ -574,7 +598,12 @@ def debug_box2_out( p1 = pos + Point3((half_vertex_length, half_vertex_length, half_vertex_length)) self._debug_boxes.append(DrawItemBox(start_point=p0, end_point=p1, color=color)) - def debug_sphere_out(self, p: Unit | Point3, r: float, color: tuple | list | Point3 = None) -> None: + def debug_sphere_out( + self, + p: Unit | Point3, + r: float, + color: tuple[float, float, float] | list[float] | Point3 | None = None, + ) -> None: """ Draws a sphere at point p with radius r. @@ -752,12 +781,12 @@ async def quick_load(self) -> None: class DrawItem: @staticmethod - def to_debug_color(color: tuple | Point3): + def to_debug_color(color: tuple[float, float, float] | list[float] | Point3 | None = None) -> debug_pb.Color: """Helper function for color conversion""" if color is None: return debug_pb.Color(r=255, g=255, b=255) # Need to check if not of type Point3 because Point3 inherits from tuple - if isinstance(color, (tuple, list)) and not isinstance(color, Point3) and len(color) == 3: + if isinstance(color, (tuple, list)) or isinstance(color, Point3) and len(color) == 3: return debug_pb.Color(r=color[0], g=color[1], b=color[2]) # In case color is of type Point3 r = getattr(color, "r", getattr(color, "x", 255)) @@ -773,11 +802,17 @@ def to_debug_color(color: tuple | Point3): class DrawItemScreenText(DrawItem): - def __init__(self, start_point: Point2 = None, color: Point3 = None, text: str = "", font_size: int = 8) -> None: - self._start_point: Point2 = start_point - self._color: Point3 = color - self._text: str = text - self._font_size: int = font_size + def __init__( + self, + start_point: Point2, + color: tuple[float, float, float] | list[float] | Point3 | None = None, + text: str = "", + font_size: int = 8, + ) -> None: + self._start_point = start_point + self._color = color + self._text = text + self._font_size = font_size def to_proto(self): return debug_pb.DebugText( @@ -793,7 +828,13 @@ def __hash__(self) -> int: class DrawItemWorldText(DrawItem): - def __init__(self, start_point: Point3 = None, color: Point3 = None, text: str = "", font_size: int = 8) -> None: + def __init__( + self, + start_point: Point3 = None, + color: tuple[float, float, float] | list[float] | Point3 | None = None, + text: str = "", + font_size: int = 8, + ) -> None: self._start_point: Point3 = start_point self._color: Point3 = color self._text: str = text @@ -813,7 +854,12 @@ def __hash__(self) -> int: class DrawItemLine(DrawItem): - def __init__(self, start_point: Point3 = None, end_point: Point3 = None, color: Point3 = None) -> None: + def __init__( + self, + start_point: Point3 = None, + end_point: Point3 = None, + color: tuple[float, float, float] | list[float] | Point3 | None = None, + ) -> None: self._start_point: Point3 = start_point self._end_point: Point3 = end_point self._color: Point3 = color @@ -829,7 +875,12 @@ def __hash__(self) -> int: class DrawItemBox(DrawItem): - def __init__(self, start_point: Point3 = None, end_point: Point3 = None, color: Point3 = None) -> None: + def __init__( + self, + start_point: Point3 = None, + end_point: Point3 = None, + color: tuple[float, float, float] | list[float] | Point3 | None = None, + ) -> None: self._start_point: Point3 = start_point self._end_point: Point3 = end_point self._color: Point3 = color @@ -846,7 +897,12 @@ def __hash__(self) -> int: class DrawItemSphere(DrawItem): - def __init__(self, start_point: Point3 = None, radius: float = None, color: Point3 = None) -> None: + def __init__( + self, + start_point: Point3 = None, + radius: float = None, + color: tuple[float, float, float] | list[float] | Point3 | None = None, + ) -> None: self._start_point: Point3 = start_point self._radius: float = radius self._color: Point3 = color diff --git a/sc2/constants.py b/sc2/constants.py index e23d4c13..6478ff01 100644 --- a/sc2/constants.py +++ b/sc2/constants.py @@ -495,7 +495,7 @@ def return_NOTAUNIT() -> UnitTypeId: UnitTypeId.EXTRACTORRICH, } # pyre-ignore[11] -DAMAGE_BONUS_PER_UPGRADE: dict[UnitTypeId, dict[TargetType, Any]] = { +DAMAGE_BONUS_PER_UPGRADE: dict[UnitTypeId, dict[int, Any]] = { # # Protoss # diff --git a/sc2/controller.py b/sc2/controller.py index 2e480330..e068aa3f 100644 --- a/sc2/controller.py +++ b/sc2/controller.py @@ -1,17 +1,22 @@ +from __future__ import annotations + import platform from pathlib import Path +from typing import TYPE_CHECKING +from aiohttp import ClientWebSocketResponse from loguru import logger -# pyre-ignore[21] from s2clientprotocol import sc2api_pb2 as sc_pb - from sc2.player import Computer from sc2.protocol import Protocol +if TYPE_CHECKING: + from sc2.sc2process import SC2Process + class Controller(Protocol): - def __init__(self, ws, process) -> None: + def __init__(self, ws: ClientWebSocketResponse, process: SC2Process) -> None: super().__init__(ws) self._process = process diff --git a/sc2/data.py b/sc2/data.py index b0c9425f..d376138b 100644 --- a/sc2/data.py +++ b/sc2/data.py @@ -1,23 +1,18 @@ # pyre-ignore-all-errors[16, 19] """For the list of enums, see here -https://github.com/Blizzard/s2client-api/blob/d9ba0a33d6ce9d233c2a4ee988360c188fbe9dbf/include/sc2api/sc2_gametypes.h -https://github.com/Blizzard/s2client-api/blob/d9ba0a33d6ce9d233c2a4ee988360c188fbe9dbf/include/sc2api/sc2_action.h -https://github.com/Blizzard/s2client-api/blob/d9ba0a33d6ce9d233c2a4ee988360c188fbe9dbf/include/sc2api/sc2_unit.h -https://github.com/Blizzard/s2client-api/blob/d9ba0a33d6ce9d233c2a4ee988360c188fbe9dbf/include/sc2api/sc2_data.h +https://github.com/Blizzard/s2client-proto/tree/bff45dae1fc685e6acbaae084670afb7d1c0832c/s2clientprotocol """ from __future__ import annotations import enum -# pyre-ignore[21] from s2clientprotocol import common_pb2 as common_pb from s2clientprotocol import data_pb2 as data_pb from s2clientprotocol import error_pb2 as error_pb from s2clientprotocol import raw_pb2 as raw_pb from s2clientprotocol import sc2api_pb2 as sc_pb - from sc2.ids.ability_id import AbilityId from sc2.ids.unit_typeid import UnitTypeId diff --git a/sc2/data.pyi b/sc2/data.pyi new file mode 100644 index 00000000..4783460e --- /dev/null +++ b/sc2/data.pyi @@ -0,0 +1,386 @@ +"""Type stubs for sc2.data module + +This stub provides static type information for dynamically generated enums. +The enums in sc2.data are created at runtime using enum.Enum() with protobuf +enum descriptors, which makes them invisible to static type checkers. + +This stub file (PEP 561 compliant) allows type checkers like Pylance, Pyright, +and mypy to understand the structure and members of these enums. +""" + +from __future__ import annotations + +from enum import Enum + +from sc2.ids.ability_id import AbilityId +from sc2.ids.unit_typeid import UnitTypeId + +class CreateGameError(Enum): + MissingMap = 1 + InvalidMapPath = 2 + InvalidMapData = 3 + InvalidMapName = 4 + InvalidMapHandle = 5 + MissingPlayerSetup = 6 + InvalidPlayerSetup = 7 + MultiplayerUnsupported = 8 + +class PlayerType(Enum): + Participant = 1 + Computer = 2 + Observer = 3 + +class Difficulty(Enum): + VeryEasy = 1 + Easy = 2 + Medium = 3 + MediumHard = 4 + Hard = 5 + Harder = 6 + VeryHard = 7 + CheatVision = 8 + CheatMoney = 9 + CheatInsane = 10 + +class AIBuild(Enum): + RandomBuild = 1 + Rush = 2 + Timing = 3 + Power = 4 + Macro = 5 + Air = 6 + +class Status(Enum): + launched = 1 + init_game = 2 + in_game = 3 + in_replay = 4 + ended = 5 + quit = 6 + unknown = 7 + +class Result(Enum): + Victory = 1 + Defeat = 2 + Tie = 3 + Undecided = 4 + +class Alert(Enum): + AlertError = 1 + AddOnComplete = 2 + BuildingComplete = 3 + BuildingUnderAttack = 4 + LarvaHatched = 5 + MergeComplete = 6 + MineralsExhausted = 7 + MorphComplete = 8 + MothershipComplete = 9 + MULEExpired = 10 + NuclearLaunchDetected = 11 + NukeComplete = 12 + NydusWormDetected = 13 + ResearchComplete = 14 + TrainError = 15 + TrainUnitComplete = 16 + TrainWorkerComplete = 17 + TransformationComplete = 18 + UnitUnderAttack = 19 + UpgradeComplete = 20 + VespeneExhausted = 21 + WarpInComplete = 22 + +class ChatChannel(Enum): + Broadcast = 1 + Team = 2 + +class Race(Enum): + """StarCraft II race enum. + + Members: + NoRace: No race specified + Terran: Terran race + Zerg: Zerg race + Protoss: Protoss race + Random: Random race selection + """ + + NoRace = 0 + Terran = 1 + Zerg = 2 + Protoss = 3 + Random = 4 + +# Enums created from raw_pb2 +class DisplayType(Enum): + Visible = 1 + Snapshot = 2 + Hidden = 3 + Placeholder = 4 + +class Alliance(Enum): + Self = 1 + Ally = 2 + Neutral = 3 + Enemy = 4 + +class CloakState(Enum): + CloakedUnknown = 1 + Cloaked = 2 + CloakedDetected = 3 + NotCloaked = 4 + CloakedAllied = 5 + +class Attribute(Enum): + Light = 1 + Armored = 2 + Biological = 3 + Mechanical = 4 + Robotic = 5 + Psionic = 6 + Massive = 7 + Structure = 8 + Hover = 9 + Heroic = 10 + Summoned = 11 + +class TargetType(Enum): + Ground = 1 + Air = 2 + Any = 3 + Invalid = 4 + +class Target(Enum): + # Note: The protobuf enum member 'None' is a Python keyword, + # so at runtime it may need special handling + Point = 1 + Unit = 2 + PointOrUnit = 3 + PointOrNone = 4 + +class ActionResult(Enum): + """Action result codes from game engine. + + This enum contains a large number of members (~200+) representing + various action results and error conditions. + """ + + Success = 1 + NotSupported = 2 + Error = 3 + CantQueueThatOrder = 4 + Retry = 5 + Cooldown = 6 + QueueIsFull = 7 + RallyQueueIsFull = 8 + NotEnoughMinerals = 9 + NotEnoughVespene = 10 + NotEnoughTerrazine = 11 + NotEnoughCustom = 12 + NotEnoughFood = 13 + FoodUsageImpossible = 14 + NotEnoughLife = 15 + NotEnoughShields = 16 + NotEnoughEnergy = 17 + LifeSuppressed = 18 + ShieldsSuppressed = 19 + EnergySuppressed = 20 + NotEnoughCharges = 21 + CantAddMoreCharges = 22 + TooMuchMinerals = 23 + TooMuchVespene = 24 + TooMuchTerrazine = 25 + TooMuchCustom = 26 + TooMuchFood = 27 + TooMuchLife = 28 + TooMuchShields = 29 + TooMuchEnergy = 30 + MustTargetUnitWithLife = 31 + MustTargetUnitWithShields = 32 + MustTargetUnitWithEnergy = 33 + CantTrade = 34 + CantSpend = 35 + CantTargetThatUnit = 36 + CouldntAllocateUnit = 37 + UnitCantMove = 38 + TransportIsHoldingPosition = 39 + BuildTechRequirementsNotMet = 40 + CantFindPlacementLocation = 41 + CantBuildOnThat = 42 + CantBuildTooCloseToDropOff = 43 + CantBuildLocationInvalid = 44 + CantSeeBuildLocation = 45 + CantBuildTooCloseToCreepSource = 46 + CantBuildTooCloseToResources = 47 + CantBuildTooFarFromWater = 48 + CantBuildTooFarFromCreepSource = 49 + CantBuildTooFarFromBuildPowerSource = 50 + CantBuildOnDenseTerrain = 51 + CantTrainTooFarFromTrainPowerSource = 52 + CantLandLocationInvalid = 53 + CantSeeLandLocation = 54 + CantLandTooCloseToCreepSource = 55 + CantLandTooCloseToResources = 56 + CantLandTooFarFromWater = 57 + CantLandTooFarFromCreepSource = 58 + CantLandTooFarFromBuildPowerSource = 59 + CantLandTooFarFromTrainPowerSource = 60 + CantLandOnDenseTerrain = 61 + AddOnTooFarFromBuilding = 62 + MustBuildRefineryFirst = 63 + BuildingIsUnderConstruction = 64 + CantFindDropOff = 65 + CantLoadOtherPlayersUnits = 66 + NotEnoughRoomToLoadUnit = 67 + CantUnloadUnitsThere = 68 + CantWarpInUnitsThere = 69 + CantLoadImmobileUnits = 70 + CantRechargeImmobileUnits = 71 + CantRechargeUnderConstructionUnits = 72 + CantLoadThatUnit = 73 + NoCargoToUnload = 74 + LoadAllNoTargetsFound = 75 + NotWhileOccupied = 76 + CantAttackWithoutAmmo = 77 + CantHoldAnyMoreAmmo = 78 + TechRequirementsNotMet = 79 + MustLockdownUnitFirst = 80 + MustTargetUnit = 81 + MustTargetInventory = 82 + MustTargetVisibleUnit = 83 + MustTargetVisibleLocation = 84 + MustTargetWalkableLocation = 85 + MustTargetPawnableUnit = 86 + YouCantControlThatUnit = 87 + YouCantIssueCommandsToThatUnit = 88 + MustTargetResources = 89 + RequiresHealTarget = 90 + RequiresRepairTarget = 91 + NoItemsToDrop = 92 + CantHoldAnyMoreItems = 93 + CantHoldThat = 94 + TargetHasNoInventory = 95 + CantDropThisItem = 96 + CantMoveThisItem = 97 + CantPawnThisUnit = 98 + MustTargetCaster = 99 + CantTargetCaster = 100 + MustTargetOuter = 101 + CantTargetOuter = 102 + MustTargetYourOwnUnits = 103 + CantTargetYourOwnUnits = 104 + MustTargetFriendlyUnits = 105 + CantTargetFriendlyUnits = 106 + MustTargetNeutralUnits = 107 + CantTargetNeutralUnits = 108 + MustTargetEnemyUnits = 109 + CantTargetEnemyUnits = 110 + MustTargetAirUnits = 111 + CantTargetAirUnits = 112 + MustTargetGroundUnits = 113 + CantTargetGroundUnits = 114 + MustTargetStructures = 115 + CantTargetStructures = 116 + MustTargetLightUnits = 117 + CantTargetLightUnits = 118 + MustTargetArmoredUnits = 119 + CantTargetArmoredUnits = 120 + MustTargetBiologicalUnits = 121 + CantTargetBiologicalUnits = 122 + MustTargetHeroicUnits = 123 + CantTargetHeroicUnits = 124 + MustTargetRoboticUnits = 125 + CantTargetRoboticUnits = 126 + MustTargetMechanicalUnits = 127 + CantTargetMechanicalUnits = 128 + MustTargetPsionicUnits = 129 + CantTargetPsionicUnits = 130 + MustTargetMassiveUnits = 131 + CantTargetMassiveUnits = 132 + MustTargetMissile = 133 + CantTargetMissile = 134 + MustTargetWorkerUnits = 135 + CantTargetWorkerUnits = 136 + MustTargetEnergyCapableUnits = 137 + CantTargetEnergyCapableUnits = 138 + MustTargetShieldCapableUnits = 139 + CantTargetShieldCapableUnits = 140 + MustTargetFlyers = 141 + CantTargetFlyers = 142 + MustTargetBuriedUnits = 143 + CantTargetBuriedUnits = 144 + MustTargetCloakedUnits = 145 + CantTargetCloakedUnits = 146 + MustTargetUnitsInAStasisField = 147 + CantTargetUnitsInAStasisField = 148 + MustTargetUnderConstructionUnits = 149 + CantTargetUnderConstructionUnits = 150 + MustTargetDeadUnits = 151 + CantTargetDeadUnits = 152 + MustTargetRevivableUnits = 153 + CantTargetRevivableUnits = 154 + MustTargetHiddenUnits = 155 + CantTargetHiddenUnits = 156 + CantRechargeOtherPlayersUnits = 157 + MustTargetHallucinations = 158 + CantTargetHallucinations = 159 + MustTargetInvulnerableUnits = 160 + CantTargetInvulnerableUnits = 161 + MustTargetDetectedUnits = 162 + CantTargetDetectedUnits = 163 + CantTargetUnitWithEnergy = 164 + CantTargetUnitWithShields = 165 + MustTargetUncommandableUnits = 166 + CantTargetUncommandableUnits = 167 + MustTargetPreventDefeatUnits = 168 + CantTargetPreventDefeatUnits = 169 + MustTargetPreventRevealUnits = 170 + CantTargetPreventRevealUnits = 171 + MustTargetPassiveUnits = 172 + CantTargetPassiveUnits = 173 + MustTargetStunnedUnits = 174 + CantTargetStunnedUnits = 175 + MustTargetSummonedUnits = 176 + CantTargetSummonedUnits = 177 + MustTargetUser1 = 178 + CantTargetUser1 = 179 + MustTargetUnstoppableUnits = 180 + CantTargetUnstoppableUnits = 181 + MustTargetResistantUnits = 182 + CantTargetResistantUnits = 183 + MustTargetDazedUnits = 184 + CantTargetDazedUnits = 185 + CantLockdown = 186 + CantMindControl = 187 + MustTargetDestructibles = 188 + CantTargetDestructibles = 189 + MustTargetItems = 190 + CantTargetItems = 191 + NoCalldownAvailable = 192 + WaypointListFull = 193 + MustTargetRace = 194 + CantTargetRace = 195 + MustTargetSimilarUnits = 196 + CantTargetSimilarUnits = 197 + CantFindEnoughTargets = 198 + AlreadySpawningLarva = 199 + CantTargetExhaustedResources = 200 + CantUseMinimap = 201 + CantUseInfoPanel = 202 + OrderQueueIsFull = 203 + CantHarvestThatResource = 204 + HarvestersNotRequired = 205 + AlreadyTargeted = 206 + CantAttackWeaponsDisabled = 207 + CouldntReachTarget = 208 + TargetIsOutOfRange = 209 + TargetIsTooClose = 210 + TargetIsOutOfArc = 211 + CantFindTeleportLocation = 212 + InvalidItemClass = 213 + CantFindCancelOrder = 214 + +# Module-level dictionaries +race_worker: dict[Race, UnitTypeId] +race_townhalls: dict[Race, set[UnitTypeId]] +warpgate_abilities: dict[AbilityId, AbilityId] +race_gas: dict[Race, UnitTypeId] diff --git a/sc2/expiring_dict.py b/sc2/expiring_dict.py index ebbcb23c..7c6cc2c0 100644 --- a/sc2/expiring_dict.py +++ b/sc2/expiring_dict.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections import OrderedDict -from collections.abc import Iterable +from collections.abc import Hashable, Iterable from threading import RLock from typing import TYPE_CHECKING, Any @@ -10,7 +10,7 @@ from sc2.bot_ai import BotAI -class ExpiringDict(OrderedDict): +class ExpiringDict(OrderedDict[Hashable, Any]): """ An expiring dict that uses the bot.state.game_loop to only return items that are valid for a specific amount of time. @@ -45,7 +45,7 @@ def frame(self) -> int: # pyre-ignore[16] return self.bot.state.game_loop - def __contains__(self, key) -> bool: + def __contains__(self, key: Hashable) -> bool: """Return True if dict has key, else False, e.g. 'key in dict'""" with self.lock: if OrderedDict.__contains__(self, key): @@ -56,7 +56,7 @@ def __contains__(self, key) -> bool: del self[key] return False - def __getitem__(self, key, with_age: bool = False) -> Any: + def __getitem__(self, key: Hashable, with_age: bool = False) -> Any: """Return the item of the dict using d[key]""" with self.lock: # Each item is a list of [value, frame time] @@ -68,7 +68,7 @@ def __getitem__(self, key, with_age: bool = False) -> Any: OrderedDict.__delitem__(self, key) raise KeyError(key) - def __setitem__(self, key, value) -> None: + def __setitem__(self, key: Hashable, value: Any) -> None: """Set d[key] = value""" with self.lock: OrderedDict.__setitem__(self, key, (value, self.frame)) @@ -83,10 +83,10 @@ def __repr__(self) -> str: print_str = ", ".join(print_list) return f"ExpiringDict({print_str})" - def __str__(self): + def __str__(self) -> str: return self.__repr__() - def __iter__(self): + def __iter__(self) -> Iterable[Hashable]: """Override 'for key in dict:'""" with self.lock: return self.keys() @@ -101,7 +101,7 @@ def __len__(self) -> int: count += 1 return count - def pop(self, key, default=None, with_age: bool = False): + def pop(self, key: Hashable, default: Any = None, with_age: bool = False): """Return the item and remove it""" with self.lock: if OrderedDict.__contains__(self, key): @@ -118,7 +118,7 @@ def pop(self, key, default=None, with_age: bool = False): return default, self.frame return default - def get(self, key, default=None, with_age: bool = False): + def get(self, key: Hashable, default: Any = None, with_age: bool = False): """Return the value for key if key is in dict, else default""" with self.lock: if OrderedDict.__contains__(self, key): @@ -134,26 +134,26 @@ def get(self, key, default=None, with_age: bool = False): return None return None - def update(self, other_dict: dict) -> None: + def update(self, other_dict: dict[Hashable, Any]) -> None: with self.lock: for key, value in other_dict.items(): self[key] = value - def items(self) -> Iterable: + def items(self) -> Iterable[tuple[Hashable, Any]]: """Return iterator of zipped list [keys, values]""" with self.lock: for key, value in OrderedDict.items(self): if self.frame - value[1] < self.max_age: yield key, value[0] - def keys(self) -> Iterable: + def keys(self) -> Iterable[Hashable]: """Return iterator of keys""" with self.lock: for key, value in OrderedDict.items(self): if self.frame - value[1] < self.max_age: yield key - def values(self) -> Iterable: + def values(self) -> Iterable[Any]: """Return iterator of values""" with self.lock: for value in OrderedDict.values(self): diff --git a/sc2/game_data.py b/sc2/game_data.py index 3bc4fc78..4be84ee2 100644 --- a/sc2/game_data.py +++ b/sc2/game_data.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from functools import lru_cache +from s2clientprotocol import data_pb2, sc2api_pb2 from sc2.data import Attribute, Race from sc2.ids.ability_id import AbilityId from sc2.ids.unit_typeid import UnitTypeId @@ -20,7 +21,7 @@ class GameData: - def __init__(self, data) -> None: + def __init__(self, data: sc2api_pb2.ResponseData) -> None: """ :param data: """ @@ -77,14 +78,14 @@ class AbilityData: ability_ids: list[int] = [ability_id.value for ability_id in AbilityId][1:] # sorted list @classmethod - def id_exists(cls, ability_id): + def id_exists(cls, ability_id: int) -> bool: assert isinstance(ability_id, int), f"Wrong type: {ability_id} is not int" if ability_id == 0: return False i = bisect_left(cls.ability_ids, ability_id) # quick binary search return i != len(cls.ability_ids) and cls.ability_ids[i] == ability_id - def __init__(self, game_data, proto) -> None: + def __init__(self, game_data: GameData, proto: data_pb2.AbilityData) -> None: self._game_data = game_data self._proto = proto @@ -131,7 +132,7 @@ def cost(self) -> Cost: class UnitTypeData: - def __init__(self, game_data: GameData, proto) -> None: + def __init__(self, game_data: GameData, proto: data_pb2.UnitTypeData) -> None: """ :param game_data: :param proto: @@ -172,12 +173,10 @@ def footprint_radius(self) -> float | None: return self.creation_ability._proto.footprint_radius @property - # pyre-ignore[11] def attributes(self) -> list[Attribute]: - return self._proto.attributes + return [Attribute(i) for i in self._proto.attributes] - def has_attribute(self, attr) -> bool: - # pyre-ignore[6] + def has_attribute(self, attr: Attribute) -> bool: assert isinstance(attr, Attribute) return attr in self.attributes @@ -225,7 +224,6 @@ def unit_alias(self) -> UnitTypeId | None: return UnitTypeId(self._proto.unit_alias) @property - # pyre-ignore[11] def race(self) -> Race: return Race(self._proto.race) @@ -236,8 +234,7 @@ def cost(self) -> Cost: @property def cost_zerg_corrected(self) -> Cost: """This returns 25 for extractor and 200 for spawning pool instead of 75 and 250 respectively""" - # pyre-ignore[16] - if self.race == Race.Zerg and Attribute.Structure.value in self.attributes: + if self.race == Race.Zerg and Attribute.Structure.value in self._proto.attributes: return Cost(self._proto.mineral_cost - 50, self._proto.vespene_cost, self._proto.build_time) return self.cost @@ -280,7 +277,7 @@ def morph_cost(self) -> Cost | None: class UpgradeData: - def __init__(self, game_data: GameData, proto) -> None: + def __init__(self, game_data: GameData, proto: data_pb2.UpgradeData) -> None: """ :param game_data: :param proto: diff --git a/sc2/game_info.py b/sc2/game_info.py index c00a0428..aab025d5 100644 --- a/sc2/game_info.py +++ b/sc2/game_info.py @@ -9,6 +9,7 @@ import numpy as np +from s2clientprotocol import sc2api_pb2 from sc2.pixel_map import PixelMap from sc2.player import Player from sc2.position import Point2, Rect, Size @@ -122,10 +123,10 @@ def depot_in_middle(self) -> Point2 | None: raise Exception("Not implemented. Trying to access a ramp that has a wrong amount of upper points.") @cached_property - def corner_depots(self) -> frozenset[Point2]: + def corner_depots(self) -> set[Point2]: """Finds the 2 depot positions on the outside""" if not self.upper2_for_ramp_wall: - return frozenset() + return set() if len(self.upper2_for_ramp_wall) == 2: points = set(self.upper2_for_ramp_wall) p1 = points.pop().offset((self.x_offset, self.y_offset)) @@ -133,7 +134,7 @@ def corner_depots(self) -> frozenset[Point2]: center = p1.towards(p2, p1.distance_to_point2(p2) / 2) depot_position = self.depot_in_middle if depot_position is None: - return frozenset() + return set() # Offset from middle depot to corner depots is (2, 1) intersects = center.circle_intersection(depot_position, 5**0.5) return intersects @@ -217,7 +218,7 @@ def protoss_wall_warpin(self) -> Point2 | None: class GameInfo: - def __init__(self, proto) -> None: + def __init__(self, proto: sc2api_pb2.ResponseGameInfo) -> None: self._proto = proto self.players: list[Player] = [Player.from_proto(p) for p in self._proto.player_info] self.map_name: str = self._proto.map_name diff --git a/sc2/game_state.py b/sc2/game_state.py index b17fd12e..7ff4bb89 100644 --- a/sc2/game_state.py +++ b/sc2/game_state.py @@ -7,6 +7,7 @@ from loguru import logger +from s2clientprotocol import raw_pb2, sc2api_pb2 from sc2.constants import IS_ENEMY, IS_MINE, FakeEffectID, FakeEffectRadii from sc2.data import Alliance, DisplayType from sc2.ids.ability_id import AbilityId @@ -25,7 +26,7 @@ class Blip: - def __init__(self, proto) -> None: + def __init__(self, proto: raw_pb2.Unit) -> None: """ :param proto: """ @@ -82,7 +83,7 @@ class Common: "larva_count", ] - def __init__(self, proto) -> None: + def __init__(self, proto: sc2api_pb2.PlayerCommon) -> None: self._proto = proto def __getattr__(self, attr) -> int: @@ -91,7 +92,7 @@ def __getattr__(self, attr) -> int: class EffectData: - def __init__(self, proto, fake: bool = False) -> None: + def __init__(self, proto: raw_pb2.Effect | raw_pb2.Unit, fake: bool = False) -> None: """ :param proto: :param fake: @@ -101,20 +102,20 @@ def __init__(self, proto, fake: bool = False) -> None: @property def id(self) -> EffectId | str: - if self.fake: + if isinstance(self._proto, raw_pb2.Unit): # Returns the string from constants.py, e.g. "KD8CHARGE" return FakeEffectID[self._proto.unit_type] return EffectId(self._proto.effect_id) @property def positions(self) -> set[Point2]: - if self.fake: + if isinstance(self._proto, raw_pb2.Unit): return {Point2.from_proto(self._proto.pos)} return {Point2.from_proto(p) for p in self._proto.pos} @property def alliance(self) -> Alliance: - return self._proto.alliance + return Alliance(self._proto.alliance) @property def is_mine(self) -> bool: @@ -191,7 +192,11 @@ class ActionError(AbilityLookupTemplateClass): class GameState: - def __init__(self, response_observation, previous_observation=None) -> None: + def __init__( + self, + response_observation: sc2api_pb2.ResponseObservation, + previous_observation: sc2api_pb2.ResponseObservation | None = None, + ) -> None: """ :param response_observation: :param previous_observation: @@ -252,7 +257,7 @@ def alerts(self) -> list[int]: """ Game alerts, see https://github.com/Blizzard/s2client-proto/blob/01ab351e21c786648e4c6693d4aad023a176d45c/s2clientprotocol/sc2api.proto#L683-L706 """ - if self.previous_observation: + if self.previous_observation is not None: return list(chain(self.previous_observation.observation.alerts, self.observation.alerts)) return self.observation.alerts @@ -265,7 +270,7 @@ def actions(self) -> list[ActionRawUnitCommand | ActionRawToggleAutocast | Actio Each action is converted into Python dataclasses: ActionRawUnitCommand, ActionRawToggleAutocast, ActionRawCameraMove """ previous_frame_actions = self.previous_observation.actions if self.previous_observation else [] - actions = [] + actions: list[ActionRawUnitCommand | ActionRawToggleAutocast | ActionRawCameraMove] = [] for action in chain(previous_frame_actions, self.response_observation.actions): action_raw = action.action_raw game_loop = action.game_loop diff --git a/sc2/main.py b/sc2/main.py index fd86c6a7..8d07314e 100644 --- a/sc2/main.py +++ b/sc2/main.py @@ -10,19 +10,21 @@ from dataclasses import dataclass from io import BytesIO from pathlib import Path +from typing import Any import mpyq import portpicker from aiohttp import ClientSession, ClientWebSocketResponse from loguru import logger -from s2clientprotocol import sc2api_pb2 as sc_pb +from s2clientprotocol import sc2api_pb2 as sc_pb from sc2.bot_ai import BotAI from sc2.client import Client from sc2.controller import Controller from sc2.data import CreateGameError, Result, Status from sc2.game_state import GameState from sc2.maps import Map +from sc2.observer_ai import ObserverAI from sc2.player import AbstractPlayer, Bot, BotProcess, Human from sc2.portconfig import Portconfig from sc2.protocol import ConnectionAlreadyClosedError, ProtocolError @@ -71,7 +73,7 @@ def needed_sc2_count(self) -> int: return sum(player.needs_sc2 for player in self.players) @property - def host_game_kwargs(self) -> dict: + def host_game_kwargs(self) -> dict[str, Any]: return { "map_settings": self.map_sc2, "players": self.players, @@ -203,7 +205,12 @@ async def run_bot_iteration(iteration: int): async def _play_game( - player: AbstractPlayer, client: Client, realtime, portconfig, game_time_limit=None, rgb_render_config=None + player: AbstractPlayer, + client: Client, + realtime: bool, + portconfig: Portconfig, + game_time_limit: int | None = None, + rgb_render_config: dict[str, Any] | None = None, ) -> Result: assert isinstance(realtime, bool), repr(realtime) @@ -225,7 +232,7 @@ async def _play_game( return result -async def _play_replay(client, ai, realtime: bool = False, player_id: int = 0): +async def _play_replay(client: Client, ai, realtime: bool = False, player_id: int = 0): ai._initialize_variables() game_data = await client.get_game_data() @@ -328,16 +335,16 @@ async def _setup_host_game( async def _host_game( - map_settings, - players, + map_settings: Map, + players: list[AbstractPlayer], realtime: bool = False, - portconfig=None, - save_replay_as=None, - game_time_limit=None, - rgb_render_config=None, - random_seed=None, - sc2_version=None, - disable_fog=None, + portconfig: Portconfig | None = None, + save_replay_as: str | None = None, + game_time_limit: int | None = None, + rgb_render_config: dict[str, Any] | None = None, + random_seed: int | None = None, + sc2_version: str | None = None, + disable_fog: bool = False, ): assert players, "Can't create a game without players" @@ -410,19 +417,19 @@ def _host_game_iter(*args, **kwargs): async def _join_game( - players, - realtime, - portconfig, - save_replay_as=None, - game_time_limit=None, - sc2_version=None, + players: list[AbstractPlayer], + realtime: bool = False, + portconfig: Portconfig | None = None, + save_replay_as: str | None = None, + game_time_limit: int | None = None, + sc2_version: str | None = None, ): async with SC2Process(fullscreen=players[1].fullscreen, sc2_version=sc2_version) as server: await server.ping() client = Client(server._ws) # Bot can decide if it wants to launch with 'raw_affects_selection=True' - if not isinstance(players[1], Human) and getattr(players[1].ai, "raw_affects_selection", None) is not None: + if isinstance(players[1], Bot) and getattr(players[1].ai, "raw_affects_selection", None) is not None: client.raw_affects_selection = players[1].ai.raw_affects_selection result = await _play_game(players[1], client, realtime, portconfig, game_time_limit) @@ -442,7 +449,9 @@ async def _setup_replay(server, replay_path, realtime, observed_id): return Client(server._ws) -async def _host_replay(replay_path, ai, realtime, _portconfig, base_build, data_version, observed_id): +async def _host_replay( + replay_path, ai: ObserverAI, realtime: bool, _portconfig: Portconfig, base_build, data_version, observed_id +): async with SC2Process(fullscreen=False, base_build=base_build, data_hash=data_version) as server: client = await _setup_replay(server, replay_path, realtime, observed_id) result = await _play_replay(client, ai, realtime) @@ -461,21 +470,47 @@ def get_replay_version(replay_path: str | Path) -> tuple[str, str]: # TODO Deprecate run_game function in favor of run_multiple_games -def run_game(map_settings, players, **kwargs) -> Result | list[Result | None]: +def run_game( + map_settings: Map, + players: list[AbstractPlayer], + realtime: bool, + portconfig: Portconfig | None = None, + save_replay_as: str | None = None, + game_time_limit: int | None = None, + rgb_render_config: dict[str, Any] | None = None, + random_seed: int | None = None, + sc2_version: str | None = None, + disable_fog: bool = False, +) -> Result | list[Result | None]: """ Returns a single Result enum if the game was against the built-in computer. Returns a list of two Result enums if the game was "Human vs Bot" or "Bot vs Bot". """ if sum(isinstance(p, (Human, Bot)) for p in players) > 1: - host_only_args = ["save_replay_as", "rgb_render_config", "random_seed", "disable_fog"] - join_kwargs = {k: v for k, v in kwargs.items() if k not in host_only_args} - portconfig = Portconfig() async def run_host_and_join(): return await asyncio.gather( - _host_game(map_settings, players, **kwargs, portconfig=portconfig), - _join_game(players, **join_kwargs, portconfig=portconfig), + _host_game( + map_settings, + players, + realtime=realtime, + portconfig=portconfig, + save_replay_as=save_replay_as, + game_time_limit=game_time_limit, + rgb_render_config=rgb_render_config, + random_seed=random_seed, + sc2_version=sc2_version, + disable_fog=disable_fog, + ), + _join_game( + players, + realtime=realtime, + portconfig=portconfig, + save_replay_as=save_replay_as, + game_time_limit=game_time_limit, + sc2_version=sc2_version, + ), return_exceptions=True, ) @@ -483,12 +518,25 @@ async def run_host_and_join(): assert isinstance(result, list) assert all(isinstance(r, Result) for r in result) else: - result: Result = asyncio.run(_host_game(map_settings, players, **kwargs)) + result: Result = asyncio.run( + _host_game( + map_settings, + players, + realtime=realtime, + portconfig=portconfig, + save_replay_as=save_replay_as, + game_time_limit=game_time_limit, + rgb_render_config=rgb_render_config, + random_seed=random_seed, + sc2_version=sc2_version, + disable_fog=disable_fog, + ) + ) assert isinstance(result, Result) return result -def run_replay(ai, replay_path: Path | str, realtime: bool = False, observed_id: int = 0): +def run_replay(ai: ObserverAI, replay_path: Path | str, realtime: bool = False, observed_id: int = 0): portconfig = Portconfig() assert Path(replay_path).is_file(), f"Replay does not exist at the given path: {replay_path}" assert Path(replay_path).is_absolute(), ( @@ -725,7 +773,7 @@ async def a_run_multiple_games_nokill(matches: list[GameMatch]) -> list[dict[Abs # Start the matches results = [] - controllers = [] + controllers: list[Controller] = [] for m in matches: logger.info(f"Starting match {1 + len(results)} / {len(matches)}: {m}") result = None diff --git a/sc2/pixel_map.py b/sc2/pixel_map.py index 6871a516..c6925d80 100644 --- a/sc2/pixel_map.py +++ b/sc2/pixel_map.py @@ -5,11 +5,12 @@ import numpy as np +from s2clientprotocol.common_pb2 import ImageData from sc2.position import Point2 class PixelMap: - def __init__(self, proto, in_bits: bool = False) -> None: + def __init__(self, proto: ImageData, in_bits: bool = False) -> None: """ :param proto: :param in_bits: diff --git a/sc2/player.py b/sc2/player.py index 74ee5463..bd1410a5 100644 --- a/sc2/player.py +++ b/sc2/player.py @@ -4,6 +4,7 @@ from abc import ABC from pathlib import Path +from s2clientprotocol import sc2api_pb2 from sc2.bot_ai import BotAI from sc2.data import AIBuild, Difficulty, PlayerType, Race @@ -12,10 +13,10 @@ class AbstractPlayer(ABC): def __init__( self, p_type: PlayerType, - race: Race = None, + race: Race | None = None, name: str | None = None, - difficulty=None, - ai_build=None, + difficulty: Difficulty | None = None, + ai_build: AIBuild | None = None, fullscreen: bool = False, ) -> None: assert isinstance(p_type, PlayerType), f"p_type is of type {type(p_type)}" @@ -50,7 +51,7 @@ def needs_sc2(self) -> bool: class Human(AbstractPlayer): - def __init__(self, race, name: str | None = None, fullscreen: bool = False) -> None: + def __init__(self, race: Race, name: str | None = None, fullscreen: bool = False) -> None: super().__init__(PlayerType.Participant, race, name=name, fullscreen=fullscreen) def __str__(self) -> str: @@ -60,7 +61,7 @@ def __str__(self) -> str: class Bot(AbstractPlayer): - def __init__(self, race, ai, name: str | None = None, fullscreen: bool = False) -> None: + def __init__(self, race: Race, ai: BotAI, name: str | None = None, fullscreen: bool = False) -> None: """ AI can be None if this player object is just used to inform the server about player types. @@ -76,7 +77,9 @@ def __str__(self) -> str: class Computer(AbstractPlayer): - def __init__(self, race, difficulty=Difficulty.Easy, ai_build=AIBuild.RandomBuild) -> None: + def __init__( + self, race: Race, difficulty: Difficulty = Difficulty.Easy, ai_build: AIBuild = AIBuild.RandomBuild + ) -> None: super().__init__(PlayerType.Computer, race, difficulty=difficulty, ai_build=ai_build) def __str__(self) -> str: @@ -95,19 +98,19 @@ class Player(AbstractPlayer): def __init__( self, player_id: int, - p_type, - requested_race, - difficulty=None, - actual_race=None, + p_type: PlayerType, + requested_race: Race, + difficulty: Difficulty | None = None, + actual_race: Race | None = None, name: str | None = None, - ai_build=None, + ai_build: AIBuild | None = None, ) -> None: super().__init__(p_type, requested_race, difficulty=difficulty, name=name, ai_build=ai_build) self.id: int = player_id - self.actual_race: Race = actual_race + self.actual_race: Race | None = actual_race @classmethod - def from_proto(cls, proto) -> Player: + def from_proto(cls, proto: sc2api_pb2.PlayerInfo) -> Player: if PlayerType(proto.type) == PlayerType.Observer: return cls(proto.player_id, PlayerType(proto.type), None, None, None) return cls( @@ -168,7 +171,9 @@ def __repr__(self) -> str: return f"Bot {self.name}({self.race.name} from {self.launch_list})" return f"Bot({self.race.name} from {self.launch_list})" - def cmd_line(self, sc2port: int | str, matchport: int | str, hostaddress: str, realtime: bool = False) -> list[str]: + def cmd_line( + self, sc2port: int | str, matchport: int | str | None, hostaddress: str, realtime: bool = False + ) -> list[str]: """ :param sc2port: the port that the launched sc2 instance listens to diff --git a/sc2/portconfig.py b/sc2/portconfig.py index 2e646faf..9041b90f 100644 --- a/sc2/portconfig.py +++ b/sc2/portconfig.py @@ -25,9 +25,11 @@ class Portconfig: E.g. for 1v1, there will be only 1 guest. For 2v2 (coming soonTM), there would be 3 guests. """ - def __init__(self, guests: int = 1, server_ports=None, player_ports=None) -> None: + def __init__( + self, guests: int = 1, server_ports: list[int] | None = None, player_ports: list[int] | None = None + ) -> None: self.shared = None - self._picked_ports = [] + self._picked_ports: list[int] = [] if server_ports: self.server = server_ports else: diff --git a/sc2/position.py b/sc2/position.py index 36a0922f..f3d9bd70 100644 --- a/sc2/position.py +++ b/sc2/position.py @@ -1,18 +1,29 @@ -# pyre-ignore-all-errors[6, 14, 15, 58] from __future__ import annotations import itertools import math import random from collections.abc import Iterable -from typing import TYPE_CHECKING, SupportsFloat, SupportsIndex +from typing import ( + Any, + Protocol, + SupportsFloat, + SupportsIndex, + TypeVar, + Union, +) -# pyre-fixme[21] from s2clientprotocol import common_pb2 as common_pb -if TYPE_CHECKING: - from sc2.unit import Unit - from sc2.units import Units + +class HasPosition2D(Protocol): + @property + def position(self) -> Point2: ... + + +_PointLike = Union[tuple[float, float], tuple[float, float], tuple[float, ...]] +_PosLike = Union[HasPosition2D, _PointLike] +_TPosLike = TypeVar("_TPosLike", bound=_PosLike) EPSILON: float = 10**-8 @@ -21,116 +32,118 @@ def _sign(num: SupportsFloat | SupportsIndex) -> float: return math.copysign(1, num) -class Pointlike(tuple): +class Pointlike(tuple[float, ...]): + T = TypeVar("T", bound="Pointlike") + @property - def position(self) -> Pointlike: + def position(self: T) -> T: return self - def distance_to(self, target: Unit | Point2) -> float: + def distance_to(self, target: _PosLike) -> float: """Calculate a single distance from a point or unit to another point or unit :param target:""" - p = target.position + p: tuple[float, ...] = target if isinstance(target, tuple) else target.position return math.hypot(self[0] - p[0], self[1] - p[1]) - def distance_to_point2(self, p: Point2 | tuple[float, float]) -> float: + def distance_to_point2(self, p: _PointLike) -> float: """Same as the function above, but should be a bit faster because of the dropped asserts and conversion. :param p:""" return math.hypot(self[0] - p[0], self[1] - p[1]) - def _distance_squared(self, p2: Point2) -> float: + def _distance_squared(self, p2: _PointLike) -> float: """Function used to not take the square root as the distances will stay proportionally the same. This is to speed up the sorting process. :param p2:""" return (self[0] - p2[0]) ** 2 + (self[1] - p2[1]) ** 2 - def sort_by_distance(self, ps: Units | Iterable[Point2]) -> list[Point2]: + def sort_by_distance(self, ps: Iterable[_TPosLike]) -> list[_TPosLike]: """This returns the target points sorted as list. You should not pass a set or dict since those are not sortable. If you want to sort your units towards a point, use 'units.sorted_by_distance_to(point)' instead. :param ps:""" - return sorted(ps, key=lambda p: self.distance_to_point2(p.position)) + return sorted(ps, key=lambda p: self.distance_to_point2(p if isinstance(p, tuple) else p.position)) - def closest(self, ps: Units | Iterable[Point2]) -> Unit | Point2: + def closest(self, ps: Iterable[_TPosLike]) -> _TPosLike: """This function assumes the 2d distance is meant :param ps:""" assert ps, "ps is empty" - return min(ps, key=lambda p: self.distance_to(p)) + return min(ps, key=lambda p: self.distance_to_point2(p if isinstance(p, tuple) else p.position)) - def distance_to_closest(self, ps: Units | Iterable[Point2]) -> float: + def distance_to_closest(self, ps: Iterable[_TPosLike]) -> float: """This function assumes the 2d distance is meant :param ps:""" assert ps, "ps is empty" closest_distance = math.inf - for p2 in ps: - p2 = p2.position - distance = self.distance_to(p2) + for p in ps: + p2: tuple[float, ...] = p if isinstance(p, tuple) else p.position + distance = self.distance_to_point2(p2) if distance <= closest_distance: closest_distance = distance return closest_distance - def furthest(self, ps: Units | Iterable[Point2]) -> Unit | Pointlike: + def furthest(self, ps: Iterable[_TPosLike]) -> _TPosLike: """This function assumes the 2d distance is meant :param ps: Units object, or iterable of Unit or Point2""" assert ps, "ps is empty" - return max(ps, key=lambda p: self.distance_to(p)) + return max(ps, key=lambda p: self.distance_to_point2(p if isinstance(p, tuple) else p.position)) - def distance_to_furthest(self, ps: Units | Iterable[Point2]) -> float: + def distance_to_furthest(self, ps: Iterable[_PosLike]) -> float: """This function assumes the 2d distance is meant :param ps:""" assert ps, "ps is empty" furthest_distance = -math.inf - for p2 in ps: - p2 = p2.position - distance = self.distance_to(p2) + for p in ps: + p2: tuple[float, ...] = p if isinstance(p, tuple) else p.position + distance = self.distance_to_point2(p2) if distance >= furthest_distance: furthest_distance = distance return furthest_distance - def offset(self, p) -> Pointlike: + def offset(self: T, p: _PointLike) -> T: """ :param p: """ return self.__class__(a + b for a, b in itertools.zip_longest(self, p[: len(self)], fillvalue=0)) - def unit_axes_towards(self, p) -> Pointlike: + def unit_axes_towards(self: T, p: _PointLike) -> T: """ :param p: """ return self.__class__(_sign(b - a) for a, b in itertools.zip_longest(self, p[: len(self)], fillvalue=0)) - def towards(self, p: Unit | Pointlike, distance: int | float = 1, limit: bool = False) -> Pointlike: + def towards(self: T, p: _PosLike, distance: float = 1, limit: bool = False) -> T: """ :param p: :param distance: :param limit: """ - p = p.position + p2: tuple[float, ...] = p if isinstance(p, tuple) else p.position # assert self != p, f"self is {self}, p is {p}" # TODO test and fix this if statement - if self == p: + if self == p2: return self # end of test - d = self.distance_to(p) + d = self.distance_to_point2(p2) if limit: distance = min(d, distance) return self.__class__( - a + (b - a) / d * distance for a, b in itertools.zip_longest(self, p[: len(self)], fillvalue=0) + a + (b - a) / d * distance for a, b in itertools.zip_longest(self, p2[: len(self)], fillvalue=0) ) - def __eq__(self, other: object) -> bool: + def __eq__(self, other: Any) -> bool: try: return all(abs(a - b) <= EPSILON for a, b in itertools.zip_longest(self, other, fillvalue=0)) except TypeError: @@ -141,23 +154,25 @@ def __hash__(self) -> int: class Point2(Pointlike): + T = TypeVar("T", bound="Point2") + @classmethod - def from_proto(cls, data) -> Point2: + def from_proto( + cls, data: common_pb.Point | common_pb.Point2D | common_pb.Size2DI | common_pb.PointI | Point2 | Point3 + ) -> Point2: """ :param data: """ return cls((data.x, data.y)) @property - # pyre-fixme[11] def as_Point2D(self) -> common_pb.Point2D: return common_pb.Point2D(x=self.x, y=self.y) @property - # pyre-fixme[11] def as_PointI(self) -> common_pb.PointI: """Represents points on the minimap. Values must be between 0 and 64.""" - return common_pb.PointI(x=self.x, y=self.y) + return common_pb.PointI(x=int(self[0]), y=int(self[1])) @property def rounded(self) -> Point2: @@ -169,12 +184,12 @@ def length(self) -> float: return math.hypot(self[0], self[1]) @property - def normalized(self) -> Point2: + def normalized(self: Point2 | Point3) -> Point2: """This property exists in case Point2 is used as a vector.""" length = self.length # Cannot normalize if length is zero assert length - return self.__class__((self[0] / length, self[1] / length)) + return Point2((self[0] / length, self[1] / length)) @property def x(self) -> float: @@ -196,18 +211,19 @@ def round(self, decimals: int) -> Point2: """Rounds each number in the tuple to the amount of given decimals.""" return Point2((round(self[0], decimals), round(self[1], decimals))) - def offset(self, p: Point2) -> Point2: - return Point2((self[0] + p[0], self[1] + p[1])) + def offset(self: T, p: _PointLike) -> T: + return self.__class__((self[0] + p[0], self[1] + p[1])) - def random_on_distance(self, distance) -> Point2: + def random_on_distance(self, distance: float | tuple[float, float] | list[float]) -> Point2: if isinstance(distance, (tuple, list)): # interval - distance = distance[0] + random.random() * (distance[1] - distance[0]) - - assert distance > 0, "Distance is not greater than 0" + dist = distance[0] + random.random() * (distance[1] - distance[0]) + else: + dist = distance + assert dist > 0, "Distance is not greater than 0" angle = random.random() * 2 * math.pi dx, dy = math.cos(angle), math.sin(angle) - return Point2((self.x + dx * distance, self.y + dy * distance)) + return Point2((self.x + dx * dist, self.y + dy * dist)) def towards_with_random_angle( self, @@ -220,7 +236,7 @@ def towards_with_random_angle( angle = (angle - max_difference) + max_difference * 2 * random.random() return Point2((self.x + math.cos(angle) * distance, self.y + math.sin(angle) * distance)) - def circle_intersection(self, p: Point2, r: int | float) -> set[Point2]: + def circle_intersection(self, p: Point2, r: float) -> set[Point2]: """self is point1, p is point2, r is the radius for circles originating in both points Used in ramp finding @@ -248,68 +264,66 @@ def circle_intersection(self, p: Point2, r: int | float) -> set[Point2]: return {intersect1, intersect2} @property - def neighbors4(self) -> set: + def neighbors4(self: T) -> set[T]: return { - Point2((self.x - 1, self.y)), - Point2((self.x + 1, self.y)), - Point2((self.x, self.y - 1)), - Point2((self.x, self.y + 1)), + self.__class__((self[0] - 1, self[1])), + self.__class__((self[0] + 1, self[1])), + self.__class__((self[0], self[1] - 1)), + self.__class__((self[0], self[1] + 1)), } @property - def neighbors8(self) -> set: + def neighbors8(self: T) -> set[T]: return self.neighbors4 | { - Point2((self.x - 1, self.y - 1)), - Point2((self.x - 1, self.y + 1)), - Point2((self.x + 1, self.y - 1)), - Point2((self.x + 1, self.y + 1)), + self.__class__((self[0] - 1, self[1] - 1)), + self.__class__((self[0] - 1, self[1] + 1)), + self.__class__((self[0] + 1, self[1] - 1)), + self.__class__((self[0] + 1, self[1] + 1)), } - def negative_offset(self, other: Point2) -> Point2: + def negative_offset(self: T, other: Point2) -> T: return self.__class__((self[0] - other[0], self[1] - other[1])) - def __add__(self, other: Point2) -> Point2: + def __add__(self, other: Point2) -> Point2: # pyright: ignore[reportIncompatibleMethodOverride] return self.offset(other) def __sub__(self, other: Point2) -> Point2: return self.negative_offset(other) - def __neg__(self) -> Point2: + def __neg__(self: T) -> T: return self.__class__(-a for a in self) def __abs__(self) -> float: - return math.hypot(self.x, self.y) + return math.hypot(self[0], self[1]) def __bool__(self) -> bool: - return self.x != 0 or self.y != 0 + return self[0] != 0 or self[1] != 0 - def __mul__(self, other: int | float | Point2) -> Point2: - try: - # pyre-ignore[16] - return self.__class__((self.x * other.x, self.y * other.y)) - except AttributeError: - return self.__class__((self.x * other, self.y * other)) + def __mul__(self, other: _PointLike | float) -> Point2: # pyright: ignore[reportIncompatibleMethodOverride] + if isinstance(other, (int, float)): + return Point2((self[0] * other, self[1] * other)) + return Point2((self[0] * other[0], self[1] * other[1])) - def __rmul__(self, other: int | float | Point2) -> Point2: + def __rmul__(self, other: _PointLike | float) -> Point2: # pyright: ignore[reportIncompatibleMethodOverride] return self.__mul__(other) - def __truediv__(self, other: int | float | Point2) -> Point2: - if isinstance(other, self.__class__): - return self.__class__((self.x / other.x, self.y / other.y)) - return self.__class__((self.x / other, self.y / other)) + def __truediv__(self, other: float | Point2) -> Point2: + if isinstance(other, (int, float)): + return self.__class__((self[0] / other, self[1] / other)) + return self.__class__((self[0] / other[0], self[1] / other[1])) def is_same_as(self, other: Point2, dist: float = 0.001) -> bool: return self.distance_to_point2(other) <= dist def direction_vector(self, other: Point2) -> Point2: """Converts a vector to a direction that can face vertically, horizontally or diagonal or be zero, e.g. (0, 0), (1, -1), (1, 0)""" - return self.__class__((_sign(other.x - self.x), _sign(other.y - self.y))) + return self.__class__((_sign(other[0] - self[0]), _sign(other[1] - self[1]))) def manhattan_distance(self, other: Point2) -> float: """ :param other: """ - return abs(other.x - self.x) + abs(other.y - self.y) + return abs(other[0] - self[0]) + abs(other[1] - self[1]) @staticmethod def center(points: list[Point2]) -> Point2: @@ -324,14 +338,13 @@ def center(points: list[Point2]) -> Point2: class Point3(Point2): @classmethod - def from_proto(cls, data) -> Point3: + def from_proto(cls, data: common_pb.Point | Point3) -> Point3: # pyright: ignore[reportIncompatibleMethodOverride] """ :param data: """ return cls((data.x, data.y, data.z)) @property - # pyre-fixme[11] def as_Point(self) -> common_pb.Point: return common_pb.Point(x=self.x, y=self.y, z=self.z) @@ -348,13 +361,21 @@ def to3(self) -> Point3: return Point3(self) def __add__(self, other: Point2 | Point3) -> Point3: - if not isinstance(other, Point3) and isinstance(other, Point2): - return Point3((self.x + other.x, self.y + other.y, self.z)) - # pyre-ignore[16] - return Point3((self.x + other.x, self.y + other.y, self.z + other.z)) + if not isinstance(other, Point3): + return Point3((self[0] + other[0], self[1] + other[1], self[2])) + return Point3((self[0] + other[0], self[1] + other[1], self[2] + other[2])) class Size(Point2): + @classmethod + def from_proto( + cls, data: common_pb.Point | common_pb.Point2D | common_pb.Size2DI | common_pb.PointI | Point2 + ) -> Size: + """ + :param data: + """ + return cls((data.x, data.y)) + @property def width(self) -> float: return self[0] @@ -364,9 +385,9 @@ def height(self) -> float: return self[1] -class Rect(tuple): +class Rect(Point2): @classmethod - def from_proto(cls, data) -> Rect: + def from_proto(cls, data: common_pb.RectangleI) -> Rect: # pyright: ignore[reportIncompatibleMethodOverride] """ :param data: """ @@ -404,8 +425,8 @@ def size(self) -> Size: return Size((self[2], self[3])) @property - def center(self) -> Point2: + def center(self) -> Point2: # pyright: ignore[reportIncompatibleMethodOverride] return Point2((self.x + self.width / 2, self.y + self.height / 2)) - def offset(self, p) -> Rect: + def offset(self, p: _PointLike) -> Rect: return self.__class__((self[0] + p[0], self[1] + p[1], self[2], self[3])) diff --git a/sc2/power_source.py b/sc2/power_source.py index 8c64bb62..07814833 100644 --- a/sc2/power_source.py +++ b/sc2/power_source.py @@ -2,6 +2,7 @@ from dataclasses import dataclass +from s2clientprotocol import raw_pb2 from sc2.position import Point2 @@ -15,7 +16,7 @@ def __post_init__(self) -> None: assert self.radius > 0 @classmethod - def from_proto(cls, proto) -> PowerSource: + def from_proto(cls, proto: raw_pb2.PowerSource) -> PowerSource: return PowerSource(Point2.from_proto(proto.pos), proto.radius, proto.tag) def covers(self, position: Point2) -> bool: @@ -30,7 +31,7 @@ class PsionicMatrix: sources: list[PowerSource] @classmethod - def from_proto(cls, proto) -> PsionicMatrix: + def from_proto(cls, proto: list[raw_pb2.PowerSource]) -> PsionicMatrix: return PsionicMatrix([PowerSource.from_proto(p) for p in proto]) def covers(self, position: Point2) -> bool: diff --git a/sc2/protocol.py b/sc2/protocol.py index 5577b08f..2722abe0 100644 --- a/sc2/protocol.py +++ b/sc2/protocol.py @@ -3,13 +3,14 @@ import asyncio import sys from contextlib import suppress +from typing import overload from aiohttp.client_ws import ClientWebSocketResponse from loguru import logger # pyre-fixme[21] from s2clientprotocol import sc2api_pb2 as sc_pb - +from s2clientprotocol.query_pb2 import RequestQuery from sc2.data import Status @@ -34,7 +35,7 @@ def __init__(self, ws: ClientWebSocketResponse) -> None: # pyre-fixme[11] self._status: Status | None = None - async def __request(self, request): + async def __request(self, request: sc_pb.Request) -> sc_pb.Response: logger.debug(f"Sending request: {request!r}") try: await self._ws.send_bytes(request.SerializeToString()) @@ -65,7 +66,51 @@ async def __request(self, request): logger.debug("Response received") return response - async def _execute(self, **kwargs): + @overload + async def _execute(self, create_game: sc_pb.RequestCreateGame) -> sc_pb.Response: ... + @overload + async def _execute(self, join_game: sc_pb.RequestJoinGame) -> sc_pb.Response: ... + @overload + async def _execute(self, restart_game: sc_pb.RequestRestartGame) -> sc_pb.Response: ... + @overload + async def _execute(self, start_replay: sc_pb.RequestStartReplay) -> sc_pb.Response: ... + @overload + async def _execute(self, leave_game: sc_pb.RequestLeaveGame) -> sc_pb.Response: ... + @overload + async def _execute(self, quick_save: sc_pb.RequestQuickSave) -> sc_pb.Response: ... + @overload + async def _execute(self, quick_load: sc_pb.RequestQuickLoad) -> sc_pb.Response: ... + @overload + async def _execute(self, quit: sc_pb.RequestQuit) -> sc_pb.Response: ... + @overload + async def _execute(self, game_info: sc_pb.RequestGameInfo) -> sc_pb.Response: ... + @overload + async def _execute(self, action: sc_pb.RequestAction) -> sc_pb.Response: ... + @overload + async def _execute(self, observation: sc_pb.RequestObservation) -> sc_pb.Response: ... + @overload + async def _execute(self, obs_action: sc_pb.RequestObserverAction) -> sc_pb.Response: ... + @overload + async def _execute(self, step: sc_pb.RequestStep) -> sc_pb.Response: ... + @overload + async def _execute(self, data: sc_pb.RequestData) -> sc_pb.Response: ... + @overload + async def _execute(self, query: RequestQuery) -> sc_pb.Response: ... + @overload + async def _execute(self, save_replay: sc_pb.RequestSaveReplay) -> sc_pb.Response: ... + @overload + async def _execute(self, map_command: sc_pb.RequestMapCommand) -> sc_pb.Response: ... + @overload + async def _execute(self, replay_info: sc_pb.RequestReplayInfo) -> sc_pb.Response: ... + @overload + async def _execute(self, available_maps: sc_pb.RequestAvailableMaps) -> sc_pb.Response: ... + @overload + async def _execute(self, save_map: sc_pb.RequestSaveMap) -> sc_pb.Response: ... + @overload + async def _execute(self, ping: sc_pb.RequestPing) -> sc_pb.Response: ... + @overload + async def _execute(self, debug: sc_pb.RequestDebug) -> sc_pb.Response: ... + async def _execute(self, **kwargs) -> sc_pb.Response: assert len(kwargs) == 1, "Only one request allowed by the API" response = await self.__request(sc_pb.Request(**kwargs)) diff --git a/sc2/proxy.py b/sc2/proxy.py index f2690322..340570cd 100644 --- a/sc2/proxy.py +++ b/sc2/proxy.py @@ -15,7 +15,6 @@ # pyre-fixme[21] from s2clientprotocol import sc2api_pb2 as sc_pb - from sc2.controller import Controller from sc2.data import Result, Status from sc2.player import BotProcess diff --git a/sc2/py.typed b/sc2/py.typed new file mode 100644 index 00000000..d360fd84 --- /dev/null +++ b/sc2/py.typed @@ -0,0 +1 @@ +# Required by https://peps.python.org/pep-0561/#packaging-type-information diff --git a/sc2/renderer.py b/sc2/renderer.py index 4d9f94ff..17e3599e 100644 --- a/sc2/renderer.py +++ b/sc2/renderer.py @@ -1,13 +1,18 @@ +from __future__ import annotations + import datetime +from typing import TYPE_CHECKING -# pyre-ignore[21] from s2clientprotocol import score_pb2 as score_pb - +from s2clientprotocol.sc2api_pb2 import ResponseObservation from sc2.position import Point2 +if TYPE_CHECKING: + from sc2.client import Client + class Renderer: - def __init__(self, client, map_size, minimap_size) -> None: + def __init__(self, client: Client, map_size: tuple[float, float], minimap_size: tuple[float, float]) -> None: self._client = client self._window = None @@ -22,7 +27,7 @@ def __init__(self, client, map_size, minimap_size) -> None: self._text_score = None self._text_time = None - async def render(self, observation) -> None: + async def render(self, observation: ResponseObservation) -> None: render_data = observation.observation.render_data map_size = render_data.map.size diff --git a/sc2/sc2process.py b/sc2/sc2process.py index 846dc480..391d30b1 100644 --- a/sc2/sc2process.py +++ b/sc2/sc2process.py @@ -2,7 +2,6 @@ import asyncio import os -import os.path import shutil import signal import subprocess @@ -14,8 +13,6 @@ from typing import Any import aiohttp - -# pyre-ignore[21] import portpicker from aiohttp.client_ws import ClientWebSocketResponse from loguru import logger @@ -143,7 +140,6 @@ def find_data_hash(self, target_sc2_version: str) -> str | None: def find_base_dir(self, target_sc2_version: str) -> str | None: """Returns the base directory from the matching version string.""" - version: dict for version in self.versions: if version["label"] == target_sc2_version: return "Base" + str(version["base-version"]) diff --git a/sc2/score.py b/sc2/score.py index 9b8f5f2c..aba9c8ff 100644 --- a/sc2/score.py +++ b/sc2/score.py @@ -1,14 +1,19 @@ +from __future__ import annotations + +from s2clientprotocol import score_pb2 + + class ScoreDetails: """Accessable in self.state.score during step function For more information, see https://github.com/Blizzard/s2client-proto/blob/master/s2clientprotocol/score.proto """ - def __init__(self, proto) -> None: + def __init__(self, proto: score_pb2.Score) -> None: self._data = proto self._proto = proto.score_details @property - def summary(self): + def summary(self) -> list[list[int | float]]: """ TODO this is super ugly, how can we improve this summary? Print summary to file with: diff --git a/sc2/unit.py b/sc2/unit.py index 07b63e90..236116f9 100644 --- a/sc2/unit.py +++ b/sc2/unit.py @@ -7,6 +7,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Any +from s2clientprotocol import raw_pb2 from sc2.cache import CacheDict from sc2.constants import ( CAN_BE_ATTACKED, @@ -57,11 +58,12 @@ from sc2.ids.buff_id import BuffId from sc2.ids.unit_typeid import UnitTypeId from sc2.ids.upgrade_id import UpgradeId -from sc2.position import Point2, Point3 +from sc2.position import HasPosition2D, Point2, Point3, _PointLike from sc2.unit_command import UnitCommand if TYPE_CHECKING: from sc2.bot_ai import BotAI + from sc2.bot_ai_internal import BotAIInternal from sc2.game_data import AbilityData, UnitTypeData @@ -71,7 +73,7 @@ class RallyTarget: tag: int | None = None @classmethod - def from_proto(cls, proto: Any) -> RallyTarget: + def from_proto(cls, proto: raw_pb2.RallyTarget) -> RallyTarget: return cls( Point2.from_proto(proto.point), proto.tag if proto.HasField("tag") else None, @@ -85,7 +87,7 @@ class UnitOrder: progress: float = 0 @classmethod - def from_proto(cls, proto: Any, bot_object: BotAI) -> UnitOrder: + def from_proto(cls, proto: raw_pb2.UnitOrder, bot_object: BotAI) -> UnitOrder: target: int | Point2 | None = proto.target_unit_tag if proto.HasField("target_world_space_pos"): target = Point2.from_proto(proto.target_world_space_pos) @@ -101,13 +103,13 @@ def __repr__(self) -> str: return f"UnitOrder({self.ability}, {self.target}, {self.progress})" -class Unit: +class Unit(HasPosition2D): class_cache = CacheDict() def __init__( self, - proto_data, - bot_object: BotAI, + proto_data: raw_pb2.Unit, + bot_object: BotAI | BotAIInternal, distance_calculation_index: int = -1, base_build: int = -1, ) -> None: @@ -118,7 +120,7 @@ def __init__( :param base_build: """ self._proto = proto_data - self._bot_object: BotAI = bot_object + self._bot_object = bot_object self.game_loop: int = bot_object.state.game_loop self.base_build = base_build # Index used in the 2D numpy array to access the 2D distance between two units @@ -162,37 +164,37 @@ def tag(self) -> int: @property def is_structure(self) -> bool: """Checks if the unit is a structure.""" - return IS_STRUCTURE in self._type_data.attributes + return IS_STRUCTURE in self._type_data._proto.attributes @property def is_light(self) -> bool: """Checks if the unit has the 'light' attribute.""" - return IS_LIGHT in self._type_data.attributes + return IS_LIGHT in self._type_data._proto.attributes @property def is_armored(self) -> bool: """Checks if the unit has the 'armored' attribute.""" - return IS_ARMORED in self._type_data.attributes + return IS_ARMORED in self._type_data._proto.attributes @property def is_biological(self) -> bool: """Checks if the unit has the 'biological' attribute.""" - return IS_BIOLOGICAL in self._type_data.attributes + return IS_BIOLOGICAL in self._type_data._proto.attributes @property def is_mechanical(self) -> bool: """Checks if the unit has the 'mechanical' attribute.""" - return IS_MECHANICAL in self._type_data.attributes + return IS_MECHANICAL in self._type_data._proto.attributes @property def is_massive(self) -> bool: """Checks if the unit has the 'massive' attribute.""" - return IS_MASSIVE in self._type_data.attributes + return IS_MASSIVE in self._type_data._proto.attributes @property def is_psionic(self) -> bool: """Checks if the unit has the 'psionic' attribute.""" - return IS_PSIONIC in self._type_data.attributes + return IS_PSIONIC in self._type_data._proto.attributes @cached_property def tech_alias(self) -> list[UnitTypeId] | None: @@ -527,7 +529,7 @@ def position_tuple(self) -> tuple[float, float]: return self._proto.pos.x, self._proto.pos.y @cached_property - def position(self) -> Point2: + def position(self) -> Point2: # pyright: ignore[reportIncompatibleMethodOverride] """Returns the 2d position of the unit.""" return Point2.from_proto(self._proto.pos) @@ -536,7 +538,7 @@ def position3d(self) -> Point3: """Returns the 3d position of the unit.""" return Point3.from_proto(self._proto.pos) - def distance_to(self, p: Unit | Point2) -> float: + def distance_to(self, p: Unit | _PointLike) -> float: """Using the 2d distance between self and p. To calculate the 3d distance, use unit.position3d.distance_to(p) @@ -546,7 +548,7 @@ def distance_to(self, p: Unit | Point2) -> float: return self._bot_object._distance_squared_unit_to_unit(self, p) ** 0.5 return self._bot_object.distance_math_hypot(self.position_tuple, p) - def distance_to_squared(self, p: Unit | Point2) -> float: + def distance_to_squared(self, p: Unit | _PointLike) -> float: """Using the 2d distance squared between self and p. Slightly faster than distance_to, so when filtering a lot of units, this function is recommended to be used. To calculate the 3d distance, use unit.position3d.distance_to(p) @@ -705,7 +707,7 @@ def calculate_damage_vs_target( # TODO: hardcode hellbats when they have blueflame or attack upgrades for bonus in weapon.damage_bonus: # More about damage bonus https://github.com/Blizzard/s2client-proto/blob/b73eb59ac7f2c52b2ca585db4399f2d3202e102a/s2clientprotocol/data.proto#L55 - if bonus.attribute in target._type_data.attributes: + if bonus.attribute in target._type_data._proto.attributes: bonus_damage_per_upgrade = ( 0 if not self.attack_upgrade_level @@ -1034,7 +1036,7 @@ def order_target(self) -> int | Point2 | None: from the first order, returns None if the unit is idle""" if self.orders: target = self.orders[0].target - if isinstance(target, int): + if target is None or isinstance(target, int): return target return Point2.from_proto(target) return None diff --git a/sc2/units.py b/sc2/units.py index 38813601..1871dfc6 100644 --- a/sc2/units.py +++ b/sc2/units.py @@ -6,6 +6,7 @@ from itertools import chain from typing import TYPE_CHECKING, Any +from s2clientprotocol import raw_pb2 from sc2.ids.unit_typeid import UnitTypeId from sc2.position import Point2 from sc2.unit import Unit @@ -14,11 +15,11 @@ from sc2.bot_ai import BotAI -class Units(list): +class Units(list[Unit]): """A collection of Unit objects. Makes it easy to select units by selectors.""" @classmethod - def from_proto(cls, units, bot_object: BotAI) -> Units: + def from_proto(cls, units: list[raw_pb2.Unit], bot_object: BotAI) -> Units: return cls((Unit(raw_unit, bot_object=bot_object) for raw_unit in units), bot_object) def __init__(self, units: Iterable[Unit], bot_object: BotAI) -> None: diff --git a/test/autotest_bot.py b/test/autotest_bot.py index 10abb82e..a6a0dc23 100644 --- a/test/autotest_bot.py +++ b/test/autotest_bot.py @@ -458,7 +458,7 @@ async def test_botai_actions12(self): # Pick scv scv: Unit = self.workers.random # Pick location to build depot on - placement_position: Point2 = await self.find_placement( + placement_position: Point2 | None = await self.find_placement( UnitTypeId.SUPPLYDEPOT, near=self.townhalls.random.position ) if placement_position: diff --git a/test/generate_pickle_files_bot.py b/test/generate_pickle_files_bot.py index ff12aa29..3b4158ee 100644 --- a/test/generate_pickle_files_bot.py +++ b/test/generate_pickle_files_bot.py @@ -10,9 +10,7 @@ from loguru import logger -# pyre-ignore[21] from s2clientprotocol import sc2api_pb2 as sc_pb - from sc2 import maps from sc2.bot_ai import BotAI from sc2.data import Difficulty, Race diff --git a/test/test_pickled_data.py b/test/test_pickled_data.py index ca0e7767..8a6c5e69 100644 --- a/test/test_pickled_data.py +++ b/test/test_pickled_data.py @@ -20,7 +20,6 @@ from pathlib import Path from typing import Any -# pyre-ignore[21] from google.protobuf.internal import api_implementation from hypothesis import given, settings from hypothesis import strategies as st