forked from anenriquez/mrta_stn
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #13 from anenriquez/develop
Update release version
- Loading branch information
Showing
24 changed files
with
1,034 additions
and
350 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
from stn.stn import STN | ||
from stn.pstn.pstn import PSTN | ||
from stn.stnu.stnu import STNU | ||
from stn.methods.srea import srea | ||
from stn.methods.fpc import get_minimal_network | ||
from stn.methods.dsc_lp import DSC_LP | ||
|
||
|
||
class STNFactory(object): | ||
|
||
def __init__(self): | ||
self._stns = {} | ||
|
||
def register_stn(self, solver_name, stn): | ||
""" Registers an stn type based and the solver that uses it | ||
Saves the stn in a dictionary of stns | ||
key - name of the solver that uses the stn | ||
value - stn class | ||
:param solver_name: solver name | ||
:param stn: stn class | ||
""" | ||
self._stns[solver_name] = stn | ||
|
||
def get_stn(self, solver_name): | ||
""" Returns an stn based on a solver name | ||
:param solver_name: solver name | ||
:return: stn class | ||
""" | ||
stn = self._stns.get(solver_name) | ||
if not stn: | ||
raise ValueError(solver_name) | ||
return stn() | ||
|
||
|
||
class STPSolverFactory(object): | ||
|
||
def __init__(self): | ||
self._solvers = {} | ||
|
||
def register_solver(self, solver_name, solver): | ||
""" Registers stp problem solvers | ||
Saves the solver in a dictionary of solvers | ||
key - solver name | ||
value - class that implements the solver | ||
:param solver_name: solver name | ||
:param solver: solver class | ||
""" | ||
self._solvers[solver_name] = solver | ||
|
||
def get_solver(self, solver_name): | ||
""" Returns the class that implements the solver | ||
:param solver_name: solver name | ||
:return: class that implements the solver | ||
""" | ||
solver = self._solvers.get(solver_name) | ||
if not solver: | ||
raise ValueError(solver_name) | ||
|
||
return solver() | ||
|
||
|
||
class StaticRobustExecution(object): | ||
|
||
def __init__(self): | ||
self.compute_dispatchable_graph = self.srea_algorithm | ||
|
||
@staticmethod | ||
def srea_algorithm(stn): | ||
""" Computes the dispatchable graph of an stn using the | ||
srea algorithm | ||
:param stn: stn (object) | ||
""" | ||
result = srea(stn, debug=True) | ||
if result is None: | ||
return | ||
risk_metric, dispatchable_graph = result | ||
|
||
return risk_metric, dispatchable_graph | ||
|
||
|
||
class DegreeStongControllability(object): | ||
|
||
def __init__(self): | ||
self.compute_dispatchable_graph = self.dsc_lp_algorithm | ||
|
||
@staticmethod | ||
def dsc_lp_algorithm(stn): | ||
""" Computes the dispatchable graph of an stn using the | ||
degree of strong controllability lp solver | ||
:param stn: stn (object) | ||
""" | ||
dsc_lp = DSC_LP(stn) | ||
status, bounds, epsilons = dsc_lp.original_lp() | ||
|
||
if epsilons is None: | ||
return | ||
original_intervals, shrinked_intervals = dsc_lp.new_interval(epsilons) | ||
|
||
dsc = dsc_lp.compute_dsc(original_intervals, shrinked_intervals) | ||
|
||
stnu = dsc_lp.get_stnu(bounds) | ||
|
||
# Returns a schedule because it is an offline approach | ||
schedule = dsc_lp.get_schedule(bounds) | ||
|
||
# A strongly controllable STNU has a DSC of 1, i.e., a DSC value of 1 is better. We take | ||
# 1 − DC to be the risk metric, so that small values are preferable | ||
risk_metric = 1 - dsc | ||
|
||
return risk_metric, schedule | ||
|
||
|
||
class FullPathConsistency(object): | ||
|
||
def __init__(self): | ||
self.compute_dispatchable_graph = self.fpc_algorithm | ||
|
||
@staticmethod | ||
def fpc_algorithm(stn): | ||
""" Computes the dispatchable graph of an stn using | ||
full path consistency | ||
:param stn: stn (object) | ||
""" | ||
dispatchable_graph = get_minimal_network(stn) | ||
if dispatchable_graph is None: | ||
return | ||
risk_metric = 1 | ||
return risk_metric, dispatchable_graph | ||
|
||
|
||
stn_factory = STNFactory() | ||
stn_factory.register_stn('fpc', STN) | ||
stn_factory.register_stn('srea', PSTN) | ||
stn_factory.register_stn('dsc_lp', STNU) | ||
|
||
stp_solver_factory = STPSolverFactory() | ||
stp_solver_factory.register_solver('fpc', FullPathConsistency) | ||
stp_solver_factory.register_solver('srea', StaticRobustExecution) | ||
stp_solver_factory.register_solver('dsc_lp', DegreeStongControllability) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,45 +1,46 @@ | ||
from stn.utils.uuid import generate_uuid | ||
|
||
|
||
class Node(object): | ||
"""Represents a timepoint in the STN """ | ||
|
||
def __init__(self, task_id='', pose='', type='zero_timepoint'): | ||
def __init__(self, task_id, pose, node_type): | ||
# id of the task represented by this node | ||
self.task_id = task_id | ||
# Pose in the map where the node has to be executed | ||
self.pose = pose | ||
# The node can be of type zero_timepoint, navigation, start or finish | ||
self.type = type | ||
# The node can be of node_type zero_timepoint, navigation, start or finish | ||
self.node_type = node_type | ||
|
||
def __str__(self): | ||
to_print = "" | ||
to_print += "node {} {}".format(self.task_id, self.type) | ||
to_print += "node {} {}".format(self.task_id, self.node_type) | ||
return to_print | ||
|
||
def __repr__(self): | ||
return str(self.to_dict()) | ||
|
||
def __hash__(self): | ||
return hash((self.task_id, self.pose, self.type)) | ||
return hash((self.task_id, self.pose, self.node_type)) | ||
|
||
def __eq__(self, other): | ||
if other is None: | ||
return False | ||
return (self.task_id == other.task_id and | ||
self.pose == other.pose and | ||
self.type == other.type) | ||
self.node_type == other.node_type) | ||
|
||
def to_dict(self): | ||
node_dict = dict() | ||
node_dict['task_id'] = self.task_id | ||
node_dict['pose'] = self.pose | ||
node_dict['type'] = self.type | ||
node_dict['node_type'] = self.node_type | ||
return node_dict | ||
|
||
@staticmethod | ||
def from_dict(node_dict): | ||
node = Node() | ||
node.task_id = node_dict['task_id'] | ||
node.pose = node_dict['pose'] | ||
node.type = node_dict['type'] | ||
task_id = node_dict['task_id'] | ||
pose = node_dict['pose'] | ||
node_type = node_dict['node_type'] | ||
node = Node(task_id, pose, node_type) | ||
return node |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.