Source code for vaex.graphql

import graphene
import graphql.execution

import vaex
import pandas as pd


[docs]class DataFrameAccessorGraphQL(object): """Exposes a GraphQL layer to a DataFrame See `the GraphQL example <http://docs.vaex.io/en/latest/example_graphql.html>`_ for more usage. The easiest way to learn to use the GraphQL language/vaex interface is to launch a server, and play with the GraphiQL graphical interface, its autocomplete, and the schema explorer. We try to stay close to the Hasura API: https://docs.hasura.io/1.0/graphql/manual/api-reference/graphql-api/query.html """
[docs] def __init__(self, df): self.df = df
[docs] def query(self, name='df'): """Creates a graphene query object exposing this DataFrame named `name`""" return create_query({name: self.df})
[docs] def schema(self, name='df', auto_camelcase=False, **kwargs): """Creates a graphene schema for this DataFrame""" return graphene.Schema(query=self.query(name), auto_camelcase=auto_camelcase, **kwargs)
[docs] def execute(self, *args, **kwargs): """Creates a schema, and execute the query (first argument) """ return self.schema().execute(*args, **kwargs)
[docs] def serve(self, port=9001, address='', name='df', verbose=True): """Serve the DataFrame via a http server""" from .tornado import Application schema = self.schema(name=name) app = Application(schema) app.listen(port, address) if not address: address = 'localhost' if verbose: print('serving at: http://{address}:{port}/graphql'.format(**locals()))
@pd.api.extensions.register_dataframe_accessor("graphql") class DataFrameAccessorGraphQLPandas(DataFrameAccessorGraphQL): def __init__(self, df): super(DataFrameAccessorGraphQLPandas, self).__init__(vaex.from_pandas(df)) def map_to_field(df, name): dtype = df[name].data_type() if vaex.array_types.is_string_type(dtype): return graphene.String elif dtype.kind == "f": return graphene.Float elif dtype.kind == "i": return graphene.Int elif dtype.kind == "b": return graphene.Boolean elif dtype.kind == "M": return graphene.DateTime elif dtype.kind == "m": return graphene.DateTime # TODO, not sure how we're gonna deal with timedelta else: raise ValueError('dtype not supported: %r' % dtype) class Compare(graphene.InputObjectType): def filter(self, df, name): expression = vaex.expression.Expression(df, '(1==1)') if self._eq: expression = expression & (df[name] == self._eq) if self._neq: expression = expression & (df[name] != self._neq) return expression class NumberCompare(Compare): def filter(self, df, name): expression = super(NumberCompare, self).filter(df, name) if self._gt is not None: expression = expression & (df[name] > self._gt) if self._lt is not None: expression = expression & (df[name] < self._lt) if self._gte is not None: expression = expression & (df[name] >= self._gte) if self._lte is not None: expression = expression & (df[name] <= self._lte) return expression class IntCompare(NumberCompare): _eq = graphene.Field(graphene.Int) _neq = graphene.Field(graphene.Int) _gt = graphene.Field(graphene.Int) _lt = graphene.Field(graphene.Int) _gte = graphene.Field(graphene.Int) _lte = graphene.Field(graphene.Int) class FloatCompare(NumberCompare): _eq = graphene.Field(graphene.Float) _neq = graphene.Field(graphene.Float) _gt = graphene.Field(graphene.Float) _lt = graphene.Field(graphene.Float) _gte = graphene.Field(graphene.Float) _lte = graphene.Field(graphene.Float) class BooleanCompare(NumberCompare): _eq = graphene.Field(graphene.Boolean) _neq = graphene.Field(graphene.Boolean) class StringCompare(Compare): _eq = graphene.Field(graphene.String) _neq = graphene.Field(graphene.String) class DateTimeCompare(Compare): _eq = graphene.Field(graphene.String) _neq = graphene.Field(graphene.String) def map_to_comparison(df, name): dtype = df[name].data_type() if vaex.array_types.is_string_type(dtype): return StringCompare elif dtype.kind == "f": return FloatCompare elif dtype.kind == "i": return IntCompare elif dtype.kind == "b": return BooleanCompare elif dtype.kind == "M": return DateTimeCompare elif dtype.kind == "m": return DateTimeCompare # TODO, not sure how we're gonna deal with timedelta else: raise ValueError('dtype not supported: %r' % dtype) def create_aggregation_on_field(df, groupby): postfix = "_".join(groupby) if postfix: postfix = "_" + postfix class AggregationOnFieldBase(graphene.ObjectType): def __init__(self, df, agg, groupby): self.df = df self.agg = agg self.groupby = groupby class Meta: def default_resolver(name, __, obj, info): if obj.groupby: groupby = obj.df.groupby(obj.groupby) agg = getattr(vaex.agg, obj.agg)(name) dfg = groupby.agg({'agg': agg}) agg_values = dfg['agg'] return agg_values.tolist() return getattr(obj.df, obj.agg)(name) if groupby: attrs = {name: graphene.List(map_to_field(df, name)) for name in df.get_column_names(alias=False)} else: attrs = {name: map_to_field(df, name)() for name in df.get_column_names(alias=False)} attrs['Meta'] = Meta DataFrameAggregationOnField = type("AggregationOnField"+postfix, (AggregationOnFieldBase, ), attrs) return DataFrameAggregationOnField def create_groupby(df, groupby): postfix = "_".join(groupby) if postfix: postfix = "_" + postfix class GroupByBase(graphene.ObjectType): count = graphene.List(graphene.Int) keys = graphene.List(graphene.Int) def __init__(self, df, by=None): self.df = df self.by = [] if by is None else by self._groupby = None # cached lazy object super(GroupByBase, self).__init__() @property def groupby(self): if self._groupby is None: self._groupby = self.df.groupby(self.by) return self._groupby def resolve_count(self, info): dfg = self.groupby.agg('count') return dfg['count'].tolist() def resolve_keys(self, info): return self.groupby.coords1d[-1] def field_groupby(name): Aggregate = create_aggregate(df, groupby + [name]) def resolver(*args, **kwargs): return Aggregate(df, groupby + [name]) return graphene.Field(Aggregate, resolver=resolver) attrs = {name: field_groupby(name) for name in df.get_column_names(alias=False)} GroupBy = type("GroupBy"+postfix, (GroupByBase, ), attrs) return GroupBy def create_row(df): class RowBase(graphene.ObjectType): pass attrs = {name: map_to_field(df, name)() for name in df.get_column_names(alias=False)} Row = type("Row", (RowBase, ), attrs) return Row def create_aggregate(df, groupby=None): postfix = "_".join(groupby) if postfix: postfix = "_" + postfix if groupby is None: groupby = [] AggregationOnField = create_aggregation_on_field(df, groupby) if len(groupby): CountType = graphene.List(graphene.Int) else: CountType = graphene.Int() Row = create_row(df) class AggregationBase(graphene.ObjectType): count = CountType min = graphene.Field(AggregationOnField) max = graphene.Field(AggregationOnField) mean = graphene.Field(AggregationOnField) if not groupby: row = graphene.List(Row, limit=graphene.Int(default_value=100), offset=graphene.Int()) def __init__(self, df, by=None): self.df = df self.by = [] if by is None else by assert self.by == groupby self._groupby = None # cached lazy object super(AggregationBase, self).__init__() if not groupby: def resolve_row(self, info, limit, offset=None): fields = [field for field in info.field_asts if field.name.value == info.field_name] subfields = fields[0].selection_set.selections names = [subfield.name.value for subfield in subfields] if offset is None: offset = 0 N = min(len(self.df)-offset, limit) df_pagination = self.df[offset:offset+N] matrix = [df_pagination[name].tolist() for name in names] rows = [] for i in range(len(df_pagination)): row = Row() for j, name in enumerate(names): setattr(row, name, matrix[j][i]) rows.append(row) return rows @property def groupby_object(self): if self._groupby is None: self._groupby = self.df.groupby(self.by) return self._groupby def resolve_count(self, info): if self.by: groupby = self.groupby_object dfg = self.groupby_object.agg('count') counts = dfg['count'] return counts.tolist() return len(self.df) def resolve_min(self, info): return AggregationOnField(self.df, 'min', self.by) def resolve_max(self, info): return AggregationOnField(self.df, 'max', self.by) def resolve_mean(self, info): return AggregationOnField(self.df, 'mean', self.by) attrs = {} if len(groupby) < 2: GroupBy = create_groupby(df, groupby) def resolver(*args, **kwargs): return GroupBy(df, groupby) attrs["groupby"] = graphene.Field(GroupBy, resolver=resolver, description="GroupBy a field for %s" % '...') Aggregation = type("Aggregation"+postfix, (AggregationBase, ), attrs) return Aggregation # similar to https://docs.hasura.io/1.0/graphql/manual/api-reference/graphql-api/query.html#whereexp def create_boolexp(df): column_names = df.get_column_names(alias=False) class BoolExpBase(graphene.InputObjectType): _and = graphene.List(lambda: BoolExp) _or = graphene.List(lambda: BoolExp) _not = graphene.Field(lambda: BoolExp) def filter(self, df): expression = vaex.expression.Expression(df, '(1==1)') if self._and: for value in self._and: expression = expression & value.filter(df) if self._or: or_expression = self._or[0].filter(df) for value in self._or[1:]: or_expression = or_expression | value.filter(df) expression = expression & or_expression if self._not: expression = expression & ~self._not.filter(df) for name in column_names: value = getattr(self, name) if value is not None: expression = expression & value.filter(df, name) return expression attrs = {} for name in column_names: attrs[name] = graphene.Field(map_to_comparison(df, name)) BoolExp = type("BoolExp", (BoolExpBase, ), attrs) return BoolExp def create_query(dfs): class QueryBase(graphene.ObjectType): hello = graphene.String(description='A typical hello world') fields = {} for name, df in dfs.items(): columns = df.get_column_names(alias=False) columns = [k for k in columns if df.is_string(k) or df[k].dtype.kind != 'O'] df = df[columns] Aggregate = create_aggregate(df, []) def closure(df=df, name=name, Aggregate=Aggregate): def resolve(parent, info, where=None): if where is not None: expression = where.filter(df) if expression.expression: return Aggregate(df=df[expression]) return Aggregate(df=df) return resolve Where = create_boolexp(df) fields[name] = graphene.Field(Aggregate, resolver=closure(), description="Aggregations for %s" % name, where=graphene.Argument(Where)) return type("Query", (QueryBase, ), fields)