Source code for nirs4all.core.task_detection

"""
Task Type Detection - Standalone utility to avoid circular imports

This module provides task type detection functionality that can be used
by both data and controller modules without creating circular import issues.
"""

import numpy as np
from nirs4all.core.task_type import TaskType


[docs] def detect_task_type(y: np.ndarray, threshold: float = 0.05) -> TaskType: """ Detect task type based on target values. Args: y: Target values array threshold: Threshold for determining if values are continuous (regression) vs discrete (classification). For integer values, if n_unique <= max_classes or n_unique <= len(y) * threshold, it's considered classification. Returns: TaskType: Detected task type """ # Flatten y to handle various shapes y_flat = np.asarray(y).ravel() # Remove NaN values if any y_clean = y_flat[~np.isnan(y_flat)] if len(y_clean) == 0: raise ValueError("Target array contains only NaN values") # Check if all values are integers (potential classification) if np.all(np.equal(np.mod(y_clean, 1), 0)): unique_values = np.unique(y_clean) n_unique = len(unique_values) # Maximum reasonable number of classes for classification max_classes = 100 # Binary classification: exactly 2 unique values if n_unique == 2: return TaskType.BINARY_CLASSIFICATION # Multi-class classification: more than 2 but reasonable number of classes elif n_unique > 2 and n_unique <= max_classes: return TaskType.MULTICLASS_CLASSIFICATION # Too many unique integer values - likely regression with integer targets else: return TaskType.REGRESSION # Check if values are in [0, 1] range (potential binary classification probabilities) if np.all(y_clean >= 0) and np.all(y_clean <= 1): unique_values = np.unique(y_clean) n_unique = len(unique_values) # If mostly 0s and 1s, treat as binary classification if n_unique == 2 and set(unique_values) == {0.0, 1.0}: return TaskType.BINARY_CLASSIFICATION # If few unique values in [0,1], might be classification probabilities elif n_unique <= len(y_clean) * threshold: if n_unique == 2: return TaskType.BINARY_CLASSIFICATION else: return TaskType.MULTICLASS_CLASSIFICATION # Default to regression for continuous values return TaskType.REGRESSION