Application of FuncX for Training and Validation of Deep Learning (BiLSTM) models

22 May 2023 - Ritwik Anand Deshpande, University of Illinois Urbana Champaign

I am a CS Masters student at University of Illinois. I, along with my class project team, used FuncX to complete a final project for CS598, Deep Generative & Dyn. Models. I also work as a graduate research assistant on the FuncX project and configured a funcX endpoint on NCSA’s Delta GPU High Performance Computer. I secured the endpoint with a Globus Group and invited my teammates to support model training.

Problem Statement

Description: We have several clinical documents containing clinical notes of doctors, in total, there are 16 different morbidities. For our application effectively we have developed 16 different binary classifiers for each morbidity. Each model has around 600,000 parameters. For each morbidity, we have around 500 documents, with the average length of the document being around 700 and a word embedding vector size of 300.

Our project aimed to emulate the results of the paper. Ensembling Classical Machine Learning and Deep Learning Approaches for Morbidity Identification From Clinical Notes[1]

Approach

In order to leverage funcX to run our training jobs, we first had to complete the required infrastructure setup. We created a funcX endpoint on the Delta cluster at NCSA. We created a singularity image to be a running container to execute our jobs. Thanks to help from Matthew Feickert, we utilized some earlier images with the appropriate PyTorch and CUDA setup. We installed the other modules using the requirements.txt file for our project. Our code was accordingly modified to utilize the Nvidia A100 GPUs of the Delta cluster. Lastly, thanks to help from Ben Galewsky, we were able to make our funcX endpoint public (for members of a Globus group) so that other members of the team could also use the endpoint id. Since the submission of a function along with its parameters has a certain cap, we mounted a shared volume on the Delta cluster to the singularity container that contained the train data files so that our code would point to these files during execution. We also specified “–nv” flag so that PyTorch could detect CUDA correctly when executing inside the container. These options were passed as a string in the container_cmds parameter in the HighThroughoutExecutor library of the Parsl SDK as part of the config.py used while setting up the funcX endpoint.

Using funcX we were able to decouple and parallelize the whole process.

For each of the 16 BiLSTM models, we basically submitted the train_and_validate function in parallel (leveraging the default asynchronous paradigm of parsl user in the funcX SDK). The train_and_validate function involved initializing the model, splitting the data into train and validation sets and feeding batched data to the model for training and performing validation to obtain the f1 micro and macro scores, which were returned by the function. Thus, every BiLSTM model (for 16 morbidities) was processing results in parallel. A future was returned after submission of the function. Using the as_result_completed module of the SDK, the f1 scores were obtained in the ascending order of execution time of the models. These scores were maintained in a CSV file for further analysis.

Results

Originally in the paper, the whole process nearly took 40 hrs using

  • NVIDIA GP102
  • 12GB Memory
  • CUDA 10.2
  • NVIDIA Driver 440.33.01

However, our advanced infrastructure and flexibility (where all the PyTorch, TensorFlow, and other complex library dependencies were handled by using a singularity container, the only dependencies required to execute the process on our client was the funcX SDK) provided by funcX enabled us to execute the whole process in under 5 mins.

See The Code on GitHub

The full source code for this project can be found in Ritwick’s GitHub repository.

This includes:

References

[1] V. Kumar, D. R. Recupero, D. Riboni and R. Helaoui, “Ensembling Classical Machine Learning and Deep Learning Approaches for Morbidity Identification From Clinical Notes,” in IEEE Access, vol. 9, pp. 7107-7126, 2021, doi: 10.1109/ACCESS.2020.3043221.