frcrawler-scripts/scripts/label-cluster.py
Gaël Berthaud-Müller 23e47ec6e7 add license mention
2024-03-06 14:51:03 +01:00

94 lines
2.8 KiB
Python

#
# SPDX-FileCopyrightText: 2023 Afnic
#
# SPDX-License-Identifier: GPL-3.0-or-later
#
import os
import uuid
import logging
from string import Template
import click
from dotenv import load_dotenv
from clickhouse_driver import Client
load_dotenv(os.getenv('FRCRAWLER_SCRIPT_ENV_FILE', 'crawler.env'))
FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'
logging.basicConfig(format=FORMAT, level=logging.INFO)
CH_DB_URL = os.getenv('CH_DB_URL', 'clickhouse://test:test@localhost:9001/test')
CH_CLUSTER = os.getenv('CH_CLUSTER', 'dev_cluster')
@click.group()
def cli():
pass
@cli.command()
def list_labels():
client = Client.from_url(CH_DB_URL)
result = client.execute('select label_id, label_name from clustering_labels order by label_name')
click.echo('Label ID' + ' ' * 28 + ' Label Name' )
for row in result:
click.echo(str(row[0]) + ' ' + row[1])
@cli.command()
@click.option('--label-name', required=True)
def create_label(label_name):
client = Client.from_url(CH_DB_URL)
logging.info('Inserting new label')
result = client.execute('insert into clustering_labels values', ((uuid.uuid4(), label_name),))
logging.info('Done')
@cli.command()
@click.option('--cluster-id', required=True, type=click.UUID)
@click.option('--label-name', required=True)
def label_cluster(cluster_id, label_name):
client = Client.from_url(CH_DB_URL)
logging.info('Fetching label id from label name')
result = client.execute(
'select label_id from clustering_labels where label_name = %(label_name)s',
{'label_name': label_name}
)
if not len(result):
logging.fatal('Label `%s` does not exist', label_name)
exit(1)
label_id = result[0][0]
logging.info(f'Inserting new hints from cluster {cluster_id}, this can take some time...')
client.execute("""
insert into clustering_label_hints
select
generateUUIDv4(),
%(label_id)s,
lzjd,
ssdeep,
from (
select batch_id, job_id from clustering_results
where cluster_label = %(cluster_id)s
order by rand() limit 100
) clustering_results
join (
select batch_id, job_id, lzjd, ssdeep from
clustering_hashes_data
) hashes using(batch_id, job_id)
""", { 'label_id': label_id, 'cluster_id': cluster_id })
logging.info(f'Changing cluster label from {cluster_id} to {label_id}')
client.execute("""
alter table clustering_results_data on cluster %(ch_cluster)s
update cluster_label = %(label_id)s
where cluster_label = %(cluster_id)s
""", { 'ch_cluster': CH_CLUSTER, 'label_id': label_id, 'cluster_id': cluster_id })
logging.info('Done')
if __name__ == '__main__':
cli()