How to Use
PyNetsPresso
For first-time use or to obtain detailed information about PyNetsPresso, please visit https://github.com/Nota-NetsPresso/PyNetsPresso.
from netspresso.enums import Task
from netspresso.trainer.optimizers import AdamW
from netspresso.trainer.schedulers import CosineAnnealingWarmRestartsWithCustomWarmUp
from netspresso.trainer.augmentations import Resize
# 1. Declare trainer
trainer = netspresso.trainer(task=Task.OBJECT_DETECTION) # IMAGE_CLASSIFICATION, OBJECT_DETECTION, SEMANTIC_SEGMENTATION
# 2. Set config for training
# 2-1. Data
trainer.set_dataset_config(
name="traffic_sign_config_example",
root_path="/root/traffic-sign",
train_image="images/train",
train_label="labels/train",
valid_image="images/valid",
valid_label="labels/valid",
id_mapping=["prohibitory", "danger", "mandatory", "other"],
)
# 2-2. Model
print(trainer.available_models) # ['YOLOX-S', 'YOLOX-M', 'YOLOX-L', 'YOLOX-X']
trainer.set_model_config(model_name="YOLOX-S", img_size=512)
# 2-3. Augmentation
trainer.set_augmentation_config(
train_transforms=[Resize()],
inference_transforms=[Resize()],
)
# 2-4. Training
optimizer = AdamW(lr=6e-3)
scheduler = CosineAnnealingWarmRestartsWithCustomWarmUp(warmup_epochs=10)
trainer.set_training_config(
epochs=40,
batch_size=16,
optimizer=optimizer,
scheduler=scheduler,
)
# 3. Train
training_result = trainer.train(gpus="0, 1", project_name="PROJECT_TRAIN_SAMPLE")
To directly use the source code of NetsPresso Trainer, please visit https://github.com/Nota-NetsPresso/netspresso-trainer.
Updated 2 months ago