#!/usr/bin/env python
"""Ties together the learning environment and main learning loop.
Authors:
Michele Albach, Shibhansh Dohare, David Quail, Parash Rahman, Niko Yasui.
"""
from __future__ import division
import time
from Queue import Queue
from multiprocessing import Value
import cv2
import geometry_msgs.msg as geom_msg
import numpy as np
import rosbag
import rospy
import std_msgs.msg as std_msg
from state_representation import StateManager
import tools
from tools import timing
from visualize_pixels import Visualize
[docs]class LearningForeground:
"""Connects the environment through sensors and ROS with the learning algs.
Args:
time_scale (float): Length of a time step in seconds.
gvfs (list of GVF): List of GVFs to learn.
features_to_use (set of str): Union of the features that each GVF
uses to learn their respective predictions.
behavior_policy (Policy): Policy for the robot to follow.
stats (list of str): List of statistics to record and publish.
control_gvf (GVF): GVF that needs to be reset when the episode
restarts.
cumulant_counter (multiprocessing Value): Record for the number of
times the cumulant is non-zero. Could be incorporated into the
Evaluator.
reset_episode (fun): Whether the episode should be reset.
custom_stats (dictionary[string:lambda]): The custom topics defined by the user.
Attributes:
COLLECT_DATA_FLAG (bool): Whether or not to save data in bags.
vis (bool): Whether or not to use the visualizer.
to_replay_experience (bool): Whether or not to use experience replay.
recent (dict of queue): Dictionary mapping topic names to the queue
of recent values from their respective topics.
publishers (dict of ROS publishers): Publishers for each of the
data we want to publish.
"""
def __init__(self,
time_scale,
gvfs,
features_to_use,
behavior_policy,
stats,
control_gvf=None,
cumulant_counter=None,
reset_episode=None,
custom_stats=None):
# function that generates a list of actions to perform to reset episode
self.reset_episode = reset_episode
# set up ros
rospy.init_node('agent', anonymous=True)
self.COLLECT_DATA_FLAG = False
# counts the total cumulant for the session
if cumulant_counter:
self.cumulant_counter = cumulant_counter
else:
self.cumulant_counter = Value('d', 0)
# capture this session's data and actions
if self.COLLECT_DATA_FLAG:
self.history = rosbag.Bag('results.bag', 'w')
self.current_time = rospy.Time().now()
self.vis = False
# self.vis = True
extras = {'core', 'ir', 'odom'}
self.features_to_use = set(features_to_use).union(extras)
topics = filter(lambda x: x,
[tools.features[f] for f in self.features_to_use])
# set up dictionary to receive sensor info
self.recent = {topic: Queue(0) for topic in topics}
# setup sensor parsers
for topic in topics:
rospy.Subscriber(topic,
tools.topic_format[topic],
self.recent[topic].put)
rospy.loginfo("Started sensor threads.")
# smooth out the actions
self.time_scale = time_scale
self.r = rospy.Rate(int(1.0 / self.time_scale))
# agent info
self.gvfs = gvfs
self.control_gvf = control_gvf
self.behavior_policy = behavior_policy
self.avg_td_err = None
self.state_manager = StateManager(features_to_use)
if self.vis:
rospy.loginfo("Creating visualization.")
self.visualization = Visualize(self.state_manager.pixel_mask,
imsizex=640,
imsizey=480)
rospy.loginfo("Done creatiing visualization.")
# previous timestep information
self.last_action = None
self.last_phi = None
self.last_observation = None
self.last_mu = 1
# experience replay
self.to_replay_experience = False
action_publisher = rospy.Publisher('action_cmd',
geom_msg.Twist,
queue_size=1)
pause_publisher = rospy.Publisher('pause',
std_msg.Bool,
queue_size=1)
termination_publisher = rospy.Publisher('termination',
std_msg.Bool,
queue_size=1)
self.publishers = {'action': action_publisher,
'pause': pause_publisher,
'termination': termination_publisher
}
valid_stats = ['prediction', 'td_error', 'avg_td_error', 'rupee',
'MSRE', 'cumulant', 'phi', 'e', 'rho', 'ESS']
self.stat_data = {'prediction': lambda g: g.last_prediction,
'cumulant': lambda g: g.last_cumulant,
'td_error': lambda g: g.evaluator.td_error,
'avg_td_error': lambda g: g.evaluator.avg_td_error,
'rupee': lambda g: g.evaluator.rupee,
'MSRE': lambda g: g.evaluator.MSRE,
'phi': lambda g: g.phi.sum(),
'e': lambda g: g.learner.e.sum(),
'rho': lambda g: g.rho,
'ESS': lambda g: g.evaluator.ESS}
self.stats = filter(lambda s: s in valid_stats, stats)
if custom_stats != None:
self.stat_data.update(custom_stats)
self.stats += custom_stats.keys()
def publisher_name(gvf, label):
return '{}/{}'.format(gvf, label) if gvf else label
def make_publisher(gvf, label):
return rospy.Publisher(publisher_name(gvf, label),
std_msg.Float64,
queue_size=10)
stat_publishers = {gvf: {stat: make_publisher(gvf.name, stat)
for stat in self.stats}
for gvf in self.gvfs}
self.publishers.update(stat_publishers)
rospy.loginfo("Done LearningForeground init.")
[docs] @timing
def update_gvfs(self, phi_prime, observation, action):
"""
Calls the GVF update function for each GVF and publishes their updated
statistics.
Args:
phi_prime (numpy array): Feature vector for timestep t+1.
observation (dict): Ancillary state information.
action (action): Action taken at time t+1.
"""
for gvf in self.gvfs:
gvf.update(self.last_observation,
self.last_phi,
self.last_action,
observation,
phi_prime,
self.last_mu,
action)
# publishing
for gvf in self.gvfs:
for stat in self.stats:
self.publishers[gvf][stat].publish(self.stat_data[stat](gvf))
[docs] def read_source(self, source, history=False):
"""Reads from the topics and returns the most recent value.
"""
temp = [] if history else None
try:
stream = tools.features[source]
while True:
if history:
temp.append(self.recent[stream].get_nowait())
else:
temp = self.recent[stream].get_nowait()
except:
pass
return temp
[docs] @timing
def create_state(self):
"""Uses data from :py:attr:`recent` to create the state representation.
1. Reads data from the :py:attr:`recent` dictionary.
2. Process data into a format to pass to :doc:`state_representation`.
3. Pass data to :py:func:`~state_representation.StateManager.get_phi`
and :py:func:`~state_representation.StateManager.get_observation`.
Returns:
(numpy array, dict): Feature vector and ancillary state
information.
"""
# bumper constants from
# http://docs.ros.org/hydro/api/kobuki_msgs/html/msg/SensorState.html
bump_codes = [1, 4, 2]
# initialize data
additional_features = set(tools.features.keys() + ['charging'])
sensors = self.features_to_use.union(additional_features)
# read data (used to make phi)
data = {sensor: None for sensor in sensors}
for source in sensors - {'ir', 'core'}:
data[source] = self.read_source(source)
data['ir'] = self.read_source('ir', history=True)[-10:]
data['core'] = self.read_source('core', history=True)
# process data
if data['core']:
bumps = [dat.bumper for dat in data['core']]
data['bump'] = np.sum(
[[bool(x & bump) for x in bump_codes] for bump in bumps],
axis=0, dtype=bool).tolist()
data['charging'] = bool(data['core'][-1].charger & 2)
# enter the data into rosbag
if self.COLLECT_DATA_FLAG:
for bindex in range(len(data['bump'])):
bump_bool = std_msg.Bool()
bump_bool.data = data['bump'][bindex] if data['bump'][
bindex] else False
self.history.write('bump' + str(bindex), bump_bool,
t=self.current_time)
charge_bool = std_msg.Bool()
charge_bool.data = data['charging']
self.history.write('charging', charge_bool,
t=self.current_time)
if data['ir']:
ir = [[0] * 6] * 3
# bitwise 'or' of all the ir data in last time_step
for temp in data['ir']:
a = [[int(x) for x in format(temp, '#08b')[2:]] for temp in
[ord(obs) for obs in temp.data]]
ir = [[k | l for k, l in zip(i, j)] for i, j in zip(a, ir)]
data['ir'] = [int(''.join([str(i) for i in ir_temp]), 2) for
ir_temp in ir]
# enter the data into rosbag
if self.COLLECT_DATA_FLAG:
ir_array = std_msg.Int32MultiArray()
ir_array.data = data['ir']
self.history.write('ir', ir_array, t=self.current_time)
if data['image'] is not None:
# enter the data into rosbag
# image_array = std_msg.Int32MultiArray()
# image_array.data = data['image']
if self.COLLECT_DATA_FLAG:
self.history.write('image', data['image'], t=self.current_time)
# uncompressed image
data['image'] = np.fromstring(data['image'].data,
np.uint8).reshape(480, 640, 3)
# compressing image
if data['cimage'] is not None:
data['image'] = cv2.imdecode(np.fromstring(data['cimage'].data,
np.uint8),
1)
if data['odom'] is not None:
pos = data['odom'].pose.pose.position
lin_vel = data['odom'].twist.twist.linear.x
ang_vel = data['odom'].twist.twist.angular.z
data['odom'] = np.array([pos.x, pos.y, lin_vel, ang_vel])
# enter the data into rosbag
if self.COLLECT_DATA_FLAG:
odom_x = std_msg.Float64()
odom_x.data = pos.x
odom_y = std_msg.Float64()
odom_y.data = pos.y
odom_lin = std_msg.Float64()
odom_lin.data = lin_vel
odom_ang = std_msg.Float64()
odom_ang.data = ang_vel
self.history.write('odom_x', odom_x, t=self.current_time)
self.history.write('odom_y', odom_y, t=self.current_time)
self.history.write('odom_lin', odom_lin, t=self.current_time)
self.history.write('odom_ang', odom_ang, t=self.current_time)
if data['imu'] is not None:
data['imu'] = data['imu'].orientation.z
# TODO: enter the data into rosbag
if 'bias' in self.features_to_use:
data['bias'] = True
data['weights'] = self.gvfs[0].learner.theta if self.gvfs else None
phi = self.state_manager.get_phi(**data)
if 'last_action' in self.features_to_use:
last_action = np.zeros(self.behavior_policy.action_space.size)
last_action[self.behavior_policy.last_index] = True
phi = np.concatenate([phi, last_action])
# update the visualization of the image data
if self.vis:
self.visualization.update_colours(data['image'])
observation = self.state_manager.get_observations(**data)
observation['action'] = self.last_action
if observation['bump']:
# adds a tally for the added cumulant
self.cumulant_counter.value += 1
return phi, observation
[docs] def take_action(self, action):
self.publishers['action'].publish(action)
[docs] def run(self):
"""Main learning loop.
Repeat:
1. Get new state.
2. Take an action.
3. Learn.
"""
avg_time = 0
time_step = 0
max_time = 0
while not rospy.is_shutdown():
start_time = time.time()
self.current_time = rospy.Time().now()
# get new state
phi_prime, observation = self.create_state()
# select and take an action
self.behavior_policy.update(phi_prime, observation)
action = self.behavior_policy.choose_action()
mu = self.behavior_policy.get_probability(action)
self.take_action(action)
if self.COLLECT_DATA_FLAG:
self.history.write('action', action, t=self.current_time)
# learn
if self.last_observation is not None:
self.update_gvfs(phi_prime, observation, action)
# check if episode is over and reset accordingly [episodic]
if self.control_gvf is not None:
if self.control_gvf.learner.episode_finished_last_step:
reset_actions = self.reset_episode()
for action in reset_actions:
self.take_action(action)
msg = 'taking random action number: {}'.format(action)
rospy.loginfo(msg)
if self.to_replay_experience:
self.control_gvf.learner.uniform_experience_replay()
self.r.sleep()
elif self.to_replay_experience:
# not to replay when the episode resets at it will also
# include the experience at the start of new episode
self.control_gvf.learner.uniform_experience_replay()
# save values
self.last_phi = phi_prime if len(phi_prime) else None
self.last_action = action
self.last_mu = mu
self.last_observation = observation
# timestep logging
total_time = time.time() - start_time
max_time = max(max_time, total_time)
time_step += 1
avg_time += (total_time - avg_time) / time_step
time_msg = "Current timestep took {:.4f} sec.".format(total_time)
rospy.loginfo(time_msg)
if total_time > self.time_scale:
if self.control_gvf is not None:
if not self.control_gvf.learner.episode_finished_last_step:
rospy.logerr("Timestep took too long!")
else:
rospy.logerr("Timestep took too long!")
# sleep until next time step
self.r.sleep()
if self.COLLECT_DATA_FLAG:
self.history.close()
[docs]def start_learning_foreground(time_scale,
GVFs,
topics,
policy,
stats,
control_gvf=None,
cumulant_counter=None,
reset_episode=None,
custom_stats=None):
"""Function to call with multiprocessing or multithreading.
"""
try:
foreground = LearningForeground(time_scale,
GVFs,
topics,
policy,
stats,
control_gvf,
cumulant_counter,
reset_episode)
foreground.run()
except rospy.ROSInterruptException as detail:
rospy.loginfo("Handling: {}".format(detail))