import lost
import json
from lost.db import model
from datetime import datetime
import pandas as pd
import numpy as np
__author__ = "Jonas Jaeger"
[docs]class LabelTree(object):
'''A class that represants a LabelTree.
Args:
dbm (:class:`lost.db.access.DBMan`): Database manager object.
root_id (int): label_leaf_id of the root Leaf.
root_leaf (:class:`lost.db.model.LabelLeaf`): Root leaf of the tree.
name (str): Name of a label tree.
logger (logger): A logger.
group_id (int): Id of the group where the LabelTree belongs to.
'''
def __init__(self, dbm, root_id=None, root_leaf=None, name=None, logger=None, group_id=None):
self.dbm = dbm # type: lost.db.access.DBMan
self.root = None # type: lost.db.model.LabelLeaf
self.tree = {}
if logger is None:
import logging
self.logger = logging
else:
self.logger = logger
if root_leaf is not None:
self.root = root_leaf
self.__collect_tree(self.root, self.tree)
elif root_id is not None:
self.root = self.dbm.get_label_leaf(root_id)
self.__collect_tree(self.root, self.tree)
elif name is not None:
if group_id is None:
root_list = self.dbm.get_all_label_trees(global_only=True)
else:
root_list = self.dbm.get_all_label_trees(group_id=group_id, add_global=True)
for leaf in root_list:
print(leaf.name)
root = next(filter(lambda x: x.name==name, root_list), None)
if root is None:
raise Exception('LabelTree with name "{}" not found in database!'.format(name))
else:
self.root = root
self.__collect_tree(self.root, self.tree)
def __collect_tree(self, label_leaf, leaf_map):
'''Collect all LabelLeafs from Tree or Subtree
Args:
label_leaf (:class:`lost.db.model.LabelLeaf`): The leaf to start leaf collection.
leaf_map (dict): Dictionary that maps leaf ids to LabelLeaf objects
{leaf_id : LabelLeaf}
'''
leaf_map[label_leaf.idx] = label_leaf
for ll in label_leaf.label_leaves:
self.__collect_tree(ll, leaf_map)
[docs] def delete_subtree(self, leaf):
'''Recursive delete all leafs in subtree starting with leaf
Args:
leaf (:class:`lost.db.model.LabelLeaf`): Delete all childs
of this leaf. The leaf itself stays.
'''
for ll in leaf.label_leaves:
self.delete_subtree(ll)
self.logger.info('Deleting label leaf: {}'.format(ll.name))
self.dbm.delete(ll)
[docs] def delete_tree(self):
'''Delete whole tree from system'''
self.delete_subtree(self.root)
self.dbm.delete(self.root)
self.dbm.commit()
[docs] def create_root(self, name, external_id=None):
'''Create the root of a label tree.
Args:
name (str): Name of the root leaf.
external_id (str): Some id of an external label system.
Retruns:
:class:`lost.db.model.LabelLeaf` or None:
The created root leaf or None if a root leaf with same
name is already present in database.
'''
root_leafs = self.dbm.get_all_label_trees(global_only=True)
if root_leafs is not None:
for leaf in root_leafs:
if name == leaf.name:
return None
self.root = model.LabelLeaf(name=name,
external_id=external_id, is_root=True)
self.dbm.add(self.root)
self.dbm.commit()
self.tree[self.root.idx] = self.root
self.logger.info('Created root leaf: {}'.format(name))
return self.root
[docs] def create_child(self, parent_id, name, external_id=None):
'''Create a new leaf in label tree.
Args:
parent_id (int): Id of the parend leaf.
name (str): Name of the leaf e.g the class name.
external_id (str): Some id of an external label system.
Retruns:
:class:`lost.db.model.LabelLeaf`: The the created child leaf.
'''
leaf = model.LabelLeaf(name=name,
external_id=external_id, parent_leaf_id=parent_id)
self.dbm.add(leaf)
self.dbm.commit()
self.tree[leaf.idx] = leaf
self.logger.info('Created child leaf: {}'.format(name))
return leaf
[docs] def get_child_vec(self, parent_id, columns='idx'):
'''Get a vector of child labels.
Args:
parent_id (int): Id of the parent leaf.
columns (str or list of str): Can be any attribute of :class:`lost.db.model.LabelLeaf`
for example 'idx', 'external_idx', 'name' or a list of these e.g.
['name', 'idx']
Example:
>>> label_tree.get_child_vec(1, columns='idx')
[2, 3, 4]
>>> label_tree.get_child_vec(1, columns=['idx', 'name'])
[
[2, 'cow'],
[3, 'horse'],
[4, 'person']
]
Returns:
list in the requested columns:
'''
parent = self.tree[parent_id] # type: lost.db.model.LabelLeaf
df_list = []
for ll in parent.label_leaves:
df_list.append(ll.to_df()[columns])
df = pd.concat(df_list)
return df.values.tolist()
[docs] def to_df(self):
'''Transform this LabelTree to a pandas DataFrame.
Returns:
pandas.DataFrame
'''
df_list = []
for leaf_id, leaf in self.tree.items():
df_list.append(leaf.to_df())
df = pd.concat(df_list)
return df.reset_index().drop(columns=['index'])
# def to_list(self):
# leaves = list()
# for leaf_id, leaf in self.tree.items():
# leaves.append(leaf.to_dict())
# return leaves
def __collect_dict_tree(self, label_leaf, t_dict):
t_dict['children'] = []
for ll in label_leaf.label_leaves:
ll_dict = ll.to_dict()
t_dict['children'].append(ll_dict)
self.__collect_dict_tree(ll, ll_dict)
def to_hierarchical_dict(self):
my_dict = self.root.to_dict()
self.__collect_dict_tree(self.root, my_dict)
return my_dict
def _df_row_to_leaf(self, row, leaf):
'''Transfrom LabelLeaf in row style to a LabelLeaf object.
Args:
row (pandas.Series): A LabelLeaf in row style.
Returns:
:class:`lost.db.model.LabelLeaf`:
The transformed row.
'''
try:
leaf.abbreviation = row['abbreviation']
self.logger.info('\tabbreviation: {}'.format(leaf.abbreviation))
except KeyError:
self.logger.info('\tNo abbreviation provided.')
try:
leaf.description = row['description']
self.logger.info('\tdescription: {}'.format(leaf.description))
except KeyError:
self.logger.info('\tNo description provided.')
try:
leaf.timestamp = row['timestamp']
self.logger.info('\ttimestamp: {}'.format(leaf.timestamp))
except KeyError:
self.logger.info('\tNo timestamp provided.')
try:
if not np.isnan(row['external_id']):
leaf.external_id = row['external_id']
self.logger.info('\texternal_id: {}'.format(leaf.external_id))
except KeyError:
self.logger.info('\tNo external_id provided.')
try:
leaf.is_deleted = row['is_deleted']
self.logger.info('\tis_deleted: {}'.format(leaf.is_deleted))
except KeyError:
self.logger.info('\tNo is_deleted provided.')
try:
leaf.color = row['color']
self.logger.info('\tcolor: {}'.format(leaf.color))
except KeyError:
self.logger.info('\tNo color provided.')
def __create_childs_from_df(self, child_dict, parent, parent_row):
'''Create child leafs from a df.
Args:
child_dict (dict): A dictionary that maps parent_ids from DataFrame
to child rows from DataFrame.
parent (:class:`lost.db.model.LabelLeaf`): A parent LabelLeaf
that was already imported.
parent_row (pandas.Series): A row from the DataFrame to import.
'''
if parent_row['idx'] not in child_dict:
return
for child_row in child_dict[parent_row['idx']]:
child = self.create_child(parent.idx,
child_row['name'])
self._df_row_to_leaf(child_row, child)
self.__create_childs_from_df(child_dict, child, child_row)
[docs] def import_df(self, df):
'''Import LabelTree from DataFrame
Args:
df (pandas.DataFrame): LabelTree in DataFrame style.
Retruns:
:class:`lost.db.model.LabelLeaf` or None:
The created root leaf or None if a root leaf with same
name is already present in database.
'''
df = df.where((pd.notnull(df)), None)
root = df[df['parent_leaf_id'].isnull()]
no_root = df[~df['parent_leaf_id'].isnull()]
childs = {}
if len(root) != 1:
raise ValueError('''Can not import. There needs
to be exactly one root leaf for that tree!
Found: \n{}'''.format(root))
else:
try:
root_leaf = self.create_root(root['name'].values[0])
if root_leaf is None:
return None #A tree with the same name already exists.
self._df_row_to_leaf(root.loc[0], root_leaf)
#Create child dict
for index, row in no_root.iterrows():
if not row['parent_leaf_id'] in childs:
childs[row['parent_leaf_id']] = []
childs[row['parent_leaf_id']].append(row)
self.__create_childs_from_df(childs, root_leaf, root.loc[0])
self.dbm.commit()
return root_leaf
except KeyError:
self.logger.error('''At least the following columns
need to be provided: *idx*, *name*, *parent_leaf_id*''')
raise