| import os |
| import shutil |
| import logging |
| import pretty_errors |
| import huggingface_hub |
| from datasets import Dataset, load_dataset, disable_caching |
| import schedule |
| import time |
|
|
| disable_caching() |
|
|
| |
| logger = logging.getLogger("basic_logger") |
| logger.setLevel(logging.INFO) |
|
|
| |
| console_handler = logging.StreamHandler() |
| console_handler.setLevel(logging.INFO) |
| formatter = logging.Formatter( |
| "%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
| ) |
| console_handler.setFormatter(formatter) |
| logger.addHandler(console_handler) |
|
|
| DS_NAME = "amaye15/object-segmentation" |
| DATA_DIR = "data" |
|
|
|
|
| def get_data(): |
| ds = load_dataset( |
| DS_NAME, |
| cache_dir=os.path.join(os.getcwd(), DATA_DIR), |
| streaming=True, |
| download_mode="force_redownload", |
| ) |
| for row in ds["train"]: |
| yield row |
|
|
|
|
| def process_and_push_data(): |
| p = os.path.join(os.getcwd(), DATA_DIR) |
|
|
| if os.path.exists(p): |
| shutil.rmtree(p) |
|
|
| os.mkdir(p) |
|
|
| ds_processed = Dataset.from_generator(get_data) |
| ds_processed.push_to_hub("amaye15/tmp") |
| |
|
|
|
|
| |
| schedule.every(1).minute.do(process_and_push_data) |
|
|
| |
| while True: |
| schedule.run_pending() |
| time.sleep(1) |
|
|