Maintaining computational graph in client state

Hi everyone,
I am a master’s student using flower for my Master’s thesis. As a baseline I tried to implement this paper: Federated Split Vision Transformer for COVID-19 CXR Diagnosis using Task-Agnostic Training. The basic idea of the paper is to train a ViT using split learning, with the feature extractor and classifier staying on the client but the main transformer encoder on the server. To do the forward/backward passes would need to be split and the gradients sent back and forth.

The authors also built this on top of Flower so I thought it would be possible, however I think I am running into a fundamental obstacle. I am aware that Flower’s clients are stateless, so I planned to store everything in the Context. However, I didn’t realize that the computational graph attached to the features cannot be serialized and therefore gets destroyed after the client is reinitialized. This makes it impossible to do a backward pass in this way.

I am quite stuck and would greatly appreciate any help and insight into potential ways this can be overcome, otherwise I would have to resort to abandoning this idea.

Many thanks!

Hi @pnnguyen2209,

Thank you for raising this. You can design stateful clientApps in Flower: Design stateful ClientApps - Flower Framework.

Please let me know if this helps.

Best regards,

William

Unfortunately writing the features to a RecordDict does not preserve the computational graph that is required for the backward pass. The only way to do that would be to somehow preserve the features in memory which to my understanding isn’t possible. For now I have resorted to delegating some of the features to an external Flask back-end, but that isn’t compatible with the simulation engine.

It works but it’s not ideal so I would welcome other alternatives as well.