These functions define various loss functions used internally by `survdnn()` for training deep neural networks on right-censored survival data.
cox_loss(pred, true)cox_l2_loss(pred, true, lambda = 0.001)
aft_loss(pred, true, sigma = 1, aft_loc = 0, eps = 1e-12)
coxtime_loss(pred, true)
A scalar `torch_tensor` representing the loss value.
A torch tensor of model predictions. Its interpretation depends on the loss function:
loss = "cox" or "cox_l2": linear predictors
(log hazard ratios).
loss = "aft": predicted log survival times.
loss = "coxtime": predicted time-dependent risk scores.
A tensor with two columns: observed time and status (1 = event, 0 = censored).
Regularization parameter for `cox_l2_loss` (default: `1e-3`).
Positive numeric scale parameter for the log-normal AFT model (default: `1`). In `survdnn()`, a learnable global scale can be used via `survdnn__aft_lognormal_nll_factory()`.
Numeric scalar location offset for the AFT model on the log-time scale. When non-zero, the model is trained on centered log-times `log(time) - aft_loc` for better numerical stability. Prediction should add this offset back: `mu = mu_resid + aft_loc`.
Small constant for numerical stability (default: `1e-12`).
- **Cox partial likelihood loss** (`cox_loss`): Negative partial log-likelihood used in proportional hazards modeling. - **L2-penalized Cox loss** (`cox_l2_loss`): Adds L2 regularization to the Cox loss. - **Accelerated Failure Time (AFT) loss** (`aft_loss`): Log-normal AFT **censored negative log-likelihood** (uses both events and censored observations). - **CoxTime loss** (`coxtime_loss`): Placeholder (see details). A correct CoxTime loss requires access to the network and the full input tensor.
# Used internally by survdnn()
Run the code above in your browser using DataLab