Trainer
Train
class Trainer(token_handler: TokenHandler, task: str | Task | None = None, yaml_path: str | None = None)
Bases: NetsPressoBase
set_dataset_config(name: str, root_path: str, train_image: str = 'images/train', train_label: str = 'labels/train', valid_image: str = 'images/valid', valid_label: str = 'labels/valid', test_image: str = 'images/valid', test_label: str = 'labels/valid', id_mapping: List[str] | Dict[str, str] | str | None = None)
Set the dataset configuration for the Trainer.
- Parameters:
- name (str) – The name of dataset.
- root_path (str) – Root directory of dataset.
- train_image (str , optional) – The directory for training images. Should be relative path to root directory. Defaults to "images/train".
- train_label (str , optional) – The directory for training labels. Should be relative path to root directory. Defaults to "labels/train".
- valid_image (str , optional) – The directory for validation images. Should be relative path to root directory. Defaults to "images/val".
- valid_label (str , optional) – The directory for validation labels. Should be relative path to root directory. Defaults to "labels/val".
- id_mapping (Union *[*List *[*str ] , Dict *[*str , str ] ] , optional) – ID mapping for the dataset. Defaults to None.
set_model_config(model_name: str, img_size: int, use_pretrained: bool = True, load_head: bool = False, path: str | None = None, fx_model_path: str | None = None, optimizer_path: str | None = None)
Set the model configuration for the Trainer.
- Parameters:
- model_name (str) – Name of the model.
- img_size (int) – Image size for the model.
- use_pretrained (bool , optional) – Whether to use a pre-trained model. Defaults to True.
- load_head (bool , optional) – Whether to load the model head. Defaults to False.
- path (str , optional) – Path to the model. Defaults to None.
- fx_model_path (str , optional) – Path to the FX model. Defaults to None.
- optimizer_path (str , optional) – Path to the optimizer. Defaults to None.
- Raises:
ValueError – If the specified model is not supported for the current task.
set_fx_model(fx_model_path: str)
Set the FX model path for retraining.
- Parameters:
fx_model_path (str) – The path to the FX model. - Raises:
ValueError – If the model is not set. Please use 'set_model_config' for model setup.
set_training_config(optimizer, scheduler, epochs: int = 3, batch_size: int = 8)
Set the training configuration.
- Parameters:
- optimizer – The configuration of optimizer.
- scheduler – The configuration of learning rate scheduler.
- epochs (int , optional) – The total number of epoch for training the model. Defaults to 3.
- batch_size (int , optional) – The number of samples in single batch input. Defaults to 8.
set_augmentation_config(train_transforms: List | None = None, inference_transforms: List | None = None)
Set the augmentation configuration for training.
- Parameters:
- train_transforms (List , optional) – List of transforms for training. Defaults to None.
- inference_transforms (List , optional) – List of transforms for inference. Defaults to None.
set_logging_config(project_id: str | None = None, output_dir: str = './outputs', tensorboard: bool = True, csv: bool = False, image: bool = True, stdout: bool = True, save_optimizer_state: bool = True, validation_epoch: int = 10, save_checkpoint_epoch: int | None = None)
Set the logging configuration.
- Parameters:
- project_id (str , optional) – Project name to save the experiment. If None, it is set as {task}_{model} (e.g. segmentation_segformer).
- output_dir (str , optional) – Root directory for saving the experiment. Defaults to "./outputs".
- tensorboard (bool , optional) – Whether to use the tensorboard. Defaults to True.
- csv (bool , optional) – Whether to save the result in csv format. Defaults to False.
- image (bool , optional) – Whether to save the validation results. It is ignored if the task is classification. Defaults to True.
- stdout (bool , optional) – Whether to log the standard output. Defaults to True.
- save_optimizer_state (bool , optional) – Whether to save optimizer state with model checkpoint to resume training. Defaults to True.
- validation_epoch (int , optional) – Validation frequency in total training process. Defaults to 10.
- save_checkpoint_epoch (int , optional) – Checkpoint saving frequency in total training process. Defaults to None.
set_environment_config(seed: int = 1, num_workers: int = 4)
Set the environment configuration.
- Parameters:
- seed (int , optional) – Random seed. Defaults to 1.
- num_workers (int , optional) – The number of multi-processing workers to be used by the data loader. Defaults to 4.
train(gpus: str, project_name: str, output_dir: str | None = './outputs') → TrainerMetadata
Train the model with the specified configuration.
- Parameters:
- gpus (str) – GPU ids to use, separated by commas.
- project_name (str) – Project name to save the experiment.
- Returns:
A dictionary containing information about the training. - Return type:
Dict
Example
Train
from netspresso import NetsPresso
from netspresso.enums import Task
from netspresso.trainer.augmentations import Resize
from netspresso.trainer.optimizers import AdamW
from netspresso.trainer.schedulers import CosineAnnealingWarmRestartsWithCustomWarmUp
netspresso = NetsPresso(email="YOUR_EMAIL", password="YOUR_PASSWORD")
# 1. Declare trainer
trainer = netspresso.trainer(task=Task.OBJECT_DETECTION)
# 2. Set config for training
# 2-1. Data
trainer.set_dataset_config(
name="traffic_sign_config_example",
root_path="/root/traffic-sign",
train_image="images/train",
train_label="labels/train",
valid_image="images/valid",
valid_label="labels/valid",
id_mapping=["prohibitory", "danger", "mandatory", "other"],
)
# 2-2. Model
print(trainer.available_models) # ['EfficientFormer', 'YOLOX-S', ...]
trainer.set_model_config(model_name="YOLOX-S", img_size=512)
# 2-3. Augmentation
trainer.set_augmentation_config(
train_transforms=[Resize()],
inference_transforms=[Resize()],
)
# 2-4. Training
optimizer = AdamW(lr=6e-3)
scheduler = CosineAnnealingWarmRestartsWithCustomWarmUp(warmup_epochs=10)
trainer.set_training_config(
epochs=40,
batch_size=16,
optimizer=optimizer,
scheduler=scheduler,
)
# 3. Train
training_result = trainer.train(gpus="0, 1", project_name="project_sample")
Retrain
from netspresso import NetsPresso
from netspresso.trainer.optimizers import AdamW
netspresso = NetsPresso(email="YOUR_EMAIL", password="YOUR_PASSWORD")
# 1. Declare trainer
trainer = netspresso.trainer(yaml_path="./temp/hparams.yaml")
# 2. Set config for retraining
# 2-1. FX Model
trainer.set_fx_model(fx_model_path="./temp/FX_MODEL_PATH.pt")
# 2-2. Training
trainer.set_training_config(
epochs=30,
batch_size=16,
optimizer=AdamW(lr=6e-3),
)
# 3. Train
trainer.train(gpus="0, 1", project_name="project_retrain_sample")
Updated about 1 month ago