diff --git a/HelloWorldEnv.py b/HelloWorldEnv.py new file mode 100644 index 0000000000000000000000000000000000000000..fc09bbf7ddb0aef471ab5d3d094d70b4d6c558ce --- /dev/null +++ b/HelloWorldEnv.py @@ -0,0 +1,59 @@ +from typing import Text +import airsim +import numpy as np +import HelloWorldEnv + +from tf_agents.environments.py_environment import PyEnvironment +from tf_agents.environments.tf_py_environment import TFPyEnvironment +from tf_agents.trajectories import time_step as ts +from tf_agents.typing import types +from tf_agents.typing.types import NestedArraySpec, TimeStep, NestedArray +from tf_agents.specs.array_spec import BoundedArraySpec, ArraySpec + +class HelloWorldEnv(PyEnvironment): + def __init__(self, ip: str, handle_auto_reset: bool = False): + self._action_spec = BoundedArraySpec(shape=(), dtype=np.int32, minimum=0, maximum=2, name='action') + self._observation_spec = ArraySpec(shape=(3,), dtype=np.float32, name='observation') + self._client = airsim.MultirotorClient(ip=ip) + super(HelloWorldEnv, self).__init__(handle_auto_reset) + + def action_spec(self) -> NestedArraySpec: + return self._action_spec + + def observation_spec(self) -> NestedArraySpec: + return self._observation_spec + + def _reset(self) -> TimeStep: + self._client.reset() + self._client.confirmConnection() + self._client.armDisarm(True) + self._client.enableApiControl(True) + return ts.restart(self._getObservation()) + + def _step(self, action: NestedArray) -> TimeStep: + vz = 0 + if action == 1: # up + vz = -1 + elif action == 2: # down + vz = 1 + elif action == 0: # stop + vz == 0 + + self._client.cancelLastTask() + self._client.moveByVelocityAsync(0, 0, vz, 5) # max 5 seconds flight without a new command until the drone stops + + obs = self._getObservation() + reward = HelloWorldEnv._calculate_reward(obs) + + return ts.transition(obs, reward=reward) + + def render(self, mode: Text = 'rgb_array') -> NestedArray | None: + raise NotImplementedError() + + def _getObservation(self): + geo_point = self._client.getGpsData().gnss.geo_point + return geo_point.latitude, geo_point.longitude, geo_point.altitude + + @staticmethod + def _calculate_reward(obs: tuple[float, float, float]): + return -abs(133 - obs[2])