Transformer for Reaction Informatics – utilizing PyTorch Lightning

Esbenbjerrum/ April 24, 2021/ Blog, Neural Network, PyTorch, Reaction Prediction, Science/ 0 comments

In the last blogpost I covered how LSTM-to-LSTM networks could be used to “translate” reactants into products of chemical reactions. Performance was however not very good of the small an untuned network. However, today there exists another architecture that currently rules the natural language processing area called the Transformer. The transformer relies solely on attention mechanisms, which has also been used to boost LSTM and other RNN networks. There’s a lot of benefits to the architecture, one of them being a much higher fidelity in the transfer of information from the encoder to the decoder. In this blogpost, I’ll show how to use the standard PyTorch classes for transformerencoder and -decoder layers to build a much better reaction predictor using a Transformer architecture. I’ll also use PyTorch Lightning as a framework to help with the training and logging.
Tensorboard will be used for monitoring the training using PyTorch Lightnings tensorboard logger. An inline widget cat be loaded in Google Colab to show the Tensorboard server, but first the extension need to be loaded.

%load_ext tensorboard

Downloading some code to be used (RDKit and molvecgen) and importing some modules that will be used.

!pip install kora -q
import kora.install.rdkit
         |████████████████████████████████| 61kB 6.8MB/s 
         |████████████████████████████████| 61kB 7.1MB/s 
    [?25h
!pip -qq install pytorch-lightning
         |████████████████████████████████| 849kB 20.0MB/s 
         |████████████████████████████████| 276kB 35.5MB/s 
         |████████████████████████████████| 829kB 46.1MB/s 
         |████████████████████████████████| 184kB 51.1MB/s 
         |████████████████████████████████| 112kB 54.0MB/s 
         |████████████████████████████████| 1.3MB 54.6MB/s 
         |████████████████████████████████| 143kB 53.7MB/s 
         |████████████████████████████████| 296kB 50.7MB/s 
    [?25h  Building wheel for PyYAML (setup.py) ... [?25l[?25hdone
      Building wheel for future (setup.py) ... [?25l[?25hdone
import os
import urllib.request
import tqdm
import pandas as pd
import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem import PandasTools

 

import torch
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import pytorch_lightning 

PyTorch Lightning datamodule

Pytorch lightning is a marvelous framework for simplifying training and organizing PyTorch code. First a datamodule needs to be created. The datamodule will takes care of procuring data, setup and DataLoader creation. I’ll do it stepwise while explaining and then provide the full object in the end. The datamodule must contain some functions that are expected by PyTorch Lightning. The first is the prepare function which will download and prepare the datafiles if they are not present. It should not set any states on the object. It’s basically the code from a previous blogpost which will now be wrapped in a reusable datamodule.

self = pytorch_lightning.LightningDataModule()
@staticmethod
def prepare():
  base_url = "https://raw.githubusercontent.com/pandegroup/reaction_prediction_seq2seq/master/processed_data/"
  sets = ["train", "test", "valid"]
  types = ["sources", "targets"]
  files = ["vocab"]
  for s in sets:
      for t in types:
          files.append("%s_%s"%(s, t))
  target_dir = "./pande_data"
  if not os.path.exists(target_dir):
      os.mkdir(target_dir)
  for filename in files:
      target_file = '%s/%s'%(target_dir, filename)
      if not os.path.exists(target_file):
          urllib.request.urlretrieve(base_url + filename, target_file)
  def parse_line_source(line):
      tokens = line.split(" ")
      klass = tokens[0]
      smiles = "".join(tokens[1:])
      return klass, smiles
  def parse_line_target(line):
      tokens = line.split(" ")
      smiles = "".join(tokens)
      return smiles
  dataframe_file= f"{target_dir}/dataframe.csv"
  if not os.path.exists(dataframe_file):
    dataframes = []
    for s in sets:
        target_file = f"{target_dir}/{s}_targets"
        source_file = f"{target_dir}/{s}_sources"    
        with open(target_file, "r") as f:
            target_lines = f.readlines()
        with open(source_file, "r") as f:
            source_lines = f.readlines()
        parsed_sources = [parse_line_source(line.strip()) for line in source_lines]
        parsed_targets = [parse_line_target(line.strip()) for line in target_lines]
        data_dict = {"reactants":parsed_targets,
                    "reaction_class": [t[0] for t in parsed_sources],
                    "products": [t[1] for t in parsed_sources],
                    "set": [s]*len(parsed_sources)}
        dataframe = pd.DataFrame(data_dict)
        dataframes.append(dataframe)
                      
    data = pd.concat(dataframes, ignore_index=True)
    data.to_csv(dataframe_file, index=False)

The next function is the setup(), that will setup and prepare the data on the object. There’s also a need for a tokenizer that can turn the SMILES strings into lists of token indexes. A simple character based tokenizer suffices for this blogpost.

class SimpleTokenizer():
  def __init__(self):
    self.start = "^"
    self.end = "$"
    self.mask = "?"
    self.pad = " "
  def create_vocabulary(self, smiles):
    charset = set("".join(list(smiles)))
    self.tokenlist = [self.pad, self.mask, self.start, self.end] + list(charset)
    
  @property
  def tokenlist(self):
    return self._tokenlist
    
  @tokenlist.setter
  def tokenlist(self, tokenlist):
    self._tokenlist = tokenlist   
    self.char_to_int = {c:i for i,c in enumerate(tokenlist)}
    self.int_to_char = {i:c for c,i in self.char_to_int.items()}
  def vectorize(self, smiles):
    return [self.char_to_int[self.start]] + [self.char_to_int[char] for char in smiles] + [self.char_to_int[self.end]]
  def devectorize(self, vector):
    return "".join([self.int_to_char[i] for i in vector]).strip("^$ ")
def setup(self):
    # called on every GPU if parallelized
    data = pd.read_csv("pande_data/dataframe.csv")
    self.train_data = data[data.set == "train"]
    self.val_data = data[data.set == "valid"]
    self.test_data = data[data.set == "test"]
    self.tokenizer = SimpleTokenizer()
    self.tokenizer.create_vocabulary(data.reactants.values + data.products.values)

The three last functions we need to create for the PyTorch Lightning framework, should return PyTorch dataloaders for minibatch creation for the training, validation and testing. However, to create the dataloaders, there’s a need for for a PyTorch dataset object which will provide the converted SMILES strings.

class MolDataset(Dataset):
    def __init__(self, reactants, products, tokenizer):
        self.reactants = reactants
        self.products = products
        self.tokenizer = tokenizer
    def __len__(self):
        return len(self.reactants)
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        reactants_smiles = self.reactants[idx]
        products_smiles = self.products[idx]
        reactants_tokens = self.tokenizer.vectorize(reactants_smiles)
        products_tokens = self.tokenizer.vectorize(products_smiles)
        
        return reactants_tokens, products_tokens

Now, that the dataset class ready, the dataloader creation methods can be made. As the SMILES strings are not of similar length, a custom collate function are provided, which uses the pad_sequence function to pad up to max lenght in the minibatch as well as transpose the tensor so that the sequence is first dimension. This is expected by PyTorch RNNs and transformer classes. The get_dataloader method should simply get two lists of similar length, that can be indexed with the __getitem__ method.

    @staticmethod
    def custom_collate_and_pad(batch):
        #Batch is a list of tuples (vectorized smiles)
        #Create seperate lists from tuples
        reactants, products  = list(zip(*batch))
        <h2>convert to tensors, pad and transpose</h2>
        reactant_tensors = torch.nn.utils.rnn.pad_sequence( [torch.tensor(l, names=["tokens"]) for l in reactants])
        product_tensors = torch.nn.utils.rnn.pad_sequence( [torch.tensor(l, names=["tokens"]) for l in products])
        return reactant_tensors, product_tensors
    def get_dataloader(self, reactants, products , shuffle=True, batch_size=32):
        dataset = MolDataset(reactants, products, self.tokenizer)
        return DataLoader(dataset, shuffle=shuffle, batch_size=batch_size, pin_memory=True, collate_fn=self.custom_collate_and_pad)
    def train_dataloader(self):
        return self.get_dataloader(self.train_data.reactants.values, self.train_data.products.values)
        
    def val_dataloader(self):
        return self.get_dataloader(self.train_data.reactants.values, self.train_data.products.values, shuffle=False)
    def test_dataloader(self):
        return self.get_dataloader(self.train_data.reactants.values, self.train_data.products.values, shuffle=False)

The full object looks like this. It will autodownload and create necessary data files with .prepare(), setup the sets with .setup() and return the ready-to-train dataloaders with .train_dataloader().

class ReactionsDataModule(pytorch_lightning.LightningDataModule):
  def __init__(self):
    super().__init__()
    self.batch_size = 128
  @staticmethod
  def prepare():
    base_url = "https://raw.githubusercontent.com/pandegroup/reaction_prediction_seq2seq/master/processed_data/"
    sets = ["train", "test", "valid"]
    types = ["sources", "targets"]
    files = ["vocab"]
    for s in sets:
        for t in types:
            files.append("%s_%s"%(s, t))
    target_dir = "./pande_data"
    if not os.path.exists(target_dir):
        os.mkdir(target_dir)
    for filename in files:
        target_file = '%s/%s'%(target_dir, filename)
        if not os.path.exists(target_file):
            urllib.request.urlretrieve(base_url + filename, target_file)
    def parse_line_source(line):
        tokens = line.split(" ")
        klass = tokens[0]
        smiles = "".join(tokens[1:])
        return klass, smiles
    def parse_line_target(line):
        tokens = line.split(" ")
        smiles = "".join(tokens)
        return smiles
    dataframe_file= f"{target_dir}/dataframe.csv"
    if not os.path.exists(dataframe_file):
      dataframes = []
      for s in sets:
          target_file = f"{target_dir}/{s}_targets"
          source_file = f"{target_dir}/{s}_sources"    
          with open(target_file, "r") as f:
              target_lines = f.readlines()
          with open(source_file, "r") as f:
              source_lines = f.readlines()
          parsed_sources = [parse_line_source(line.strip()) for line in source_lines]
          parsed_targets = [parse_line_target(line.strip()) for line in target_lines]
          data_dict = {"reactants":parsed_targets,
                      "reaction_class": [t[0] for t in parsed_sources],
                      "products": [t[1] for t in parsed_sources],
                      "set": [s]*len(parsed_sources)}
          dataframe = pd.DataFrame(data_dict)
          dataframes.append(dataframe)
                        
      data = pd.concat(dataframes, ignore_index=True)
      data.to_csv(dataframe_file, index=False)
  def setup(self):
      # called on every GPU if parallelized
      data = pd.read_csv("pande_data/dataframe.csv")
      self.train_data = data[data.set == "train"]
      
      val_data = data[data.set == "valid"]
      val_data["length"] = val_data.reactants.apply(len)
      self.val_data = val_data.sort_values(by="length")
      
      test_data = data[data.set == "test"]
      test_data["length"] = test_data.reactants.apply(len)
      self.test_data = test_data.sort_values(by="length")
      
      self.tokenizer = SimpleTokenizer()
      self.tokenizer.create_vocabulary(data.reactants.values + data.products.values)
  @staticmethod
  def custom_collate_and_pad(batch):
      #Batch is a list of tuples (vectorized smiles, label)
      #Create seperate lists from tuples
      reactants, products  = list(zip(*batch))
      <h2>convert to tensors, pad and transpose</h2>
      reactant_tensors = torch.nn.utils.rnn.pad_sequence( [torch.tensor(l) for l in reactants])
      product_tensors = torch.nn.utils.rnn.pad_sequence( [torch.tensor(l) for l in products])
      return reactant_tensors, product_tensors
  def get_dataloader(self, reactants, products , shuffle=True, batch_size=None):
      if batch_size is None:
        batch_size = self.batch_size
      dataset = MolDataset(reactants, products, self.tokenizer)
      return DataLoader(dataset, shuffle=shuffle, batch_size=batch_size, pin_memory=True, collate_fn=self.custom_collate_and_pad)
  def train_dataloader(self):
      return self.get_dataloader(self.train_data.reactants.values, self.train_data.products.values)
      
  def val_dataloader(self):
      return self.get_dataloader(self.val_data.reactants.values, self.val_data.products.values, shuffle=False)
  def test_dataloader(self):
      return self.get_dataloader(self.test_data.reactants.values, self.test_data.products.values, shuffle=False)   

With the ful datamodule ready, its simple to prepare it for training.

datamodule = ReactionsDataModule()
datamodule.prepare()
datamodule.setup()

 

Let’s visualize an example batch.

example_batch = next(iter(datamodule.train_dataloader()))
plt.matshow(example_batch[0].numpy().T)

 


It creates some beautiful vectorized and tensorized SMILES in mini_batch tensor format, just as intended.

Defining the Model

For the model, PyTorch Lightning will again help with a special module that is subclassed. As for the datamodule, some special methods need to be declared, as they are expected by the framework. The model as such are created similar to how it’s done in PyTorch, but then “step” functions are created that defines how the minibatch should be converted to a loss and what properties should be logged. It’s also needed to declare a function that returns the optimizer to use with an optional learning rate scheduler.
For reasons I don’t know, the positional encoding class is not part of the standard classes in PyTorch, but there is one in one of the examples of PyTorch, that we can download and reuse.

! wget -qc https://raw.githubusercontent.com/pytorch/examples/master/word_language_model/model.py
from model import PositionalEncoding

The full model is shown below, I’ll explain some of the highlights.
The self.save_hyperparameters() is Pytorch Lightning method, that saves arguments and their values from the __init__() method to a self.hparams object. The hparams object will also be used by the logger object that will be used later, so that the parameters are saved with the training log. This is very nice to keep track of what hyperparameters are associated with what model when experimenting with tuning of hyperparameters.
The setup_layer method defines all the layers needed. First an embedding, then the positional encoding object, the encoder and decoder as well as a fully connected layer that projects the output from the dimension of the model to the dimension of the vocabulary. The encoder and decoder are defined seperately rather than using the PyTorch TransformerModel class directly, as this allows to intercept the memory tensor. The forward method is divided into encode_memory and decode_memory methods.
The training- and validation_step are used by pytorch lightning. Here also the losses are logged using the self.log() method provided by the parent class.

class ReactionTransformerModel(pytorch_lightning.LightningModule):
    def __init__(
        self,
        n_tokens,
        d_model=256,
        nhead=8,
        num_encoder_layers=2,
        num_decoder_layers=2,
        dim_feedforward=512,
        dropout=0.1,
        activation="relu",
        max_length = 1000,
        max_lr =  3e-3, #3e-3 works, 3e-2 is too much, learning stops over 8e-3, 6.5e-3 is also too much
        final_lr = 1e-3, #3e-4,
        start_lr = 1e-6,
        max_epochs = 50,
        notes = "higher final lr, extended training"
    ):
      super().__init__()
      self.save_hyperparameters() #Populates self.hparams from __init__ options.
      self.hparams["batch_size"] = datamodule.batch_size
      self.setup_layers()
      self.criterion = torch.nn.NLLLoss(reduction='mean') #Expects (minibatch,C,d)
    
    def configure_optimizers(self): #Expected by PyTorch Lightning
        optimizer = torch.optim.Adam(self.parameters())
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.hparams.max_lr, 
                       total_steps=None, epochs=self.hparams.max_epochs, steps_per_epoch=len(self.train_dataloader()),#We call train_dataloader, just to get the length, is this necessary?
                       pct_start=10/self.hparams.max_epochs, anneal_strategy='cos', cycle_momentum=True, 
                       base_momentum=0.85, max_momentum=0.95, #These need to be tuned?
                       div_factor=self.hparams.max_lr/self.hparams.start_lr,
                       final_div_factor=self.hparams.start_lr/self.hparams.final_lr,
                       last_epoch=-1)
        scheduler = {"scheduler": scheduler, "interval" : "step" }
        return [optimizer], [scheduler]
    def setup_layers(self):
      #Embedding layer to turn token idx into vectors. Note the use of self.hparams
      self.embedding = torch.nn.Embedding(self.hparams.n_tokens, self.hparams.d_model)
      self.positional_encoder = PositionalEncoding(self.hparams.d_model, dropout=self.hparams.dropout)
      
      #We specify the encoder and decoder seperately, so that we can get the memory during inference
      encoder_layer = torch.nn.TransformerEncoderLayer(self.hparams.d_model, self.hparams.nhead, self.hparams.dim_feedforward, self.hparams.dropout, self.hparams.activation)
      encoder_norm = torch.nn.LayerNorm(self.hparams.d_model)
      self.encoder = torch.nn.TransformerEncoder(encoder_layer, self.hparams.num_encoder_layers, encoder_norm)
      
      decoder_layer = torch.nn.TransformerDecoderLayer(self.hparams.d_model, self.hparams.nhead, self.hparams.dim_feedforward, self.hparams.dropout, self.hparams.activation)
      decoder_norm = torch.nn.LayerNorm(self.hparams.d_model)
      self.decoder = torch.nn.TransformerDecoder(decoder_layer, self.hparams.num_decoder_layers, decoder_norm)
      
      self.fc_out = torch.nn.Linear(self.hparams.d_model, self.hparams.n_tokens)
      self.logsoftmax = torch.nn.LogSoftmax(dim=1) #take batch,token,sequence
    def generate_square_subsequent_mask(self, sz: int):
        r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
            Unmasked positions are filled with float(0.0).
        """
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask.to(self.device)
    def encode_memory(self, reactants):
      embedded = self.embedding(reactants)
      positional_encoded = self.positional_encoder(embedded)
      memory = self.encoder(positional_encoded)
      return memory
    def decode_memory(self, tgt_teachers_forcing, memory):
      embedded = self.embedding(tgt_teachers_forcing)
      positional_encoded = self.positional_encoder(embedded)
      tgt_mask = self.generate_square_subsequent_mask(tgt_teachers_forcing.shape[0]) 
      out = self.decoder(positional_encoded, memory, tgt_mask=tgt_mask)
      logits = self.fc_out(out)
      return logits
    def forward(self, reactants, products):
      memory = self.encode_memory(reactants)
      logits = self.decode_memory(products[:-1], memory)
      return logits 
    def training_step(self, batch, batch_idx): #Expected by PyTorch Lightning
        self.train()
        reactants, products = batch
        logits = self.forward(reactants, products) # Sequence, batch, tokens
        logits =  logits.permute(1,2,0) # Now batch, tokens, sequence
        log_softmax = self.logsoftmax(logits) 
        loss = self.criterion(log_softmax, products[1:].permute(1,0))# Skipping the start-char in the target
        self.log('loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss
    def validation_step(self, batch, batch_idx): #Optional for PyTorch Lightning
        self.eval()
        reactants, products = batch
        logits = self.forward(reactants, products) # Sequence, batch, tokens
        logits =  logits.permute(1,2,0) # Now batch, tokens, sequence
        log_softmax = self.logsoftmax(logits) 
        loss = self.criterion(log_softmax, products[1:].permute(1,0))# Skipping the start-char in the target
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

We can now instantiate the model with the vocabulary size and the number of epochs we want to train.

model = self = ReactionTransformerModel(len(datamodule.tokenizer.char_to_int), max_epochs=200,
                                        )

In case there’s a previously saved model (see later) that we want to reload, it can be done like this:

import pickle
save_dict = pickle.load(open("drive/MyDrive/temp/statedict_200.pickle","rb"))
model.load_state_dict(save_dict["state_dict"])
datamodule.tokenizer.tokenlist = save_dict["tokenlist"]

Training the Transformer

For training there are some helper classes from PyTorch Lightning that can be used for logging to tensorboard. Moreover, there’s a callback class to log the learning rate that should be used with the trainer. Finally, the trainer object is created, which will take care of the training for us 🙂
It’s easy to switch between cpu and gpu with the “gpus” argument. The progress_bar_refresh_rate is lowered, as it can others sometimes give issues with non-local training if it’s refreshing too fast.

lr_logger = pytorch_lightning.callbacks.lr_monitor.LearningRateMonitor(logging_interval="epoch")
tb_logger = pytorch_lightning.loggers.TensorBoardLogger('tensorboard_logs/')
trainer = pytorch_lightning.Trainer(
    logger=tb_logger,  
    callbacks=[lr_logger], 
    max_epochs=model.hparams.max_epochs, 
    gpus=1, 
    progress_bar_refresh_rate=20)
    GPU available: True, used: True
    TPU available: False, using: 0 TPU cores

Start an inline tensorboard session, so that the training can be followed real time.

%tensorboard --logdir tensorboard_logs

 

The trainer object takes care of all the training loops and validation loss calculation, logging while providing some useful feedback about the training time and losses.

trainer.fit(model, datamodule)
    LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
    
      | Name               | Type               | Params
    ----------------------------------------------------------
    0 | embedding          | Embedding          | 13.3 K
    1 | positional_encoder | PositionalEncoding | 0     
    2 | encoder            | TransformerEncoder | 1.1 M 
    3 | decoder            | TransformerDecoder | 1.6 M 
    4 | fc_out             | Linear             | 13.4 K
    5 | logsoftmax         | LogSoftmax         | 0     
    6 | criterion          | NLLLoss            | 0     
    ----------------------------------------------------------
    2.7 M     Trainable params
    0         Non-trainable params
    2.7 M     Total params
    10.654    Total estimated model params size (MB)
    

After a long training, it can be smart to save the weights and the tokenizer as well as the hyperparamters to a file for later reuse and referal.

save_dict = {"state_dict": model.state_dict(), "tokenlist": datamodule.tokenizer.tokenlist, "hparams":model.hparams}
pickle.dump(save_dict, open("drive/MyDrive/temp/statedict_200.pickle","wb"))

As we saw last time, the raw output doesn’t necessarely makes sense, as the teachers forcing input disturbs the generation. So instead the Sampling is created as a class.
One big difference to the sampling of RNNs is that the output generated during the sampling, needs to be input to the decoder again for each new character that is sampled. Precomputing the memory saves some computations, but it’s still slower than sampling LSTM networks.
A sample_reactions function creates a dataloader from a provided list of smiles if a dataloade is not already defined. The dataloaders will enable sampling in minibatches, so that we don’t run out of GPU memory.

class ReactionSampler():
  def __init__(self, model, datamodule,  max_length=250):
    self.model = model
    self.datamodule = datamodule
    self.tokenizer = datamodule.tokenizer
    self.softmax = torch.nn.Softmax(dim=2)
    self.max_length = max_length
  def sample_reactions(self, smiles_list):
    data_loader = self.datamodule.get_dataloader(smiles_list, [""]*len(smiles_list), shuffle=False)
    products = self.sample_dataloader(data_loader)
    return products
  def sample_batch(self, batch):
    reactants = batch[0].to(self.model.device)
    model.eval()
    with torch.no_grad():
      memory = model.encode_memory(reactants)
      predicted = torch.zeros(self.max_length+1, reactants.shape[1], dtype=torch.long).to(self.model.device)
      predicted[0,:] = self.datamodule.tokenizer.char_to_int["^"]
      for i in range(self.max_length):
        logits = self.model.decode_memory(predicted[0:i+1], memory)
        probabilities = self.softmax(logits)
        char_index = torch.argmax(probabilities, dim=2)[-1]
        predicted[i+1] = char_index
        if torch.all((predicted == self.tokenizer.char_to_int["$"]).sum(dim=0)):
          #Break if all samples has a stop token
          break
    return predicted
  def sample_dataloader(self, dataloader):
    products = []
    for batch in tqdm.tqdm(dataloader):
      predicted_tensor = self.sample_batch(batch)
      smiles_list = self.devectorize_tensor(predicted_tensor)
      products.extend(smiles_list)
    return products
  def devectorize_tensor(self, predicted_tensor):
    smiles_strings = []
    for vector in predicted_tensor.T:
      smiles = (self.datamodule.tokenizer.devectorize(vector.detach().cpu().numpy()))
      smiles_strings.append(smiles)
    return smiles_strings
    

Moving the model to the GPU significantly speeds up sampling.

_ = model.to("cuda")

Instatiation of the sampler makes it possble to sample from the test_dataloader from the datamodule.

sampler = ReactionSampler(model, datamodule)

We’ll get the test_dataloader from the datamodule, and sample the product SMILES given the input reactants.

data_loader = datamodule.test_dataloader()

 

sampled_smiles = sampler.sample_dataloader(data_loader)
    100%|██████████| 40/40 [00:56<00:00,  1.41s/it]

For convenience we’ll use the test_dataset pandas frame to store the results.

test_dataset = datamodule.test_data

Adding a column to the test_dataset makes it easy to see compate the sampled with the products column. It’s also necessary to canonicalize the sampled SMILES strings so that the accuracy can be evaluated on the molecular level.

test_dataset["sampled"] = sampled_smiles
def canonicalize(smiles):
  try:
    mol = Chem.MolFromSmiles(smiles)
  except:
    return None
  if mol:
    return Chem.MolToSmiles(mol, canonical=True)
test_dataset["can_sampled"] = test_dataset.sampled.apply(canonicalize)

 

test_dataset.head(20)
reactants reaction_class products set length sampled can_sampled
42682 CI.CO <RX_1> O=C(O)c1ccc(-c2ccn[nH]2)cc1 test 5 COC COC
43958 CCOC(=O)CBr <RX_9> CCOC(C)=O test 11 CCOC(=O)CBr CCOC(=O)CBr
44072 N#CCN1CCOCC1 <RX_7> NCCN1CCOCC1 test 12 NCCN1CCOCC1 NCCN1CCOCC1
43550 CCCCCCC(C)=O <RX_7> CCCCCCC(C)O test 12 CCCCCCCC(C)=O CCCCCCCC(C)=O
40959 C=CCCCCCCCCCO <RX_8> C=CCCCCCCCCC=O test 13 C=CCCCCCCCCCCO C=CCCCCCCCCCCO
44759 C=CCCOCC=O.NO <RX_9> C=CCCOCC=NO test 13 C=CCCOCC=NO)OCCOCC=C None
43005 OCCOCCc1ccccc1 <RX_8> O=CCOCCc1ccccc1 test 14 O=CCOCCc1ccccc1 O=CCOCCc1ccccc1
43924 CC(=O)c1ccccc1 <RX_7> CC(O)c1ccccc1 test 14 CC(O)c1ccccc1 CC(O)c1ccccc1
44098 CC(=O)c1ccccn1 <RX_7> C[C@@H](O)c1ccccn1 test 14 CC(O)c1ccccn1 CC(O)c1ccccn1
43753 OCc1cncc(Br)c1 <RX_8> O=Cc1cncc(Br)c1 test 14 O=Cc1cncc(Br)c1 O=Cc1cncc(Br)c1
42457 Cn1ncc(I)c1C=O <RX_7> Cn1ncc(I)c1CO test 14 Cn1ncc(I)c1C=O Cn1ncc(I)c1C=O
43325 FCC1CNC1.OCCBr <RX_1> OCCN1CC(CF)C1 test 14 OCCN1CC(CF)C1 OCCN1CC(CF)C1
42768 C=Cc1ccccc1.OO <RX_4> c1ccc(C2CO2)cc1 test 14 OCCCc1ccccc1 OCCCc1ccccc1
42511 N#CCC1(CC#N)CC1 <RX_7> N#CCC1(CCN)CC1 test 15 N#CCC1(Cc2ncc(C3(CcN)CC3)c2CC#N)CC1 None
44118 O=C(CBr)c1cncs1 <RX_7> OC(CBr)c1cncs1 test 15 O=C(CBr)c1cncs1 O=C(CBr)c1cncs1
42382 CC(C#N)c1ccccc1 <RX_7> CC(CN)c1ccccc1 test 15 CC(C#N)c1ccccc1 CC(C#N)c1ccccc1
40368 ClCBr.Cn1nnnc1S <RX_1> Cn1nnnc1SCCl test 15 Cn1nnnc1SCCl Cn1nnnc1SCCl
44595 CC(O)c1cnc(Br)s1 <RX_8> CC(=O)c1cnc(Br)s1 test 16 CC(=O)c1cnc(Br)s1 CC(=O)c1cnc(Br)s1
43735 O=c1cc(O)cc[nH]1 <RX_9> O=C1NC(=O)C(c2ccccc2)(c2ccccc2)N1 test 16 O=c1cc(O)cc[nH]1 O=c1cc(O)cc[nH]1
43141 Cc1cc(C)nc(CO)c1 <RX_8> Cc1cc(C)nc(C=O)c1 test 16 Cc1cc(C)nc(C=O)c1 Cc1cc(C)nc(C=O)c1

It definetely look like the SMILES strings are related to the input. The first row looks like an error in the dataset, there’s no way that aromatic nitrogen containing compound can be made from cloride and methanol. Methylether is probably a better suggestion. What is the overall accuracy?

succeeded = test_dataset.products == test_dataset.can_sampled
succeeded.sum()/len(test_dataset)
    0.5419664268585132

Much better than the LSTM networks!. With more layers and a bigger model and larger datasets, the accuracy can be driven even higher. The validity also seem reasonable, but how high? If the sampled smiles could not be converted to a molecule, the entry for the canonical sampled will be None.

1-test_dataset.can_sampled.isna().sum()/len(test_dataset)
    0.8283373301358913

82%, I would have expected higher, but is is a small dataset and we run it unaugmented. Maybe longer SMILES have a higher chance of being invalid, so is the success rate related to the sequence length? From the plot below it doesn’t seem so.

plt.hist(test_dataset[succeeded].length, alpha=0.5, label="success",bins=50)
plt.hist(test_dataset[~succeeded].length, alpha=0.5, label="failure", bins=50)
plt.legend()
    


The dataset contained ten different reaction classes. So is the failures related to the different reaction types? The dataset was unbalanced, so maybe the lower prevalence reaction types are not learned good enough.

#Getting the value counts
class_succes = test_dataset[succeeded].reaction_class.value_counts()
class_failure = test_dataset[~succeeded].reaction_class.value_counts()
test_class_size = test_dataset.reaction_class.value_counts()
train_class_size = datamodule.train_data.reaction_class.value_counts()
#Some pandas data wrangling
df = pd.DataFrame(class_succes / test_class_size)
df_size = pd.DataFrame(train_class_size)
df_size.rename(columns={"reaction_class": "class_size"}, inplace=True)
df = df.join(df_size)
df
reaction_class class_size
<RX_10> 0.478261 181
<RX_1> 0.566138 12098
<RX_2> 0.679261 9531
<RX_3> 0.602837 4511
<RX_4> 0.188889 720
<RX_5> 0.738462 520
<RX_6> 0.459880 6683
<RX_7> 0.276688 3667
<RX_8> 0.432099 651
<RX_9> 0.461957 1467
plt.scatter(df.class_size.values, df.reaction_class.values)
plt.xlabel("class size")
_ = plt.ylabel("succes rate")


From the plot, it doesn’t seem to be related to training set size. Some reaction classes are very unsuccesful, like whereas almost equal sized has a high succesrate. According to the original paper, these classes are heterocycle formation and protections, respectively. It seems plausible that ring-forming reactions are trickier to get right with all the changes in e.g. aromaticity that could happen.
I hope this blogpost has been instructive ans shown how to use transformer networks for chemical related tasks utilizing the SMILES chemical language. The network here was pretty small and larger datasets and larger models can lead to even better performance.

Happy Modelling

Esben

Share this Post

Leave a Comment

Your email address will not be published. Required fields are marked *

*
*