Thanks @adamgayoso . It’s very helpful! I think I got the point of definition of E and Q functions. However, I’m still a little confused about the calling of loss function (Sorry I’m not very good at statistics and deep learning). I saw the calling of generative function in line 119 of CellAssign, but I’m not sure when the loss function defined in CellAssignModule is called.
I added print code within loss function to see if it’s called, and I couldn’t see the output of printing. I’m not sure whether and when the loss function is called. I guess I still need to understand better of the code.
Thanks so much @adamgayoso! It really helps, and I totally understand the abstraction of code, which make it more flexible and compatible. Really appreciate the hard work!
Basically I changed a little bit of the code of CellAssignModule to fit in other variables in the CellAssign algorithm. What I’m concerned is that the variable inference in the new model might not work. It seems the predict function output the originally defined randomized delta variables instead of optimized delta variables in the new model. I debugged it for a while and it seems the loss function might not be called in the new code. Not sure why it happened.
Train function is still same. Prediction function is like below:
def predict(self) -> pd.DataFrame:
"""Predict soft cell type assignment probability for each cell."""
adata = self._validate_anndata(None)
scdl = self._make_data_loader(adata=adata)
# predictions = 
for idx, tensors in enumerate(scdl):
generative_inputs = self.module._get_generative_input(tensors, None)
outputs = self.module.generative(**generative_inputs)
if idx == 0:
delta_c = outputs["delta_c"]
delta_p = outputs["delta_p"]
delta_cp = outputs["delta_cp"]
delta_c = torch.cat((delta_c, outputs["delta_c"]))
delta_p = torch.cat((delta_p, outputs["delta_p"]))
delta_cp = torch.cat((delta_cp, outputs["delta_cp"]))
# gamma = outputs["gamma"]
# predictions += [gamma.cpu()]
# to be better specified ??
return delta_c.numpy(), delta_p.numpy(), delta_cp.numpy()