Source code for vaex.jupyter.model

import asyncio
import enum
import typing

import traitlets
import numpy as np
import xarray

import vaex
import vaex.jupyter
from .decorators import signature_has_traits
from .traitlets import Expression
import logging
from .vendor import contextlib


logger = logging.getLogger('vaex.jupyter.model')


class _HasState(traitlets.HasTraits):
    _debug = traitlets.Bool(False)

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # if self._debug:
        self._allow_state_change = asyncio.Semaphore(0)
        self._current_status_wait_future = None

    # def _cancel_computation(self):
    #     pass

    def _debug_wait_for_status(self, status):
        self._current_status_wait_future = asyncio.Future()

        def on_status_change(change):
            if change['new'] == status:
                self._current_status_wait_future.set_result(None)
                self.unobserve(on_status_change, 'status')
            else:
                self._current_status_wait_future.set_exception(RuntimeError(f'Did not expect change to status: {change.new}'))
        self.observe(on_status_change, 'status')
        return self._current_status_wait_future

    def _error(self, e):
        if self._debug:
            if self._current_status_wait_future:
                self._current_status_wait_future.set_exception(e)
        if isinstance(e, asyncio.CancelledError):
            print("cancelled")
        else:
            print("Error: ", e, type(e))
            try:
                vaex.utils.print_exception_trace(e)
            except Exception as e:
                print(e)

    async def _allow_state_change_to(self, status):
        # print(f"allowing state change from {self.status} to {status}")
        self._allow_state_change.release()
        result = await self._debug_wait_for_status(status)
        return result

    async def _allow_state_change_cancel(self, previous=False):
        # print(f"allowing cancel")
        self._allow_state_change.release()
        await self.computation._await_last_call(previous)

    @contextlib.asynccontextmanager
    async def _state_change_to(self, new_status):
        current_status = self.status
        logger.debug(f'Current state is {self.status}')
        if self._debug:
            # print(f"waiting to allow a state change - from {current_status} to {new_status}")
            await self._allow_state_change.acquire()
        yield
        if current_status == self.status:
            self.status = new_status
        else:
            pass
            raise asyncio.CancelledError(f"Status expected to be {current_status}, but is {self.status}")
        logger.debug(f'State change {type(self)} from {self.status} to {new_status}')
        self.status = new_status


[docs]@signature_has_traits class Axis(_HasState):
[docs] class Status(enum.Enum): """ State transitions NO_LIMITS -> STAGED_CALCULATING_LIMITS -> CALCULATING_LIMITS -> CALCULATED_LIMITS -> READY when expression changes: STAGED_CALCULATING_LIMITS: calculation.cancel() ->NO_LIMITS CALCULATING_LIMITS: calculation.cancel() ->NO_LIMITS when min/max changes: STAGED_CALCULATING_LIMITS: calculation.cancel() ->NO_LIMITS CALCULATING_LIMITS: calculation.cancel() ->NO_LIMITS """ NO_LIMITS = 1 STAGED_CALCULATING_LIMITS = 2 CALCULATING_LIMITS = 3 CALCULATED_LIMITS = 4 READY = 5 EXCEPTION = 6 ABORTED = 7
status = traitlets.UseEnum(Status, Status.NO_LIMITS) df = traitlets.Instance(vaex.dataframe.DataFrame) expression = Expression() slice = traitlets.CInt(None, allow_none=True) min = traitlets.CFloat(None, allow_none=True) max = traitlets.CFloat(None, allow_none=True) bin_centers = traitlets.Any() shape = traitlets.CInt(None, allow_none=True) shape_default = traitlets.CInt(64) _calculation = traitlets.Any(None, allow_none=True) exception = traitlets.Any(None, allow_none=True) _status_change_delay = traitlets.Float(0) def __init__(self, **kwargs): super().__init__(**kwargs) if self.min is not None and self.max is not None: self.status = Axis.Status.READY self._calculate_centers() else: self.computation() self.observe(self.on_change_expression, 'expression') self.observe(self.on_change_shape, 'shape') self.observe(self.on_change_shape_default, 'shape_default') def __repr__(self): def myrepr(value, key): if isinstance(value, vaex.expression.Expression): return str(value) return value args = ', '.join('{}={}'.format(key, myrepr(getattr(self, key), key)) for key in self.traits().keys() if key != 'df' and not key.startswith('_')) return '{}({})'.format(self.__class__.__name__, args) @property def has_missing_limit(self): # return not self.df.is_category(self.expression) and (self.min is None or self.max is None) return (self.min is None or self.max is None)
[docs] def on_change_expression(self, change): self.min = None self.max = None self.status = Axis.Status.NO_LIMITS if self._calculation is not None: self._cancel_computation() self.computation()
[docs] def on_change_shape(self, change): if self.min is not None and self.max is not None: self._calculate_centers()
[docs] def on_change_shape_default(self, change): if self.min is not None and self.max is not None: self._calculate_centers()
def _cancel_computation(self): self._continue_calculation = False @traitlets.observe('min', 'max') def on_change_limits(self, change): if self.min is not None and self.max is not None: self._calculate_centers() if self.status == Axis.Status.NO_LIMITS: if self.min is not None and self.max is not None: self.status = Axis.Status.READY elif self.status == Axis.Status.READY: if self.min is None or self.max is None: self.status = Axis.Status.NO_LIMITS else: # in this case, grids may want to be computed # this happens when a user change min/max pass else: if self._calculation is not None: self._cancel_computation() if self.min is not None and self.max is not None: self.status = Axis.Status.READY else: self.status = Axis.Status.NO_LIMITS else: # in this case we've set min/max after the calculation assert self.min is not None or self.max is not None
[docs] @vaex.jupyter.debounced(delay_seconds=0.1, reentrant=False, on_error=_HasState._error) async def computation(self): categorical = self.df.is_category(self.expression) if categorical: N = self.df.category_count(self.expression) self.min, self.max = -0.5, N-0.5 # centers = np.arange(N) # self.shape = N self._calculate_centers() self.status = Axis.Status.READY else: try: self._continue_calculation = True self._calculation = self.df.minmax(self.expression, delay=True, progress=self._progress) self.df.widget.execute_debounced() # keep a nearly reference to this, since awaits (which trigger the execution, AND reset of this future) may change it this execute_prehook_future = self.df.widget.execute_debounced.pre_hook_future async with self._state_change_to(Axis.Status.STAGED_CALCULATING_LIMITS): pass async with self._state_change_to(Axis.Status.CALCULATING_LIMITS): await execute_prehook_future async with self._state_change_to(Axis.Status.CALCULATED_LIMITS): vmin, vmax = await self._calculation # indicate we are done with the calculation self._calculation = None if not self._continue_calculation: assert self.status == Axis.Status.READY async with self._state_change_to(Axis.Status.READY): self.min, self.max = vmin, vmax self._calculate_centers() except vaex.execution.UserAbort: # probably means expression or min/max changed, we don't have to take action pass except asyncio.CancelledError: pass
def _progress(self, f): # we use the progres callback to cancel as calculation return self._continue_calculation def _calculate_centers(self): categorical = self.df.is_category(self.expression) if categorical: N = self.df.category_count(self.expression) centers = np.arange(N) self.shape = N else: centers = self.df.bin_centers(self.expression, [self.min, self.max], shape=self.shape or self.shape_default) self.bin_centers = centers
[docs]@signature_has_traits class DataArray(_HasState):
[docs] class Status(enum.Enum): MISSING_LIMITS = 1 STAGED_CALCULATING_LIMITS = 3 CALCULATING_LIMITS = 4 CALCULATED_LIMITS = 5 NEEDS_CALCULATING_GRID = 6 STAGED_CALCULATING_GRID = 7 CALCULATING_GRID = 8 CALCULATED_GRID = 9 READY = 10 EXCEPTION = 11
status = traitlets.UseEnum(Status, Status.MISSING_LIMITS) status_text = traitlets.Unicode('Initializing') exception = traitlets.Any(None) df = traitlets.Instance(vaex.dataframe.DataFrame) axes = traitlets.List(traitlets.Instance(Axis), []) grid = traitlets.Instance(xarray.DataArray, allow_none=True) grid_sliced = traitlets.Instance(xarray.DataArray, allow_none=True) shape = traitlets.CInt(64) selection = traitlets.Any(None) def __init__(self, **kwargs): super(DataArray, self).__init__(**kwargs) self.signal_slice = vaex.events.Signal() self.signal_regrid = vaex.events.Signal() self.signal_grid_progress = vaex.events.Signal() self.observe(lambda change: self.signal_regrid.emit(), 'selection') self._on_axis_status_change() # keep a set of axis that need new limits self._dirty_axes = set() for axis in self.axes: assert axis.df is self.df, "axes should have the same dataframe" traitlets.link((self, 'shape'), (axis, 'shape_default')) axis.observe(self._on_axis_status_change, 'status') axis.observe(lambda _: self.signal_slice.emit(self), ['slice']) def on_change_min_max(change): if change.owner.status == Axis.Status.READY: # this indicates a user changed the min/max self.status = DataArray.Status.NEEDS_CALCULATING_GRID axis.observe(on_change_min_max, ['min', 'max']) self._on_axis_status_change() self.df.signal_selection_changed.connect(self._on_change_selection) def _on_change_selection(self, df, name): # TODO: check if the selection applies to us def _translate_selection(selection): if selection in [None, False]: return None if selection is True: return 'default' else: return selection if name == _translate_selection(self.selection) or (isinstance(self.selection, (list, tuple)) and name in [_translate_selection(k) for k in self.selection]): self.status = DataArray.Status.NEEDS_CALCULATING_GRID async def _allow_state_change_cancel(self): self._allow_state_change.release() def _on_axis_status_change(self, change=None): missing_limits = [axis for axis in self.axes if axis.status == Axis.Status.NO_LIMITS] staged_calculating_limits = [axis for axis in self.axes if axis.status == Axis.Status.STAGED_CALCULATING_LIMITS] calculating_limits = [axis for axis in self.axes if axis.status == Axis.Status.CALCULATING_LIMITS] calculated_limits = [axis for axis in self.axes if axis.status == Axis.Status.CALCULATED_LIMITS] def names(axes): return ", ".join([str(axis.expression) for axis in axes]) if staged_calculating_limits: self.status = DataArray.Status.STAGED_CALCULATING_LIMITS self.status_text = 'Staged limit computation for {}'.format(names(staged_calculating_limits)) elif missing_limits: self.status = DataArray.Status.MISSING_LIMITS self.status_text = 'Missing limits for {}'.format(names(missing_limits)) elif calculating_limits: self.status = DataArray.Status.CALCULATING_LIMITS self.status_text = 'Computing limits for {}'.format(names(calculating_limits)) elif calculated_limits: self.status = DataArray.Status.CALCULATED_LIMITS self.status_text = 'Computed limits for {}'.format(names(calculating_limits)) else: assert all([axis.status == Axis.Status.READY for axis in self.axes]) self.status = DataArray.Status.NEEDS_CALCULATING_GRID @traitlets.observe('status') def _on_change_status(self, change): if self.status == DataArray.Status.EXCEPTION: self.status_text = f'Exception: {self.exception}' elif self.status == DataArray.Status.NEEDS_CALCULATING_GRID: self.status_text = 'Grid needs to be calculated' elif self.status == DataArray.Status.STAGED_CALCULATING_GRID: self.status_text = 'Staged grid computation' elif self.status == DataArray.Status.CALCULATING_GRID: self.status_text = 'Calculating grid' elif self.status == DataArray.Status.CALCULATED_GRID: self.status_text = 'Calculated grid' elif self.status == DataArray.Status.READY: self.status_text = 'Ready' # GridCalculator can change the status # self._update_grid() # self.status_text = 'Computing limits for {}'.format(names(missing_limits)) @property def has_missing_limits(self): return any([axis.has_missing_limit for axis in self.axes])
[docs] def on_progress_grid(self, f): return all(self.signal_grid_progress.emit(f))
[docs]class Histogram(DataArray): x = traitlets.Instance(Axis) # type = traitlets.CaselessStrEnum(['count', 'min', 'max', 'mean'], default_value='count') # groupby = traitlets.Instance(Axis) # groupby_normalize = traitlets.Bool(False, allow_none=True) # grid = traitlets.Any() # grid_sliced = traitlets.Any() def __init__(self, **kwargs): kwargs['axes'] = [kwargs['x']] super().__init__(**kwargs)
[docs]class Heatmap(DataArray): x = traitlets.Instance(Axis) y = traitlets.Instance(Axis) def __init__(self, **kwargs): kwargs['axes'] = [kwargs['x'], kwargs['y']] super().__init__(**kwargs)
[docs]class GridCalculator(_HasState): '''A grid is responsible for scheduling the grid calculations and possible slicing'''
[docs] class Status(enum.Enum): VOID = 1 STAGED_CALCULATION = 3 CALCULATING = 4 READY = 9
status = traitlets.UseEnum(Status, Status.VOID) df = traitlets.Instance(vaex.dataframe.DataFrame) models = traitlets.List(traitlets.Instance(DataArray)) _calculation = traitlets.Any(None, allow_none=True) _debug = traitlets.Bool(False) def __init__(self, df, models): super().__init__(df=df, models=[]) self._callbacks_regrid = [] self._callbacks_slice = [] for model in models: self.model_add(model) self._testing_exeception_regrid = False # used for testing, to throw an exception self._testing_exeception_reslice = False # used for testing, to throw an exception # def model_remove(self, model, regrid=True): # index = self.models.index(model) # del self.models[index] # del self._callbacks_regrid[index] # del self._callbacks_slice[index]
[docs] def model_add(self, model): self.models = self.models + [model] if model.status == DataArray.Status.NEEDS_CALCULATING_GRID: if self._calculation is not None: self._cancel_computation() self.computation() def on_status_changed(change): if change.owner.status == DataArray.Status.NEEDS_CALCULATING_GRID: if self._calculation is not None: self._cancel_computation() self.computation() model.observe(on_status_changed, 'status') # TODO: if we listen to the same axis twice it will trigger twice for axis in model.axes: axis.observe(lambda change: self.reslice(), 'slice') # self._callbacks_regrid.append(model.signal_regrid.connect(self.on_regrid)) # self._callbacks_slice.append(model.signal_slice.connect(self.reslice)) assert model.df == self.df
# @vaex.jupyter.debounced(delay_seconds=0.05, reentrant=False) # def reslice_debounced(self): # self.reslice()
[docs] def reslice(self, source_model=None): if self._testing_exeception_reslice: raise RuntimeError("test:reslice") coords = [] selection_was_list, [selections] = vaex.utils.listify(self.models[0].selection) selections = [k for k in selections if k is None or self.df.has_selection(k)] for model in self.models: subgrid = self.grid if not selection_was_list: subgrid = subgrid[0] subgrid_sliced = self.grid if not selection_was_list: subgrid_sliced = subgrid_sliced[0] axis_index = 1 if selection_was_list else 0 has_slice = False dims = ["selection"] if selection_was_list else [] coords = [selections.copy()] if selection_was_list else [] mins = [] maxs = [] for other_model in self.models: if other_model == model: # simply skip these axes # for expression, shape, limit, slice_index in other_model.bin_parameters(): for axis in other_model.axes: axis_index += 1 dims.append(str(axis.expression)) coords.append(axis.bin_centers) mins.append(axis.min) maxs.append(axis.max) else: # for expression, shape, limit, slice_index in other_model.bin_parameters(): for axis in other_model.axes: if axis.slice is not None: subgrid_sliced = subgrid_sliced.__getitem__(tuple([slice(None)] * axis_index + [axis.slice])).copy() subgrid = np.sum(subgrid, axis=axis_index) has_slice = True else: subgrid_sliced = np.sum(subgrid_sliced, axis=axis_index) subgrid = np.sum(subgrid, axis=axis_index) grid = xarray.DataArray(subgrid, dims=dims, coords=coords) # +1 to skip the selection axis dim_offset = 1 if selection_was_list else 0 for i, (vmin, vmax) in enumerate(zip(mins, maxs)): grid.coords[dims[i+dim_offset]].attrs['min'] = vmin grid.coords[dims[i+dim_offset]].attrs['max'] = vmax model.grid = grid if has_slice: model.grid_sliced = xarray.DataArray(subgrid_sliced) else: model.grid_sliced = None
def _regrid_error(self, e): try: self._error(e) for model in self.models: model._error(e) for model in self.models: model.exception = e model.status = vaex.jupyter.model.DataArray.Status.EXCEPTION except Exception as e2: print(e2)
[docs] def on_regrid(self, ignore=None): self.regrid()
[docs] @vaex.jupyter.debounced(delay_seconds=0.5, reentrant=False, on_error=_regrid_error) async def computation(self): try: logger.debug('Starting grid computation') # vaex.utils.print_stack_trace() if self._testing_exeception_regrid: raise RuntimeError("test:regrid") if not self.models: return binby = [] shapes = [] limits = [] selection = self.models[0].selection selection_was_list, [selections] = vaex.utils.listify(self.models[0].selection) selections = [k for k in selections if k is None or self.df.has_selection(k)] for model in self.models: if model.selection != selection: raise ValueError('Selections for all models should be the same') for axis in model.axes: binby.append(axis.expression) limits.append([axis.min, axis.max]) shapes.append(axis.shape or axis.shape_default) selections = [k for k in selections if k is None or self.df.has_selection(k)] self._continue_calculation = True logger.debug('Setting up grid computation...') self._calculation = self.df.count(binby=binby, shape=shapes, limits=limits, selection=selections, progress=self.progress, delay=True) logger.debug('Setting up grid computation done tasks=%r', self.df.executor.tasks) logger.debug('Schedule debounced execute') self.df.widget.execute_debounced() # keep a nearly reference to this, since awaits (which trigger the execution, AND reset of this future) may change it this execute_prehook_future = self.df.widget.execute_debounced.pre_hook_future async with contextlib.AsyncExitStack() as stack: for model in self.models: await stack.enter_async_context(model._state_change_to(DataArray.Status.STAGED_CALCULATING_GRID)) async with contextlib.AsyncExitStack() as stack: for model in self.models: await stack.enter_async_context(model._state_change_to(DataArray.Status.CALCULATING_GRID)) await execute_prehook_future async with contextlib.AsyncExitStack() as stack: for model in self.models: await stack.enter_async_context(model._state_change_to(DataArray.Status.CALCULATED_GRID)) # first assign to local grid = await self._calculation # indicate we are done with the calculation self._calculation = None # raise asyncio.CancelledError("User abort") async with contextlib.AsyncExitStack() as stack: for model in self.models: await stack.enter_async_context(model._state_change_to(DataArray.Status.READY)) self.grid = grid self.reslice() except vaex.execution.UserAbort: pass # a user changed the limits or expressions except asyncio.CancelledError: pass # cancelled...
def _cancel_computation(self): logger.debug('Cancelling grid computation') self._continue_calculation = False
[docs] def progress(self, f): return self._continue_calculation and all([model.on_progress_grid(f) for model in self.models])