[Feature] Add selected_out_keys to ProbabilisticTensorDictSequential.#1497
Conversation
|
Hi @tobiabir! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
There was a problem hiding this comment.
Hey @tobiabir!
Thanks for this and sorry for the delay to review!
I hope you don't mind: I added a couple of small fixes to handle some edge cases:
- Passing selected_out_keys through the recursive call (line ~1003) — When a dict is passed to the constructor, the code recursively calls
self.__init__()but wasn't forwarding selected_out_keys. Now it does! - Fixed the "no probabilistic modules" fallback (line ~1034-1047) — The original code used return
TensorDictSequential(...)which doesn't work in Python 3.12+ (__init__must returnNone). I rewrote this to properly initialize self viaTensorDictSequential.__init__()and set the required attributes (_requires_sample,_det_part,return_composite). Also added selected_out_keys here. - Set
return_composite=Truefor the no-prob-modules case so thatforward()just iterates through modules normally without trying to find a distribution to sample from.
All your original tests pass, and I verified the edge cases work too. LGTM otherwise — nice clean implementation! 🎉
929b437 to
fdf28f6
Compare
87f6030 to
ee1f422
Compare
Description
Describe your changes in detail.
selected_out_keysargument to the__init__ofProbabilisticTensorDictSequentialand pass it through to the parentTensorDictSequential.Motivation and Context
Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax
close #15213if this solves the issue #15213close #1496
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
xin all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!