Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel
Jinen Setpal de637021fb
fixed best epoch, stronger gaussian std.dev, minor refactoring
2 months ago
de637021fb
fixed best epoch, stronger gaussian std.dev, minor refactoring
2 months ago
64232a333d
pushed complete pipeline
2 months ago
1beaa57ed0
added paper draft, updated requirements and dvc remote
7 months ago
de637021fb
fixed best epoch, stronger gaussian std.dev, minor refactoring
2 months ago
d38f7711c1
updated slides
1 year ago
src
de637021fb
fixed best epoch, stronger gaussian std.dev, minor refactoring
2 months ago
64e34808f8
initial commit
1 year ago
f3b8a2ce8b
fixed evaluation, added public remote
7 months ago
2eb5c5ad45
fixed cams resolution and linguist
3 months ago
de637021fb
fixed best epoch, stronger gaussian std.dev, minor refactoring
2 months ago
64e34808f8
initial commit
1 year ago
56050bfda6
grammar is hard
1 year ago
de637021fb
fixed best epoch, stronger gaussian std.dev, minor refactoring
2 months ago
de637021fb
fixed best epoch, stronger gaussian std.dev, minor refactoring
2 months ago
cdb29b5e40
no pre-trained weights, updated packages
2 months ago
Storage Buckets
Data Pipeline
Legend
DVC Managed File
Git Managed File
Metric
Stage File
External File

README.md

You have to be logged in to leave a comment. Sign In

Interpretability Tools as Feedback Loops

Idea

Class Activation Mappings (CAMs) are a popular interpretability technique for CNNs. It allows us to identify the regions within the image that contribute to the final classification.

I think they can be used within the training sequence against a kernel loss, so as to leveraged 'shared knowledge' within a dataset. In this case, most targets in the dataset are centred. Using the CAM, we check if the region outside the center has a high weightage. We take the mean squared error of this region, with a higher weight the further we move from the center.

Since there is a high loss assigned to the background, the model is less likely to overfit and attempt to memorize the background, even if the dataset alone doesn't account for that (e.g. with varied backgrounds).

Since CNNs are resistant to pixel-shift, positional generalization is not an issue, and during prediction, the model identifies targets not in the center with about the same accuracy as centred images.

Task Setup

In this (largely contrived) example, the CALTECH-256 dataset is restricted to two classes: orcas and leopards, reducing it to a binary classification task.

I've split up the test subset such that the leopards have a blue background, and the orcas have uncommon backgrounds (I could not find green, but I did find one with a yellow tint!). I'm trying to mislead the model as much as possible.

The model architecture is standard, and doesn't change between experiments. It is only the training strategy that is updated.

The model with the updated training strategy tended to have better performance than the standard model over a large set of runs.

Running the Code

The repository is structured as a package, so most of the relative imports are handled when the code is run using the -m flag.

Setup

From the root of the repository:

pip install -r requirements.txt  # please use a virtual environment!
dvc pull -r origin

Training

python -m src.model.train {camloss|default}

The 'default' argument trains the default model. Any other name will train using the updated strategy. The resultant model is saved under models/{model_name}

CAMs

python -m src.data.generator {model_name}

CAMs are evaluated on {model_name}. It should be in the models/ directory

Project Structure

├── LICENSE
├── README.md
├── data
│   ├── cams                <- generated activation mappings from models
│   ├── test
│   ├── train
│   └── val
│
├── models                  <- trained and serialized models
│   ├── default             <- standard model
│   └── camloss             <- model with cam loss and training strategy applied
│
├── slides                  <- talk slides
│   └── img                 <- source images
│
├── requirements.txt
│
└── src
    ├── data
    │   ├── augment.py      <- make sure leopards and orcas have equal training samples
    │   └── generator.py    <- dataset + cams generator
    │
    ├── models
    │   ├── arch.py         <- model architecture
    │   ├── callbacks.py    <- callbacks for managing training parameters
    │   ├── loss.py         <- kernel loss function
    │   ├── train.py        <- driver
    │   └── predict.py      <- prediction script
    │
    └── scripts             <- helper scripts
        ├── pipeline.sh     <- run all the code at once
        └── colab_setup.sh  <- automate colab setup

Tip!

Press p or to see the previous file or, n or to see the next file

About

lint - Leverage INTerpretability

Collaborators 1

Comments

Loading...