Source code for btgym.datafeed.casual

import copy
import random
import math
import datetime
from logbook import WARNING

from .base import BTgymBaseData
from .derivative import BTgymEpisode, BTgymDataTrial,  BTgymRandomDataDomain


[docs]class BTgymCasualTrial(BTgymDataTrial): """ Intermediate-level data class. Implements conception of `Trial` object. Supports exact data train/test separation by means of `global_time` Do not use directly. """ trial_params = dict( nested_class_ref=BTgymEpisode, ) def __init__(self, name='TimeTrial', **kwargs): """ Args: filename: not used; sampling_params: dict, sample retrieving options, see base class description for details; task: int, optional; parsing_params: csv parsing options, see base class description for details; log_level: int, optional, logbook.level; _config_stack: dict, holding configuration for nested child samples; """ super(BTgymCasualTrial, self).__init__(name=name, **kwargs) # self.log.warning('self.frozen_time_split: {}'.format(self.frozen_time_split))
[docs] def set_global_timestamp(self, timestamp): """ Performs validity checks and sets current global_time. Args: timestamp: POSIX timestamp Returns: """ if self.data is not None: if self.frozen_split_timestamp is not None: self.global_timestamp = self.frozen_split_timestamp else: if self.metadata['type']: if timestamp is not None: assert timestamp < self.final_timestamp, \ 'global time passed <{}> is out of upper bound <{}> for provided data.'. \ format( datetime.datetime.fromtimestamp(timestamp), datetime.datetime.fromtimestamp(self.final_timestamp) ) if timestamp < self.start_timestamp: if self.global_timestamp == 0: self.global_timestamp = self.start_timestamp else: if timestamp > self.global_timestamp: self.global_timestamp = timestamp else: if self.global_timestamp == 0: self.global_timestamp = self.start_timestamp else: self.global_timestamp = self.start_timestamp
[docs] def get_global_index(self): """ Returns: data row corresponded to current global_time """ if self.is_ready: return self.data.index.get_loc( datetime.datetime.fromtimestamp(self.global_timestamp), method='backfill' ) else: return 0
[docs] def get_intervals(self): """ Estimates exact sampling intervals such as test episode starts as close to current global time point as data consistency allows but no earlier; Returns: dict of train and test sampling intervals for current global_time point """ if self.is_ready: if self.metadata['type']: # Intervals for target trial: current_index = self.get_global_index() self.log.debug( 'current_index: {}, total_num_records: {}, sample_num_records: {}'.format( current_index, self.total_num_records, self.sample_num_records ) ) assert 0 <= current_index <= self.total_num_records - self.sample_num_records,\ 'global_time: {} outside data interval: {} - {}, considering sample duration: {}'.format( self.data.index[current_index], self.data.index[0], self.data.index[-1], self.max_sample_len_delta ) train_interval = [0, current_index] test_interval = [current_index + 1, self.total_num_records - 1] else: # Intervals for source trial: train_interval = [0, self.train_num_records - 1] test_interval = [self.train_num_records, self.total_num_records - 1] # TODO: ?! self.log.debug( 'train_interval: {}, datetimes: {} - {}'. format( train_interval, self.data.index[train_interval[0]], self.data.index[train_interval[-1]], ) ) self.log.debug( 'test_interval: {}, datetimes: {} - {}'. format( test_interval, self.data.index[test_interval[0]], self.data.index[test_interval[-1]], ) ) else: train_interval = None test_interval = None return train_interval, test_interval
[docs] def sample( self, get_new=True, sample_type=0, timestamp=None, align_left=True, b_alpha=1.0, b_beta=1.0, **kwargs ): """ Samples continuous subset of data. Args: get_new (bool): sample new (True) or reuse (False) last made sample; sample_type (int or bool): 0 (train) or 1 (test) - get sample from train or test data subsets respectively. timestamp: POSIX timestamp. align_left: bool, if True: set test interval as close to current timepoint as possible. b_alpha (float): beta-distribution sampling alpha > 0, valid for train episodes. b_beta (float): beta-distribution sampling beta > 0, valid for train episodes. """ try: assert self.is_ready except AssertionError: self.log.exception( 'Sampling attempt: data not ready. Hint: forgot to call data.reset()?' ) raise AssertionError try: assert sample_type in [0, 1] except AssertionError: self.log.exception( 'Sampling attempt: expected sample type be in {}, got: {}'.\ format([0, 1], sample_type) ) raise AssertionError # Set actual time: if timestamp is not None: self.set_global_timestamp(timestamp) if 'interval' not in kwargs.keys(): train_interval, test_interval = self.get_intervals() else: train_interval = test_interval = kwargs.pop('interval') if self.sample_instance is None or get_new: if sample_type == 0: # Get beta_distributed sample in train interval: self.sample_instance = self._sample_interval( train_interval, b_alpha=b_alpha, b_beta=b_beta, name='train_' + self.sample_name, **kwargs ) else: # If parent is target - get left-aligned (i.e. as close as possible to current global_time) # sample in test interval; else (parenet is source) - uniformly sample from test interval: if self.metadata['parent_sample_type']: align = align_left else: align = False self.sample_instance = self._sample_aligned_interval( test_interval, align_left=align, b_alpha=1, b_beta=1, name='test_' + self.sample_name, **kwargs ) self.sample_instance.metadata['type'] = sample_type self.sample_instance.metadata['sample_num'] = self.sample_num self.sample_instance.metadata['parent_sample_num'] = copy.deepcopy(self.metadata['sample_num']) self.sample_instance.metadata['parent_sample_type'] = copy.deepcopy(self.metadata['type']) self.sample_num += 1 else: # Do nothing: self.log.debug('Reusing sample, id: {}'.format(self.sample_instance.filename)) return self.sample_instance
[docs]class BTgymCasualDataDomain(BTgymRandomDataDomain): """ Imitates online data stream by implementing conception of sliding `current time point` and enabling sampling control according to it. Objective is to enable proper train/evaluation/test data split and prevent data leakage by allowing training on known, past data only and testing on unknown, future data, providing realistic training cycle. Source trials set is defined as all trials starting somewhere in past and ending no later than current time point, and target trials set as set of trials such as: trial test period starts somewhere in the past and ends at current time point and trial test period starts from now on for all time points in available dataset range. Sampling control is defined by: - `current time point` is set arbitrary and is stateful in sense it can be only increased (no backward time); - source trials can be sampled from past (known) data multiply times; - target trial can only be sampled once according to current time point or later (unknown data); - as any sampled target trial is being evaluated by outer algorithm, current time should be incremented either by providing 'timestamp' arg. to sample() method or calling set_global_timestamp() method, to match last evaluated record (marking all evaluated data as already known and making it available for training); """ trial_class_ref = BTgymCasualTrial episode_class_ref = BTgymEpisode def __init__( self, filename, trial_params, episode_params, frozen_time_split=None, name='TimeDataDomain', data_names=('default_asset',), **kwargs): """ Args: filename: Str or list of str, file_names containing CSV historic data; parsing_params: csv parsing options, see base class description for details; trial_params: dict, describes trial parameters, should contain keys: {sample_duration, time_gap, start_00, start_weekdays, test_period, expanding}; episode_params: dict, describes episode parameters, should contain keys: {sample_duration, time_gap, start_00, start_weekdays}; name: str, optional task: int, optional log_level: int, logbook.level """ self.train_range_row = 0 self.test_range_row = 0 self.test_range_row = 0 self.test_mean_row = 0 self.global_step = 0 self.total_samples = -1 self.sample_num = -1 self.sample_stride = -1 # if frozen_time_split is not None: # self.frozen_time_split = datetime.datetime(**frozen_time_split) # # else: # self.frozen_time_split = None # # self.frozen_split_timestamp = None kwargs.update({'target_period': episode_params['sample_duration']}) trial_params['start_00'] = False trial_params['frozen_time_split'] = frozen_time_split super(BTgymCasualDataDomain, self).__init__( filename=filename, trial_params=trial_params, episode_params=episode_params, use_target_backshift=False, name=name, data_names=data_names, frozen_time_split=frozen_time_split, **kwargs ) # self.log.warning('2: self.frozen_time_split: {}'.format(self.frozen_time_split)) self.log.debug('trial_class_ref: {}'.format(self.trial_class_ref)) self.log.debug('episode_class_ref: {}'.format(self.episode_class_ref)) self.log.debug('sampling_params: {}'.format(self.sampling_params)) self.log.debug('nested_params: {}'.format(self.nested_params))
[docs] def set_global_timestamp(self, timestamp): """ Performs validity checks and sets current global_time. Args: timestamp: POSIX timestamp Returns: """ if self.data is not None: if self.frozen_split_timestamp is not None: self.global_timestamp = self.frozen_split_timestamp else: if timestamp is not None: assert timestamp < self.final_timestamp, \ 'global time passed <{}> is out of upper bound <{}> for provided data.'. \ format( datetime.datetime.fromtimestamp(timestamp), datetime.datetime.fromtimestamp(self.final_timestamp) ) if timestamp < self.start_timestamp: if self.global_timestamp == 0: self.global_timestamp = self.start_timestamp else: if timestamp > self.global_timestamp: self.global_timestamp = timestamp else: if self.global_timestamp == 0: self.global_timestamp = self.start_timestamp
[docs] def get_global_index(self): """ Returns: data row corresponded to current global_time """ if self.is_ready: return self.data.index.get_loc( datetime.datetime.fromtimestamp(self.global_timestamp), method='backfill' ) else: return 0
[docs] def get_intervals(self): """ Estimates exact sampling intervals such as train period of target trial overlaps by known up to date data Returns: dict of train and test sampling intervals for current global_time point """ if self.is_ready: current_index = self.get_global_index() assert current_index >= self.train_num_records assert current_index + self.test_num_records <= self.total_num_records, 'End of data!' self.log.debug( 'current_index: {}, total_num_records: {}, sample_num_records: {}'.format( current_index, self.total_num_records, self.sample_num_records ) ) if self.expanding: train_interval = [0, current_index] else: train_interval = [current_index - self.sample_num_records, current_index] test_interval = [current_index - self.train_num_records, current_index + self.test_num_records] self.log.debug( 'train_interval: {}, datetimes: {} - {}'. format( train_interval, self.data.index[train_interval[0]], self.data.index[train_interval[-1]], ) ) self.log.debug( 'test_interval: {}, datetimes: {} - {}'. format( test_interval, self.data.index[test_interval[0]], self.data.index[test_interval[-1]], ) ) else: train_interval = None test_interval = None return train_interval, test_interval
def _reset(self, data_filename=None, timestamp=None, **kwargs): self.read_csv(data_filename) # Maximum data time gap allowed within sample as pydatetimedelta obj: self.max_time_gap = datetime.timedelta(**self.time_gap) # Max. gap number of records: self.max_gap_num_records = int(self.max_time_gap.total_seconds() / (60 * self.timeframe)) # ... maximum episode time duration: self.max_sample_len_delta = datetime.timedelta(**self.sample_duration) # Maximum possible number of data records (rows) within episode: self.sample_num_records = int(self.max_sample_len_delta.total_seconds() / (60 * self.timeframe)) self.log.debug('sample_num_records: {}'.format(self.sample_num_records)) self.log.debug('sliding_test_period: {}'.format(self.test_period)) # Train/test timedeltas: self.test_range_delta = datetime.timedelta(**self.test_period) self.train_range_delta = datetime.timedelta(**self.sample_duration) - datetime.timedelta(**self.test_period) self.test_num_records = round(self.test_range_delta.total_seconds() / (60 * self.timeframe)) self.train_num_records = self.sample_num_records - self.test_num_records self.log.debug('test_num_records: {}'.format(self.test_num_records)) self.log.debug('train_num_records: {}'.format(self.train_num_records)) self.start_timestamp = self.data.index[self.sample_num_records].timestamp() self.final_timestamp = self.data.index[-self.test_num_records].timestamp() if self.frozen_time_split is not None: frozen_index = self.data.index.get_loc(self.frozen_time_split, method='ffill') self.frozen_split_timestamp = self.data.index[frozen_index].timestamp() self.set_global_timestamp(self.frozen_split_timestamp) else: self.frozen_split_timestamp = None self.set_global_timestamp(timestamp) current_index = self.get_global_index() try: assert self.train_num_records >= self.test_num_records except AssertionError: self.log.exception( 'Train subset should contain at least one episode, got: train_set size: {} rows, episode_size: {} rows'. format(self.train_num_records, self.test_num_records) ) raise AssertionError self.sample_num = 0 self.is_ready = True
[docs] def sample(self, get_new=True, sample_type=0, timestamp=None, b_alpha=1.0, b_beta=1.0, **kwargs): """ Samples from sequence of `Trials`. Args: get_new (bool): sample new (True) or reuse (False) last made sample; n/a for target trials sample_type (int or bool): 0 (train) or 1 (test) - get sample from source or target data subsets respectively; timestamp: POSIX timestamp indicating current global time of training loop b_alpha (float): beta-distribution sampling alpha > 0, valid for train episodes. b_beta (float): beta-distribution sampling beta > 0, valid for train episodes. Returns: Trial as `BTgymBaseDataTrial` instance; None, if trial's sequence is exhausted (global time is up). """ self.set_global_timestamp(timestamp) if 'interval' not in kwargs.keys(): train_interval, test_interval = self.get_intervals() else: train_interval = test_interval = kwargs.pop('interval') if get_new or self.sample_instance is None: if sample_type: self.sample_instance = self._sample_interval( interval=test_interval, b_alpha=b_alpha, b_beta=b_beta, name='target_trial_', **kwargs ) if self.sample_instance is None: # Exhausted: return False else: self.sample_instance = self._sample_interval( interval=train_interval, b_alpha=b_alpha, b_beta=b_beta, name='source_trial_', **kwargs ) if self.sample_instance is None: # Exhausted: return False self.log.debug( 'sampled new trial <{}> with metadata: {}'. format(self.sample_instance.filename, self.sample_instance.metadata) ) else: self.log.debug( 'reused trial <{}> with metadata: {}'. format(self.sample_instance.filename, self.sample_instance.metadata) ) self.sample_instance.metadata['type'] = sample_type self.sample_instance.metadata['sample_num'] = self.sample_num self.sample_instance.metadata['parent_sample_num'] = copy.deepcopy(self.metadata['sample_num']) self.sample_instance.metadata['parent_sample_type'] = copy.deepcopy(self.metadata['type']) return self.sample_instance