diff --git a/HelloWorldEnv.py b/HelloWorldEnv.py index 154e1354993671fcf90c720b7a7d82dc1a11d709..fc570735469f9b81afaf703ffdd8ad13a71c97db 100644 --- a/HelloWorldEnv.py +++ b/HelloWorldEnv.py @@ -13,7 +13,7 @@ from tf_agents.specs.array_spec import BoundedArraySpec, ArraySpec class HelloWorldEnv(PyEnvironment): def __init__(self, ip: str, handle_auto_reset: bool = False): self._action_spec = BoundedArraySpec(shape=(), dtype=np.int32, minimum=0, maximum=2, name='action') - self._observation_spec = ArraySpec(shape=(3,), dtype=np.float32, name='observation') + self._observation_spec = ArraySpec(shape=(1,), dtype=np.float32, name='observation') self._client = airsim.MultirotorClient(ip=ip) super(HelloWorldEnv, self).__init__(handle_auto_reset) @@ -52,8 +52,10 @@ class HelloWorldEnv(PyEnvironment): def _getObservation(self): geo_point = self._client.getGpsData().gnss.geo_point - return geo_point.latitude, geo_point.longitude, geo_point.altitude + # return geo_point.latitude, geo_point.longitude, geo_point.altitude + return (geo_point.altitude,) @staticmethod def _calculate_reward(obs: tuple[float, float, float]): - return -abs(133 - obs[2]) + # return -abs(133 - obs[2]) + return -abs(133 - (obs[0]))