Source code for surround.stage
from abc import ABC, abstractmethod
[docs]class Stage(ABC):
"""
Base class of all stages in a Surround pipeline.
See the following class for more information:
- :class:`surround.stage.Estimator`
"""
[docs] def dump_output(self, state, config):
"""
Dump the output of the stage after the stage has transformed the data.
.. note:: This is called by :meth:`surround.assembler.Assembler.run` (when dumping output is requested).
:param state: Stores intermediate data from each stage in the pipeline
:type state: Instance or child of the :class:`surround.State` class
:param config: Config of the pipeline
:type config: :class:`surround.config.Config`
"""
[docs] def operate(self, state, config):
"""
Main function to be called in an assembly.
:param state: Contains all pipeline state including input and output data
:param config: Config for the assembly
"""
[docs] def initialise(self, config):
"""
Initialise the stage, this may be loading a model or loading data.
.. note:: This is called by :meth:`surround.assembler.Assembler.init_assembler`.
:param config: Contains the settings for each stage
:type config: :class:`surround.config.Config`
"""
[docs]class Estimator(Stage):
"""
Base class for an estimator in a Surround pipeline. Responsible for performing estimation
or training using the input data.
This stage is executed by :meth:`surround.assembler.Assembler.run`.
Example::
class Predict(Estimator):
def initialise(self, config):
self.model = load_model(os.path.join(config["models_path"], "model.pb"))
def estimate(self, state, config):
state.output_data = run_model(self.model)
def fit(self, state, config):
state.output_data = train_model(self.model)
"""
[docs] @abstractmethod
def estimate(self, state, config):
"""
Process input data and store estimated values.
.. note:: This method is ONLY called by :meth:`surround.assembler.Assembler.run` when
running in predict/batch-predict mode.
:param state: Stores intermediate data from each stage in the pipeline
:type state: Instance or child of the :class:`surround.State` class
:param config: Contains the settings for each stage
:type config: :class:`surround.config.Config`
"""
[docs] def fit(self, state, config):
"""
Train a model using the input data.
.. note:: This method is ONLY called by :meth:`surround.assembler.Assembler.run` when
running in training mode.
:param state: Stores intermediate data from each stage in the pipeline
:type state: Instance or child of the :class:`surround.State` class
:param config: Contains the settings for each stage
:type config: :class:`surround.config.Config`
"""