Contrastive Learning of Temporal Distinctiveness for Survival Analysis in Electronic Health Records

Survival analysis plays a crucial role in many healthcare decisions, where the risk prediction for the events of interest can support an informative outlook for a patient's medical journey. Given the existence of data censoring, an effective way of survival analysis is to enforce the pairwise temporal concordance between censored and observed data, aiming to utilize the time interval before censoring as partially observed time-to-event labels for supervised learning. Although existing studies mostly employed ranking methods to pursue an ordering objective, contrastive methods which learn a discriminative embedding by having data contrast against each other, have not been explored thoroughly for survival analysis. Therefore, in this paper, we propose a novel Ontology-aware Temporality-based Contrastive Survival (OTCSurv) analysis framework that utilizes survival durations from both censored and observed data to define temporal distinctiveness and construct negative sample pairs with adjustable hardness for contrastive learning. Specifically, we first use an ontological encoder and a sequential self-attention encoder to represent the longitudinal EHR data with rich contexts. Second, we design a temporal contrastive loss to capture varying survival durations in a supervised setting through a hardness-aware negative sampling mechanism. Last, we incorporate the contrastive task into the time-to-event predictive task with multiple loss components. We conduct extensive experiments using a large EHR dataset to forecast the risk of hospitalized patients who are in danger of developing acute kidney injury (AKI), a critical and urgent medical condition. The effectiveness and explainability of the proposed model are validated through comprehensive quantitative and qualitative studies.


INTRODUCTION
The increasingly abundant electronic health records (EHRs) have provided an unprecedented opportunity to apply predictive analytics to support healthcare decisions [27,36].To achieve the optimal outcomes for a patient's medical journey, an important question faced by healthcare providers is how to precisely anticipate the adversarial events (e.g., kidney injury, heart failure, and stroke), so that these critical incidents can be responded to timely with sufficient clinical attention.Therefore, it is crucial to investigate the application of survival analysis (SA) in longitudinal EHR data, which aims to identify the significant factors that influence the degree of risks and to further forecast the time to events of interest.
For survival analysis, a key challenge is how to deal with the existence of censored data for time-to-event modeling.In the case of censoring, events of interest may not be observed for some patients due to the limited duration of observation or the withdrawal of patients during the study.In order to address this challenge, various traditional survival analysis models have been developed although they suffer from multiple limitations.Parametric survival models [25] assume a specific distribution for the baseline hazard function, such as the exponential, Weibull, or log-normal distribution.However, events in the real world are usually too complex to be captured by such predefined distributions.On the other hand, although the semi-parametric Cox model [5] makes no assumptions about the baseline hazard function, it requires the hazard function to be multiplicatively proportional to the covariates.Moreover, most of these approaches (e.g., Cox-based models) only focus on predicting the relative ordering of survival durations of individuals, overlooking their actual event time.Therefore, the capability of time estimation for future event occurrences is unfortunately compromised.
To overcome the limitations of early studies, deep learning techniques have been increasingly applied to survival predictive tasks [12,22,30] which offer the capacity to capture complex survival patterns without making explicit distributional assumptions.While some studies have explored the enforcement of patient concordance by survival probabilities to accommodate both observed and censored survival data, there exists very limited literature on contrastive learning (CL) methods aimed at learning a discriminative representation of patient records to achieve better predictive performance.In contrastive learning, data are contrasted against each other in self-supervised [3], semi-supervised [39], or supervised [19] settings.Generally, it trains an objective to distinguish the subtle characteristics in data, by maximizing the similarity between positive pairs (instances that belong to the same labels) and minimizing the similarity between negative pairs (instances that have different labels).Although the positive vs. negative labeling strategy for contrasted pairs has been defined by self-augmentation (i.e., whether the pair originates from a single data point) or supervised classes [13] (i.e., whether the pair belongs to the same class), the exploration of contrasting labeling based on temporal distinctiveness (which is based on the time difference between the two survival durations) for survival analysis is still lacking.Furthermore, given that survival duration is a numerical entity, accounting for the hardness defined by the time difference for contrastive labels can help the model learn the survival data with more flexibility.
Another challenge associated with survival analysis in EHR is the possible data insufficiency.Usually, a large variety of medical codes are recorded in a dataset, but many codes may have a relatively small number of occurrences (e.g., rare diseases).As a result, for patients with rare codes or sparse visits, the embedding of their medical history is often sub-optimal.One way to address this issue is to incorporate the domain-specific knowledge inherent in medical ontology into the representation of EHR features [37].Medical ontology is a hierarchical classification structure of medical concepts (e.g., diagnosis, medications, etc.), which can serve as an auxiliary categorization for knowledge representation [4,24,32].For example, GRAM [4] proposes a graph-based attention model that employs the attention mechanism on hierarchical levels of each medical code to learn medically meaningful EHR feature embeddings.With ontological encoding, survival models can better build the association between codes or patients, and transfer the medical knowledge from one sample to another.Therefore, to further improve the quality of patient profiling, the ontology learning of EHR features can be integrated with the contrastive learning core of survival analysis.
In this paper, we introduce an Ontology-aware Temporalitybased Contrastive Survival analysis framework called OTCSurv, which combines the ontology-enhanced EHR data encoder, the contrastive learning of temporal distinctiveness, and the survival probability predictor with multiple loss components, for interpretable, data-efficient, and discriminative survival analysis.Specifically, the main contributions of this study can be summarized in three-fold: • We design a Supervised Weighted Contrastive (SupWCon) Learning loss function that uses survival duration as its pairing criteria which is able to utilize both observed and censored observations.SupWCon considers the hardness of negative pairs based on the survival duration differences to enrich the grain of contrastive learning.
• We used a sequential attention-based ontological encoder to learn medically informed embeddings for sequential hospital visits of patients.Ontology information brings data efficiency to our model by referring to higher-level medical concepts when the observation is sparse.• We optimize survival prediction through multiple loss components, focusing on two key goals: accurately predicting survival duration time and precisely ranking the risks or survival probabilities of patients at each time point.We train our model with a meticulous configuration of SupWCon, accompanied by three more loss functions, guiding the training towards an optimum point satisfying these two goals.
Finally, we evaluate our proposed method and demonstrate the strength of our model on a real-world EHR dataset for Acute Kidney Injury (AKI) by performing baseline comparison, ablation study, and interpretability analysis.

RELATED WORK
In this section, we provide an overview of key contributions and advancements in the survival analysis field, concentrating on relevant methodologies and techniques in the literature.
One of the most widely used statistical methods in survival analysis is the Kaplan-Meier (KM) estimator [16] which is a nonparametric survival analysis method, calculating the survival probability by dividing the number of individuals who have survived up to a given time by the number of patients at risk just before that time.However, KM does not take into account the covariates of patients.Early works in survival analysis primarily revolved around the Cox proportional hazards (CPH) model [6], which assumes a proportional relationship between covariates and the hazard function.Due to the advantages of CPH, such as simplicity and interpretability, many survival analysis models have been proposed based on CPH, such as incorporation of time-varying covariates [23], accounting for competing risks [9], and CoxTime [20] which expands upon Cox model by extending its capabilities beyond the assumption of proportional hazards.
In recent years, there has been an increasing interest in applying machine learning techniques to survival analysis.Random Survival Forests [15], Deep Exponential Families [28,29], and semiparametric Bayesian models based on Gaussian Processes [8], offer flexibility in capturing complex survival patterns and handling non-linear relationships.As for deep learning-based approaches, DeepSurv [17] introduced the application of deep neural networks for survival prediction, capturing complex relationships between covariates and survival outcomes using the Cox partial likelihood loss function.This has opened up many doors for utilizing deep learning in survival analysis, leading to the development of models like DeepHit [22], which is a multitask deep learning model capable of handling competing risks, DRSA [30] and RNN-SURV [12], both of which exploiting a recurrent neural network (RNN) to handle sequential data, and N-MTLR [10] which leverages deep neural networks to replace the linear core of the MTLR [38].
Some more recent state-of-the-art deep learning-based SA models are Dynamic-DeepHit [21], Survtrace [35], and Deep-CSA [13].An extension of DeepHit is Dynamic-DeepHit which instead of the simple neural network, uses a recurrent neural network to dynamically capture longitudinal dependencies in the presence of competing risks.Survtrace proposes a transformer-based SA model that handles competing risks and benefits from a multitask learning framework to learn a strong shared representation.Transformer-Based Deep Survival Analysis [14] tries to make a trade-off between time predictive power and risk ranking power using both the absolute error as well as ranking evaluation metrics.
Generally, the existing architectures suffer from multiple limitations.Some works, such as Cox-based survival models, show suboptimal performances due to certain assumptions for the underlying stochastic process.Violations of these assumptions can lead to incorrect conclusions.Some of the deep learning methods are black boxes and do not offer sufficient interpretability.Also, many works only employ ranking methods to reach survival rate concordance and, to the best of our knowledge, there is no exploration of the use of contrastive methods based on the temporality for healthcare survival analysis.OTCSurv managed to mitigate the aforementioned challenges and support interpretable, data-efficient, and discriminative survival analysis.As demonstrated in the following sections, OTCSurv exhibits an enhanced performance compared to its predecessors.

PROPOSED METHOD
In this section, we first describe the notations and formulate the EHR survival analysis problem.We then present the overview of the model.Last, we introduce each module in detail.

Problem Formulation
Electronic healthcare records (EHRs) usually contain comprehensive information about a patient's medical history.Each patient normally has multiple hospital visits where diagnoses, prescriptions, and procedures are recorded using standardized codes in the hospital's database.EHRs can be exploited in three sets of information to be used in survival analysis: 1) covariates, 2) time to the event, and 3) a label indicating the type of the event (censored/observed).A discrete and finite time window with a maximum length of  max is considered for the time prediction.Therefore, our goal is to predict in which time interval  ∈ {0, . . ., max } the event of interest is most likely to happen, or to determine the probability of survival in each time interval.We show the event label by a binary variable .If the instance is observed  = 1, otherwise (censored)  = 0. We can consider each instance (i.e., patient) as a triple of ( , , ) where  = {  }  =1 is a sequence of covariates showing  visits, and   = { 1 ,  2 , . . .,  | | ,  1 ,  2 , . . .,  | D | } indicating the existence of both binary and continuous features.Binary medical codes are denoted by   , and continuous features such as demographics are denoted by   where | | and |D| represent the sizes.For each medical code   , we extract the set of its ancestor codes (higher level concepts) in the hierarchy of the medical ontology, represented as a directed acyclic graph (DAG).
We denote the probability by , the hazard function by (), the probability density function by  (), and the survival probability by  ().By adding the caret symbol to each notation, we indicate their estimated forms, e.g., Ŝ () is the estimated survival probability.
Task: Given the patient's sequential medical history in terms of longitudinal hospital visits containing medical codes, we aim to build a model to estimate the survival probability of patients in each time interval inside the prediction time window in the future.

Model Overview
In this subsection, we introduce an overview of our proposed OTCSurv model architecture.As shown in Figure 1, the model consists of three main components.The first component is Sequential attention-based ontological encoder that consists of two main blocks.The first one is the ontological encoder which effectively utilizes the inherent valuable information within the medical ontologies to generate informed embedding vectors for medical codes.Next is the sequential attention encoder which consists of three attention-based parts: visit-level attention-pooling, transformer encoder, and instance-level attention-pooling.The visitlevel attention-pooling uses the attention mechanism to reduce the dimension of the visit representations.Then, the output of the visitlevel attention-pooling integrates with demographic information inside a data integration block to produce a representation containing all the patient's information.This representation along with positional encoding of visits is fed to the transformer encoder.The multi-head attention of the transformer will extract the interactions of medical visits to produce a rich representation of a patient that encompasses all the meaningful information.Instance-level attention-pooling is implemented on the output of the transformer encoder, to compress and combine the information of different visits of an instance using the attention mechanism and produce the ultimate instance representation.This ultimate instance representation will go through the second and the third main components of OTCSurv parallelly.One is Contrastive Learning component (the second main component) where a projection head which is a nonlinear transformation, e.g., a simple multilayer perceptron (MLP) with a nonlinear activation function, transfers the ultimate instance representation to a different latent space.This is where our proposed SupWCon loss comes into play, adding temporal distinctive refinements to the ultimate representations.In parallel, the ultimate representation is fed into the Survival Prediction component (the third main component) which consists of a fully connected neural network.This neural network predicts  max number of probabilities, which are the complement of the hazard rates, for each of the  max predefined time intervals.To train this model, combined with SupWCon, three loss functions of Loglikelihood Loss, Pairwise Ranking Loss, Mean Squared Error loss are implemented to guide the model towards an optimum point regarding predictive, discriminating, and ranking ability in survival analysis.

Sequential Attention-based Ontological Encoder
The sequential attention-based ontological encoder is responsible for generating instance representations using attention-based components which are explained hereunder.

Ontological Encoder.
In order to address the challenge of data limitation in the healthcare domain, acquire comprehensive representations of medical codes, and increase predictability, we utilize the attention-based graph representation approach known as GRAM [4].First, an initial embedding vector ℎ  ∈ R   is assigned to each medical code as well as its ancestors (higher lever concepts) in the medical ontology, where   is the code embedding dimension.Then, each code's final representation   ∈ R   is calculated as a convex combination of the initial embeddings of itself and its ancestors using the attention mechanism: where () is the set containing the indices of the code   and its ancestors.   ∈ R + shows the attention weight given to ancestor code embedding   when calculating   , which is the final representation of   .Using a Softmax function,    is formulated as: where  (ℎ  ; ℎ  ) is the concatenation of ℎ  and ℎ  in a childancestor order. (•) is an MLP operator with learnable parameters of   ,  ,   .

Attention-Pooling.
We used two attention-pooling components in our architecture, one after the ontological encoder, which is the visit-level attention-pooling, and one after the transformer encoder, which is the instance-level attention-pooling, to compress the information flow using the attention mechanism.
Assumes that the input of visit-level attention-pooling for a patient  is a tensor   ∈ R  × ×  , where  , ,   are the number of visits, the specified maximum number of possible codes inside each visit, and the dimension for the code embedding 1 , respectively.Using the attention mechanism, we assign a weight to each code in a visit and use those weights to calculate the weighted average of medical code vectors.Thus instead of having a vector of size ( ×  ×   ) for each patient, we reduce its dimension to a vector of size ( ×  ).So, given the -th visit representation   ∈ R  ×  , which is the concatenation of M code embeddings    (1 ≤  ≤ ), we calculate an attention energy    ∈ R for each of  medical code embedding: where   ∈ R  ×1 contains M attention energies for the codes within -th visit, and  (•) is a MLP operater with a ReLU activation function  and learnable parameters of  1 ,  2 ,  1 .Using softmax on attention energies, we calculate attention weights   ∈ R  ×1 : where   is the concatenation of M attention weights    (1 ≤  ≤ ).Finally, we have where   ∈ R   (1 ≤  ≤  ) represents the -th visit of the patient.So, for each patient, we have  ∈ R  ×  as the concatenation of  visit representations   ∈ R   .The output of the visit-level attention block  ∈ R  ×  is concatenated with each patient demographic embedding  ∈ R  ×  (  is the dimension for the demographic feature embedding) to obtain  = Concat(, ) ∈ R  × where  =   +   .
For the instance-level attention-pooling which is implemented on the output of the transformer encoder, we use the same technique described above to reduce the dimensionality.The output of the transformer encoder for a patient is  ∈ R  × , where  is the transformer dimension. is fed to the instance-level attentionpooling, where using the attention mechanism,  attention weights for each of the visit representations are generated.These weights are used to calculate the weighted average of visit representations, thereby reducing the dimension of  , outputting  ∈ R  as the ultimate instance (patient) representation to be used in both the contrastive task and the survival prediction downstream task.

Transformer
Encoder.The encoder of the transformer architecture serves as the primary block for obtaining representations for survival analysis.For each patient, the input to the encoder of the transformer is a sequence of final visits' embeddings.The transformer's multihead-attention mechanism captures complex relationships among different hospital visits of a patient, enabling the model to encode comprehensive information about their dependencies over time.This results in rich representations that can capture survival patterns and time-dependent features.

Temporal Distinctiveness with Supervised Weighted Contrastive Learning
Contrastive learning aims to learn meaningful representations by maximizing agreement between similar examples while minimizing agreement between dissimilar examples.One technique to measure the similarity between two vectors is the cosine similarity, which can be calculated by the dot product of two vectors.It calculates the cosine of the angle between two vectors, representing their similarity by assessing how closely the two vectors align in the vector space.In this study, we formulate a contrastive learning loss function featuring an adaptive temperature parameter, referred to as Supervised Weighted Contrastive (SupWCon) loss.SupWCon is an extended version of the method proposed in [19] and has been tailored for survival analysis, particularly for handling censored data.We formulate where  is the set of indices of all the instances,  () is the set of indices of the instances that make a positive pair with the instance , and () ≡  \{}.The dot (•) operator in the formulation represents the dot product of two vectors. ∈ R + is the constant scalar temperature parameter for positive pairs and   ∈ R + is the adaptive scalar temperature parameter for negative pairs which will be explained shortly.The instance with the index  is called the anchor.
Positive and negative pairs were particularly generated considering both the survival duration times and the labels of instances (observed/censored).For anchor , which is an observed instance, any other observed instance with the survival duration time  (which is the duration from day one to the day before the event time) inside the time window of   −  /2 ≤  <   +  /2 (referred to as positive window) makes a positive pair with the anchor and belongs to  ().Time window length  is a hyperparameter that needs to be tuned with respect to the  max , data distribution, and the nature of the problem.Any observed instance with a survival duration time  outside the positive window ( <   −  /2 or   +  /2 ≤ ) plus any censored instance with the survival duration time (which is the duration from day one to the day of censoring) greater or equal to   +  /2 (  +  /2 ≤ ) makes a negative pair with the anchor.We do not consider censored instances with a censoring time smaller than   + /2 for both positive and negative pair generation because what happened to the patient after censoring is unknown (whether they were diagnosed with AKI or not, and if so when that happened).In fact, if, after censoring, the event (AKI) happens inside the positive window of the anchor, making a negative pair is wrong.Conversely, for patients with a censoring time greater or equal to   +  /2, we are sure that their survival duration is outside of the positive window of the anchor, so they are safe to be considered for negative pair generation.The temperature parameter for positive pairs  is a constant positive scalar for all positive pairs and will be chosen by hyperparameter tuning.However, we adjusted the temperature parameter for each negative pair to encourage our model to regulate the amount of dissimilarity between the representations of negative pairs that exhibit various differences in survival duration.Hence, the model can better capture the distinction for negative pairs of various hardness.For example, if patient  and patient  make a negative pair, the adjusted temperature parameter for this negative pair is calculated as follows: which is the inverse of their difference in survival duration.The more distant their survival duration is, the more SupWCon pulls their representations apart in the latent space.This is the first time in the context of survival analysis, to the best of our knowledge, that contrastive learning is used to make hardness-aware temporal distinctiveness based on the known survival duration of subjects.We used two more tricks that have been established as effective in the literature regarding contrastive learning.First, introducing a learnable nonlinear transformation, such as a simple two-layer fully connected neural net with a nonlinear activation function, between the ultimate instance representation and where the SupWCon loss performs.This trick substantially improves the quality of the ultimate instance representations compared to when the SupWCon performs directly on them [3] 2 .Second, we normalized the vector representations of instances onto the unit sphere ( 2 normalization) prior to using them in SupWCon, which also experimentally proved to be effective [3].
It is noteworthy that the contrastive learning component is only used during training to add hardness-aware distinctive refinements to the ultimate representations and is discarded during inference.

Survival Prediction
For continuous survival models, the hazard function, denoted as (), represents the instantaneous probability of an event occurring at time , given that the individual has survived up to time .However, in the discrete setting, where time is considered as a sequence of distinct points, the hazard function is defined differently.Instead of dealing with infinitesimal intervals, the hazard function represents the conditional probability that the patient dies at a specific time , given he/she was alive before .Given that training data consists of pairs of covariates and time (, ), our goal is to model the distribution of event times.The probability density function  ( |), the survival function  ( |), and the hazard function ( |) respectively are defined as: which represents the probability mass assigned to an event occurring exactly at time ,  ( |) =   ( > ) (10) which gives the probability that an event has not occurred up to and including time , and finally the hazard function formulation, Using the above formulation, we can rewrite the survival function formulation as follows: If we show the complement of hazard function by  ( |) = 1 − ( |), we have: By recursively expanding on equation 13, the survival function can be expressed as: A feed-forward neural network (FFN) is the core of the survival prediction component, which predicts the complement of the hazard function  ( |) for all times up to   .Thus, the output of the survival prediction component for a patient  is a vector of size   as follows: ŷ In continuous-time survival analysis, the mean lifetime of a patient or the expected value of the random variable  , which represents the average time until an event occurs, can be calculated by integrating the survival function over time.Mathematically, the mean lifetime  (also known as the expected lifetime or average survival time) is derived in the following manner: Using the technique of integration by parts, we arrive at, which is the area under the survival curve.In the discrete-time formulation, we can approximate it by the sum of the survival probabilities up to   as follows: We consider μ as our predicted survival time duration.

Loss Functions
In this section, we will expand upon different loss functions implemented to train our model.Besides the SupWCon loss which was explained in 3.4, we have three more losses working in combination with SupWCon.The motivation is to optimize the model with respect to the two important objectives of survival analysis: 1) accuracy in the prediction of survival duration for observed data, and 2) accurately ranking patients (both observed and censored) in terms of their risk and survival rate in different time points.
3.6.1 Loglikelihood Loss.Loglikelihood loss is the main loss used to train the survival task.For observed data points, we minimize the following loss: and for censored data, the loss is defined as follows:

+ 𝐿
Loglikelihood ob (21) where  is either event time or censoring time.In other words, for observed data points, we maximize the summation of the survival probabilities for 1 ≤  <  (since the patient has survived in this time window) and minimize the summation of the survival probabilities for  ≥  (which means there is no survival starting from the occurrence of the event).For censored data points, we only maximize the summation of the survival probabilities for 1 ≤  ≤  .In essence, for observed instances, the survival probabilities of all the time intervals are optimized, whereas for censored data, only the survival probabilities up to  , which is the time of censoring, are optimized.This is because, after censoring time, we do not have any information about the survival of the patients.
3.6.2Pairwise Ranking Loss.We employ a pairwise ranking loss function that incorporates the concept of concordance and is based on the method used in [14].Such ranking losses have been widely used in the literature [21,22] for survival analysis.According to this idea, a patient who experiences an event at time  should have a shorter predicted survival duration time (a higher risk) at time s compared to a patient who survives beyond time s.In other words, we want to penalize the discordant pairs.Let   and   represent the observed event times for patients i and j, and respectively,   <   .The predicted survival durations T and T (obtained from Eq. 18) are considered discordant if T > T .Our aim is to minimize the number of such discordant pairs.For every observed patient  in the training set, we randomly select (with replacement) another patient , ensuring that   <   .we only compare them with one other randomly selected data point since comparing with all the possible data points is too computationally expensive.As   can be subject to censoring, the actual survival duration for patient  cannot be smaller than   .Consequently, the difference between the predicted durations T and T should be at least   −   .Hence, the ranking loss formulation is as follows: This loss ensures that the proposed model performs well at accurate time prediction for observed patients instead of only being able to rank patients in terms of their risk.Therefore, for observed patients, MSE is calculated as follows: where  ob is the number of observed instances,   is the true survival duration, and T is the predicted survival duration obtained as μ from equation 18.

EXPERIMENTS 4.1 Dataset and Preprocessing
We test our model on a real-world EHR dataset acquired from the University of Kansas Medical Center (KUMC) gathered from early 2009 to late 2021 for the purpose of the Acute Kidney Injury (AKI) study.In this dataset, each patient has a history of one year of hospital visits before the final hospital visit, called the onset visit, which was monitored for the occurrence of AKI.Each hospital visit comprises a collection of documented medical codes.Diagnosis codes were recorded using the International Classification of Diseases system in both the ninth and tenth Revisions (ICD-9 & ICD-10).
The prescription codes follow the RxNorm format, which provides standardized names for clinical drugs.After preprocessing, we acquired a dataset with the statistics demonstrated in table 1.Since the dataset is highly imbalanced, we balance the training dataset by duplicating the observed data so we have a 50% censored-50% observed train set.For implementing the GRAM method, we used the hierarchical ontology of ICD-9 3 and the Anatomical Therapeutic Chemical (ATC) classification system respectively for diagnosis codes and prescription codes.

Experimental Setting
Conducting a hyperparameter tuning, we chose 128 as the dimension of code embeddings for the ontological encoder.Two layers of transformer encoder each with two heads of multihead-attention are selected for the main encoder with a hidden dimension of 512, which outputs a representation vector of size 256.The survival prediction part is a three-layered fully connected neural network with hidden dimensions of [256, 128, 9] which outputs 9 probabilities 3 In preprocessing, all the ICD-10 codes were converted to ICD-9.for each instance ( max = 9).Every probability is for a unit of time interval as one day. max was chosen 9 because 96% of hospitalized patients were diagnosed with AKI or discharged (censored) in 9 days.RMSprop optimizer with a learning rate of 1 − 3 and weight decay of 2 − 5 was employed for training the proposed model.For implementing SupWCon, after a thorough hyperparameter search, we chose a time window length of 2 as the positive contrastive pairing criteria.As for the training strategy, We let the model first run for 40 epochs with loglikelihood, ranking, and MSE losses, and then add the SupWCon loss and train for 50 more epochs.This strategy gives us the best performance.As mentioned earlier, the contrastive component is discarded during inference.We released the GitHub implementation code of OTCSurv. 4

Evaluation Metrics
The model was evaluated with two main metrics: the time-dependent discrimination index   [1], and the Mean Absolute Error (MAE).The time-dependent discrimination index   , which is one of the most widely used evaluation metrics in survival analysis, is an extension of Harrell's concordance index (C-index).  , unlike the conventional c-index, assesses the model's discriminatory ability at specific time points, capturing changing predictive performance over time.Also, we used the MAE of the predicted survival duration to express the model's performance in estimating the exact survival duration for observed data.

Results and Discussion
4.4.1 Baselines.We compare the results of our model with various popular baselines, which are introduced below briefly.Table 2 shows the performance of each model on the AKI survival analysis task.It is evident that our proposed model outperforms all of the baselines regarding both evaluation metrics.Also, Figure 2 and Figure 3 exhibit the comparison of the mean survival curves of each model with the Kaplan-Meier curve, which is the survival curve based on  true data.Our proposed model's mean survival curve is the closest to the Kaplan-Meier curve considering all the data as well as only observed data, indicating that our model accurately captures the survival behavior and provides survival predictions that are more consistent with the actual outcomes.
• Nnet-survival [11]: Nnet-survival which is trained with stochastic gradient descent employs parameterization of discrete hazards and optimization of survival likelihood and allows for non-proportional hazards.• N-MTLR [10]: The Neural Multi-Task Logistic Regression uses the Multi-Task Logistic Regression (MTLR) [38] model as its base and a deep learning architecture as its core.
• DeepHit [22]: DeepHit is a deep learning-based survival analysis that uses a multi-task learning framework to simultaneously estimate the survival time and the event type probabilities, thereby handling competing risks.• CoxTime [20]: Cox-Time is an extension of Cox regression that goes beyond the proportional hazards assumption and incorporates the concept of relative risk.• DeepSurv (CoxPH) [17]: DeepSurv, a personalized treatment recommender system, is a Cox proportional hazards deep neural network, modeling interactions between a patient's covariates and treatment effectiveness.

Ablation Study
The ablation study was conducted to determine the contribution of each component in the model to the performance.We experimented with different combinations of loss components and show the results in Table 3. Training the model with  Loglikelihood alone performs relatively poor, particularly in   .Having only ranking loss makes the model much stronger in terms of   , but the accurate time predictive ability of the model is reduced since MAE increases by 0.5 compared to training only with  Loglikelihood .Using  SupWCon along with  Loglikelihood increases the   by 0.0161 and decreases the MAE by 0.4, demonstrating the prominent effectiveness of our SupWCon loss on improving both evaluation metrics.We also tried adding  Ranking to  Loglikelihood which results in a substantial increase in   by 0.026 but an undesired increase in MAE by 0.13.The last two combinations bring the best performances.With  Loglikelihood ,  Ranking , and  SupWCon , we achieve the highest   .With all four loss components, we achieve the best trade-off in the performance with a small compromise on   but an improvement to the lowest MAE.Eventually, we utilized a weighted summation of these four losses as follows: where  1 ,  2 ,  3 ,  4 are hyperparameters.From Table 4, which is the ablation study of the ontological encoder and the attention-pooling blocks, we can realize that they play a significant role in improving the final results.We first, remove the ontological encoder from the architecture, which leads to a drop in the   index and an increase in MAE, indicating the effectiveness of incorporating the knowledge domain from medical ontologies in the overall model performance.The same result happens when we remove both attention-pooling parts, which results in increasing the number of the model's parameters, making the model complex and less generalizable, and also losing the advantage of attention's performance boosting and Interpretability.

INTERPRETABILITY
Our proposed model can be interpreted by analyzing the attention weights learned in each of the model's components.In the ontological encoder, we can find the attention weights assigned to each medical code and its ancestors to realize their importance in generating the medical code embeddings.The weights learned in the visit-level attention-pooling determine the relative significance of each diagnosis or prescription code in calculating the visit-level representations.Also, by examining the attention weights learned in the instance-level attention-pooling, we can infer the relative importance of each of the visits inside the patient's medical history in composing the instance (patient) representations that will be used in the SA downstream task.
To illustrate the interpretability of our model, we select a random patient diagnosed with AKI on the second day of hospitalization from the test set.Extracting the visits' attention weights from the instance-level attention-pooling and codes' attention weights from the visit-level attention-pooling, we plot Figure 4. Therefore, we can realize the most important visits and the medical codes inside each visit for the model's decision-making.In Figure 4, we have five stacked bars representing the five hospital visits of the patient.The height of each stacked bar shows the attention weight assigned to each visit.Each stacked bar associated with a visit has some bars indicating different codes and their attention weights.It is clear that visit 5 has the highest attention weight and consequently is the most important hospital visit for this patient.Among all the codes in this visit, "572", "410", and "287" have the highest attention weights."572" is the ICD-9 code associated with liver abscess and sequelae of chronic liver disease, which can potentially lead to AKI [2,7,26].ICD-9 code "410" is for acute myocardial infarction (AMI), commonly known as a heart attack.AMI also can be closely associated with the onset of AKI which is discussed carefully in the medical literature [31,33].ICD-9 code "287" represents purpura and other hemorrhagic conditions.Some hemorrhagic conditions, including certain types of purpura and other bleeding disorders, can potentially result in acute kidney injury (AKI) as a complication [18,34].Furthermore, in Figure 5, we demonstrate how the ontological encoder learns code representations and refers to higher-level medical concepts when it comes to a rare medical code.Clearly, the "460-519" ICD-9 code which is the most general ancestor of the "514" ICD-9 code, receives the highest attention weight because first, the "514" code is not a frequent code across the train set and second, there are enough samples with the children of "460-519" (as their parent) in the train set.

CONCLUSION
This paper introduces a novel survival model on the basis of longitudinal healthcare data, termed Ontology-aware Temporalitybased Contrastive Survival analysis (OTCSurv), which combines the benefits of a contrastive learning approach adapted for survival analysis as well as attention-based methods.Specifically, we designed a supervised weighted contrastive learning (SupWCon) loss function which is specifically formulated to handle data censoring and improve patients' representations using the time labels as the contrastive pairing criteria.SupWCon regulates the weights (temperature parameters) assigned to each negative pair by considering their differences in survival duration.Also, we used a sequential attention-based ontological encoder, which consists of an ontological encoder block to incorporate domain knowledge through medical ontologies, and a sequential attention encoder to capture temporal dependencies while making the model interpretable.Along with SupWCon, three other losses are employed to guide the training towards two goals of survival analysis which are risk ranking ability and precise time prediction capability.Experimental results, including baseline comparison and ablation study, on a real-world EHR dataset, showcase the superiority of the proposed model compared to existing approaches regarding both mentioned goals.Also, an attention analysis study was conducted to demonstrate the interpretability of the OTCSurv.

Figure 1 :
Figure1: Architecture of the proposed OTCSurv model.There are three main components: 1) Sequential attention-based ontological encoder, which mainly consists of an ontological encoder, two attention-pooling blocks, and a transformer encoder to learn the ultimate instance-level representations of patients.2) Contrastive Learning, which uses an intermediary transformation to transfer the ultimate representations to another latent space where SupWCon is functioning.3) Survival prediction, which is a fully connected neural network and outputs the probabilities necessary for survival analysis calculation.

Figure 2 :
Figure 2: Comparison of the mean survival curves of proposed model and baselines with Kaplan-Meier curve (for all patients).

Figure 3 :
Figure 3: Comparison of the mean survival curves of proposed model and baselines with Kaplan-Meier curve (only for observed patients).

Figure 4 :
Figure 4: Analysis of the attention weights of the visits and the medical codes (ICD-9) inside each visit.The height of each stacked bar is the attention weight of the corresponding visit and each bar inside the stacked bars presents a medical code inside the visit and its attention weight.

Figure 5 :
Figure 5: Attention weights which GRAM assigned to the ICD-9 diagnosis code  1 : 514 and its ancestors ( 2 : 510-519,  3 : 460-519).The size of each node as well as the height of their bar plots show the amount of attention they received.

Table 2 :
Evaluation based on   and MAE for AKI survival prediction

Table 4 :
Ontological encoder and attention-pooling contributions