Walkthru 12
walkthru 12
• 10 min read
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
from fastai.vision.all import *
from fastcore.parallel import *
path = Path()
trn_path = path/'train_images'
arch ='convnext_tiny_in22k'
df = pd.read_csv(path/'train.csv')
df
| image_id | label | variety | age | |
|---|---|---|---|---|
| 0 | 100330.jpg | bacterial_leaf_blight | ADT45 | 45 | 
| 1 | 100365.jpg | bacterial_leaf_blight | ADT45 | 45 | 
| 2 | 100382.jpg | bacterial_leaf_blight | ADT45 | 45 | 
| 3 | 100632.jpg | bacterial_leaf_blight | ADT45 | 45 | 
| 4 | 101918.jpg | bacterial_leaf_blight | ADT45 | 45 | 
| ... | ... | ... | ... | ... | 
| 10402 | 107607.jpg | tungro | Zonal | 55 | 
| 10403 | 107811.jpg | tungro | Zonal | 55 | 
| 10404 | 108547.jpg | tungro | Zonal | 55 | 
| 10405 | 110245.jpg | tungro | Zonal | 55 | 
| 10406 | 110381.jpg | tungro | Zonal | 55 | 
10407 rows × 4 columns
dls = ImageDataLoaders.from_folder(
    trn_path, valid_pct=.2, seed=42, item_tfms=Resize(224),batch=aug_transforms(size=224, min_scale= 0.75),bs=32)   
learn = vision_learner(dls, arch, metrics=error_rate)
learn.fine_tune(1, 0.02)
| epoch | train_loss | valid_loss | error_rate | time | 
|---|---|---|---|---|
| 0 | 1.392303 | 0.888152 | 0.271985 | 00:38 | 
| epoch | train_loss | valid_loss | error_rate | time | 
|---|---|---|---|---|
| 0 | 0.507126 | 0.295987 | 0.097069 | 01:55 | 
m = learn.model
m
Sequential(
  (0): TimmBody(
    (model): ConvNeXt(
      (stem): Sequential(
        (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
        (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
      )
      (stages): Sequential(
        (0): ConvNeXtStage(
          (downsample): Identity()
          (blocks): Sequential(
            (0): ConvNeXtBlock(
              (conv_dw): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
              (norm): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
              (mlp): Mlp(
                (fc1): Linear(in_features=96, out_features=384, bias=True)
                (act): GELU()
                (drop1): Dropout(p=0.0, inplace=False)
                (fc2): Linear(in_features=384, out_features=96, bias=True)
                (drop2): Dropout(p=0.0, inplace=False)
              )
              (drop_path): Identity()
            )
            (1): ConvNeXtBlock(
              (conv_dw): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
              (norm): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
              (mlp): Mlp(
                (fc1): Linear(in_features=96, out_features=384, bias=True)
                (act): GELU()
                (drop1): Dropout(p=0.0, inplace=False)
                (fc2): Linear(in_features=384, out_features=96, bias=True)
                (drop2): Dropout(p=0.0, inplace=False)
              )
              (drop_path): Identity()
            )
            (2): ConvNeXtBlock(
              (conv_dw): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
              (norm): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
              (mlp): Mlp(
                (fc1): Linear(in_features=96, out_features=384, bias=True)
                (act): GELU()
                (drop1): Dropout(p=0.0, inplace=False)
                (fc2): Linear(in_features=384, out_features=96, bias=True)
                (drop2): Dropout(p=0.0, inplace=False)
              )
              (drop_path): Identity()
            )
          )
        )
        (1): ConvNeXtStage(
          (downsample): Sequential(
            (0): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
            (1): Conv2d(96, 192, kernel_size=(2, 2), stride=(2, 2))
          )
          (blocks): Sequential(
            (0): ConvNeXtBlock(
              (conv_dw): Conv2d(192, 192, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=192)
              (norm): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
              (mlp): Mlp(
                (fc1): Linear(in_features=192, out_features=768, bias=True)
                (act): GELU()
                (drop1): Dropout(p=0.0, inplace=False)
                (fc2): Linear(in_features=768, out_features=192, bias=True)
                (drop2): Dropout(p=0.0, inplace=False)
              )
              (drop_path): Identity()
            )
            (1): ConvNeXtBlock(
              (conv_dw): Conv2d(192, 192, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=192)
              (norm): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
              (mlp): Mlp(
                (fc1): Linear(in_features=192, out_features=768, bias=True)
                (act): GELU()
                (drop1): Dropout(p=0.0, inplace=False)
                (fc2): Linear(in_features=768, out_features=192, bias=True)
                (drop2): Dropout(p=0.0, inplace=False)
              )
              (drop_path): Identity()
            )
            (2): ConvNeXtBlock(
              (conv_dw): Conv2d(192, 192, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=192)
              (norm): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
              (mlp): Mlp(
                (fc1): Linear(in_features=192, out_features=768, bias=True)
                (act): GELU()
                (drop1): Dropout(p=0.0, inplace=False)
                (fc2): Linear(in_features=768, out_features=192, bias=True)
                (drop2): Dropout(p=0.0, inplace=False)
              )
              (drop_path): Identity()
            )
          )
        )
        (2): ConvNeXtStage(
          (downsample): Sequential(
            (0): LayerNorm2d((192,), eps=1e-06, elementwise_affine=True)
            (1): Conv2d(192, 384, kernel_size=(2, 2), stride=(2, 2))
          )
          (blocks): Sequential(
            (0): ConvNeXtBlock(
              (conv_dw): Conv2d(384, 384, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=384)
              (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
              (mlp): Mlp(
                (fc1): Linear(in_features=384, out_features=1536, bias=True)
                (act): GELU()
                (drop1): Dropout(p=0.0, inplace=False)
                (fc2): Linear(in_features=1536, out_features=384, bias=True)
                (drop2): Dropout(p=0.0, inplace=False)
              )
              (drop_path): Identity()
            )
            (1): ConvNeXtBlock(
              (conv_dw): Conv2d(384, 384, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=384)
              (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
              (mlp): Mlp(
                (fc1): Linear(in_features=384, out_features=1536, bias=True)
                (act): GELU()
                (drop1): Dropout(p=0.0, inplace=False)
                (fc2): Linear(in_features=1536, out_features=384, bias=True)
                (drop2): Dropout(p=0.0, inplace=False)
              )
              (drop_path): Identity()
            )
            (2): ConvNeXtBlock(
              (conv_dw): Conv2d(384, 384, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=384)
              (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
              (mlp): Mlp(
                (fc1): Linear(in_features=384, out_features=1536, bias=True)
                (act): GELU()
                (drop1): Dropout(p=0.0, inplace=False)
                (fc2): Linear(in_features=1536, out_features=384, bias=True)
                (drop2): Dropout(p=0.0, inplace=False)
              )
              (drop_path): Identity()
            )
            (3): ConvNeXtBlock(
              (conv_dw): Conv2d(384, 384, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=384)
              (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
              (mlp): Mlp(
                (fc1): Linear(in_features=384, out_features=1536, bias=True)
                (act): GELU()
                (drop1): Dropout(p=0.0, inplace=False)
                (fc2): Linear(in_features=1536, out_features=384, bias=True)
                (drop2): Dropout(p=0.0, inplace=False)
              )
              (drop_path): Identity()
            )
            (4): ConvNeXtBlock(
              (conv_dw): Conv2d(384, 384, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=384)
              (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
              (mlp): Mlp(
                (fc1): Linear(in_features=384, out_features=1536, bias=True)
                (act): GELU()
                (drop1): Dropout(p=0.0, inplace=False)
                (fc2): Linear(in_features=1536, out_features=384, bias=True)
                (drop2): Dropout(p=0.0, inplace=False)
              )
              (drop_path): Identity()
            )
            (5): ConvNeXtBlock(
              (conv_dw): Conv2d(384, 384, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=384)
              (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
              (mlp): Mlp(
                (fc1): Linear(in_features=384, out_features=1536, bias=True)
                (act): GELU()
                (drop1): Dropout(p=0.0, inplace=False)
                (fc2): Linear(in_features=1536, out_features=384, bias=True)
                (drop2): Dropout(p=0.0, inplace=False)
              )
              (drop_path): Identity()
            )
            (6): ConvNeXtBlock(
              (conv_dw): Conv2d(384, 384, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=384)
              (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
              (mlp): Mlp(
                (fc1): Linear(in_features=384, out_features=1536, bias=True)
                (act): GELU()
                (drop1): Dropout(p=0.0, inplace=False)
                (fc2): Linear(in_features=1536, out_features=384, bias=True)
                (drop2): Dropout(p=0.0, inplace=False)
              )
              (drop_path): Identity()
            )
            (7): ConvNeXtBlock(
              (conv_dw): Conv2d(384, 384, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=384)
              (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
              (mlp): Mlp(
                (fc1): Linear(in_features=384, out_features=1536, bias=True)
                (act): GELU()
                (drop1): Dropout(p=0.0, inplace=False)
                (fc2): Linear(in_features=1536, out_features=384, bias=True)
                (drop2): Dropout(p=0.0, inplace=False)
              )
              (drop_path): Identity()
            )
            (8): ConvNeXtBlock(
              (conv_dw): Conv2d(384, 384, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=384)
              (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
              (mlp): Mlp(
                (fc1): Linear(in_features=384, out_features=1536, bias=True)
                (act): GELU()
                (drop1): Dropout(p=0.0, inplace=False)
                (fc2): Linear(in_features=1536, out_features=384, bias=True)
                (drop2): Dropout(p=0.0, inplace=False)
              )
              (drop_path): Identity()
            )
          )
        )
        (3): ConvNeXtStage(
          (downsample): Sequential(
            (0): LayerNorm2d((384,), eps=1e-06, elementwise_affine=True)
            (1): Conv2d(384, 768, kernel_size=(2, 2), stride=(2, 2))
          )
          (blocks): Sequential(
            (0): ConvNeXtBlock(
              (conv_dw): Conv2d(768, 768, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=768)
              (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
              (mlp): Mlp(
                (fc1): Linear(in_features=768, out_features=3072, bias=True)
                (act): GELU()
                (drop1): Dropout(p=0.0, inplace=False)
                (fc2): Linear(in_features=3072, out_features=768, bias=True)
                (drop2): Dropout(p=0.0, inplace=False)
              )
              (drop_path): Identity()
            )
            (1): ConvNeXtBlock(
              (conv_dw): Conv2d(768, 768, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=768)
              (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
              (mlp): Mlp(
                (fc1): Linear(in_features=768, out_features=3072, bias=True)
                (act): GELU()
                (drop1): Dropout(p=0.0, inplace=False)
                (fc2): Linear(in_features=3072, out_features=768, bias=True)
                (drop2): Dropout(p=0.0, inplace=False)
              )
              (drop_path): Identity()
            )
            (2): ConvNeXtBlock(
              (conv_dw): Conv2d(768, 768, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=768)
              (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
              (mlp): Mlp(
                (fc1): Linear(in_features=768, out_features=3072, bias=True)
                (act): GELU()
                (drop1): Dropout(p=0.0, inplace=False)
                (fc2): Linear(in_features=3072, out_features=768, bias=True)
                (drop2): Dropout(p=0.0, inplace=False)
              )
              (drop_path): Identity()
            )
          )
        )
      )
      (norm_pre): Identity()
      (head): Sequential(
        (global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=Identity())
        (norm): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True)
        (flatten): Flatten(start_dim=1, end_dim=-1)
        (drop): Dropout(p=0.0, inplace=False)
        (fc): Identity()
      )
    )
  )
  (1): Sequential(
    (0): AdaptiveConcatPool2d(
      (ap): AdaptiveAvgPool2d(output_size=1)
      (mp): AdaptiveMaxPool2d(output_size=1)
    )
    (1): Flatten(full=False)
    (2): BatchNorm1d(1536, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.25, inplace=False)
    (4): Linear(in_features=1536, out_features=512, bias=False)
    (5): ReLU(inplace=True)
    (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): Dropout(p=0.5, inplace=False)
    (8): Linear(in_features=512, out_features=10, bias=False)
  )
)
learn.summary()
Sequential (Input shape: 32 x 3 x 224 x 224)
============================================================================
Layer (type)         Output Shape         Param #    Trainable 
============================================================================
                     32 x 96 x 56 x 56   
Conv2d                                    4704       True      
LayerNorm2d                               192        True      
Identity                                                       
Conv2d                                    4800       True      
LayerNorm                                 192        True      
____________________________________________________________________________
                     32 x 56 x 56 x 384  
Linear                                    37248      True      
GELU                                                           
Dropout                                                        
____________________________________________________________________________
                     32 x 56 x 56 x 96   
Linear                                    36960      True      
Dropout                                                        
Identity                                                       
Conv2d                                    4800       True      
LayerNorm                                 192        True      
____________________________________________________________________________
                     32 x 56 x 56 x 384  
Linear                                    37248      True      
GELU                                                           
Dropout                                                        
____________________________________________________________________________
                     32 x 56 x 56 x 96   
Linear                                    36960      True      
Dropout                                                        
Identity                                                       
Conv2d                                    4800       True      
LayerNorm                                 192        True      
____________________________________________________________________________
                     32 x 56 x 56 x 384  
Linear                                    37248      True      
GELU                                                           
Dropout                                                        
____________________________________________________________________________
                     32 x 56 x 56 x 96   
Linear                                    36960      True      
Dropout                                                        
Identity                                                       
LayerNorm2d                               192        True      
____________________________________________________________________________
                     32 x 192 x 28 x 28  
Conv2d                                    73920      True      
Conv2d                                    9600       True      
LayerNorm                                 384        True      
____________________________________________________________________________
                     32 x 28 x 28 x 768  
Linear                                    148224     True      
GELU                                                           
Dropout                                                        
____________________________________________________________________________
                     32 x 28 x 28 x 192  
Linear                                    147648     True      
Dropout                                                        
Identity                                                       
Conv2d                                    9600       True      
LayerNorm                                 384        True      
____________________________________________________________________________
                     32 x 28 x 28 x 768  
Linear                                    148224     True      
GELU                                                           
Dropout                                                        
____________________________________________________________________________
                     32 x 28 x 28 x 192  
Linear                                    147648     True      
Dropout                                                        
Identity                                                       
Conv2d                                    9600       True      
LayerNorm                                 384        True      
____________________________________________________________________________
                     32 x 28 x 28 x 768  
Linear                                    148224     True      
GELU                                                           
Dropout                                                        
____________________________________________________________________________
                     32 x 28 x 28 x 192  
Linear                                    147648     True      
Dropout                                                        
Identity                                                       
LayerNorm2d                               384        True      
____________________________________________________________________________
                     32 x 384 x 14 x 14  
Conv2d                                    295296     True      
Conv2d                                    19200      True      
LayerNorm                                 768        True      
____________________________________________________________________________
                     32 x 14 x 14 x 1536 
Linear                                    591360     True      
GELU                                                           
Dropout                                                        
____________________________________________________________________________
                     32 x 14 x 14 x 384  
Linear                                    590208     True      
Dropout                                                        
Identity                                                       
Conv2d                                    19200      True      
LayerNorm                                 768        True      
____________________________________________________________________________
                     32 x 14 x 14 x 1536 
Linear                                    591360     True      
GELU                                                           
Dropout                                                        
____________________________________________________________________________
                     32 x 14 x 14 x 384  
Linear                                    590208     True      
Dropout                                                        
Identity                                                       
Conv2d                                    19200      True      
LayerNorm                                 768        True      
____________________________________________________________________________
                     32 x 14 x 14 x 1536 
Linear                                    591360     True      
GELU                                                           
Dropout                                                        
____________________________________________________________________________
                     32 x 14 x 14 x 384  
Linear                                    590208     True      
Dropout                                                        
Identity                                                       
Conv2d                                    19200      True      
LayerNorm                                 768        True      
____________________________________________________________________________
                     32 x 14 x 14 x 1536 
Linear                                    591360     True      
GELU                                                           
Dropout                                                        
____________________________________________________________________________
                     32 x 14 x 14 x 384  
Linear                                    590208     True      
Dropout                                                        
Identity                                                       
Conv2d                                    19200      True      
LayerNorm                                 768        True      
____________________________________________________________________________
                     32 x 14 x 14 x 1536 
Linear                                    591360     True      
GELU                                                           
Dropout                                                        
____________________________________________________________________________
                     32 x 14 x 14 x 384  
Linear                                    590208     True      
Dropout                                                        
Identity                                                       
Conv2d                                    19200      True      
LayerNorm                                 768        True      
____________________________________________________________________________
                     32 x 14 x 14 x 1536 
Linear                                    591360     True      
GELU                                                           
Dropout                                                        
____________________________________________________________________________
                     32 x 14 x 14 x 384  
Linear                                    590208     True      
Dropout                                                        
Identity                                                       
Conv2d                                    19200      True      
LayerNorm                                 768        True      
____________________________________________________________________________
                     32 x 14 x 14 x 1536 
Linear                                    591360     True      
GELU                                                           
Dropout                                                        
____________________________________________________________________________
                     32 x 14 x 14 x 384  
Linear                                    590208     True      
Dropout                                                        
Identity                                                       
Conv2d                                    19200      True      
LayerNorm                                 768        True      
____________________________________________________________________________
                     32 x 14 x 14 x 1536 
Linear                                    591360     True      
GELU                                                           
Dropout                                                        
____________________________________________________________________________
                     32 x 14 x 14 x 384  
Linear                                    590208     True      
Dropout                                                        
Identity                                                       
Conv2d                                    19200      True      
LayerNorm                                 768        True      
____________________________________________________________________________
                     32 x 14 x 14 x 1536 
Linear                                    591360     True      
GELU                                                           
Dropout                                                        
____________________________________________________________________________
                     32 x 14 x 14 x 384  
Linear                                    590208     True      
Dropout                                                        
Identity                                                       
LayerNorm2d                               768        True      
____________________________________________________________________________
                     32 x 768 x 7 x 7    
Conv2d                                    1180416    True      
Conv2d                                    38400      True      
LayerNorm                                 1536       True      
____________________________________________________________________________
                     32 x 7 x 7 x 3072   
Linear                                    2362368    True      
GELU                                                           
Dropout                                                        
____________________________________________________________________________
                     32 x 7 x 7 x 768    
Linear                                    2360064    True      
Dropout                                                        
Identity                                                       
Conv2d                                    38400      True      
LayerNorm                                 1536       True      
____________________________________________________________________________
                     32 x 7 x 7 x 3072   
Linear                                    2362368    True      
GELU                                                           
Dropout                                                        
____________________________________________________________________________
                     32 x 7 x 7 x 768    
Linear                                    2360064    True      
Dropout                                                        
Identity                                                       
Conv2d                                    38400      True      
LayerNorm                                 1536       True      
____________________________________________________________________________
                     32 x 7 x 7 x 3072   
Linear                                    2362368    True      
GELU                                                           
Dropout                                                        
____________________________________________________________________________
                     32 x 7 x 7 x 768    
Linear                                    2360064    True      
Dropout                                                        
Identity                                                       
Identity                                                       
____________________________________________________________________________
                     32 x 768 x 1 x 1    
AdaptiveAvgPool2d                                              
AdaptiveMaxPool2d                                              
____________________________________________________________________________
                     32 x 1536           
Flatten                                                        
BatchNorm1d                               3072       True      
Dropout                                                        
____________________________________________________________________________
                     32 x 512            
Linear                                    786432     True      
ReLU                                                           
BatchNorm1d                               1024       True      
Dropout                                                        
____________________________________________________________________________
Total params: 28,602,496
Total trainable params: 28,602,496
Total non-trainable params: 0
Optimizer used: <function Adam at 0x7fba154d2820>
Loss function: FlattenedLoss of CrossEntropyLoss()
Model unfrozen
Callbacks:
  - TrainEvalCallback
  - Recorder
  - ProgressCallback
h= m[1]
h
Sequential(
  (0): AdaptiveConcatPool2d(
    (ap): AdaptiveAvgPool2d(output_size=1)
    (mp): AdaptiveMaxPool2d(output_size=1)
  )
  (1): Flatten(full=False)
  (2): BatchNorm1d(1536, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Dropout(p=0.25, inplace=False)
  (4): Linear(in_features=1536, out_features=512, bias=False)
  (5): ReLU(inplace=True)
  (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (7): Dropout(p=0.5, inplace=False)
  (8): Linear(in_features=512, out_features=10, bias=False)
)
ll =h[-1]
ll
Linear(in_features=512, out_features=10, bias=False)
ll.parameters()
<generator object Module.parameters at 0x7f3eb07f6510>
llp=list(ll.parameters())[0]
llp
Parameter containing:
tensor([[-0.0744,  0.0336,  0.0605,  ..., -0.0420,  0.0424, -0.0034],
        [-0.1795, -0.2076,  0.1650,  ..., -0.0621, -0.0006,  0.0082],
        [ 0.1246,  0.1805, -0.0646,  ...,  0.0427, -0.0914, -0.0676],
        ...,
        [-0.0006,  0.0500,  0.1109,  ..., -0.1318, -0.1138,  0.0212],
        [-0.0090,  0.0812, -0.0604,  ...,  0.1045,  0.2306,  0.1239],
        [-0.0352, -0.0684,  0.1508,  ...,  0.0299,  0.0247, -0.0539]],
       device='cuda:0', requires_grad=True)
llp.shape
torch.Size([10, 512])
ll
Linear(in_features=512, out_features=10, bias=False)
from copy import deepcopy
dls = ImageDataLoaders.from_folder(trn_path, valid_pct=.2, seed=42, item_tfms=Resize(224),batch=aug_transforms(size=224, min_scale= 0.75))   
learn = vision_learner(dls, arch)#, metrics=error_rate) # If you miss this part, you'll get an error! metrics=error_rate creates a problem 
learn2 = deepcopy(learn)
curr_loss=learn2.loss_func
def dtc_loss(preds,targs):
    rice_preds,dis_preds = preds    
    return curr_loss(dis_preds,targs)
    
class DiseaseAndTypeClassifier(nn.Module):
    def __init__(self,m):
        super().__init__()
        self.l1 = nn.Linear(in_features=512, out_features=10, bias=False) #rice type
        self.l2 = nn.Linear(in_features=512, out_features=10, bias=False) #disease
        del(m[1][-1]) #it removes the last layer
        self.m = m
    def forward (self,x):
        x = self.m(x)
        x1 = self.l1(x)
        x2 = self.l2(x)
        return x1,x2
        
dtc = DiseaseAndTypeClassifier(learn2.model)
learn2.model = dtc
learn2.loss_func=dtc_loss
preds,targs = learn2.get_preds(dl=learn2.dls.valid)
rice_preds, dis_preds = preds
rice_preds.shape,dis_preds.shape
(torch.Size([2081, 10]), torch.Size([2081, 10]))
need to check Fastai create head function