Transformers enable accurate prediction of acute and chronic chemical toxicity in aquatic organisms

Environmental hazard assessments are reliant on toxicity data that cover multiple organism groups. Generating experimental toxicity data is, however, resource-intensive and time-consuming. Computational methods are fast and cost-efficient alternatives, but the low accuracy and narrow applicability domains have made their adaptation slow. Here, we present a AI-based model for predicting chemical toxicity. The model uses transformers to capture toxicity-specific features directly from the chemical structures and deep neural networks to predict effect concentrations. The model showed high predictive performance for all tested organism groups—algae, aquatic invertebrates and fish—and has, in comparison to commonly used QSAR methods, a larger applicability domain and a considerably lower error. When the model was trained on data with multiple effect concentrations (EC50/EC10), the performance was further improved. We conclude that deep learning and transformers have the potential to markedly advance computational prediction of chemical toxicity.


Hyperparameter and parameter sweep setup
Initial testing showed that a fully trainable ChemBERTa transformer had the highest accuracy and consequently neither the embedding layer nor any of the encoders were frozen during the rest of the model development.Initial testing also showed that the dropout rate did not notably affect the accuracy and it was therefore fixed at 0.2.The final model hyperparameter/-parameter configurations were determined by Bayesian optimization, training the model on the largest datasets per effect concentration (fish EC50 and EC10).The optimization included the number of frozen and/or reinitialized encoders and frozen embedding layers in ChemBERTa, batch size, learning rate, dropout rate, as well as the number of hidden layers and number of neurons per layer for the deep neural network.Due to the high number of possible parameter combinations, the optimization was performed in two individual steps.The separation was performed based on the parameters either displaying a decoupling during initial training or because they had little or no impact on model accuracy during initial testing (Table S2, Table S3).The resulting model configurations were decided based on overall performance (weighted median and mean error) and inter-model uniformity.

Selection of ChemBERTa version and loss function
Initially, we explored the available pre-trained versions of ChemBERTa and evaluated our model using two loss functions, mean absolute error (MAE/L1) and mean squared error (MSE).ChemBERTa is available with versions pretrained using either a Byte-Pair Encoding (BPE) or a SMILES tokenizer and with varying amounts of SMILES included in the pre-training.The largest number of unique SMILES included in the pre-training was 10 and 1 million for the BPE and SMILES tokenizers, respectively.Hyperparameter sweeps using Bayesian optimization were performed for both versions, using either the MAE or MSE loss functions.Thus, the optimizer changed both the loss function and learning rate while minimizing the weighted median loss (Table S1).In total 30 individual models per ChemBERTa version were trained (i.e., testing ~15 different learning rates per version and loss function combination).Training was performed over 30 epochs, using five-fold cross-validation and the best-performing model was used for the evaluation (Fig. S1).This optimization was performed using the largest individual dataset (fish EC50).

Principal Component Analysis for invertebrates and algae
PCAs for the CLS-embeddings after the model was trained using the aquatic invertebrate EC50 and EC10 datasets (Fig. S2) as well as the algae EC50 and EC10 datasets (Fig. S3).

Model performance by cosine similarity
The difference in absolute prediction error increased with decreasing between the CLS-embeddings of the validation chemical and the chemicals in the training set (Fig. S4).

Combined model comparison
The accuracy and residual distribution from the combined EC50 and EC10 model showed a consistent improvement across all species groups, demonstrating the model's ability to integrate and utilize different types of data (Fig. S5).The improvement is seen both by the Absolute Prediction Error (fold change) and by a decrease in the number of residuals outside a factor of 10, 100 and 1000.(h-i) algae model able to predict both EC50/EC10, when predictions were evaluated on the EC50 and EC10 datasets.The reported percentage values show the percentage of chemicals which are erroneously predicted by a factor of more than 10, 100 or 1,000.

Venn diagrams over QSAR applicability
The number of SMILES predictable by each QSAR method, for all chemicals with measured data, as well as all chemicals inside the QSARs applicability domains for the six individual datasets (Fig. S6).

QSAR accuracy and residual analysis
The accuracies were analyzed for the set of chemicals that were inside the shared applicability domain for all three QSARs (ECOSAR, VEGA, T.E.S.T.) for each of the six datasets, that neither model had been trained on.For our model the validation accuracy from the ten times repeated ten-fold cross-validation is used.Thus, neither model had been trained with the chemicals that are predicted, and no differences in coverage influence the comparison of accuracy.The residuals, per chemical, for all chemicals inside each model's applicability domain was also analyzed and results are summarized in the main text in Table 2.That figure is complemented here for fish (Fig. S7), aquatic invertebrates (Fig. S8) and algae (Fig. S9).

QSAR residual analysis per effect type for EC10 datasets
The residual error was determined individually for each effect to ensure that the error distributions within the fish and aquatic invertebrate EC10 datasets were not dependent on the measured effect (Fig. S10, Fig. S11).The proportions of errors exceeding 10, 100 and 1000 absolute fold prediction errors show that the model performs well also within each effect individually.Note, that algae are not presented as the model only predicts one population effects.Table S1: ChemBERTa version and loss-function optimization.Parameter/Hyperparameter sweep to determine the best performing ChemBERTa version and loss function based on the fish EC50 dataset.This sweep uses a Bayesian optimizer and a fivefold cross-validation with the average median loss as the target metric.The choice of a log-uniform distribution for the learning rate ensures that learning rates from different magnitudes are equally likely to be tested.Hidden layer sizes [350,20] a The resulting learning rate is not specifically of interest, but an interval is necessary as the loss function varied between the tests.
Table S2: Batch size sweep.Hyperparameter/parameter optimization to determine batch size based on the fish EC50 dataset performed using Bayesian optimization with the average median loss across a five-fold cross-validation as the target metric.The choice of a log-uniform distribution for the learning rate ensures that learning rates from different magnitudes are equally likely to be tested.

ChemBERTa version PubChem10M_SMILES_BPE_450k
a The resulting learning rate is not specifically of interest, but an interval is necessary as the batch sizes vary.The learning rate will therefore be subject to change in subsequent sweeps.

Figure S1 :
Figure S1: ChemBERTa version and loss-function optimization.Mean, median, weighted mean and weighted median validation loss for the five-fold cross-validation on the fish EC50 dataset.Training and validation were performed using a transformer pretrained either using a byte-pair (BPE) or SMILES tokenizer, and a pre-training-set of either 1 or 10 million SMILES.The error bars show the standard error of the mean when associated with mean errors, and the median absolute deviation (MAD) when associated with median errors.The validation losses were recorded at the epoch where the lowest normalized median validation loss was observed within each fold.The bars are based on the validations from the five 80/20 splits between training and validation.Thus, per fold n = 1934 SMILES was used for the training and n = 483 SMILES were used for the validation.

Figure S2 :
Figure S2: PCA projection of CLS-embeddings from the transformer when trained on aquatic invertebrates EC50 and EC10 data.Principal Component Analysis of CLS-embeddings from the transformer when trained using the (a) EC50 dataset (n = 3741) (b) EC10 dataset (n = 2647).

Figure S3 :
Figure S3: PCA projection of CLS-embeddings from the transformer when trained on algae EC50 and EC10 data.Principal Component Analysis of CLS-embeddings from the transformer when trained using the (a) EC50 dataset (n = 2843) (b) EC10 dataset (n = 2756).

Figure S4 :
Figure S4: Model performance by cosine similarity.The mean absolute prediction error, measured as the absolute fold change (i.e., always using the larger of the measured and predicted value as the numerator when calculating the ratio), determined from ten-fold cross-validations repeated ten times, split by the median cosine similarity of the validation chemical to the training dataset.High similarity is defined as a cosine similarity > 0.3, intermediate similarity between 0.2 -0.3 and low similarity < 0.3.In panel (a) fish EC50 model (n = 52666), (b) aquatic invertebrate EC50 model (n = 34820), (c) algae EC50 model (n = 13019), (d) fish EC10 model (n = 19751), (e) aquatic invertebrate EC10 model (n = 15372), (c) algae EC10 model (n = 11830).The error bars show the standard error of the mean.The reported percentage values show the percentage of validation chemicals that belonged to the respective classification during training.

Figure S5 :
Figure S5: combined model performance fish, aquatic invertebrates, and algae.Panels (a,d,g) show the performance as the absolute median and mean error, measured as the absolute fold-change between predicted and experimental values, determined from the ten-fold cross-validations for the (a) fish EC50 model (n = 52666), EC10 model (n = 19751), and the model able to predict both EC50/EC10 (n = 72417), (d) aquatic invertebrates EC50 model (n = 34820), EC10 model (n = 15372), and the model able to predict both EC50/EC10 (n = 50192), and (g) algae EC50 model (n = 13019), EC10 model (n = 11830), and the model able to predict both EC50/EC10 (n = 24849).The error bars show the median absolute deviation and the standard error of the mean for the respective prediction error.Panels (b-c, e-f, h-i) show the histogram of residuals for the (b-c) fish, (e-f) aquatic invertebrate, and (h-i) algae model able to predict both EC50/EC10, when predictions were evaluated on the EC50 and EC10 datasets.The reported percentage values show the percentage of chemicals which are erroneously predicted by a factor of more than 10, 100 or 1,000.

Figure S6 :
Figure S6: QSAR models applicability intersection.The number of SMILES predictable by each QSAR.(a,b) Number of predictable SMILES both in and outside of the applicability domain for the three QSARs based on all chemicals in the fish EC50 dataset.(c,d) The number of predictable SMILES both in and outside of the applicability domain for the three QSARs based on all chemicals in the fish EC10 dataset.(e,f) Number of predictable SMILES both in and outside of the applicability domain for the three QSAR tools based on all chemicals in the aquatic invertebrate EC50 dataset.(g,h) Number of predictable SMILES both in and outside of the applicability domain for the three QSARs based on all chemicals in the aquatic invertebrate EC10 dataset.(i,j) The number of predictable SMILES both in and outside of the applicability domain for the three QSARs based on all chemicals in the algae EC50 dataset.(k,l) The number of predictable SMILES both in and outside of the applicability domain for the three QSARs based on all chemicals in the algae EC10 dataset.

Figure S7 :
Figure S7: Comparison of model performance and absolute error distribution fish.The mean and median absolute error for the chemicals that are within the applicability domains, but not included in the training, of ECOSAR, VEGA and T.E.S.T. for the models trained using the (a) EC50 dataset (n = 734) and (b) EC10 dataset (n = 130).(c-j) The absolute error distribution for all chemicals within the applicability domain of the transformer-based model, ECOSAR, VEGA and T.E.S.T.

Figure S8 :
Figure S8: Comparison of model performance and absolute error distribution aquatic invertebrates.The mean and median absolute error for the chemicals that are within the applicability domains, but not included in the training, of ECOSAR, VEGA and T.E.S.T. for the models trained using the (a) EC50 dataset (n = 752) and (b) EC10 dataset (n = 518).(c-j) The absolute error distribution for all chemicals within the applicability domain of the transformer-based model, ECOSAR, VEGA and T.E.S.T.

Figure S9 :
Figure S9: Comparison of model performance and absolute error distribution algae.The mean and median absolute error for the chemicals that are within the applicability domains, but not included in the training, of ECOSAR, VEGA and T.E.S.T. for the models trained using the (a) EC50 dataset (n = 72) and (b) EC10 dataset (n = 120).(c-j) The absolute error distribution for all chemicals within the applicability domain of the transformer-based model, ECOSAR, VEGA and T.E.S.T.

Figure S10 :
Figure S10: Absolute error distribution per effect fish.The absolute error distribution for all chemicals trained using the fish-EC10 dataset for the transformer-based model, ECOSAR, VEGA and T.E.S.T., split by the effects that are inside the applicability domain of our model.Effect abbreviations: DVP = development, GRO = growth, ITX = intoxication, MOR = mortality, MPH = morphology, POP = population, REP = reproduction.

Figure S11 :
Figure S11: Absolute error distribution per effect invertebrates.The absolute error distribution for all chemicals trained using the aquatic invertebrate EC10 dataset for the transformer-based model, ECOSAR, VEGA and T.E.S.T., split by the effects that are inside the applicability domain of our model.Effect abbreviations: DVP = development, ITX = intoxication, MOR = mortality, MPH = morphology, POP = population, REP = reproduction.