Explorar el Código

feat: finalize supervised-only branch contents

kekezack hace 1 mes
padre
commit
83bc5ae2a8

+ 52 - 45
.gitignore

@@ -1,45 +1,52 @@
-# Python
-__pycache__/
-*.py[cod]
-*$py.class
-*.so
-*.egg
-*.egg-info/
-dist/
-build/
-*.whl
-
-# IDE
-.idea/
-.vscode/
-*.swp
-*.swo
-
-# OS
-.DS_Store
-Thumbs.db
-
-# Reference code & papers (do not upload)
-ref/
-tmp/
-
-# Weights & checkpoints
-*.pth
-*.pt
-*.ckpt
-*.onnx
-
-
-# Logs & outputs
-*.log
-outputs/
-runs/
-lightning_logs/
-
-# Jupyter
-.ipynb_checkpoints/
-
-# Environment
-.env
-.venv/
-venv/
+# Python
+__pycache__/
+*.py[cod]
+*$py.class
+*.so
+*.egg
+*.egg-info/
+dist/
+build/
+*.whl
+
+# TypeScript
+lib/sam2/demo/
+
+
+# IDE
+.idea/
+.vscode/
+*.swp
+*.swo
+
+# OS
+.DS_Store
+Thumbs.db
+
+# Reference code & papers (do not upload)
+ref/
+tmp/
+
+# Weights & checkpoints
+*.pth
+*.pt
+*.ckpt
+*.onnx
+
+# Logs & outputs
+*.log
+outputs/
+runs/
+lightning_logs/
+
+# Jupyter
+.ipynb_checkpoints/
+
+# Environment
+.env
+.venv/
+venv/
+
+# data
+data/
+cache/

+ 72 - 72
LICENSE

@@ -1,72 +1,72 @@
-Apache License 
-Version 2.0, January 2004 
-http://www.apache.org/licenses/
-TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
-1. Definitions.
-
-"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
-
-"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
-
-"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
-
-"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
-
-"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
-
-"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
-
-"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
-
-"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
-
-"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
-
-"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
-
-2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
-
-3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
-
-4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
-
-(a) You must give any other recipients of the Work or Derivative Works a copy of this License; and
-
-(b) You must cause any modified files to carry prominent notices stating that You changed the files; and
-
-(c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
-
-(d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
-
-You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
-
-5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
-
-6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
-
-7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
-
-8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
-
-9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
-
-END OF TERMS AND CONDITIONS
-
-APPENDIX: How to apply the Apache License to your work.
-
-To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives.
-
-Copyright [yyyy] [name of copyright owner]
-
-Licensed under the Apache License, Version 2.0 (the "License"); 
-you may not use this file except in compliance with the License. 
-You may obtain a copy of the License at
-
-http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software 
-distributed under the License is distributed on an "AS IS" BASIS, 
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
-See the License for the specific language governing permissions and 
-limitations under the License.
+Apache License 
+Version 2.0, January 2004 
+http://www.apache.org/licenses/
+TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+1. Definitions.
+
+"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
+
+"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
+
+"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
+
+"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
+
+"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
+
+"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
+
+"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
+
+"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
+
+"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
+
+"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
+
+2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
+
+3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
+
+4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
+
+(a) You must give any other recipients of the Work or Derivative Works a copy of this License; and
+
+(b) You must cause any modified files to carry prominent notices stating that You changed the files; and
+
+(c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
+
+(d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
+
+You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
+
+5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
+
+6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
+
+7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
+
+8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
+
+9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
+
+END OF TERMS AND CONDITIONS
+
+APPENDIX: How to apply the Apache License to your work.
+
+To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives.
+
+Copyright [yyyy] [name of copyright owner]
+
+Licensed under the Apache License, Version 2.0 (the "License"); 
+you may not use this file except in compliance with the License. 
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software 
+distributed under the License is distributed on an "AS IS" BASIS, 
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
+See the License for the specific language governing permissions and 
+limitations under the License.

+ 18 - 18
configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml

@@ -1,19 +1,19 @@
-DATA:
-  DATASET: imagenet22K
-  IMG_SIZE: 192
-MODEL:
-  TYPE: swinv2
-  NAME: swinv2_base_patch4_window12_192_22k
-  DROP_PATH_RATE: 0.2
-  SWINV2:
-    EMBED_DIM: 128
-    DEPTHS: [ 2, 2, 18, 2 ]
-    NUM_HEADS: [ 4, 8, 16, 32 ]
-    WINDOW_SIZE: 12
-TRAIN:
-  EPOCHS: 90
-  WARMUP_EPOCHS: 5
-  WEIGHT_DECAY: 0.1
-  BASE_LR: 1.25e-4 # 4096 batch-size
-  WARMUP_LR: 1.25e-7
+DATA:
+  DATASET: imagenet22K
+  IMG_SIZE: 192
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_base_patch4_window12_192_22k
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 12
+TRAIN:
+  EPOCHS: 90
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 0.1
+  BASE_LR: 1.25e-4 # 4096 batch-size
+  WARMUP_LR: 1.25e-7
   MIN_LR: 1.25e-6

+ 18 - 18
configs/swinv2/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.yaml

@@ -1,19 +1,19 @@
-DATA:
-  IMG_SIZE: 256
-MODEL:
-  TYPE: swinv2
-  NAME: swinv2_base_patch4_window12to16_192to256_22kto1k_ft
-  DROP_PATH_RATE: 0.2
-  SWINV2:
-    EMBED_DIM: 128
-    DEPTHS: [ 2, 2, 18, 2 ]
-    NUM_HEADS: [ 4, 8, 16, 32 ]
-    WINDOW_SIZE: 16
-    PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
-TRAIN:
-  EPOCHS: 30
-  WARMUP_EPOCHS: 5
-  WEIGHT_DECAY: 1e-8
-  BASE_LR: 2e-05
-  WARMUP_LR: 2e-08
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_base_patch4_window12to16_192to256_22kto1k_ft
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 16
+    PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
+TRAIN:
+  EPOCHS: 30
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 1e-8
+  BASE_LR: 2e-05
+  WARMUP_LR: 2e-08
   MIN_LR: 2e-07

+ 20 - 20
configs/swinv2/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.yaml

@@ -1,21 +1,21 @@
-DATA:
-  IMG_SIZE: 384
-MODEL:
-  TYPE: swinv2
-  NAME: swinv2_base_patch4_window12to24_192to384_22kto1k_ft
-  DROP_PATH_RATE: 0.2
-  SWINV2:
-    EMBED_DIM: 128
-    DEPTHS: [ 2, 2, 18, 2 ]
-    NUM_HEADS: [ 4, 8, 16, 32 ]
-    WINDOW_SIZE: 24
-    PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
-TRAIN:
-  EPOCHS: 30
-  WARMUP_EPOCHS: 5
-  WEIGHT_DECAY: 1e-8
-  BASE_LR: 2e-05
-  WARMUP_LR: 2e-08
-  MIN_LR: 2e-07
-TEST:
+DATA:
+  IMG_SIZE: 384
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_base_patch4_window12to24_192to384_22kto1k_ft
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 24
+    PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
+TRAIN:
+  EPOCHS: 30
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 1e-8
+  BASE_LR: 2e-05
+  WARMUP_LR: 2e-08
+  MIN_LR: 2e-07
+TEST:
   CROP: False

+ 10 - 10
configs/swinv2/swinv2_base_patch4_window16_256.yaml

@@ -1,11 +1,11 @@
-DATA:
-  IMG_SIZE: 256
-MODEL:
-  TYPE: swinv2
-  NAME: swinv2_base_patch4_window16_256
-  DROP_PATH_RATE: 0.5
-  SWINV2:
-    EMBED_DIM: 128
-    DEPTHS: [ 2, 2, 18, 2 ]
-    NUM_HEADS: [ 4, 8, 16, 32 ]
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_base_patch4_window16_256
+  DROP_PATH_RATE: 0.5
+  SWINV2:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
     WINDOW_SIZE: 16

+ 10 - 10
configs/swinv2/swinv2_base_patch4_window8_256.yaml

@@ -1,11 +1,11 @@
-DATA:
-  IMG_SIZE: 256
-MODEL:
-  TYPE: swinv2
-  NAME: swinv2_base_patch4_window8_256
-  DROP_PATH_RATE: 0.5
-  SWINV2:
-    EMBED_DIM: 128
-    DEPTHS: [ 2, 2, 18, 2 ]
-    NUM_HEADS: [ 4, 8, 16, 32 ]
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_base_patch4_window8_256
+  DROP_PATH_RATE: 0.5
+  SWINV2:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
     WINDOW_SIZE: 8

+ 18 - 18
configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml

@@ -1,19 +1,19 @@
-DATA:
-  DATASET: imagenet22K
-  IMG_SIZE: 192
-MODEL:
-  TYPE: swinv2
-  NAME: swinv2_large_patch4_window12_192_22k
-  DROP_PATH_RATE: 0.2
-  SWINV2:
-    EMBED_DIM: 192
-    DEPTHS: [ 2, 2, 18, 2 ]
-    NUM_HEADS: [ 6, 12, 24, 48 ]
-    WINDOW_SIZE: 12
-TRAIN:
-  EPOCHS: 90
-  WARMUP_EPOCHS: 5
-  WEIGHT_DECAY: 0.1
-  BASE_LR: 1.25e-4 # 4096 batch-size
-  WARMUP_LR: 1.25e-7
+DATA:
+  DATASET: imagenet22K
+  IMG_SIZE: 192
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_large_patch4_window12_192_22k
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 192
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 6, 12, 24, 48 ]
+    WINDOW_SIZE: 12
+TRAIN:
+  EPOCHS: 90
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 0.1
+  BASE_LR: 1.25e-4 # 4096 batch-size
+  WARMUP_LR: 1.25e-7
   MIN_LR: 1.25e-6

+ 18 - 18
configs/swinv2/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.yaml

@@ -1,19 +1,19 @@
-DATA:
-  IMG_SIZE: 256
-MODEL:
-  TYPE: swinv2
-  NAME: swinv2_base_patch4_window12to16_192to256_22kto1k_ft
-  DROP_PATH_RATE: 0.2
-  SWINV2:
-    EMBED_DIM: 192
-    DEPTHS: [ 2, 2, 18, 2 ]
-    NUM_HEADS: [ 6, 12, 24, 48 ]
-    WINDOW_SIZE: 16
-    PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
-TRAIN:
-  EPOCHS: 30
-  WARMUP_EPOCHS: 5
-  WEIGHT_DECAY: 1e-8
-  BASE_LR: 2e-05
-  WARMUP_LR: 2e-08
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_base_patch4_window12to16_192to256_22kto1k_ft
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 192
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 6, 12, 24, 48 ]
+    WINDOW_SIZE: 16
+    PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
+TRAIN:
+  EPOCHS: 30
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 1e-8
+  BASE_LR: 2e-05
+  WARMUP_LR: 2e-08
   MIN_LR: 2e-07

+ 20 - 20
configs/swinv2/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.yaml

@@ -1,21 +1,21 @@
-DATA:
-  IMG_SIZE: 384
-MODEL:
-  TYPE: swinv2
-  NAME: swinv2_large_patch4_window12to24_192to384_22kto1k_ft
-  DROP_PATH_RATE: 0.2
-  SWINV2:
-    EMBED_DIM: 192
-    DEPTHS: [ 2, 2, 18, 2 ]
-    NUM_HEADS: [ 6, 12, 24, 48 ]
-    WINDOW_SIZE: 24
-    PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
-TRAIN:
-  EPOCHS: 30
-  WARMUP_EPOCHS: 5
-  WEIGHT_DECAY: 1e-8
-  BASE_LR: 2e-05
-  WARMUP_LR: 2e-08
-  MIN_LR: 2e-07
-TEST:
+DATA:
+  IMG_SIZE: 384
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_large_patch4_window12to24_192to384_22kto1k_ft
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 192
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 6, 12, 24, 48 ]
+    WINDOW_SIZE: 24
+    PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
+TRAIN:
+  EPOCHS: 30
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 1e-8
+  BASE_LR: 2e-05
+  WARMUP_LR: 2e-08
+  MIN_LR: 2e-07
+TEST:
   CROP: False

+ 10 - 10
configs/swinv2/swinv2_small_patch4_window16_256.yaml

@@ -1,11 +1,11 @@
-DATA:
-  IMG_SIZE: 256
-MODEL:
-  TYPE: swinv2
-  NAME: swinv2_small_patch4_window16_256
-  DROP_PATH_RATE: 0.3
-  SWINV2:
-    EMBED_DIM: 96
-    DEPTHS: [ 2, 2, 18, 2 ]
-    NUM_HEADS: [ 3, 6, 12, 24 ]
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_small_patch4_window16_256
+  DROP_PATH_RATE: 0.3
+  SWINV2:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
     WINDOW_SIZE: 16

+ 10 - 10
configs/swinv2/swinv2_small_patch4_window8_256.yaml

@@ -1,11 +1,11 @@
-DATA:
-  IMG_SIZE: 256
-MODEL:
-  TYPE: swinv2
-  NAME: swinv2_small_patch4_window8_256
-  DROP_PATH_RATE: 0.3
-  SWINV2:
-    EMBED_DIM: 96
-    DEPTHS: [ 2, 2, 18, 2 ]
-    NUM_HEADS: [ 3, 6, 12, 24 ]
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_small_patch4_window8_256
+  DROP_PATH_RATE: 0.3
+  SWINV2:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
     WINDOW_SIZE: 8

+ 10 - 10
configs/swinv2/swinv2_tiny_patch4_window16_256.yaml

@@ -1,11 +1,11 @@
-DATA:
-  IMG_SIZE: 256
-MODEL:
-  TYPE: swinv2
-  NAME: swinv2_tiny_patch4_window16_256
-  DROP_PATH_RATE: 0.2
-  SWINV2:
-    EMBED_DIM: 96
-    DEPTHS: [ 2, 2, 6, 2 ]
-    NUM_HEADS: [ 3, 6, 12, 24 ]
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_tiny_patch4_window16_256
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 6, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
     WINDOW_SIZE: 16

+ 10 - 10
configs/swinv2/swinv2_tiny_patch4_window8_256.yaml

@@ -1,11 +1,11 @@
-DATA:
-  IMG_SIZE: 256
-MODEL:
-  TYPE: swinv2
-  NAME: swinv2_tiny_patch4_window8_256
-  DROP_PATH_RATE: 0.2
-  SWINV2:
-    EMBED_DIM: 96
-    DEPTHS: [ 2, 2, 6, 2 ]
-    NUM_HEADS: [ 3, 6, 12, 24 ]
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_tiny_patch4_window8_256
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 6, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
     WINDOW_SIZE: 8

+ 25 - 25
lib/SwinTransformer/SUPPORT.md

@@ -1,25 +1,25 @@
-# TODO: The maintainer of this repo has not yet edited this file
-
-**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
-
-- **No CSS support:** Fill out this template with information about how to file issues and get help.
-- **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport).
-- **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide.
-
-*Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
-
-# Support
-
-## How to file issues and get help  
-
-This project uses GitHub Issues to track bugs and feature requests. Please search the existing 
-issues before filing new issues to avoid duplicates.  For new issues, file your bug or 
-feature request as a new Issue.
-
-For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 
-FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
-CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
-
-## Microsoft Support Policy  
-
-Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
+# TODO: The maintainer of this repo has not yet edited this file
+
+**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
+
+- **No CSS support:** Fill out this template with information about how to file issues and get help.
+- **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport).
+- **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide.
+
+*Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
+
+# Support
+
+## How to file issues and get help  
+
+This project uses GitHub Issues to track bugs and feature requests. Please search the existing 
+issues before filing new issues to avoid duplicates.  For new issues, file your bug or 
+feature request as a new Issue.
+
+For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 
+FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
+CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
+
+## Microsoft Support Policy  
+
+Support for this **PROJECT or PRODUCT** is limited to the resources listed above.

+ 3 - 3
lib/SwinTransformer/__init__.py

@@ -1,3 +1,3 @@
-from .models.swin_transformer_v2 import SwinTransformerV2, SwinTransformerBlock
-
-__all__ = ["SwinTransformerV2", "SwinTransformerBlock"]
+from .models.swin_transformer_v2 import SwinTransformerV2, SwinTransformerBlock
+
+__all__ = ["SwinTransformerV2", "SwinTransformerBlock"]

+ 315 - 121
lib/SwinTransformer/models/swin_transformer_v2.py

@@ -14,7 +14,14 @@ import numpy as np
 
 
 class Mlp(nn.Module):
-    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+    def __init__(
+        self,
+        in_features,
+        hidden_features=None,
+        out_features=None,
+        act_layer=nn.GELU,
+        drop=0.0,
+    ):
         super().__init__()
         out_features = out_features or in_features
         hidden_features = hidden_features or in_features
@@ -43,7 +50,9 @@ def window_partition(x, window_size):
     """
     B, H, W, C = x.shape
     x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
-    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+    windows = (
+        x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+    )
     return windows
 
 
@@ -59,13 +68,15 @@ def window_reverse(windows, window_size, H, W):
         x: (B, H, W, C)
     """
     B = int(windows.shape[0] / (H * W / window_size / window_size))
-    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+    x = windows.view(
+        B, H // window_size, W // window_size, window_size, window_size, -1
+    )
     x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
     return x
 
 
 class WindowAttention(nn.Module):
-    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+    r"""Window based multi-head self attention (W-MSA) module with relative position bias.
     It supports both of shifted and non-shifted window.
 
     Args:
@@ -78,8 +89,16 @@ class WindowAttention(nn.Module):
         pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
     """
 
-    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
-                 pretrained_window_size=(0, 0)):
+    def __init__(
+        self,
+        dim,
+        window_size,
+        num_heads,
+        qkv_bias=True,
+        attn_drop=0.0,
+        proj_drop=0.0,
+        pretrained_window_size=(0, 0),
+    ):
 
         super().__init__()
         self.dim = dim
@@ -87,28 +106,42 @@ class WindowAttention(nn.Module):
         self.pretrained_window_size = pretrained_window_size
         self.num_heads = num_heads
 
-        self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
+        self.logit_scale = nn.Parameter(
+            torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
+        )
 
         # mlp to generate continuous relative position bias
-        self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
-                                     nn.ReLU(inplace=True),
-                                     nn.Linear(512, num_heads, bias=False))
+        self.cpb_mlp = nn.Sequential(
+            nn.Linear(2, 512, bias=True),
+            nn.ReLU(inplace=True),
+            nn.Linear(512, num_heads, bias=False),
+        )
 
         # get relative_coords_table
-        relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
-        relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
-        relative_coords_table = torch.stack(
-            torch.meshgrid([relative_coords_h,
-                            relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0)  # 1, 2*Wh-1, 2*Ww-1, 2
+        relative_coords_h = torch.arange(
+            -(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32
+        )
+        relative_coords_w = torch.arange(
+            -(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32
+        )
+        relative_coords_table = (
+            torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
+            .permute(1, 2, 0)
+            .contiguous()
+            .unsqueeze(0)
+        )  # 1, 2*Wh-1, 2*Ww-1, 2
         if pretrained_window_size[0] > 0:
-            relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
-            relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
+            relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1
+            relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1
         else:
-            relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
-            relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
+            relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
+            relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
         relative_coords_table *= 8  # normalize to -8, 8
-        relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
-            torch.abs(relative_coords_table) + 1.0) / np.log2(8)
+        relative_coords_table = (
+            torch.sign(relative_coords_table)
+            * torch.log2(torch.abs(relative_coords_table) + 1.0)
+            / np.log2(8)
+        )
 
         self.register_buffer("relative_coords_table", relative_coords_table)
 
@@ -117,8 +150,12 @@ class WindowAttention(nn.Module):
         coords_w = torch.arange(self.window_size[1])
         coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
         coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
-        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
-        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+        relative_coords = (
+            coords_flatten[:, :, None] - coords_flatten[:, None, :]
+        )  # 2, Wh*Ww, Wh*Ww
+        relative_coords = relative_coords.permute(
+            1, 2, 0
+        ).contiguous()  # Wh*Ww, Wh*Ww, 2
         relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
         relative_coords[:, :, 1] += self.window_size[1] - 1
         relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
@@ -146,27 +183,50 @@ class WindowAttention(nn.Module):
         B_, N, C = x.shape
         qkv_bias = None
         if self.q_bias is not None:
-            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
+            qkv_bias = torch.cat(
+                (
+                    self.q_bias,
+                    torch.zeros_like(self.v_bias, requires_grad=False),
+                    self.v_bias,
+                )
+            )
         qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
         qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
-        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
+        q, k, v = (
+            qkv[0],
+            qkv[1],
+            qkv[2],
+        )  # make torchscript happy (cannot use tensor as tuple)
 
         # cosine attention
-        attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
-        logit_scale = torch.clamp(self.logit_scale,
-                                  max=torch.log(torch.tensor(1. / 0.01, device=self.logit_scale.device))).exp()
+        attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
+        logit_scale = torch.clamp(
+            self.logit_scale,
+            max=torch.log(torch.tensor(1.0 / 0.01, device=self.logit_scale.device)),
+        ).exp()
         attn = attn * logit_scale
 
-        relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
-        relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
-            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
-        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+        relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(
+            -1, self.num_heads
+        )
+        relative_position_bias = relative_position_bias_table[
+            self.relative_position_index.view(-1)
+        ].view(
+            self.window_size[0] * self.window_size[1],
+            self.window_size[0] * self.window_size[1],
+            -1,
+        )  # Wh*Ww,Wh*Ww,nH
+        relative_position_bias = relative_position_bias.permute(
+            2, 0, 1
+        ).contiguous()  # nH, Wh*Ww, Wh*Ww
         relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
         attn = attn + relative_position_bias.unsqueeze(0)
 
         if mask is not None:
             nW = mask.shape[0]
-            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
+                1
+            ).unsqueeze(0)
             attn = attn.view(-1, self.num_heads, N, N)
             attn = self.softmax(attn)
         else:
@@ -180,8 +240,10 @@ class WindowAttention(nn.Module):
         return x
 
     def extra_repr(self) -> str:
-        return f'dim={self.dim}, window_size={self.window_size}, ' \
-               f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
+        return (
+            f"dim={self.dim}, window_size={self.window_size}, "
+            f"pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}"
+        )
 
     def flops(self, N):
         # calculate flops for 1 window with token length of N
@@ -198,7 +260,7 @@ class WindowAttention(nn.Module):
 
 
 class SwinTransformerBlock(nn.Module):
-    r""" Swin Transformer Block.
+    r"""Swin Transformer Block.
 
     Args:
         dim (int): Number of input channels.
@@ -216,9 +278,22 @@ class SwinTransformerBlock(nn.Module):
         pretrained_window_size (int): Window size in pre-training.
     """
 
-    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
-                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
-                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
+    def __init__(
+        self,
+        dim,
+        input_resolution,
+        num_heads,
+        window_size=7,
+        shift_size=0,
+        mlp_ratio=4.0,
+        qkv_bias=True,
+        drop=0.0,
+        attn_drop=0.0,
+        drop_path=0.0,
+        act_layer=nn.GELU,
+        norm_layer=nn.LayerNorm,
+        pretrained_window_size=0,
+    ):
         super().__init__()
         self.dim = dim
         self.input_resolution = input_resolution
@@ -230,39 +305,59 @@ class SwinTransformerBlock(nn.Module):
             # if window size is larger than input resolution, we don't partition windows
             self.shift_size = 0
             self.window_size = min(self.input_resolution)
-        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+        assert (
+            0 <= self.shift_size < self.window_size
+        ), "shift_size must in 0-window_size"
 
         self.norm1 = norm_layer(dim)
         self.attn = WindowAttention(
-            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
-            qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
-            pretrained_window_size=to_2tuple(pretrained_window_size))
-
-        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+            dim,
+            window_size=to_2tuple(self.window_size),
+            num_heads=num_heads,
+            qkv_bias=qkv_bias,
+            attn_drop=attn_drop,
+            proj_drop=drop,
+            pretrained_window_size=to_2tuple(pretrained_window_size),
+        )
+
+        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
         self.norm2 = norm_layer(dim)
         mlp_hidden_dim = int(dim * mlp_ratio)
-        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+        self.mlp = Mlp(
+            in_features=dim,
+            hidden_features=mlp_hidden_dim,
+            act_layer=act_layer,
+            drop=drop,
+        )
 
         if self.shift_size > 0:
             # calculate attention mask for SW-MSA
             H, W = self.input_resolution
             img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
-            h_slices = (slice(0, -self.window_size),
-                        slice(-self.window_size, -self.shift_size),
-                        slice(-self.shift_size, None))
-            w_slices = (slice(0, -self.window_size),
-                        slice(-self.window_size, -self.shift_size),
-                        slice(-self.shift_size, None))
+            h_slices = (
+                slice(0, -self.window_size),
+                slice(-self.window_size, -self.shift_size),
+                slice(-self.shift_size, None),
+            )
+            w_slices = (
+                slice(0, -self.window_size),
+                slice(-self.window_size, -self.shift_size),
+                slice(-self.shift_size, None),
+            )
             cnt = 0
             for h in h_slices:
                 for w in w_slices:
                     img_mask[:, h, w, :] = cnt
                     cnt += 1
 
-            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
+            mask_windows = window_partition(
+                img_mask, self.window_size
+            )  # nW, window_size, window_size, 1
             mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
             attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
-            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+            attn_mask = attn_mask.masked_fill(
+                attn_mask != 0, float(-100.0)
+            ).masked_fill(attn_mask == 0, float(0.0))
         else:
             attn_mask = None
 
@@ -278,16 +373,24 @@ class SwinTransformerBlock(nn.Module):
 
         # cyclic shift
         if self.shift_size > 0:
-            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+            shifted_x = torch.roll(
+                x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
+            )
         else:
             shifted_x = x
 
         # partition windows
-        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
-        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C
+        x_windows = window_partition(
+            shifted_x, self.window_size
+        )  # nW*B, window_size, window_size, C
+        x_windows = x_windows.view(
+            -1, self.window_size * self.window_size, C
+        )  # nW*B, window_size*window_size, C
 
         # W-MSA/SW-MSA
-        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C
+        attn_windows = self.attn(
+            x_windows, mask=self.attn_mask
+        )  # nW*B, window_size*window_size, C
 
         # merge windows
         attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
@@ -295,7 +398,9 @@ class SwinTransformerBlock(nn.Module):
 
         # reverse cyclic shift
         if self.shift_size > 0:
-            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+            x = torch.roll(
+                shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
+            )
         else:
             x = shifted_x
         x = x.view(B, H * W, C)
@@ -307,8 +412,10 @@ class SwinTransformerBlock(nn.Module):
         return x
 
     def extra_repr(self) -> str:
-        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
-               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+        return (
+            f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
+            f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+        )
 
     def flops(self):
         flops = 0
@@ -326,7 +433,7 @@ class SwinTransformerBlock(nn.Module):
 
 
 class PatchMerging(nn.Module):
-    r""" Patch Merging Layer.
+    r"""Patch Merging Layer.
 
     Args:
         input_resolution (tuple[int]): Resolution of input feature.
@@ -375,7 +482,7 @@ class PatchMerging(nn.Module):
 
 
 class BasicLayer(nn.Module):
-    """ A basic Swin Transformer layer for one stage.
+    """A basic Swin Transformer layer for one stage.
 
     Args:
         dim (int): Number of input channels.
@@ -394,10 +501,23 @@ class BasicLayer(nn.Module):
         pretrained_window_size (int): Local window size in pre-training.
     """
 
-    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
-                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
-                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
-                 pretrained_window_size=0):
+    def __init__(
+        self,
+        dim,
+        input_resolution,
+        depth,
+        num_heads,
+        window_size,
+        mlp_ratio=4.0,
+        qkv_bias=True,
+        drop=0.0,
+        attn_drop=0.0,
+        drop_path=0.0,
+        norm_layer=nn.LayerNorm,
+        downsample=None,
+        use_checkpoint=False,
+        pretrained_window_size=0,
+    ):
 
         super().__init__()
         self.dim = dim
@@ -406,21 +526,33 @@ class BasicLayer(nn.Module):
         self.use_checkpoint = use_checkpoint
 
         # build blocks
-        self.blocks = nn.ModuleList([
-            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
-                                 num_heads=num_heads, window_size=window_size,
-                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
-                                 mlp_ratio=mlp_ratio,
-                                 qkv_bias=qkv_bias,
-                                 drop=drop, attn_drop=attn_drop,
-                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
-                                 norm_layer=norm_layer,
-                                 pretrained_window_size=pretrained_window_size)
-            for i in range(depth)])
+        self.blocks = nn.ModuleList(
+            [
+                SwinTransformerBlock(
+                    dim=dim,
+                    input_resolution=input_resolution,
+                    num_heads=num_heads,
+                    window_size=window_size,
+                    shift_size=0 if (i % 2 == 0) else window_size // 2,
+                    mlp_ratio=mlp_ratio,
+                    qkv_bias=qkv_bias,
+                    drop=drop,
+                    attn_drop=attn_drop,
+                    drop_path=(
+                        drop_path[i] if isinstance(drop_path, list) else drop_path
+                    ),
+                    norm_layer=norm_layer,
+                    pretrained_window_size=pretrained_window_size,
+                )
+                for i in range(depth)
+            ]
+        )
 
         # patch merging layer
         if downsample is not None:
-            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+            self.downsample = downsample(
+                input_resolution, dim=dim, norm_layer=norm_layer
+            )
         else:
             self.downsample = None
 
@@ -454,7 +586,7 @@ class BasicLayer(nn.Module):
 
 
 class PatchEmbed(nn.Module):
-    r""" Image to Patch Embedding
+    r"""Image to Patch Embedding
 
     Args:
         img_size (int): Image size.  Default: 224.
@@ -464,11 +596,16 @@ class PatchEmbed(nn.Module):
         norm_layer (nn.Module, optional): Normalization layer. Default: None
     """
 
-    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+    def __init__(
+        self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
+    ):
         super().__init__()
         img_size = to_2tuple(img_size)
         patch_size = to_2tuple(patch_size)
-        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+        patches_resolution = [
+            img_size[0] // patch_size[0],
+            img_size[1] // patch_size[1],
+        ]
         self.img_size = img_size
         self.patch_size = patch_size
         self.patches_resolution = patches_resolution
@@ -477,7 +614,9 @@ class PatchEmbed(nn.Module):
         self.in_chans = in_chans
         self.embed_dim = embed_dim
 
-        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+        self.proj = nn.Conv2d(
+            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
+        )
         if norm_layer is not None:
             self.norm = norm_layer(embed_dim)
         else:
@@ -486,8 +625,9 @@ class PatchEmbed(nn.Module):
     def forward(self, x):
         B, C, H, W = x.shape
         # FIXME look at relaxing size constraints
-        assert H == self.img_size[0] and W == self.img_size[1], \
-            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+        assert (
+            H == self.img_size[0] and W == self.img_size[1]
+        ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
         x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
         if self.norm is not None:
             x = self.norm(x)
@@ -495,14 +635,20 @@ class PatchEmbed(nn.Module):
 
     def flops(self):
         Ho, Wo = self.patches_resolution
-        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+        flops = (
+            Ho
+            * Wo
+            * self.embed_dim
+            * self.in_chans
+            * (self.patch_size[0] * self.patch_size[1])
+        )
         if self.norm is not None:
             flops += Ho * Wo * self.embed_dim
         return flops
 
 
 class SwinTransformerV2(nn.Module):
-    r""" Swin Transformer
+    r"""Swin Transformer
         A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
           https://arxiv.org/pdf/2103.14030
 
@@ -527,12 +673,28 @@ class SwinTransformerV2(nn.Module):
         pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer.
     """
 
-    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
-                 embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
-                 window_size=7, mlp_ratio=4., qkv_bias=True,
-                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
-                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
-                 use_checkpoint=False, pretrained_window_sizes=(0, 0, 0, 0), **kwargs):
+    def __init__(
+        self,
+        img_size=224,
+        patch_size=4,
+        in_chans=3,
+        num_classes=1000,
+        embed_dim=96,
+        depths=(2, 2, 6, 2),
+        num_heads=(3, 6, 12, 24),
+        window_size=7,
+        mlp_ratio=4.0,
+        qkv_bias=True,
+        drop_rate=0.0,
+        attn_drop_rate=0.0,
+        drop_path_rate=0.1,
+        norm_layer=nn.LayerNorm,
+        ape=False,
+        patch_norm=True,
+        use_checkpoint=False,
+        pretrained_window_sizes=(0, 0, 0, 0),
+        **kwargs,
+    ):
         super().__init__()
 
         self.num_classes = num_classes
@@ -545,8 +707,12 @@ class SwinTransformerV2(nn.Module):
 
         # split image into non-overlapping patches
         self.patch_embed = PatchEmbed(
-            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
-            norm_layer=norm_layer if self.patch_norm else None)
+            img_size=img_size,
+            patch_size=patch_size,
+            in_chans=in_chans,
+            embed_dim=embed_dim,
+            norm_layer=norm_layer if self.patch_norm else None,
+        )
         num_patches = self.patch_embed.num_patches
         patches_resolution = self.patch_embed.patches_resolution
         self.patches_resolution = patches_resolution
@@ -554,36 +720,49 @@ class SwinTransformerV2(nn.Module):
         # absolute position embedding
         if self.ape:
             # noinspection PyTypeChecker
-            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
-            trunc_normal_(self.absolute_pos_embed, std=.02)
+            self.absolute_pos_embed = nn.Parameter(
+                torch.zeros(1, num_patches, embed_dim)
+            )
+            trunc_normal_(self.absolute_pos_embed, std=0.02)
 
         self.pos_drop = nn.Dropout(p=drop_rate)
 
         # stochastic depth
-        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
+        dpr = [
+            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
+        ]  # stochastic depth decay rule
 
         # build layers
         self.layers = nn.ModuleList()
         for i_layer in range(self.num_layers):
-            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
-                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
-                                                 patches_resolution[1] // (2 ** i_layer)),
-                               depth=depths[i_layer],
-                               num_heads=num_heads[i_layer],
-                               window_size=window_size,
-                               mlp_ratio=self.mlp_ratio,
-                               qkv_bias=qkv_bias,
-                               drop=drop_rate, attn_drop=attn_drop_rate,
-                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
-                               norm_layer=norm_layer,
-                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
-                               use_checkpoint=use_checkpoint,
-                               pretrained_window_size=pretrained_window_sizes[i_layer])
+            layer = BasicLayer(
+                dim=int(embed_dim * 2**i_layer),
+                input_resolution=(
+                    patches_resolution[0] // (2**i_layer),
+                    patches_resolution[1] // (2**i_layer),
+                ),
+                depth=depths[i_layer],
+                num_heads=num_heads[i_layer],
+                window_size=window_size,
+                mlp_ratio=self.mlp_ratio,
+                qkv_bias=qkv_bias,
+                drop=drop_rate,
+                attn_drop=attn_drop_rate,
+                drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
+                norm_layer=norm_layer,
+                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+                use_checkpoint=use_checkpoint,
+                pretrained_window_size=pretrained_window_sizes[i_layer],
+            )
             self.layers.append(layer)
 
         self.norm = norm_layer(self.num_features)
         self.avgpool = nn.AdaptiveAvgPool1d(1)
-        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+        self.head = (
+            nn.Linear(self.num_features, num_classes)
+            if num_classes > 0
+            else nn.Identity()
+        )
 
         self.apply(self._init_weights)
         for bly in self.layers:
@@ -601,7 +780,11 @@ class SwinTransformerV2(nn.Module):
     @staticmethod
     def proj_out(x, normalize=False):
         if normalize:
-            x = F.layer_norm(x.permute(0, 2, 3, 1), [x.shape[1]]).permute(0, 3, 1, 2).contiguous()
+            x = (
+                F.layer_norm(x.permute(0, 2, 3, 1), [x.shape[1]])
+                .permute(0, 3, 1, 2)
+                .contiguous()
+            )
         return x
 
     def forward_multiscale_features(self, x, normalize=True, include_patch_embed=True):
@@ -626,13 +809,17 @@ class SwinTransformerV2(nn.Module):
         resolution = tuple(self.patches_resolution)
 
         if include_patch_embed:
-            features.append(self.proj_out(self._tokens_to_feature_map(x, resolution), normalize))
+            features.append(
+                self.proj_out(self._tokens_to_feature_map(x, resolution), normalize)
+            )
 
         for layer in self.layers:
             x = layer(x)
             if layer.downsample is not None:
                 resolution = (resolution[0] // 2, resolution[1] // 2)
-            features.append(self.proj_out(self._tokens_to_feature_map(x, resolution), normalize))
+            features.append(
+                self.proj_out(self._tokens_to_feature_map(x, resolution), normalize)
+            )
 
         return features
 
@@ -657,7 +844,9 @@ class SwinTransformerV2(nn.Module):
                 else:
                     x = blk(x)
 
-            features.append(self.proj_out(self._tokens_to_feature_map(x, resolution), normalize))
+            features.append(
+                self.proj_out(self._tokens_to_feature_map(x, resolution), normalize)
+            )
 
             if layer.downsample is not None:
                 x = layer.downsample(x)
@@ -668,7 +857,7 @@ class SwinTransformerV2(nn.Module):
     @staticmethod
     def _init_weights(m):
         if isinstance(m, nn.Linear):
-            trunc_normal_(m.weight, std=.02)
+            trunc_normal_(m.weight, std=0.02)
             if isinstance(m, nn.Linear) and m.bias is not None:
                 nn.init.constant_(m.bias, 0)
         elif isinstance(m, nn.LayerNorm):
@@ -677,11 +866,11 @@ class SwinTransformerV2(nn.Module):
 
     @torch.jit.ignore
     def no_weight_decay(self):
-        return {'absolute_pos_embed'}
+        return {"absolute_pos_embed"}
 
     @torch.jit.ignore
     def no_weight_decay_keywords(self):
-        return {"cpb_mlp", "logit_scale", 'relative_position_bias_table'}
+        return {"cpb_mlp", "logit_scale", "relative_position_bias_table"}
 
     def forward_features(self, x):
         x = self.patch_embed(x)
@@ -707,6 +896,11 @@ class SwinTransformerV2(nn.Module):
         flops += self.patch_embed.flops()
         for i, layer in enumerate(self.layers):
             flops += layer.flops()
-        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
+        flops += (
+            self.num_features
+            * self.patches_resolution[0]
+            * self.patches_resolution[1]
+            // (2**self.num_layers)
+        )
         flops += self.num_features * self.num_classes
         return flops

+ 132 - 132
lib/modules/attentions_2d.py

@@ -1,132 +1,132 @@
-"""
-Circulant Attention 2D.
-
-核心思想: 自注意力矩阵近似 BC CB 结构,通过 2D FFT 在 O(N log N) 时间内计算。
-"""
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from typing import Literal
-
-try:
-    import ptwt
-except ImportError as exc:
-    raise ImportError(
-        "wavelet_fft requires ptwt. Install it before importing this package."
-    ) from exc
-
-from .layers_2d import Scale
-
-
-class ComplexLinear(nn.Linear):
-    def __init__(self, in_features, out_features, device=None, dtype=None):
-        super().__init__(in_features, out_features, bias=False, device=device, dtype=dtype)
-
-    def forward(self, inp):
-        x = torch.view_as_real(inp).transpose(-2, -1)
-        x = F.linear(x, self.weight).transpose(-2, -1)
-        if x.dtype != torch.float32:
-            x = x.to(torch.float32)
-        return torch.view_as_complex(x.contiguous())
-
-
-class CirculantAttention2d(nn.Module):
-    def __init__(self, dim, proj_drop=0.0):
-        super().__init__()
-        self.qkv = ComplexLinear(dim, dim * 3)
-        self.gate = nn.Sequential(nn.Linear(dim, dim), nn.SiLU())
-        self.proj = nn.Linear(dim, dim)
-        self.proj_drop = nn.Dropout(proj_drop)
-
-    def forward(self, x):
-        b, c, h, w = x.shape
-        spatial_perm = [0, 2, 3, 1]
-        spatial = x.permute(spatial_perm).contiguous()
-        gate = self.gate(spatial.reshape(b, h * w, c)).reshape(b, h, w, c)
-        freq = torch.fft.rfft2(spatial, dim=(1, 2), norm="ortho")
-        qkv = self.qkv(freq)
-        q, k, v = torch.chunk(qkv, chunks=3, dim=-1)
-        attn = torch.conj(q) * k
-        attn = torch.fft.irfft2(attn, s=(h, w), dim=(1, 2), norm="ortho")
-        attn = attn.reshape(b, h * w, c).softmax(dim=1).reshape(b, h, w, c)
-        attn = torch.fft.rfft2(attn, dim=(1, 2))
-        out = torch.conj(attn) * v
-        out = torch.fft.irfft2(out, s=(h, w), dim=(1, 2), norm="ortho")
-        out = out.reshape(b, h * w, c) * gate.reshape(b, h * w, c)
-        out = self.proj_drop(self.proj(out))
-        return out.transpose(1, 2).reshape(b, c, h, w)
-
-
-class WaveletAttentionGlobalBranch2d(nn.Module):
-    def __init__(
-            self, in_channels, kernel_size=5, stride=1, wt_levels=1,
-            wt_type="db1", wt_mode: Literal["constant", "zero", "reflect", "periodic", "symmetric"] = "zero",
-            proj_drop=0.0,
-    ):
-        super().__init__()
-        if in_channels <= 0:
-            raise ValueError("in_channels must be positive.")
-
-        self.in_channels = in_channels
-        self.wt_levels = wt_levels
-        self.stride = stride
-
-        self.wavelet = wt_type
-        self.wt_mode = wt_mode
-
-        self.global_attn = CirculantAttention2d(in_channels, proj_drop=proj_drop)
-        self.base_scale = Scale((1, in_channels, 1, 1))
-
-        self.wavelet_convs = nn.ModuleList([
-            nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, 1,
-                      kernel_size // 2, groups=in_channels * 4, bias=False)
-            for _ in range(wt_levels)
-        ])
-        self.wavelet_scale = nn.ModuleList([
-            Scale((1, in_channels * 4, 1, 1), init_scale=0.1)
-            for _ in range(wt_levels)
-        ])
-
-        if stride > 1:
-            self.register_buffer("stride_filter", torch.ones(in_channels, 1, 1, 1), persistent=False)
-        else:
-            self.stride_filter = None
-
-    def forward(self, x):
-        low_levels, high_levels, shapes_in_levels = [], [], []
-        curr_low = x
-
-        for level in range(self.wt_levels):
-            shapes_in_levels.append(curr_low.shape[-2:])
-            coeffs = ptwt.wavedec2(curr_low, self.wavelet, mode=self.wt_mode, level=1)
-            low = coeffs[0]
-            detail = coeffs[1]
-            high = torch.stack([detail.horizontal, detail.vertical, detail.diagonal], dim=2)
-            bands = torch.cat([low.unsqueeze(2), high], dim=2)
-            b, c, _, h_half, w_half = bands.shape
-            bands = bands.reshape(b, c * 4, h_half, w_half)
-            bands = self.wavelet_scale[level](self.wavelet_convs[level](bands))
-            bands = bands.reshape(b, c, 4, h_half, w_half)
-            low_levels.append(bands[:, :, 0, :, :])
-            high_levels.append(bands[:, :, 1:4, :, :])
-            curr_low = low
-
-        wavelet_out = x
-        if self.wt_levels > 0:
-            next_low = None
-            for level in range(self.wt_levels - 1, -1, -1):
-                low = low_levels.pop()
-                high = high_levels.pop()
-                height, width = shapes_in_levels.pop()
-                if next_low is not None:
-                    low = low + next_low
-                cH, cV, cD = high.unbind(dim=2)
-                next_low = ptwt.waverec2((low, ptwt.constants.WaveletDetailTuple2d(cH, cV, cD)), self.wavelet)
-                next_low = next_low[:, :, :height, :width]
-            wavelet_out = next_low
-
-        out = self.base_scale(self.global_attn(x)) + wavelet_out
-        if self.stride_filter is not None:
-            out = F.conv2d(out, self.stride_filter, stride=self.stride, groups=self.in_channels)
-        return out
+"""
+Circulant Attention 2D.
+
+核心思想: 自注意力矩阵近似 BC CB 结构,通过 2D FFT 在 O(N log N) 时间内计算。
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Literal
+
+try:
+    import ptwt
+except ImportError as exc:
+    raise ImportError(
+        "wavelet_fft requires ptwt. Install it before importing this package."
+    ) from exc
+
+from .layers_2d import Scale
+
+
+class ComplexLinear(nn.Linear):
+    def __init__(self, in_features, out_features, device=None, dtype=None):
+        super().__init__(in_features, out_features, bias=False, device=device, dtype=dtype)
+
+    def forward(self, inp):
+        x = torch.view_as_real(inp).transpose(-2, -1)
+        x = F.linear(x, self.weight).transpose(-2, -1)
+        if x.dtype != torch.float32:
+            x = x.to(torch.float32)
+        return torch.view_as_complex(x.contiguous())
+
+
+class CirculantAttention2d(nn.Module):
+    def __init__(self, dim, proj_drop=0.0):
+        super().__init__()
+        self.qkv = ComplexLinear(dim, dim * 3)
+        self.gate = nn.Sequential(nn.Linear(dim, dim), nn.SiLU())
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+    def forward(self, x):
+        b, c, h, w = x.shape
+        spatial_perm = [0, 2, 3, 1]
+        spatial = x.permute(spatial_perm).contiguous()
+        gate = self.gate(spatial.reshape(b, h * w, c)).reshape(b, h, w, c)
+        freq = torch.fft.rfft2(spatial, dim=(1, 2), norm="ortho")
+        qkv = self.qkv(freq)
+        q, k, v = torch.chunk(qkv, chunks=3, dim=-1)
+        attn = torch.conj(q) * k
+        attn = torch.fft.irfft2(attn, s=(h, w), dim=(1, 2), norm="ortho")
+        attn = attn.reshape(b, h * w, c).softmax(dim=1).reshape(b, h, w, c)
+        attn = torch.fft.rfft2(attn, dim=(1, 2))
+        out = torch.conj(attn) * v
+        out = torch.fft.irfft2(out, s=(h, w), dim=(1, 2), norm="ortho")
+        out = out.reshape(b, h * w, c) * gate.reshape(b, h * w, c)
+        out = self.proj_drop(self.proj(out))
+        return out.transpose(1, 2).reshape(b, c, h, w)
+
+
+class WaveletAttentionGlobalBranch2d(nn.Module):
+    def __init__(
+            self, in_channels, kernel_size=5, stride=1, wt_levels=1,
+            wt_type="db1", wt_mode: Literal["constant", "zero", "reflect", "periodic", "symmetric"] = "zero",
+            proj_drop=0.0,
+    ):
+        super().__init__()
+        if in_channels <= 0:
+            raise ValueError("in_channels must be positive.")
+
+        self.in_channels = in_channels
+        self.wt_levels = wt_levels
+        self.stride = stride
+
+        self.wavelet = wt_type
+        self.wt_mode = wt_mode
+
+        self.global_attn = CirculantAttention2d(in_channels, proj_drop=proj_drop)
+        self.base_scale = Scale((1, in_channels, 1, 1))
+
+        self.wavelet_convs = nn.ModuleList([
+            nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, 1,
+                      kernel_size // 2, groups=in_channels * 4, bias=False)
+            for _ in range(wt_levels)
+        ])
+        self.wavelet_scale = nn.ModuleList([
+            Scale((1, in_channels * 4, 1, 1), init_scale=0.1)
+            for _ in range(wt_levels)
+        ])
+
+        if stride > 1:
+            self.register_buffer("stride_filter", torch.ones(in_channels, 1, 1, 1), persistent=False)
+        else:
+            self.stride_filter = None
+
+    def forward(self, x):
+        low_levels, high_levels, shapes_in_levels = [], [], []
+        curr_low = x
+
+        for level in range(self.wt_levels):
+            shapes_in_levels.append(curr_low.shape[-2:])
+            coeffs = ptwt.wavedec2(curr_low, self.wavelet, mode=self.wt_mode, level=1)
+            low = coeffs[0]
+            detail = coeffs[1]
+            high = torch.stack([detail.horizontal, detail.vertical, detail.diagonal], dim=2)
+            bands = torch.cat([low.unsqueeze(2), high], dim=2)
+            b, c, _, h_half, w_half = bands.shape
+            bands = bands.reshape(b, c * 4, h_half, w_half)
+            bands = self.wavelet_scale[level](self.wavelet_convs[level](bands))
+            bands = bands.reshape(b, c, 4, h_half, w_half)
+            low_levels.append(bands[:, :, 0, :, :])
+            high_levels.append(bands[:, :, 1:4, :, :])
+            curr_low = low
+
+        wavelet_out = x
+        if self.wt_levels > 0:
+            next_low = None
+            for level in range(self.wt_levels - 1, -1, -1):
+                low = low_levels.pop()
+                high = high_levels.pop()
+                height, width = shapes_in_levels.pop()
+                if next_low is not None:
+                    low = low + next_low
+                cH, cV, cD = high.unbind(dim=2)
+                next_low = ptwt.waverec2((low, ptwt.constants.WaveletDetailTuple2d(cH, cV, cD)), self.wavelet)
+                next_low = next_low[:, :, :height, :width]
+            wavelet_out = next_low
+
+        out = self.base_scale(self.global_attn(x)) + wavelet_out
+        if self.stride_filter is not None:
+            out = F.conv2d(out, self.stride_filter, stride=self.stride, groups=self.in_channels)
+        return out

+ 504 - 504
lib/modules/build_swinv2.py

@@ -1,504 +1,504 @@
-from __future__ import annotations
-
-from argparse import Namespace
-from pathlib import Path
-from types import SimpleNamespace
-from typing import Any
-
-import torch
-import yaml
-
-from lib.SwinTransformer.models.swin_transformer_v2 import SwinTransformerV2
-
-ROOT_DIR = Path(__file__).resolve().parents[2]
-SWINV2_CONFIG_DIR = ROOT_DIR / "configs" / "swinv2"
-SWINV2_WEIGHT_DIR = ROOT_DIR / "weights" / "swinv2"
-MAP22KTO1K_PATH = ROOT_DIR / "lib" / "SwinTransformer" / "data" / "map22kto1k.txt"
-
-DEFAULTS: dict[str, Any] = {
-    "DATA": {
-        "IMG_SIZE": 224,
-    },
-    "MODEL": {
-        "TYPE": "swinv2",
-        "NAME": "swinv2_tiny_patch4_window8_256",
-        "NUM_CLASSES": 1000,
-        "DROP_RATE": 0.0,
-        "DROP_PATH_RATE": 0.1,
-        "PRETRAINED": "",
-        "SWINV2": {
-            "PATCH_SIZE": 4,
-            "IN_CHANS": 3,
-            "EMBED_DIM": 96,
-            "DEPTHS": [2, 2, 6, 2],
-            "NUM_HEADS": [3, 6, 12, 24],
-            "WINDOW_SIZE": 7,
-            "MLP_RATIO": 4.0,
-            "QKV_BIAS": True,
-            "APE": False,
-            "PATCH_NORM": True,
-            "PRETRAINED_WINDOW_SIZES": [0, 0, 0, 0],
-        },
-    },
-    "TRAIN": {
-        "USE_CHECKPOINT": False,
-    },
-}
-
-
-def _deep_copy_dict(value: dict[str, Any]) -> dict[str, Any]:
-    copied: dict[str, Any] = {}
-    for key, item in value.items():
-        if isinstance(item, dict):
-            copied[key] = _deep_copy_dict(item)
-        elif isinstance(item, list):
-            copied[key] = list(item)
-        else:
-            copied[key] = item
-    return copied
-
-
-def _merge_dict(dst: dict[str, Any], src: dict[str, Any]) -> dict[str, Any]:
-    for key, value in src.items():
-        if isinstance(value, dict) and isinstance(dst.get(key), dict):
-            _merge_dict(dst[key], value)
-        else:
-            dst[key] = value
-    return dst
-
-
-def _dict_to_namespace(value: Any) -> Any:
-    if isinstance(value, dict):
-        return SimpleNamespace(**{key: _dict_to_namespace(item) for key, item in value.items()})
-    if isinstance(value, list):
-        return [_dict_to_namespace(item) for item in value]
-    return value
-
-
-def _get_arg(args: Namespace | None, *names: str) -> Any:
-    if args is None:
-        return None
-    for name in names:
-        if hasattr(args, name):
-            value = getattr(args, name)
-            if value is not None:
-                return value
-    return None
-
-
-def _to_path(value: str | Path | None) -> Path | None:
-    if value is None:
-        return None
-    return Path(value).expanduser().resolve()
-
-
-def _resolve_model_name(
-        model_name: str | None,
-        config_path: Path | None,
-        weight_path: Path | None,
-        args: Namespace | None,
-) -> str | None:
-    return (
-            _get_arg(args, "model_name", "model")
-            or model_name
-            or (weight_path.stem if weight_path is not None else None)
-            or (config_path.stem if config_path is not None else None)
-    )
-
-
-def _resolve_config_path(model_name: str | None, config_path: str | Path | None, args: Namespace | None) -> Path | None:
-    cli_cfg = _to_path(_get_arg(args, "cfg", "config", "config_path"))
-    if cli_cfg is not None:
-        return cli_cfg
-
-    explicit_cfg = _to_path(config_path)
-    if explicit_cfg is not None:
-        return explicit_cfg
-
-    if model_name is None:
-        return None
-
-    candidate = SWINV2_CONFIG_DIR / f"{model_name}.yaml"
-    return candidate if candidate.exists() else None
-
-
-def _resolve_weight_path(model_name: str | None, weight_path: str | Path | None, args: Namespace | None) -> Path | None:
-    cli_weight = _to_path(_get_arg(args, "pretrained", "weights", "weight_path", "ckpt", "checkpoint"))
-    if cli_weight is not None:
-        return cli_weight
-
-    explicit_weight = _to_path(weight_path)
-    if explicit_weight is not None:
-        return explicit_weight
-
-    if model_name is None:
-        return None
-
-    candidate = SWINV2_WEIGHT_DIR / f"{model_name}.pth"
-    return candidate if candidate.exists() else None
-
-
-def _resolve_config_with_source(
-        model_name: str | None,
-        config_path: str | Path | None,
-        args: Namespace | None,
-) -> tuple[Path | None, str]:
-    cli_cfg = _to_path(_get_arg(args, "cfg", "config", "config_path"))
-    if cli_cfg is not None:
-        return cli_cfg, "args.cfg"
-
-    explicit_cfg = _to_path(config_path)
-    if explicit_cfg is not None:
-        return explicit_cfg, "function config_path"
-
-    if model_name is None:
-        return None, "defaults only"
-
-    candidate = SWINV2_CONFIG_DIR / f"{model_name}.yaml"
-    if candidate.exists():
-        return candidate, "auto by MODEL.NAME"
-    return None, "defaults only"
-
-
-def _resolve_weight_with_source(
-        model_name: str | None,
-        weight_path: str | Path | None,
-        args: Namespace | None,
-) -> tuple[Path | None, str]:
-    cli_weight = _to_path(_get_arg(args, "pretrained", "weights", "weight_path", "ckpt", "checkpoint"))
-    if cli_weight is not None:
-        return cli_weight, "args.pretrained"
-
-    explicit_weight = _to_path(weight_path)
-    if explicit_weight is not None:
-        return explicit_weight, "function weight_path"
-
-    if model_name is None:
-        return None, "not resolved"
-
-    candidate = SWINV2_WEIGHT_DIR / f"{model_name}.pth"
-    if candidate.exists():
-        return candidate, "auto by MODEL.NAME"
-    return None, "not resolved"
-
-
-def _load_yaml_config(config_path: Path | None) -> dict[str, Any]:
-    config = _deep_copy_dict(DEFAULTS)
-    if config_path is None:
-        return config
-    if not config_path.exists():
-        raise FileNotFoundError(f"SwinV2 config not found: {config_path}")
-
-    with config_path.open("r", encoding="utf-8") as handle:
-        yaml_config = yaml.safe_load(handle) or {}
-    return _merge_dict(config, yaml_config)
-
-
-def _set_nested(config: dict[str, Any], path: tuple[str, ...], value: Any):
-    current = config
-    for key in path[:-1]:
-        current = current.setdefault(key, {})
-    current[path[-1]] = value
-
-
-def _collect_function_overrides(
-        model_name: str | None,
-        weight_path: Path | None,
-        num_classes: int | None,
-        img_size: int | None,
-        in_chans: int | None,
-        use_checkpoint: bool | None,
-        model_kwargs: dict[str, Any],
-) -> list[tuple[tuple[str, ...], Any]]:
-    overrides: list[tuple[tuple[str, ...], Any]] = []
-    if model_name is not None:
-        overrides.append((("MODEL", "NAME"), model_name))
-    if weight_path is not None:
-        overrides.append((("MODEL", "PRETRAINED"), str(weight_path)))
-    if num_classes is not None:
-        overrides.append((("MODEL", "NUM_CLASSES"), num_classes))
-    if img_size is not None:
-        overrides.append((("DATA", "IMG_SIZE"), img_size))
-    if in_chans is not None:
-        overrides.append((("MODEL", "SWINV2", "IN_CHANS"), in_chans))
-    if use_checkpoint is not None:
-        overrides.append((("TRAIN", "USE_CHECKPOINT"), use_checkpoint))
-
-    model_key_map = {
-        "patch_size": ("MODEL", "SWINV2", "PATCH_SIZE"),
-        "embed_dim": ("MODEL", "SWINV2", "EMBED_DIM"),
-        "depths": ("MODEL", "SWINV2", "DEPTHS"),
-        "num_heads": ("MODEL", "SWINV2", "NUM_HEADS"),
-        "window_size": ("MODEL", "SWINV2", "WINDOW_SIZE"),
-        "mlp_ratio": ("MODEL", "SWINV2", "MLP_RATIO"),
-        "qkv_bias": ("MODEL", "SWINV2", "QKV_BIAS"),
-        "ape": ("MODEL", "SWINV2", "APE"),
-        "patch_norm": ("MODEL", "SWINV2", "PATCH_NORM"),
-        "pretrained_window_sizes": ("MODEL", "SWINV2", "PRETRAINED_WINDOW_SIZES"),
-        "drop_rate": ("MODEL", "DROP_RATE"),
-        "drop_path_rate": ("MODEL", "DROP_PATH_RATE"),
-    }
-    for key, path in model_key_map.items():
-        if key in model_kwargs and model_kwargs[key] is not None:
-            overrides.append((path, model_kwargs[key]))
-    return overrides
-
-
-def _collect_arg_overrides(args: Namespace | None) -> list[tuple[tuple[str, ...], Any]]:
-    if args is None:
-        return []
-
-    key_map = {
-        ("model_name", "model"): ("MODEL", "NAME"),
-        ("pretrained", "weights", "weight_path", "ckpt", "checkpoint"): ("MODEL", "PRETRAINED"),
-        ("num_classes",): ("MODEL", "NUM_CLASSES"),
-        ("img_size", "image_size", "input_size"): ("DATA", "IMG_SIZE"),
-        ("in_chans", "in_channels"): ("MODEL", "SWINV2", "IN_CHANS"),
-        ("patch_size",): ("MODEL", "SWINV2", "PATCH_SIZE"),
-        ("embed_dim",): ("MODEL", "SWINV2", "EMBED_DIM"),
-        ("depths",): ("MODEL", "SWINV2", "DEPTHS"),
-        ("num_heads",): ("MODEL", "SWINV2", "NUM_HEADS"),
-        ("window_size",): ("MODEL", "SWINV2", "WINDOW_SIZE"),
-        ("mlp_ratio",): ("MODEL", "SWINV2", "MLP_RATIO"),
-        ("qkv_bias",): ("MODEL", "SWINV2", "QKV_BIAS"),
-        ("ape",): ("MODEL", "SWINV2", "APE"),
-        ("patch_norm",): ("MODEL", "SWINV2", "PATCH_NORM"),
-        ("pretrained_window_sizes",): ("MODEL", "SWINV2", "PRETRAINED_WINDOW_SIZES"),
-        ("drop_rate",): ("MODEL", "DROP_RATE"),
-        ("drop_path_rate",): ("MODEL", "DROP_PATH_RATE"),
-        ("use_checkpoint",): ("TRAIN", "USE_CHECKPOINT"),
-    }
-
-    overrides: list[tuple[tuple[str, ...], Any]] = []
-    for names, path in key_map.items():
-        value = _get_arg(args, *names)
-        if value is not None:
-            overrides.append((path, value))
-    return overrides
-
-
-def _apply_overrides(config: dict[str, Any], overrides: list[tuple[tuple[str, ...], Any]]) -> dict[str, Any]:
-    for path, value in overrides:
-        _set_nested(config, path, value)
-    return config
-
-
-def _extract_state_dict(checkpoint: Any) -> dict[str, torch.Tensor]:
-    if isinstance(checkpoint, dict):
-        for key in ("model", "state_dict"):
-            if key in checkpoint and isinstance(checkpoint[key], dict):
-                return checkpoint[key]
-        return checkpoint
-    raise TypeError(f"Unsupported checkpoint format: {type(checkpoint)!r}")
-
-
-def _remap_head_if_needed(model: SwinTransformerV2, state_dict: dict[str, torch.Tensor]):
-    if "head.bias" not in state_dict or "head.weight" not in state_dict:
-        return
-
-    ckpt_classes = state_dict["head.bias"].shape[0]
-    head_bias = getattr(model.head, "bias", None)
-    model_classes = head_bias.shape[0] if head_bias is not None else 0
-    if ckpt_classes == model_classes:
-        return
-
-    if ckpt_classes == 21841 and model_classes == 1000 and MAP22KTO1K_PATH.exists():
-        with MAP22KTO1K_PATH.open("r", encoding="utf-8") as handle:
-            indices = [int(line.strip()) for line in handle if line.strip()]
-        state_dict["head.weight"] = state_dict["head.weight"][indices, :]
-        state_dict["head.bias"] = state_dict["head.bias"][indices]
-        return
-
-    state_dict.pop("head.weight", None)
-    state_dict.pop("head.bias", None)
-
-
-def _sanitize_state_dict(model: SwinTransformerV2, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
-    if any(key.startswith("encoder.") for key in state_dict):
-        state_dict = {
-            key.replace("encoder.", "", 1): value
-            for key, value in state_dict.items()
-            if key.startswith("encoder.")
-        }
-
-    remapped: dict[str, torch.Tensor] = {}
-    for key, value in state_dict.items():
-        if "relative_position_index" in key or "relative_coords_table" in key or "attn_mask" in key:
-            continue
-        remapped[key.replace("rpe_mlp", "cpb_mlp")] = value
-
-    _remap_head_if_needed(model, remapped)
-
-    model_state = model.state_dict()
-    filtered: dict[str, torch.Tensor] = {}
-    for key, value in remapped.items():
-        if key in model_state and model_state[key].shape == value.shape:
-            filtered[key] = value
-    return filtered
-
-
-def _load_checkpoint(weight_path: Path) -> dict[str, Any]:
-    try:
-        return torch.load(weight_path, map_location="cpu")
-    except Exception as exc:
-        raise RuntimeError(f"Failed to load SwinV2 checkpoint: {weight_path}") from exc
-
-
-def build_swinv2(
-        model_name: str | None = None,
-        config_path: str | Path | None = None,
-        weight_path: str | Path | None = None,
-        args: Namespace | None = None,
-        *,
-        num_classes: int | None = None,
-        img_size: int | None = None,
-        in_chans: int | None = None,
-        use_checkpoint: bool | None = None,
-        strict: bool = False,
-        load_weights: bool = True,
-        return_config: bool = False,
-        **model_kwargs,
-):
-    """Build a SwinTransformerV2 with loaded weights.
-
-    Precedence order:
-    1. internal defaults
-    2. YAML config under ``configs/swinv2``
-    3. explicit function inputs outside config files
-    4. command line style ``args`` overrides
-    """
-
-    explicit_weight_path = _to_path(weight_path)
-    initial_model_name = _resolve_model_name(model_name, _to_path(config_path), explicit_weight_path, args)
-    resolved_config_path = _resolve_config_path(initial_model_name, config_path, args)
-    config_dict = _load_yaml_config(resolved_config_path)
-
-    resolved_model_name = _resolve_model_name(
-        initial_model_name or config_dict["MODEL"]["NAME"],
-        resolved_config_path,
-        explicit_weight_path,
-        args,
-    )
-    resolved_weight_path = _resolve_weight_path(resolved_model_name, weight_path, args)
-
-    function_overrides = _collect_function_overrides(
-        model_name=resolved_model_name,
-        weight_path=resolved_weight_path,
-        num_classes=num_classes,
-        img_size=img_size,
-        in_chans=in_chans,
-        use_checkpoint=use_checkpoint,
-        model_kwargs=model_kwargs,
-    )
-    arg_overrides = _collect_arg_overrides(args)
-    config_dict = _apply_overrides(config_dict, function_overrides)
-    config_dict = _apply_overrides(config_dict, arg_overrides)
-
-    model_cfg = config_dict["MODEL"]
-    swinv2_cfg = model_cfg["SWINV2"]
-    model = SwinTransformerV2(
-        img_size=config_dict["DATA"]["IMG_SIZE"],
-        patch_size=swinv2_cfg["PATCH_SIZE"],
-        in_chans=swinv2_cfg["IN_CHANS"],
-        num_classes=model_cfg["NUM_CLASSES"],
-        embed_dim=swinv2_cfg["EMBED_DIM"],
-        depths=tuple(swinv2_cfg["DEPTHS"]),
-        num_heads=tuple(swinv2_cfg["NUM_HEADS"]),
-        window_size=swinv2_cfg["WINDOW_SIZE"],
-        mlp_ratio=swinv2_cfg["MLP_RATIO"],
-        qkv_bias=swinv2_cfg["QKV_BIAS"],
-        drop_rate=model_cfg["DROP_RATE"],
-        drop_path_rate=model_cfg["DROP_PATH_RATE"],
-        ape=swinv2_cfg["APE"],
-        patch_norm=swinv2_cfg["PATCH_NORM"],
-        use_checkpoint=config_dict["TRAIN"]["USE_CHECKPOINT"],
-        pretrained_window_sizes=tuple(swinv2_cfg["PRETRAINED_WINDOW_SIZES"]),
-    )
-
-    if load_weights:
-        if resolved_weight_path is None:
-            raise FileNotFoundError(
-                f"No SwinV2 weight file resolved for model '{model_cfg['NAME']}'. "
-                f"Expected one under {SWINV2_WEIGHT_DIR}."
-            )
-        if not resolved_weight_path.exists():
-            raise FileNotFoundError(f"SwinV2 weight file not found: {resolved_weight_path}")
-
-        checkpoint = _load_checkpoint(resolved_weight_path)
-        state_dict = _sanitize_state_dict(model, _extract_state_dict(checkpoint))
-        model.load_state_dict(state_dict, strict=strict)
-
-    config = _dict_to_namespace(config_dict)
-    if return_config:
-        return model, config
-    return model
-
-
-def build_swinv2_auto(
-        model_name: str | None = None,
-        config_path: str | Path | None = None,
-        weight_path: str | Path | None = None,
-        args: Namespace | None = None,
-        *,
-        verbose: bool = True,
-        return_config: bool = False,
-        return_resolution: bool = False,
-        **kwargs,
-):
-    """Auto-resolve SwinV2 config and weights by ``MODEL.NAME`` and print sources.
-
-    This wrapper keeps the same precedence rules as ``build_swinv2`` while making
-    the config/weight resolution explicit for callers.
-    """
-
-    explicit_weight_path = _to_path(weight_path)
-    candidate_model_name = _resolve_model_name(model_name, _to_path(config_path), explicit_weight_path, args)
-    resolved_config_path, config_source = _resolve_config_with_source(candidate_model_name, config_path, args)
-
-    temp_config = _load_yaml_config(resolved_config_path)
-    final_model_name = _resolve_model_name(
-        candidate_model_name or temp_config["MODEL"]["NAME"],
-        resolved_config_path,
-        explicit_weight_path,
-        args,
-    ) or temp_config["MODEL"]["NAME"]
-    resolved_weight_path, weight_source = _resolve_weight_with_source(final_model_name, weight_path, args)
-
-    built = build_swinv2(
-        model_name=final_model_name,
-        config_path=resolved_config_path,
-        weight_path=resolved_weight_path,
-        args=args,
-        return_config=True,
-        **kwargs,
-    )
-    if not isinstance(built, tuple) or len(built) != 2:
-        raise RuntimeError("build_swinv2(return_config=True) must return (model, config)")
-    model, config = built
-
-    resolution = {
-        "model_name": config.MODEL.NAME,
-        "config_path": str(resolved_config_path) if resolved_config_path is not None else None,
-        "config_source": config_source,
-        "weight_path": str(resolved_weight_path) if resolved_weight_path is not None else None,
-        "weight_source": weight_source,
-    }
-
-    if verbose:
-        print(
-            "[build_swinv2_auto] "
-            f"MODEL.NAME={resolution['model_name']} | "
-            f"config={resolution['config_path']} ({resolution['config_source']}) | "
-            f"weight={resolution['weight_path']} ({resolution['weight_source']})"
-        )
-
-    if return_config and return_resolution:
-        return model, config, resolution
-    if return_config:
-        return model, config
-    if return_resolution:
-        return model, resolution
-    return model
-
-
-__all__ = ["build_swinv2", "build_swinv2_auto"]
+from __future__ import annotations
+
+from argparse import Namespace
+from pathlib import Path
+from types import SimpleNamespace
+from typing import Any
+
+import torch
+import yaml
+
+from lib.SwinTransformer.models.swin_transformer_v2 import SwinTransformerV2
+
+ROOT_DIR = Path(__file__).resolve().parents[2]
+SWINV2_CONFIG_DIR = ROOT_DIR / "configs" / "swinv2"
+SWINV2_WEIGHT_DIR = ROOT_DIR / "weights" / "swinv2"
+MAP22KTO1K_PATH = ROOT_DIR / "lib" / "SwinTransformer" / "data" / "map22kto1k.txt"
+
+DEFAULTS: dict[str, Any] = {
+    "DATA": {
+        "IMG_SIZE": 224,
+    },
+    "MODEL": {
+        "TYPE": "swinv2",
+        "NAME": "swinv2_tiny_patch4_window8_256",
+        "NUM_CLASSES": 1000,
+        "DROP_RATE": 0.0,
+        "DROP_PATH_RATE": 0.1,
+        "PRETRAINED": "",
+        "SWINV2": {
+            "PATCH_SIZE": 4,
+            "IN_CHANS": 3,
+            "EMBED_DIM": 96,
+            "DEPTHS": [2, 2, 6, 2],
+            "NUM_HEADS": [3, 6, 12, 24],
+            "WINDOW_SIZE": 7,
+            "MLP_RATIO": 4.0,
+            "QKV_BIAS": True,
+            "APE": False,
+            "PATCH_NORM": True,
+            "PRETRAINED_WINDOW_SIZES": [0, 0, 0, 0],
+        },
+    },
+    "TRAIN": {
+        "USE_CHECKPOINT": False,
+    },
+}
+
+
+def _deep_copy_dict(value: dict[str, Any]) -> dict[str, Any]:
+    copied: dict[str, Any] = {}
+    for key, item in value.items():
+        if isinstance(item, dict):
+            copied[key] = _deep_copy_dict(item)
+        elif isinstance(item, list):
+            copied[key] = list(item)
+        else:
+            copied[key] = item
+    return copied
+
+
+def _merge_dict(dst: dict[str, Any], src: dict[str, Any]) -> dict[str, Any]:
+    for key, value in src.items():
+        if isinstance(value, dict) and isinstance(dst.get(key), dict):
+            _merge_dict(dst[key], value)
+        else:
+            dst[key] = value
+    return dst
+
+
+def _dict_to_namespace(value: Any) -> Any:
+    if isinstance(value, dict):
+        return SimpleNamespace(**{key: _dict_to_namespace(item) for key, item in value.items()})
+    if isinstance(value, list):
+        return [_dict_to_namespace(item) for item in value]
+    return value
+
+
+def _get_arg(args: Namespace | None, *names: str) -> Any:
+    if args is None:
+        return None
+    for name in names:
+        if hasattr(args, name):
+            value = getattr(args, name)
+            if value is not None:
+                return value
+    return None
+
+
+def _to_path(value: str | Path | None) -> Path | None:
+    if value is None:
+        return None
+    return Path(value).expanduser().resolve()
+
+
+def _resolve_model_name(
+        model_name: str | None,
+        config_path: Path | None,
+        weight_path: Path | None,
+        args: Namespace | None,
+) -> str | None:
+    return (
+            _get_arg(args, "model_name", "model")
+            or model_name
+            or (weight_path.stem if weight_path is not None else None)
+            or (config_path.stem if config_path is not None else None)
+    )
+
+
+def _resolve_config_path(model_name: str | None, config_path: str | Path | None, args: Namespace | None) -> Path | None:
+    cli_cfg = _to_path(_get_arg(args, "cfg", "config", "config_path"))
+    if cli_cfg is not None:
+        return cli_cfg
+
+    explicit_cfg = _to_path(config_path)
+    if explicit_cfg is not None:
+        return explicit_cfg
+
+    if model_name is None:
+        return None
+
+    candidate = SWINV2_CONFIG_DIR / f"{model_name}.yaml"
+    return candidate if candidate.exists() else None
+
+
+def _resolve_weight_path(model_name: str | None, weight_path: str | Path | None, args: Namespace | None) -> Path | None:
+    cli_weight = _to_path(_get_arg(args, "pretrained", "weights", "weight_path", "ckpt", "checkpoint"))
+    if cli_weight is not None:
+        return cli_weight
+
+    explicit_weight = _to_path(weight_path)
+    if explicit_weight is not None:
+        return explicit_weight
+
+    if model_name is None:
+        return None
+
+    candidate = SWINV2_WEIGHT_DIR / f"{model_name}.pth"
+    return candidate if candidate.exists() else None
+
+
+def _resolve_config_with_source(
+        model_name: str | None,
+        config_path: str | Path | None,
+        args: Namespace | None,
+) -> tuple[Path | None, str]:
+    cli_cfg = _to_path(_get_arg(args, "cfg", "config", "config_path"))
+    if cli_cfg is not None:
+        return cli_cfg, "args.cfg"
+
+    explicit_cfg = _to_path(config_path)
+    if explicit_cfg is not None:
+        return explicit_cfg, "function config_path"
+
+    if model_name is None:
+        return None, "defaults only"
+
+    candidate = SWINV2_CONFIG_DIR / f"{model_name}.yaml"
+    if candidate.exists():
+        return candidate, "auto by MODEL.NAME"
+    return None, "defaults only"
+
+
+def _resolve_weight_with_source(
+        model_name: str | None,
+        weight_path: str | Path | None,
+        args: Namespace | None,
+) -> tuple[Path | None, str]:
+    cli_weight = _to_path(_get_arg(args, "pretrained", "weights", "weight_path", "ckpt", "checkpoint"))
+    if cli_weight is not None:
+        return cli_weight, "args.pretrained"
+
+    explicit_weight = _to_path(weight_path)
+    if explicit_weight is not None:
+        return explicit_weight, "function weight_path"
+
+    if model_name is None:
+        return None, "not resolved"
+
+    candidate = SWINV2_WEIGHT_DIR / f"{model_name}.pth"
+    if candidate.exists():
+        return candidate, "auto by MODEL.NAME"
+    return None, "not resolved"
+
+
+def _load_yaml_config(config_path: Path | None) -> dict[str, Any]:
+    config = _deep_copy_dict(DEFAULTS)
+    if config_path is None:
+        return config
+    if not config_path.exists():
+        raise FileNotFoundError(f"SwinV2 config not found: {config_path}")
+
+    with config_path.open("r", encoding="utf-8") as handle:
+        yaml_config = yaml.safe_load(handle) or {}
+    return _merge_dict(config, yaml_config)
+
+
+def _set_nested(config: dict[str, Any], path: tuple[str, ...], value: Any):
+    current = config
+    for key in path[:-1]:
+        current = current.setdefault(key, {})
+    current[path[-1]] = value
+
+
+def _collect_function_overrides(
+        model_name: str | None,
+        weight_path: Path | None,
+        num_classes: int | None,
+        img_size: int | None,
+        in_chans: int | None,
+        use_checkpoint: bool | None,
+        model_kwargs: dict[str, Any],
+) -> list[tuple[tuple[str, ...], Any]]:
+    overrides: list[tuple[tuple[str, ...], Any]] = []
+    if model_name is not None:
+        overrides.append((("MODEL", "NAME"), model_name))
+    if weight_path is not None:
+        overrides.append((("MODEL", "PRETRAINED"), str(weight_path)))
+    if num_classes is not None:
+        overrides.append((("MODEL", "NUM_CLASSES"), num_classes))
+    if img_size is not None:
+        overrides.append((("DATA", "IMG_SIZE"), img_size))
+    if in_chans is not None:
+        overrides.append((("MODEL", "SWINV2", "IN_CHANS"), in_chans))
+    if use_checkpoint is not None:
+        overrides.append((("TRAIN", "USE_CHECKPOINT"), use_checkpoint))
+
+    model_key_map = {
+        "patch_size": ("MODEL", "SWINV2", "PATCH_SIZE"),
+        "embed_dim": ("MODEL", "SWINV2", "EMBED_DIM"),
+        "depths": ("MODEL", "SWINV2", "DEPTHS"),
+        "num_heads": ("MODEL", "SWINV2", "NUM_HEADS"),
+        "window_size": ("MODEL", "SWINV2", "WINDOW_SIZE"),
+        "mlp_ratio": ("MODEL", "SWINV2", "MLP_RATIO"),
+        "qkv_bias": ("MODEL", "SWINV2", "QKV_BIAS"),
+        "ape": ("MODEL", "SWINV2", "APE"),
+        "patch_norm": ("MODEL", "SWINV2", "PATCH_NORM"),
+        "pretrained_window_sizes": ("MODEL", "SWINV2", "PRETRAINED_WINDOW_SIZES"),
+        "drop_rate": ("MODEL", "DROP_RATE"),
+        "drop_path_rate": ("MODEL", "DROP_PATH_RATE"),
+    }
+    for key, path in model_key_map.items():
+        if key in model_kwargs and model_kwargs[key] is not None:
+            overrides.append((path, model_kwargs[key]))
+    return overrides
+
+
+def _collect_arg_overrides(args: Namespace | None) -> list[tuple[tuple[str, ...], Any]]:
+    if args is None:
+        return []
+
+    key_map = {
+        ("model_name", "model"): ("MODEL", "NAME"),
+        ("pretrained", "weights", "weight_path", "ckpt", "checkpoint"): ("MODEL", "PRETRAINED"),
+        ("num_classes",): ("MODEL", "NUM_CLASSES"),
+        ("img_size", "image_size", "input_size"): ("DATA", "IMG_SIZE"),
+        ("in_chans", "in_channels"): ("MODEL", "SWINV2", "IN_CHANS"),
+        ("patch_size",): ("MODEL", "SWINV2", "PATCH_SIZE"),
+        ("embed_dim",): ("MODEL", "SWINV2", "EMBED_DIM"),
+        ("depths",): ("MODEL", "SWINV2", "DEPTHS"),
+        ("num_heads",): ("MODEL", "SWINV2", "NUM_HEADS"),
+        ("window_size",): ("MODEL", "SWINV2", "WINDOW_SIZE"),
+        ("mlp_ratio",): ("MODEL", "SWINV2", "MLP_RATIO"),
+        ("qkv_bias",): ("MODEL", "SWINV2", "QKV_BIAS"),
+        ("ape",): ("MODEL", "SWINV2", "APE"),
+        ("patch_norm",): ("MODEL", "SWINV2", "PATCH_NORM"),
+        ("pretrained_window_sizes",): ("MODEL", "SWINV2", "PRETRAINED_WINDOW_SIZES"),
+        ("drop_rate",): ("MODEL", "DROP_RATE"),
+        ("drop_path_rate",): ("MODEL", "DROP_PATH_RATE"),
+        ("use_checkpoint",): ("TRAIN", "USE_CHECKPOINT"),
+    }
+
+    overrides: list[tuple[tuple[str, ...], Any]] = []
+    for names, path in key_map.items():
+        value = _get_arg(args, *names)
+        if value is not None:
+            overrides.append((path, value))
+    return overrides
+
+
+def _apply_overrides(config: dict[str, Any], overrides: list[tuple[tuple[str, ...], Any]]) -> dict[str, Any]:
+    for path, value in overrides:
+        _set_nested(config, path, value)
+    return config
+
+
+def _extract_state_dict(checkpoint: Any) -> dict[str, torch.Tensor]:
+    if isinstance(checkpoint, dict):
+        for key in ("model", "state_dict"):
+            if key in checkpoint and isinstance(checkpoint[key], dict):
+                return checkpoint[key]
+        return checkpoint
+    raise TypeError(f"Unsupported checkpoint format: {type(checkpoint)!r}")
+
+
+def _remap_head_if_needed(model: SwinTransformerV2, state_dict: dict[str, torch.Tensor]):
+    if "head.bias" not in state_dict or "head.weight" not in state_dict:
+        return
+
+    ckpt_classes = state_dict["head.bias"].shape[0]
+    head_bias = getattr(model.head, "bias", None)
+    model_classes = head_bias.shape[0] if head_bias is not None else 0
+    if ckpt_classes == model_classes:
+        return
+
+    if ckpt_classes == 21841 and model_classes == 1000 and MAP22KTO1K_PATH.exists():
+        with MAP22KTO1K_PATH.open("r", encoding="utf-8") as handle:
+            indices = [int(line.strip()) for line in handle if line.strip()]
+        state_dict["head.weight"] = state_dict["head.weight"][indices, :]
+        state_dict["head.bias"] = state_dict["head.bias"][indices]
+        return
+
+    state_dict.pop("head.weight", None)
+    state_dict.pop("head.bias", None)
+
+
+def _sanitize_state_dict(model: SwinTransformerV2, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+    if any(key.startswith("encoder.") for key in state_dict):
+        state_dict = {
+            key.replace("encoder.", "", 1): value
+            for key, value in state_dict.items()
+            if key.startswith("encoder.")
+        }
+
+    remapped: dict[str, torch.Tensor] = {}
+    for key, value in state_dict.items():
+        if "relative_position_index" in key or "relative_coords_table" in key or "attn_mask" in key:
+            continue
+        remapped[key.replace("rpe_mlp", "cpb_mlp")] = value
+
+    _remap_head_if_needed(model, remapped)
+
+    model_state = model.state_dict()
+    filtered: dict[str, torch.Tensor] = {}
+    for key, value in remapped.items():
+        if key in model_state and model_state[key].shape == value.shape:
+            filtered[key] = value
+    return filtered
+
+
+def _load_checkpoint(weight_path: Path) -> dict[str, Any]:
+    try:
+        return torch.load(weight_path, map_location="cpu")
+    except Exception as exc:
+        raise RuntimeError(f"Failed to load SwinV2 checkpoint: {weight_path}") from exc
+
+
+def build_swinv2(
+        model_name: str | None = None,
+        config_path: str | Path | None = None,
+        weight_path: str | Path | None = None,
+        args: Namespace | None = None,
+        *,
+        num_classes: int | None = None,
+        img_size: int | None = None,
+        in_chans: int | None = None,
+        use_checkpoint: bool | None = None,
+        strict: bool = False,
+        load_weights: bool = True,
+        return_config: bool = False,
+        **model_kwargs,
+):
+    """Build a SwinTransformerV2 with loaded weights.
+
+    Precedence order:
+    1. internal defaults
+    2. YAML config under ``configs/swinv2``
+    3. explicit function inputs outside config files
+    4. command line style ``args`` overrides
+    """
+
+    explicit_weight_path = _to_path(weight_path)
+    initial_model_name = _resolve_model_name(model_name, _to_path(config_path), explicit_weight_path, args)
+    resolved_config_path = _resolve_config_path(initial_model_name, config_path, args)
+    config_dict = _load_yaml_config(resolved_config_path)
+
+    resolved_model_name = _resolve_model_name(
+        initial_model_name or config_dict["MODEL"]["NAME"],
+        resolved_config_path,
+        explicit_weight_path,
+        args,
+    )
+    resolved_weight_path = _resolve_weight_path(resolved_model_name, weight_path, args)
+
+    function_overrides = _collect_function_overrides(
+        model_name=resolved_model_name,
+        weight_path=resolved_weight_path,
+        num_classes=num_classes,
+        img_size=img_size,
+        in_chans=in_chans,
+        use_checkpoint=use_checkpoint,
+        model_kwargs=model_kwargs,
+    )
+    arg_overrides = _collect_arg_overrides(args)
+    config_dict = _apply_overrides(config_dict, function_overrides)
+    config_dict = _apply_overrides(config_dict, arg_overrides)
+
+    model_cfg = config_dict["MODEL"]
+    swinv2_cfg = model_cfg["SWINV2"]
+    model = SwinTransformerV2(
+        img_size=config_dict["DATA"]["IMG_SIZE"],
+        patch_size=swinv2_cfg["PATCH_SIZE"],
+        in_chans=swinv2_cfg["IN_CHANS"],
+        num_classes=model_cfg["NUM_CLASSES"],
+        embed_dim=swinv2_cfg["EMBED_DIM"],
+        depths=tuple(swinv2_cfg["DEPTHS"]),
+        num_heads=tuple(swinv2_cfg["NUM_HEADS"]),
+        window_size=swinv2_cfg["WINDOW_SIZE"],
+        mlp_ratio=swinv2_cfg["MLP_RATIO"],
+        qkv_bias=swinv2_cfg["QKV_BIAS"],
+        drop_rate=model_cfg["DROP_RATE"],
+        drop_path_rate=model_cfg["DROP_PATH_RATE"],
+        ape=swinv2_cfg["APE"],
+        patch_norm=swinv2_cfg["PATCH_NORM"],
+        use_checkpoint=config_dict["TRAIN"]["USE_CHECKPOINT"],
+        pretrained_window_sizes=tuple(swinv2_cfg["PRETRAINED_WINDOW_SIZES"]),
+    )
+
+    if load_weights:
+        if resolved_weight_path is None:
+            raise FileNotFoundError(
+                f"No SwinV2 weight file resolved for model '{model_cfg['NAME']}'. "
+                f"Expected one under {SWINV2_WEIGHT_DIR}."
+            )
+        if not resolved_weight_path.exists():
+            raise FileNotFoundError(f"SwinV2 weight file not found: {resolved_weight_path}")
+
+        checkpoint = _load_checkpoint(resolved_weight_path)
+        state_dict = _sanitize_state_dict(model, _extract_state_dict(checkpoint))
+        model.load_state_dict(state_dict, strict=strict)
+
+    config = _dict_to_namespace(config_dict)
+    if return_config:
+        return model, config
+    return model
+
+
+def build_swinv2_auto(
+        model_name: str | None = None,
+        config_path: str | Path | None = None,
+        weight_path: str | Path | None = None,
+        args: Namespace | None = None,
+        *,
+        verbose: bool = True,
+        return_config: bool = False,
+        return_resolution: bool = False,
+        **kwargs,
+):
+    """Auto-resolve SwinV2 config and weights by ``MODEL.NAME`` and print sources.
+
+    This wrapper keeps the same precedence rules as ``build_swinv2`` while making
+    the config/weight resolution explicit for callers.
+    """
+
+    explicit_weight_path = _to_path(weight_path)
+    candidate_model_name = _resolve_model_name(model_name, _to_path(config_path), explicit_weight_path, args)
+    resolved_config_path, config_source = _resolve_config_with_source(candidate_model_name, config_path, args)
+
+    temp_config = _load_yaml_config(resolved_config_path)
+    final_model_name = _resolve_model_name(
+        candidate_model_name or temp_config["MODEL"]["NAME"],
+        resolved_config_path,
+        explicit_weight_path,
+        args,
+    ) or temp_config["MODEL"]["NAME"]
+    resolved_weight_path, weight_source = _resolve_weight_with_source(final_model_name, weight_path, args)
+
+    built = build_swinv2(
+        model_name=final_model_name,
+        config_path=resolved_config_path,
+        weight_path=resolved_weight_path,
+        args=args,
+        return_config=True,
+        **kwargs,
+    )
+    if not isinstance(built, tuple) or len(built) != 2:
+        raise RuntimeError("build_swinv2(return_config=True) must return (model, config)")
+    model, config = built
+
+    resolution = {
+        "model_name": config.MODEL.NAME,
+        "config_path": str(resolved_config_path) if resolved_config_path is not None else None,
+        "config_source": config_source,
+        "weight_path": str(resolved_weight_path) if resolved_weight_path is not None else None,
+        "weight_source": weight_source,
+    }
+
+    if verbose:
+        print(
+            "[build_swinv2_auto] "
+            f"MODEL.NAME={resolution['model_name']} | "
+            f"config={resolution['config_path']} ({resolution['config_source']}) | "
+            f"weight={resolution['weight_path']} ({resolution['weight_source']})"
+        )
+
+    if return_config and return_resolution:
+        return model, config, resolution
+    if return_config:
+        return model, config
+    if return_resolution:
+        return model, resolution
+    return model
+
+
+__all__ = ["build_swinv2", "build_swinv2_auto"]

+ 139 - 0
lib/modules/decoder_2d.py

@@ -0,0 +1,139 @@
+from __future__ import annotations
+
+from typing import Sequence
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .layers_2d import Conv2dBN
+
+
+class BoundaryRefineBlock2d(nn.Module):
+    """
+    使用边界提示和稳定性图对解码特征做轻量细化。
+    """
+
+    def __init__(self, channels: int) -> None:
+        super().__init__()
+        self.refine = nn.Sequential(
+            Conv2dBN(channels, channels, 3, 1, 1),
+            nn.ReLU(inplace=True),
+            Conv2dBN(channels, channels, 3, 1, 1),
+        )
+
+    def forward(
+            self,
+            x: torch.Tensor,
+            boundary_hint: torch.Tensor | None = None,
+            stability_map: torch.Tensor | None = None,
+    ) -> torch.Tensor:
+        modulator = 1.0
+
+        if stability_map is not None:
+            stability_map = F.interpolate(
+                stability_map, size=x.shape[-2:], mode="bilinear", align_corners=False
+            )
+            modulator = modulator + stability_map
+
+        if boundary_hint is not None:
+            boundary_hint = F.interpolate(
+                boundary_hint, size=x.shape[-2:], mode="bilinear", align_corners=False
+            )
+            modulator = modulator + boundary_hint
+
+        return x + self.refine(x * modulator)
+
+
+class StructureAwareDecodeBlock2d(nn.Module):
+    """
+    单层结构感知解码块。
+    """
+
+    def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None:
+        super().__init__()
+        self.high_proj = nn.Sequential(
+            Conv2dBN(in_channels, out_channels, 1, 1, 0),
+            nn.ReLU(inplace=True),
+        )
+        self.skip_proj = nn.Sequential(
+            Conv2dBN(skip_channels, out_channels, 1, 1, 0),
+            nn.ReLU(inplace=True),
+        )
+        self.fuse = nn.Sequential(
+            Conv2dBN(out_channels * 2, out_channels, 3, 1, 1),
+            nn.ReLU(inplace=True),
+            Conv2dBN(out_channels, out_channels, 3, 1, 1),
+            nn.ReLU(inplace=True),
+        )
+        self.refine = BoundaryRefineBlock2d(out_channels)
+
+    def forward(
+            self,
+            x: torch.Tensor,
+            skip: torch.Tensor,
+            stability_map: torch.Tensor | None = None,
+            boundary_hint: torch.Tensor | None = None,
+    ) -> torch.Tensor:
+        x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
+        x = self.high_proj(x)
+        skip = self.skip_proj(skip)
+        x = self.fuse(torch.cat([x, skip], dim=1))
+        x = self.refine(x, boundary_hint=boundary_hint, stability_map=stability_map)
+        return x
+
+
+class StructureAwareDecoder2d(nn.Module):
+    """
+    第一版结构感知解码器骨架。
+
+    输入特征默认按从浅到深排列,最后一个特征视为最深层输入。
+    """
+
+    def __init__(self, encoder_channels: Sequence[int], decoder_channels: Sequence[int] | None = None) -> None:
+        super().__init__()
+        if len(encoder_channels) < 2:
+            raise ValueError("encoder_channels must contain at least two stages.")
+
+        self.encoder_channels = list(encoder_channels)
+        if decoder_channels is None:
+            decoder_channels = list(reversed(self.encoder_channels[:-1]))
+
+        if len(decoder_channels) != len(self.encoder_channels) - 1:
+            raise ValueError("decoder_channels length must match len(encoder_channels) - 1.")
+
+        in_channels = self.encoder_channels[-1]
+        skip_channels = list(reversed(self.encoder_channels[:-1]))
+
+        blocks = []
+        for skip_ch, out_ch in zip(skip_channels, decoder_channels):
+            blocks.append(StructureAwareDecodeBlock2d(in_channels, skip_ch, out_ch))
+            in_channels = out_ch
+        self.blocks = nn.ModuleList(blocks)
+        self.out_channels = in_channels
+
+    def forward(
+            self,
+            features: Sequence[torch.Tensor],
+            stability_map: torch.Tensor | None = None,
+            boundary_hints: Sequence[torch.Tensor] | None = None,
+    ) -> tuple[torch.Tensor, list[torch.Tensor]]:
+        if len(features) != len(self.encoder_channels):
+            raise ValueError(
+                f"feature count mismatch: got {len(features)}, expected {len(self.encoder_channels)}"
+            )
+
+        x = features[-1]
+        skips = list(reversed(features[:-1]))
+        decoder_features = []
+
+        if boundary_hints is None:
+            boundary_hints = [None] * len(self.blocks)
+        elif len(boundary_hints) != len(self.blocks):
+            raise ValueError("boundary_hints length must match decoder depth.")
+
+        for block, skip, boundary_hint in zip(self.blocks, skips, boundary_hints):
+            x = block(x, skip, stability_map=stability_map, boundary_hint=boundary_hint)
+            decoder_features.append(x)
+
+        return x, decoder_features

+ 302 - 0
lib/modules/fwta_2d.py

@@ -0,0 +1,302 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import ptwt
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def build_gaussian_lowpass(
+        channels: int,
+        sigma_ratio: float = 0.35,
+        device: Optional[torch.device] = None,
+        dtype: Optional[torch.dtype] = None,
+) -> torch.Tensor:
+    """
+    构建用于通道维度的 1D 高斯低通滤波器。
+
+    Returns:
+        Tensor of shape [1, 1, C].
+    """
+    sigma = max(channels * sigma_ratio, 1.0)
+    center = (channels - 1) / 2.0
+    coords = torch.arange(channels, device=device, dtype=dtype or torch.float32)
+    kernel = torch.exp(-0.5 * ((coords - center) / sigma) ** 2)
+    kernel = kernel / kernel.max().clamp_min(1e-6)
+    return kernel.view(1, 1, channels)
+
+
+@dataclass
+class FWTADebug:
+    fourier_score: torch.Tensor
+    wavelet_score: torch.Tensor
+    fused_score: torch.Tensor
+    gate: torch.Tensor
+    pooled_token: torch.Tensor
+
+
+class FourierWaveletTokenAggregation(nn.Module):
+    """
+    傅里叶 - 小波令牌聚合模块。
+
+    Inputs:
+        cls_token: [B, C]
+        patch_tokens: [B, N, C]
+
+    Output:
+        cls_out: [B, C]
+        gate: [B, N]
+
+    Design:
+        - Fourier branch estimates token-wise semantic stability.
+        - Wavelet branch estimates token-wise structural saliency.
+        - Fused score produces a softmax gate over tokens.
+        - Weighted pooled token is added back to the CLS token by residual update.
+    """
+
+    def __init__(
+            self,
+            dim: int,
+            grid_size: Tuple[int, int],
+            wavelet: str = "haar",
+            wavelet_level: int = 1,
+            sigma_ratio: float = 0.35,
+            tau_fourier: float = 0.15,
+            gate_temperature: float = 1.0,
+            residual_scale_init: float = 1.0,
+            fusion_hidden_ratio: float = 0.5,
+            use_cls_conditioning: bool = True,
+            eps: float = 1e-6,
+    ) -> None:
+        super().__init__()
+        self.dim = dim
+        self.grid_size = grid_size
+        self.wavelet = wavelet
+        self.wavelet_level = wavelet_level
+        self.sigma_ratio = sigma_ratio
+        self.tau_fourier = tau_fourier
+        self.gate_temperature = gate_temperature
+        self.use_cls_conditioning = use_cls_conditioning
+        self.eps = eps
+
+        hidden_dim = max(int(dim * fusion_hidden_ratio), 32)
+        fuse_in_dim = 3 if use_cls_conditioning else 2
+
+        self.score_fuser = nn.Sequential(
+            nn.Linear(fuse_in_dim, hidden_dim),
+            nn.GELU(),
+            nn.Linear(hidden_dim, 1),
+        )
+
+        self.token_proj = nn.Sequential(
+            nn.LayerNorm(dim),
+            nn.Linear(dim, dim),
+            nn.GELU(),
+            nn.Linear(dim, dim),
+        )
+
+        self.out_norm = nn.LayerNorm(dim)
+        self.residual_scale = nn.Parameter(torch.tensor(float(residual_scale_init)))
+
+        # 学习系数以平衡粗结构、边缘线索和噪声。
+        self.wavelet_ll_weight = nn.Parameter(torch.tensor(1.0))
+        self.wavelet_edge_weight = nn.Parameter(torch.tensor(0.5))
+        self.wavelet_noise_weight = nn.Parameter(torch.tensor(0.5))
+
+        self.register_buffer("gaussian_kernel", build_gaussian_lowpass(dim, sigma_ratio), persistent=False)
+
+    def forward(
+            self,
+            cls_token: torch.Tensor,
+            patch_tokens: torch.Tensor,
+            return_debug: bool = False,
+    ):
+        B, N, C = patch_tokens.shape
+        H, W = self.grid_size
+        if N != H * W:
+            raise ValueError(f"patch count mismatch: got N={N}, expected H*W={H * W}")
+        if C != self.dim:
+            raise ValueError(f"channel mismatch: got C={C}, expected dim={self.dim}")
+
+        fourier_score = self._fourier_stability_score(patch_tokens)
+        wavelet_score = self._wavelet_saliency_score(patch_tokens)
+
+        fuse_inputs = [fourier_score, wavelet_score]
+        if self.use_cls_conditioning:
+            cls_alignment = self._cls_alignment_score(cls_token, patch_tokens)
+            fuse_inputs.append(cls_alignment)
+
+        fused_input = torch.stack(fuse_inputs, dim=-1)  # [B, N, 2 or 3]
+        fused_score = self.score_fuser(fused_input).squeeze(-1)  # [B, N]
+        gate = torch.softmax(fused_score / max(self.gate_temperature, self.eps), dim=1)
+
+        pooled_token = torch.sum(gate.unsqueeze(-1) * patch_tokens, dim=1)  # [B, C]
+        pooled_token = self.token_proj(pooled_token)
+
+        cls_out = cls_token + self.residual_scale * pooled_token
+        cls_out = self.out_norm(cls_out)
+
+        if return_debug:
+            debug = FWTADebug(
+                fourier_score=fourier_score,
+                wavelet_score=wavelet_score,
+                fused_score=fused_score,
+                gate=gate,
+                pooled_token=pooled_token,
+            )
+            return cls_out, gate, debug
+        return cls_out, gate
+
+    def get_stability_map(self, patch_tokens: torch.Tensor) -> torch.Tensor:
+        """
+        为分割任务提供二维稳定性图接口。
+
+        Returns:
+            Tensor of shape [B, 1, H, W].
+        """
+        _, gate = self.forward(
+            cls_token=patch_tokens.mean(dim=1),
+            patch_tokens=patch_tokens,
+            return_debug=False,
+        )
+        H, W = self.grid_size
+        return gate.reshape(patch_tokens.shape[0], 1, H, W)
+
+    def forward_with_map(
+            self,
+            cls_token: torch.Tensor,
+            patch_tokens: torch.Tensor,
+            return_debug: bool = False,
+    ):
+        """
+        同时返回 CLS 更新结果、门控权重以及二维稳定性图。
+        """
+        outputs = self.forward(cls_token, patch_tokens, return_debug=return_debug)
+        H, W = self.grid_size
+
+        if return_debug:
+            cls_out, gate, debug = outputs
+            stability_map = gate.reshape(patch_tokens.shape[0], 1, H, W)
+            return cls_out, gate, stability_map, debug
+
+        cls_out, gate = outputs
+        stability_map = gate.reshape(patch_tokens.shape[0], 1, H, W)
+        return cls_out, gate, stability_map
+
+    def _fourier_stability_score(self, patch_tokens: torch.Tensor) -> torch.Tensor:
+        """
+        通过通道级低通滤波后的变化量来评分令牌。
+
+        Higher score => more stable token => more likely to carry coherent semantics.
+        """
+        kernel = self.gaussian_kernel.to(device=patch_tokens.device, dtype=patch_tokens.dtype)
+
+        xf = torch.fft.fft(patch_tokens, dim=-1)
+        xf = torch.fft.fftshift(xf, dim=-1)
+        xf_low = xf * kernel
+        xf_low = torch.fft.ifftshift(xf_low, dim=-1)
+        x_low = torch.fft.ifft(xf_low, dim=-1).real
+
+        delta = torch.mean(torch.abs(patch_tokens - x_low), dim=-1)  # [B, N]
+        score = torch.exp(-delta / max(self.tau_fourier, self.eps))
+        return score
+
+    def _wavelet_saliency_score(self, patch_tokens: torch.Tensor) -> torch.Tensor:
+        """
+        使用 Token-Grid 小波分解来估计结构前景显著性。
+
+        The patch tokens are treated as a low-resolution feature map [B, C, H, W].
+        """
+        B, N, C = patch_tokens.shape
+        H, W = self.grid_size
+
+        x2d = patch_tokens.transpose(1, 2).reshape(B, C, H, W)
+        coeffs = ptwt.wavedec2(x2d, self.wavelet, level=self.wavelet_level)
+
+        ll = coeffs[0]
+        detail_coeffs = coeffs[1:]
+
+        ll_energy = ll.abs().mean(dim=1, keepdim=True)
+        ll_energy = F.interpolate(ll_energy, size=(H, W), mode="nearest")
+
+        edge_energy = torch.zeros_like(ll_energy)
+        noise_energy = torch.zeros_like(ll_energy)
+
+        for level_detail in detail_coeffs:
+            lh, hl, hh = level_detail
+            level_edge = 0.5 * (lh.abs().mean(dim=1, keepdim=True) + hl.abs().mean(dim=1, keepdim=True))
+            level_noise = hh.abs().mean(dim=1, keepdim=True)
+
+            target_size = (H, W)
+            level_edge = F.interpolate(level_edge, size=target_size, mode="nearest")
+            level_noise = F.interpolate(level_noise, size=target_size, mode="nearest")
+
+            edge_energy = edge_energy + level_edge
+            noise_energy = noise_energy + level_noise
+
+        raw_score = (
+                self.wavelet_ll_weight * ll_energy
+                + self.wavelet_edge_weight * edge_energy
+                - self.wavelet_noise_weight * noise_energy
+        )
+        raw_score = raw_score.flatten(1)  # [B, N]
+        score = torch.sigmoid(raw_score)
+        return score
+
+    def _cls_alignment_score(self, cls_token: torch.Tensor, patch_tokens: torch.Tensor) -> torch.Tensor:
+        """
+        可选稳定器:偏好已与现有 CLS 令牌对齐的令牌。
+        这有助于模块作为修正项而不是完全独立的分支发挥作用。
+        """
+        cls_norm = F.normalize(cls_token, dim=-1)
+        patch_norm = F.normalize(patch_tokens, dim=-1)
+        score = torch.sum(patch_norm * cls_norm.unsqueeze(1), dim=-1)
+        score = 0.5 * (score + 1.0)  # map cosine similarity from [-1, 1] to [0, 1]
+        return score
+
+
+class ViTBlockWithFWTA(nn.Module):
+    """
+    最小包装器,展示如何在 Transformer Block 后插入 FWTA。
+
+    Expected input:
+        x: [B, 1 + N, C]
+
+    Output:
+        x: [B, 1 + N, C]
+    """
+
+    def __init__(self, block: nn.Module, dim: int, grid_size: Tuple[int, int]) -> None:
+        super().__init__()
+        self.block = block
+        self.fwta = FourierWaveletTokenAggregation(dim=dim, grid_size=grid_size)
+
+    def forward(self, x: torch.Tensor):
+        x = self.block(x)
+        cls_token = x[:, 0]
+        patch_tokens = x[:, 1:]
+        cls_token, gate = self.fwta(cls_token, patch_tokens)
+        x = torch.cat([cls_token.unsqueeze(1), patch_tokens], dim=1)
+        return x, gate
+
+
+class FinalAggregatorWithFWTA(nn.Module):
+    """
+    适用于 torchvision / timm 风格 ViT 的更简单变体:
+    保持所有 Encoder Block 不变,仅在最后应用 FWTA。
+    """
+
+    def __init__(self, dim: int, grid_size: Tuple[int, int], num_classes: int) -> None:
+        super().__init__()
+        self.fwta = FourierWaveletTokenAggregation(dim=dim, grid_size=grid_size)
+        self.head = nn.Linear(dim, num_classes)
+
+    def forward(self, encoder_tokens: torch.Tensor):
+        cls_token = encoder_tokens[:, 0]
+        patch_tokens = encoder_tokens[:, 1:]
+        cls_token, gate = self.fwta(cls_token, patch_tokens)
+        logits = self.head(cls_token)
+        return logits, gate

+ 105 - 105
lib/modules/layers_2d.py

@@ -1,105 +1,105 @@
-"""
-通用基础层(2D)。
-"""
-
-from collections import OrderedDict
-
-import torch
-import torch.nn as nn
-from timm.layers.drop import DropPath
-from timm.layers.mlp import Mlp
-from timm.layers.squeeze_excite import SqueezeExcite
-from timm.layers.weight_init import trunc_normal_
-
-
-class Scale(nn.Module):
-    def __init__(self, dims, init_scale=1.0):
-        super().__init__()
-        self.weight = nn.Parameter(torch.ones(*dims) * init_scale)
-
-    def forward(self, x):
-        return self.weight * x
-
-
-class Residual(nn.Module):
-    def __init__(self, module, drop=0.0):
-        super().__init__()
-        self.module = module
-        self.drop = drop
-
-    def forward(self, x):
-        if self.training and self.drop > 0.0:
-            keep = torch.rand(x.size(0), 1, 1, 1, device=x.device)
-            keep = keep.ge_(self.drop).div(1.0 - self.drop).detach()
-            return x + self.module(x) * keep
-        return x + self.module(x)
-
-
-class FFN2d(nn.Module):
-    def __init__(self, embed_dim, hidden_dim):
-        super().__init__()
-        self.mlp = Mlp(
-            in_features=embed_dim,
-            hidden_features=hidden_dim,
-            out_features=embed_dim,
-            act_layer=nn.ReLU,
-            use_conv=True,
-            bias=False,
-        )
-        for m in self.mlp.modules():
-            if isinstance(m, nn.BatchNorm2d) and m.num_features == embed_dim:
-                nn.init.constant_(m.weight, 0.0)
-                nn.init.constant_(m.bias, 0.0)
-
-    def forward(self, x):
-        return self.mlp(x)
-
-
-class BNLinear1d(nn.Sequential):
-    def __init__(self, in_features, out_features, bias=True, std=0.02):
-        bn = nn.BatchNorm1d(in_features)
-        linear = nn.Linear(in_features, out_features, bias=bias)
-        trunc_normal_(linear.weight, std=std)
-        if bias:
-            nn.init.constant_(linear.bias, 0)
-        super().__init__(OrderedDict([("bn", bn), ("linear", linear)]))
-
-
-class Conv2dBN(nn.Sequential):
-    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0,
-                 dilation=1, groups=1, bn_weight_init=1.0):
-        conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
-        bn = nn.BatchNorm2d(out_channels)
-        nn.init.constant_(bn.weight, bn_weight_init)
-        nn.init.constant_(bn.bias, 0)
-        super().__init__(OrderedDict([("conv", conv), ("bn", bn)]))
-
-
-class DWConv2dBNReLU(nn.Sequential):
-    def __init__(self, in_channels, out_channels, kernel_size=3, bn_weight_init=1.0):
-        super().__init__(OrderedDict([
-            ("dwconv3x3",
-             nn.Conv2d(in_channels, in_channels, kernel_size, 1, kernel_size // 2, groups=in_channels, bias=False)),
-            ("bn1", nn.BatchNorm2d(in_channels)),
-            ("relu", nn.ReLU(inplace=True)),
-            ("dwconv1x1", nn.Conv2d(in_channels, out_channels, 1, 1, 0, groups=in_channels, bias=False)),
-            ("bn2", nn.BatchNorm2d(out_channels)),
-        ]))
-        for bn_name in ["bn1", "bn2"]:
-            bn = getattr(self, bn_name)
-            nn.init.constant_(bn.weight, bn_weight_init)
-            nn.init.constant_(bn.bias, 0)
-
-
-class PatchMerging2d(nn.Module):
-    def __init__(self, dim, out_dim):
-        super().__init__()
-        hidden_dim = int(dim * 4)
-        self.conv1 = Conv2dBN(dim, hidden_dim, 1, 1, 0)
-        self.act = nn.ReLU(inplace=True)
-        self.conv2 = Conv2dBN(hidden_dim, hidden_dim, 3, 2, 1, groups=hidden_dim)
-        self.se = SqueezeExcite(hidden_dim, rd_ratio=0.25)
-        self.conv3 = Conv2dBN(hidden_dim, out_dim, 1, 1, 0)
-
-    def forward(self, x):
-        return self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x))))))
+"""
+通用基础层(2D)。
+"""
+
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+from timm.layers.drop import DropPath
+from timm.layers.mlp import Mlp
+from timm.layers.squeeze_excite import SqueezeExcite
+from timm.layers.weight_init import trunc_normal_
+
+
+class Scale(nn.Module):
+    def __init__(self, dims, init_scale=1.0):
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(*dims) * init_scale)
+
+    def forward(self, x):
+        return self.weight * x
+
+
+class Residual(nn.Module):
+    def __init__(self, module, drop=0.0):
+        super().__init__()
+        self.module = module
+        self.drop = drop
+
+    def forward(self, x):
+        if self.training and self.drop > 0.0:
+            keep = torch.rand(x.size(0), 1, 1, 1, device=x.device)
+            keep = keep.ge_(self.drop).div(1.0 - self.drop).detach()
+            return x + self.module(x) * keep
+        return x + self.module(x)
+
+
+class FFN2d(nn.Module):
+    def __init__(self, embed_dim, hidden_dim):
+        super().__init__()
+        self.mlp = Mlp(
+            in_features=embed_dim,
+            hidden_features=hidden_dim,
+            out_features=embed_dim,
+            act_layer=nn.ReLU,
+            use_conv=True,
+            bias=False,
+        )
+        for m in self.mlp.modules():
+            if isinstance(m, nn.BatchNorm2d) and m.num_features == embed_dim:
+                nn.init.constant_(m.weight, 0.0)
+                nn.init.constant_(m.bias, 0.0)
+
+    def forward(self, x):
+        return self.mlp(x)
+
+
+class BNLinear1d(nn.Sequential):
+    def __init__(self, in_features, out_features, bias=True, std=0.02):
+        bn = nn.BatchNorm1d(in_features)
+        linear = nn.Linear(in_features, out_features, bias=bias)
+        trunc_normal_(linear.weight, std=std)
+        if bias:
+            nn.init.constant_(linear.bias, 0)
+        super().__init__(OrderedDict([("bn", bn), ("linear", linear)]))
+
+
+class Conv2dBN(nn.Sequential):
+    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0,
+                 dilation=1, groups=1, bn_weight_init=1.0):
+        conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
+        bn = nn.BatchNorm2d(out_channels)
+        nn.init.constant_(bn.weight, bn_weight_init)
+        nn.init.constant_(bn.bias, 0)
+        super().__init__(OrderedDict([("conv", conv), ("bn", bn)]))
+
+
+class DWConv2dBNReLU(nn.Sequential):
+    def __init__(self, in_channels, out_channels, kernel_size=3, bn_weight_init=1.0):
+        super().__init__(OrderedDict([
+            ("dwconv3x3",
+             nn.Conv2d(in_channels, in_channels, kernel_size, 1, kernel_size // 2, groups=in_channels, bias=False)),
+            ("bn1", nn.BatchNorm2d(in_channels)),
+            ("relu", nn.ReLU(inplace=True)),
+            ("dwconv1x1", nn.Conv2d(in_channels, out_channels, 1, 1, 0, groups=in_channels, bias=False)),
+            ("bn2", nn.BatchNorm2d(out_channels)),
+        ]))
+        for bn_name in ["bn1", "bn2"]:
+            bn = getattr(self, bn_name)
+            nn.init.constant_(bn.weight, bn_weight_init)
+            nn.init.constant_(bn.bias, 0)
+
+
+class PatchMerging2d(nn.Module):
+    def __init__(self, dim, out_dim):
+        super().__init__()
+        hidden_dim = int(dim * 4)
+        self.conv1 = Conv2dBN(dim, hidden_dim, 1, 1, 0)
+        self.act = nn.ReLU(inplace=True)
+        self.conv2 = Conv2dBN(hidden_dim, hidden_dim, 3, 2, 1, groups=hidden_dim)
+        self.se = SqueezeExcite(hidden_dim, rd_ratio=0.25)
+        self.conv3 = Conv2dBN(hidden_dim, out_dim, 1, 1, 0)
+
+    def forward(self, x):
+        return self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x))))))

+ 0 - 163
lib/modules/nets_2d.py

@@ -1,163 +0,0 @@
-"""
-WaveletFFTNet(2D 版本)。
-"""
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from typing import Literal
-
-from .blocks_2d import WaveletFFTBlock2d
-from .layers_2d import (
-    BNLinear1d,
-    Conv2dBN,
-    FFN2d,
-    PatchMerging2d,
-    Residual,
-)
-
-
-class WaveletFFTNet2d(nn.Module):
-    def __init__(
-            self, img_size=224, in_chans=3, num_classes=1000,
-            embed_dim=(192, 384, 448), global_ratio=(0.8, 0.7, 0.6),
-            local_ratio=(0.2, 0.2, 0.3), depth=(1, 2, 2),
-            kernels=(7, 5, 3), down_ops=(("subsample", 2), ("subsample", 2), ("",)),
-            distillation=False, drop_path=0.0, wt_levels=1,
-            wt_type="db1", wt_mode: Literal["constant", "zero", "reflect", "periodic", "symmetric"] = "zero",
-            proj_drop=0.0,
-    ):
-        super().__init__()
-        self.img_size = img_size
-
-        self.patch_embed = nn.Sequential(
-            Conv2dBN(in_chans, embed_dim[0] // 8, 3, 2, 1),
-            nn.ReLU(inplace=True),
-            Conv2dBN(embed_dim[0] // 8, embed_dim[0] // 4, 3, 2, 1),
-            nn.ReLU(inplace=True),
-            Conv2dBN(embed_dim[0] // 4, embed_dim[0] // 2, 3, 2, 1),
-            nn.ReLU(inplace=True),
-            Conv2dBN(embed_dim[0] // 2, embed_dim[0], 3, 2, 1),
-        )
-
-        stages = [[], [], []]
-        dprs = [x.item() for x in torch.linspace(0, drop_path, sum(depth))]
-
-        for stage_idx, (ed, dpth, gr, lr, down_op, kernel) in enumerate(
-                zip(embed_dim, depth, global_ratio, local_ratio, down_ops, kernels)
-        ):
-            start = sum(depth[:stage_idx])
-            stage_drop = dprs[start: start + dpth]
-
-            for block_idx in range(dpth):
-                stages[stage_idx].append(
-                    WaveletFFTBlock2d(
-                        ed, global_ratio=gr, local_ratio=lr, kernel_size=kernel,
-                        wt_levels=wt_levels, wt_type=wt_type, wt_mode=wt_mode,
-                        proj_drop=proj_drop, drop_path=stage_drop[block_idx],
-                    )
-                )
-
-            if stage_idx < len(embed_dim) - 1 and down_op[0] == "subsample":
-                stages[stage_idx + 1].append(
-                    nn.Sequential(
-                        Residual(
-                            Conv2dBN(embed_dim[stage_idx], embed_dim[stage_idx], 3, 1, 1, groups=embed_dim[stage_idx])),
-                        Residual(FFN2d(embed_dim[stage_idx], int(embed_dim[stage_idx] * 2))),
-                    )
-                )
-                stages[stage_idx + 1].append(PatchMerging2d(embed_dim[stage_idx], embed_dim[stage_idx + 1]))
-                stages[stage_idx + 1].append(
-                    nn.Sequential(
-                        Residual(Conv2dBN(embed_dim[stage_idx + 1], embed_dim[stage_idx + 1], 3, 1, 1,
-                                          groups=embed_dim[stage_idx + 1])),
-                        Residual(FFN2d(embed_dim[stage_idx + 1], int(embed_dim[stage_idx + 1] * 2))),
-                    )
-                )
-
-        self.blocks1 = nn.Sequential(*stages[0])
-        self.blocks2 = nn.Sequential(*stages[1])
-        self.blocks3 = nn.Sequential(*stages[2])
-
-        self.head = BNLinear1d(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
-        self.distillation = distillation
-        if distillation:
-            self.head_dist = BNLinear1d(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
-
-    def forward_features(self, x):
-        x = self.patch_embed(x)
-        x = self.blocks1(x)
-        x = self.blocks2(x)
-        x = self.blocks3(x)
-        return F.adaptive_avg_pool2d(x, 1).flatten(1)
-
-    def forward(self, x):
-        x = self.forward_features(x)
-        if self.distillation:
-            x = self.head(x), self.head_dist(x)
-            if not self.training:
-                x = (x[0] + x[1]) / 2
-            return x
-        return self.head(x)
-
-
-CFG_WAVELET_FFT_T2 = {
-    "img_size": 192, "embed_dim": (144, 272, 368), "depth": (1, 2, 2),
-    "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3),
-    "kernels": (7, 5, 3), "drop_path": 0.0,
-}
-CFG_WAVELET_FFT_T4 = {
-    "img_size": 192, "embed_dim": (176, 368, 448), "depth": (1, 2, 2),
-    "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3),
-    "kernels": (7, 5, 3), "drop_path": 0.0,
-}
-CFG_WAVELET_FFT_S6 = {
-    "img_size": 224, "embed_dim": (192, 384, 448), "depth": (1, 2, 2),
-    "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3),
-    "kernels": (7, 5, 3), "drop_path": 0.0,
-}
-CFG_WAVELET_FFT_B1 = {
-    "img_size": 256, "embed_dim": (200, 376, 448), "depth": (2, 3, 2),
-    "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3),
-    "kernels": (7, 5, 3), "drop_path": 0.03,
-}
-CFG_WAVELET_FFT_B2 = {
-    "img_size": 384, "embed_dim": (200, 376, 448), "depth": (2, 3, 2),
-    "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3),
-    "kernels": (7, 5, 3), "drop_path": 0.03,
-}
-CFG_WAVELET_FFT_B4 = {
-    "img_size": 512, "embed_dim": (200, 376, 448), "depth": (2, 3, 2),
-    "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3),
-    "kernels": (7, 5, 3), "drop_path": 0.03,
-}
-
-
-def _build_model(model_cfg, **kwargs):
-    cfg = dict(model_cfg)
-    cfg.update(kwargs)
-    return WaveletFFTNet2d(**cfg)
-
-
-def wavelet_fft_t2(**kwargs):
-    return _build_model(CFG_WAVELET_FFT_T2, **kwargs)
-
-
-def wavelet_fft_t4(**kwargs):
-    return _build_model(CFG_WAVELET_FFT_T4, **kwargs)
-
-
-def wavelet_fft_s6(**kwargs):
-    return _build_model(CFG_WAVELET_FFT_S6, **kwargs)
-
-
-def wavelet_fft_b1(**kwargs):
-    return _build_model(CFG_WAVELET_FFT_B1, **kwargs)
-
-
-def wavelet_fft_b2(**kwargs):
-    return _build_model(CFG_WAVELET_FFT_B2, **kwargs)
-
-
-def wavelet_fft_b4(**kwargs):
-    return _build_model(CFG_WAVELET_FFT_B4, **kwargs)

+ 0 - 26
lib/modules/smoke_test.py

@@ -1,26 +0,0 @@
-import torch
-
-from .attentions_2d import WaveletAttentionGlobalBranch2d
-from .nets_2d import wavelet_fft_t2
-
-
-def run_smoke_test():
-    with torch.no_grad():
-        global_op = WaveletAttentionGlobalBranch2d(32, kernel_size=5, wt_levels=2)
-        for shape in ((2, 32, 32, 32), (1, 32, 31, 29)):
-            x = torch.randn(*shape)
-            y = global_op(x)
-            assert y.shape == x.shape, f"global_op shape mismatch: {shape} -> {tuple(y.shape)}"
-
-        model = wavelet_fft_t2(num_classes=10)
-        model.eval()
-
-        x = torch.randn(2, 3, 193, 193)
-        y = model(x)
-        assert y.shape == (2, 10), f"model output shape mismatch: {tuple(y.shape)}"
-
-    return "wavelet_fft smoke test passed"
-
-
-if __name__ == "__main__":
-    print(run_smoke_test())

+ 70 - 0
lib/modules/swin_transformer_v2_fwta.py

@@ -0,0 +1,70 @@
+from __future__ import annotations
+
+from lib.SwinTransformer.models.swin_transformer_v2 import SwinTransformerV2
+from lib.modules.fwta_2d import FourierWaveletTokenAggregation
+
+
+class SwinTransformerV2FWTA(SwinTransformerV2):
+    """
+    Keep the original SwinTransformerV2 backbone intact and only replace the
+    final global aggregation path.
+    """
+
+    def __init__(
+            self,
+            *args,
+            fwta_wavelet: str = "haar",
+            fwta_level: int = 1,
+            fwta_sigma_ratio: float = 0.35,
+            fwta_tau_fourier: float = 0.15,
+            fwta_gate_temperature: float = 1.0,
+            fwta_fusion_hidden_ratio: float = 0.5,
+            fwta_use_global_conditioning: bool = True,
+            fwta_residual_scale_init: float = 1.0,
+            **kwargs,
+    ):
+        super().__init__(*args, **kwargs)
+
+        final_resolution = (
+            int(self.patches_resolution[0] // (2 ** (self.num_layers - 1))),
+            int(self.patches_resolution[1] // (2 ** (self.num_layers - 1))),
+        )
+
+        self.fwta = FourierWaveletTokenAggregation(
+            dim=int(self.num_features),
+            grid_size=final_resolution,
+            wavelet=fwta_wavelet,
+            wavelet_level=fwta_level,
+            sigma_ratio=fwta_sigma_ratio,
+            tau_fourier=fwta_tau_fourier,
+            gate_temperature=fwta_gate_temperature,
+            residual_scale_init=fwta_residual_scale_init,
+            fusion_hidden_ratio=fwta_fusion_hidden_ratio,
+            use_cls_conditioning=fwta_use_global_conditioning,
+        )
+
+    def forward_features(self, x, return_gate: bool = False):
+        x = self.patch_embed(x)
+        if self.ape:
+            x = x + self.absolute_pos_embed
+        x = self.pos_drop(x)
+
+        for layer in self.layers:
+            x = layer(x)
+
+        x = self.norm(x)  # [B, L, C]
+        gap = x.mean(dim=1)  # [B, C]
+        feat, gate = self.fwta(gap, x)
+
+        if return_gate:
+            return feat, gate
+        return feat
+
+    def forward(self, x, return_gate: bool = False):
+        if return_gate:
+            feat, gate = self.forward_features(x, return_gate=True)
+            logits = self.head(feat)
+            return logits, gate
+        feat = self.forward_features(x, return_gate=False)
+        logits = self.head(feat)
+        return logits

+ 108 - 0
lib/modules/swinv2_fwta_encoder_2d.py

@@ -0,0 +1,108 @@
+from __future__ import annotations
+
+from argparse import Namespace
+from pathlib import Path
+from typing import Any
+
+import torch
+import torch.nn as nn
+
+from .build_swinv2 import build_swinv2
+from .fwta_2d import FourierWaveletTokenAggregation
+
+
+class SwinV2FWTAEncoder2d(nn.Module):
+    """
+    面向分割的 SwinV2 + FWTA 编码器封装。
+    """
+
+    def __init__(
+            self,
+            model_name: str | None = None,
+            config_path: str | Path | None = None,
+            weight_path: str | Path | None = None,
+            args: Namespace | None = None,
+            *,
+            load_weights: bool = True,
+            normalize_features: bool = True,
+            use_multiscale_features: bool = True,
+            include_patch_embed: bool = True,
+            fwta_wavelet: str = "haar",
+            fwta_level: int = 1,
+            fwta_sigma_ratio: float = 0.35,
+            fwta_tau_fourier: float = 0.15,
+            fwta_gate_temperature: float = 1.0,
+            fwta_fusion_hidden_ratio: float = 0.5,
+            fwta_use_global_conditioning: bool = True,
+            fwta_residual_scale_init: float = 1.0,
+            **model_kwargs: Any,
+    ) -> None:
+        super().__init__()
+        backbone, cfg = build_swinv2(
+            model_name=model_name,
+            config_path=config_path,
+            weight_path=weight_path,
+            args=args,
+            load_weights=load_weights,
+            return_config=True,
+            **model_kwargs,
+        )
+        self.backbone = backbone
+        self.cfg = cfg
+        self.normalize_features = normalize_features
+        self.use_multiscale_features = use_multiscale_features
+        self.include_patch_embed = include_patch_embed
+
+        depths = tuple(cfg.MODEL.SWINV2.DEPTHS)
+        embed_dim = int(cfg.MODEL.SWINV2.EMBED_DIM)
+        if self.use_multiscale_features:
+            stage_channels = []
+            if self.include_patch_embed:
+                stage_channels.append(embed_dim)
+            for i in range(len(depths)):
+                # forward_multiscale_features appends each layer output after its internal downsample.
+                channel_multiplier = 2 ** min(i + 1, len(depths) - 1)
+                stage_channels.append(int(embed_dim * channel_multiplier))
+            self.stage_channels = stage_channels
+        else:
+            self.stage_channels = [int(embed_dim * 2 ** i) for i in range(len(depths))]
+
+        final_resolution = (
+            int(self.backbone.patches_resolution[0] // (2 ** (len(depths) - 1))),
+            int(self.backbone.patches_resolution[1] // (2 ** (len(depths) - 1))),
+        )
+        self.fwta = FourierWaveletTokenAggregation(
+            dim=int(self.backbone.num_features),
+            grid_size=final_resolution,
+            wavelet=fwta_wavelet,
+            wavelet_level=fwta_level,
+            sigma_ratio=fwta_sigma_ratio,
+            tau_fourier=fwta_tau_fourier,
+            gate_temperature=fwta_gate_temperature,
+            residual_scale_init=fwta_residual_scale_init,
+            fusion_hidden_ratio=fwta_fusion_hidden_ratio,
+            use_cls_conditioning=fwta_use_global_conditioning,
+        )
+
+    def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor | list[torch.Tensor]]:
+        if self.use_multiscale_features:
+            features = self.backbone.forward_multiscale_features(
+                x,
+                normalize=self.normalize_features,
+                include_patch_embed=self.include_patch_embed,
+            )
+        else:
+            features = self.backbone.forward_stage_features(x, normalize=self.normalize_features)
+        deepest = features[-1]
+        b, c, h, w = deepest.shape
+        patch_tokens = deepest.flatten(2).transpose(1, 2)
+        cls_token = patch_tokens.mean(dim=1)
+        cls_out, gate, stability_map = self.fwta.forward_with_map(cls_token, patch_tokens)
+
+        return {
+            "features": features,
+            "deepest_feature": deepest,
+            "global_token": cls_out,
+            "token_gate": gate,
+            "stability_map": stability_map,
+        }

+ 44 - 0
lib/utils/config.py

@@ -0,0 +1,44 @@
+from __future__ import annotations
+
+from copy import deepcopy
+from pathlib import Path
+from typing import Any
+
+import yaml
+
+
+def load_yaml_config(path: str | Path) -> dict[str, Any]:
+    with Path(path).open("r", encoding="utf-8") as handle:
+        return yaml.safe_load(handle) or {}
+
+
+def deep_update(dst: dict[str, Any], src: dict[str, Any]) -> dict[str, Any]:
+    for key, value in src.items():
+        if isinstance(value, dict) and isinstance(dst.get(key), dict):
+            deep_update(dst[key], value)
+        else:
+            dst[key] = value
+    return dst
+
+
+def apply_dotlist_overrides(cfg: dict[str, Any], overrides: list[str] | None) -> dict[str, Any]:
+    if not overrides:
+        return cfg
+
+    updated = deepcopy(cfg)
+    for item in overrides:
+        if "=" not in item:
+            raise ValueError(f"Invalid override '{item}'. Expected key=value format.")
+        key, raw_value = item.split("=", 1)
+        value = yaml.safe_load(raw_value)
+        parts = key.split(".")
+        current = updated
+        for part in parts[:-1]:
+            if part not in current or not isinstance(current[part], dict):
+                current[part] = {}
+            current = current[part]
+        current[parts[-1]] = value
+    return updated
+
+
+__all__ = ["load_yaml_config", "deep_update", "apply_dotlist_overrides"]