Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
H
Hello World
Manage
Activity
Members
Labels
Plan
Issues
6
Issue boards
Milestones
Wiki
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Operate
Terraform modules
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
OST
SA
projects
Hello World
Commits
9ae6fc92
Commit
9ae6fc92
authored
1 year ago
by
Andri Joos
Browse files
Options
Downloads
Patches
Plain Diff
fix env observation shape
parent
0608cd70
Loading
Loading
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
HelloWorldEnv.py
+2
-2
2 additions, 2 deletions
HelloWorldEnv.py
with
2 additions
and
2 deletions
HelloWorldEnv.py
+
2
−
2
View file @
9ae6fc92
...
...
@@ -13,7 +13,7 @@ from tf_agents.specs.array_spec import BoundedArraySpec, ArraySpec
class
HelloWorldEnv
(
PyEnvironment
):
def
__init__
(
self
,
ip
:
str
,
handle_auto_reset
:
bool
=
False
):
self
.
_action_spec
=
BoundedArraySpec
(
shape
=
(),
dtype
=
np
.
int32
,
minimum
=
0
,
maximum
=
2
,
name
=
'
action
'
)
self
.
_observation_spec
=
ArraySpec
(
shape
=
(
1
,),
dtype
=
np
.
float
32
,
name
=
'
observation
'
)
self
.
_observation_spec
=
ArraySpec
(
shape
=
(
1
,),
dtype
=
np
.
float
64
,
name
=
'
observation
'
)
self
.
_client
=
airsim
.
MultirotorClient
(
ip
=
ip
)
super
(
HelloWorldEnv
,
self
).
__init__
(
handle_auto_reset
)
...
...
@@ -53,7 +53,7 @@ class HelloWorldEnv(PyEnvironment):
def
_getObservation
(
self
):
geo_point
=
self
.
_client
.
getGpsData
().
gnss
.
geo_point
# return geo_point.latitude, geo_point.longitude, geo_point.altitude
return
(
geo_point
.
altitude
,
)
return
np
.
array
([
geo_point
.
altitude
]
)
@staticmethod
def
_calculate_reward
(
obs
:
tuple
[
float
,
float
,
float
]):
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment