Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
H
Hello World
Manage
Activity
Members
Labels
Plan
Issues
6
Issue boards
Milestones
Wiki
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Operate
Terraform modules
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
OST
SA
projects
Hello World
Commits
a8e4c8f4
Commit
a8e4c8f4
authored
1 year ago
by
Andri Joos
Browse files
Options
Downloads
Patches
Plain Diff
add environment
parent
0c672e90
No related branches found
Branches containing commit
No related tags found
1 merge request
!1
Resolve "Setup Environment"
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
HelloWorldEnv.py
+59
-0
59 additions, 0 deletions
HelloWorldEnv.py
with
59 additions
and
0 deletions
HelloWorldEnv.py
0 → 100644
+
59
−
0
View file @
a8e4c8f4
from
typing
import
Text
import
airsim
import
numpy
as
np
import
HelloWorldEnv
from
tf_agents.environments.py_environment
import
PyEnvironment
from
tf_agents.environments.tf_py_environment
import
TFPyEnvironment
from
tf_agents.trajectories
import
time_step
as
ts
from
tf_agents.typing
import
types
from
tf_agents.typing.types
import
NestedArraySpec
,
TimeStep
,
NestedArray
from
tf_agents.specs.array_spec
import
BoundedArraySpec
,
ArraySpec
class
HelloWorldEnv
(
PyEnvironment
):
def
__init__
(
self
,
ip
:
str
,
handle_auto_reset
:
bool
=
False
):
self
.
_action_spec
=
BoundedArraySpec
(
shape
=
(),
dtype
=
np
.
int32
,
minimum
=
0
,
maximum
=
2
,
name
=
'
action
'
)
self
.
_observation_spec
=
ArraySpec
(
shape
=
(
3
,),
dtype
=
np
.
float32
,
name
=
'
observation
'
)
self
.
_client
=
airsim
.
MultirotorClient
(
ip
=
ip
)
super
(
HelloWorldEnv
,
self
).
__init__
(
handle_auto_reset
)
def
action_spec
(
self
)
->
NestedArraySpec
:
return
self
.
_action_spec
def
observation_spec
(
self
)
->
NestedArraySpec
:
return
self
.
_observation_spec
def
_reset
(
self
)
->
TimeStep
:
self
.
_client
.
reset
()
self
.
_client
.
confirmConnection
()
self
.
_client
.
armDisarm
(
True
)
self
.
_client
.
enableApiControl
(
True
)
return
ts
.
restart
(
self
.
_getObservation
())
def
_step
(
self
,
action
:
NestedArray
)
->
TimeStep
:
vz
=
0
if
action
==
1
:
# up
vz
=
-
1
elif
action
==
2
:
# down
vz
=
1
elif
action
==
0
:
# stop
vz
==
0
self
.
_client
.
cancelLastTask
()
self
.
_client
.
moveByVelocityAsync
(
0
,
0
,
vz
,
5
)
# max 5 seconds flight without a new command until the drone stops
obs
=
self
.
_getObservation
()
reward
=
HelloWorldEnv
.
_calculate_reward
(
obs
)
return
ts
.
transition
(
obs
,
reward
=
reward
)
def
render
(
self
,
mode
:
Text
=
'
rgb_array
'
)
->
NestedArray
|
None
:
raise
NotImplementedError
()
def
_getObservation
(
self
):
geo_point
=
self
.
_client
.
getGpsData
().
gnss
.
geo_point
return
geo_point
.
latitude
,
geo_point
.
longitude
,
geo_point
.
altitude
@staticmethod
def
_calculate_reward
(
obs
:
tuple
[
float
,
float
,
float
]):
return
-
abs
(
133
-
obs
[
2
])
This diff is collapsed.
Click to expand it.
Preview
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment