iBioHash2024

Config

import os

import pickle

import random

import sklearn


import numpy as np

import pandas as pd

import tensorflow as tf

import matplotlib.pyplot as plt


from tqdm.notebook import *

from sklearn.neighbors import *

from sklearn.decomposition import *

from sklearn.model_selection import *


try:

    import farmhash

    import keras_cv_attention_models


    import tensorflow_addons as tfa

except:

    !pip install -qq cityhash

    !pip install -qq tensorflow-addons

    !pip install -qq keras-cv-attention-models


    import farmhash

    import keras_cv_attention_models


    import tensorflow_addons as tfa

try:

    tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection

except ValueError:

    tpu = None

    gpus = tf.config.experimental.list_logical_devices("GPU")


if tpu:

    tf.config.experimental_connect_to_cluster(tpu)

    tf.tpu.experimental.initialize_tpu_system(tpu)

    strategy = tf.distribute.experimental.TPUStrategy(tpu)

    tf.config.set_soft_device_placement(True)


    print('Running on TPU ', tpu.master())

elif len(gpus) > 0:

    strategy = tf.distribute.MirroredStrategy(gpus)

    print('Running on ', len(gpus), ' GPU(s) ')

else:

    strategy = tf.distribute.get_strategy()

    print('Running on CPU')


print("Number of accelerators: ", strategy.num_replicas_in_sync)


AUTO = tf.data.experimental.AUTOTUNE


config = {

    "seed": 1213,


    "lr": 1e-5,

    "epochs": 10,

    "batch_size": 32 * strategy.num_replicas_in_sync,


    "n_classes": 1000,

    "image_size": [224, 224, 3],

    "hashLength": 2048,


    "data_paths": ['gs://kds-3f5ce2e56559d17c577c6d9f0c1acd2179da3bc0c1f87ae24d66f478', 'gs://kds-6d0cfb649153417b33f05a9913c59a358a3c03706a8fd6c5f523b8fc', 'gs://kds-0a7e8b42af254bd7711b5670f5aeb0eeda8d169841722ea59680d3d4', 'gs://kds-d27f1d2f32cff13b0ce97a6b912eb44a68e8ee3128504dcb52e23a3e', 'gs://kds-7a36360336e83a60d74c0196158fed9d6fa70cd342dc96383b7f6e2a', 'gs://kds-f99762a40174159b2d5b1caedff9d3e42758db9134c67405c68ce52d', 'gs://kds-af8ef10c715e13cb364f2655d5188aedfbb874b71e250d2773dade89', 'gs://kds-b0b80433ab5beabf70aca68659672322018f637039ae27fbd9ee9bf1'],

    "save_path": "./",

    "backbones": ["EfficientNetV2S", "beit.BeitV2BasePatch16"] #["EfficientNetV2M", "beit.BeitV2BasePatch16", "davit.DaViT_S"]

}

def seed_everything(seed):

    random.seed(seed)

    np.random.seed(seed)

    tf.random.set_seed(seed)


seed_everything(config["seed"])

Data

def train_transform(image, label):

    image = tf.expand_dims(image, axis=0)


    if (tf.random.uniform([1]) < 0.5):

        image = tf.image.random_flip_left_right(image)

    if (tf.random.uniform([1]) < 0.2):

        image = tf.image.random_brightness(image, 0.08)

    if (tf.random.uniform([1]) < 0.3):

        image = tfa.image.rotate(image, 1.0*tf.random.uniform([1]) - 0.5)

    return tf.squeeze(image, axis=0), label

def margin_format(image, label):

    return {'image': image, 'label': label}, label


def onehot(data, label):

    label = tf.one_hot(label, tf.constant(config["n_classes"]), axis =-1, dtype=tf.float32)

    return data, label


def decode_image(image_data):

    image = tf.image.decode_jpeg(image_data, channels = 3)

    image = tf.image.resize_with_pad(image, target_width = config["image_size"][0], target_height = config["image_size"][1])

    return image


def read_tfrecord(example):

    TFREC_FORMAT = {

        "image": tf.io.FixedLenFeature([], tf.string),

        'label': tf.io.FixedLenFeature([], tf.int64),

    }

    example = tf.io.parse_single_example(example, TFREC_FORMAT)

    image = decode_image(example['image'])

    label = tf.cast(example['label'], dtype = tf.int64)

    return image, label


def load_dataset(filenames, ordered):

    ignore_order = tf.data.Options()

    if not ordered:

        ignore_order.experimental_deterministic = False


    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads = AUTO)

    dataset = dataset.with_options(ignore_order)

    dataset = dataset.map(read_tfrecord, num_parallel_calls = AUTO)

    return dataset


def get_train_dataset(filenames):

    dataset = load_dataset(filenames, ordered = False).repeat().shuffle(config["seed"])


    dataset = dataset.map(train_transform, num_parallel_calls = AUTO)

    dataset = dataset.map(onehot, num_parallel_calls = AUTO)

    dataset = dataset.map(margin_format, num_parallel_calls = AUTO)


    dataset = dataset.batch(config["batch_size"])

    dataset = dataset.prefetch(AUTO)

    return dataset


def get_valid_dataset(filenames):

    dataset = load_dataset(filenames, ordered = True)


    dataset = dataset.map(onehot, num_parallel_calls = AUTO)

    dataset = dataset.map(margin_format, num_parallel_calls = AUTO)


    dataset = dataset.batch(config["batch_size"])

    dataset = dataset.prefetch(AUTO)

    return dataset


def get_test_dataset(filenames):

    dataset = load_dataset(filenames, ordered = True)


    dataset = dataset.map(margin_format, num_parallel_calls = AUTO)


    dataset = dataset.batch(config["batch_size"])

    dataset = dataset.prefetch(AUTO)

    return dataset


Model

class Margin(tf.keras.layers.Layer):

    def __init__(self, num_classes, margin = 0.1, scale=32, **kwargs):

        super().__init__(**kwargs)

        self.scale = scale

        self.margin = margin

        self.num_classes = num_classes


    def build(self, input_shape):

        self.W = self.add_weight(shape=(self.num_classes, input_shape[0][-1]), initializer='glorot_uniform', trainable=True)


    def build_hash(self, x, scale = 10.0):

        return tf.nn.sigmoid(scale*tf.nn.l2_normalize(x, axis = 1))


    def hamming_distance(self, feature):

        x = self.build_hash(feature, self.scale)

        w = self.build_hash(self.W, self.scale)


        x = tf.tile(tf.expand_dims(feature, 2), [1, 1, self.W.shape[0]])

        w = tf.transpose(self.W)

        return tf.clip_by_value(tf.reduce_sum(tf.math.abs(x - w), axis = 1), 1e-4, 48)


    def logits_hamming(self, feature, labels):

        distance = self.hamming_distance(feature)

        mr = tf.random.normal(shape = tf.shape(distance), mean = self.margin, stddev = 0.1*self.margin)

        distance_add = distance + mr

        mask = tf.cast(labels, dtype=distance.dtype)

        logits = self.scale/(mask*distance_add + (1-mask)*distance)

        return logits


    def distance(self, feature):

        x = tf.nn.l2_normalize(feature, axis=1)

        w = tf.nn.l2_normalize(self.W, axis=1)


        x = tf.tile(tf.expand_dims(x, 2), [1, 1, self.W.shape[0]])

        w = tf.transpose(w)

        return tf.reduce_sum( tf.math.pow( x - w, 2 ), axis=1)


    def logits_distance(self, feature, labels):

        distance = self.distance(feature)

        mr = tf.random.normal(shape = tf.shape(distance), mean = self.margin, stddev = 0.1*self.margin)

        distance_add = distance + mr

        mask = tf.cast(labels, dtype=distance.dtype)

        logits = self.scale/(mask*distance_add + (1-mask)*distance)

        return logits


    def cosine(self, feature):

        x = tf.nn.l2_normalize(feature, axis=1)

        w = tf.nn.l2_normalize(self.W, axis=1)

        cos = tf.matmul(x, tf.transpose(w))

        return cos


    def logits_cosine(self, feature, labels):

        cosine = self.cosine(feature)

        mr = tf.random.normal(shape = tf.shape(cosine), mean = self.margin, stddev = 0.1*self.margin)

        theta = tf.acos(tf.clip_by_value(cosine, -1, 1))

        cosine_add = tf.math.cos(theta + mr)


        mask = tf.cast(labels, dtype=cosine.dtype)

        logits = mask*cosine_add + (1-mask)*cosine

        return logits*self.scale


    def call(self, inputs, training):

        feature, labels = inputs


        if training:

            logits = self.logits_cosine(feature, labels)

        else:

            logits = self.cosine(feature)

        return logits


    def get_config(self):

        config = super().get_config().copy()

        config.update({

            'scale': self.scale,

            'margin': self.margin,

            'num_classes': self.num_classes,

        })

        return config


def get_backbone(backbone_name, x):

    if hasattr(tf.keras.applications, backbone_name):

        headModel = tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="tf"))(x)

        return tf.keras.layers.GlobalAveragePooling2D()(getattr(tf.keras.applications, backbone_name)(weights = "imagenet", include_top = False)(headModel))

    else:

        backbone = getattr(getattr(keras_cv_attention_models, backbone_name.split(".")[0]), backbone_name.split(".")[1])(num_classes=0)

        headModel = tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"))(x)

        backbone.trainable = True

        if "beit" in backbone_name:

            return backbone(headModel)

        return tf.keras.layers.GlobalAveragePooling2D()(backbone(headModel))


def model_factory(backbones, n_classes):

    image = tf.keras.layers.Input(shape = (None, None, 3), dtype=tf.uint8, name = 'image')

    label = tf.keras.layers.Input(shape = (), name = 'label', dtype = tf.int64)


    features = [get_backbone(backbone, image) for backbone in backbones]

    headModel = tf.keras.layers.Concatenate(name = "concat")(features)

    headModel = tf.keras.layers.Dense(config["hashLength"], activation = "linear", name = "feature")(headModel)


    margin = Margin(num_classes = n_classes, name = "margin")([headModel, label])

    output = tf.keras.layers.Softmax(dtype=tf.float32, name = "output")(margin)


    model = tf.keras.models.Model(inputs = [image, label], outputs = [output])

    return model

Callbacks

%%writefile generate_submit_csv.py

import argparse

import numpy as np


parser = argparse.ArgumentParser()

parser.add_argument('--query', type=str, default='submit_query.csv', help='48-bit hash code file for query set.')

parser.add_argument('--gallery', type=str, default='submit_gallery.csv', help='48-bit hash code file for gallery set.')

parser.add_argument('--submit', type=str, default='submit.csv', help='Final submit csv file.')

parser.add_argument('--k', type=int, default=100, help='Topk == 100.')

args = parser.parse_args()


k = args.k

query_code_path = args.query

gallery_code_path = args.gallery

submit_path = args.submit


query_code = []

query_images = []

with open(query_code_path, 'r') as f:

    f.readline()

    for line in f:

        image, code = line[:-1].split(',')

        code = [-1. if i == 0 else 1. for i in list(map(float, list(code[1:-1])))]

        query_code.append(code)

        query_images.append(image)

query_code = np.array(query_code)


gallery_code = []

gallery_images = []

with open(gallery_code_path, 'r') as f:

    f.readline()

    for line in f:

        image, code = line[:-1].split(',')

        code = [-1. if i == 0 else 1. for i in list(map(float, list(code[1:-1])))]

        gallery_code.append(code)

        gallery_images.append(image[:-4])

gallery_code = np.array(gallery_code)


with open(submit_path, 'w') as f:

    f.write('Id,Predicted\n')

    for i, q in enumerate(query_code):

        hamming_dist = q.shape[0] - np.dot(q, gallery_code.T)

        index = np.argsort(hamming_dist)[:k]

        retrieval_images = [gallery_images[j] for j in index]

        f.write(query_images[i] + ',')

        f.write(' '.join(retrieval_images) + '\n')

        print(f'\r writing {i + 1}/{query_code.shape[0]}', end='')

class Evaluation(tf.keras.callbacks.Callback):

    def __init__(self, g_data, q_data, v_g_data):

        self.g_data = g_data

        self.q_data = q_data

        self.v_g_data = v_g_data


    def bitdigest(self, digest):

        return [str(x) for i, x in enumerate(digest.tolist())]


    def to_id(self, ids):

        return [str(x) + ".jpg" for x in ids]


    def process_vgallery(self, model, epoch, steps = None):

        feature, vg_id = model.predict(self.v_g_data, verbose = 1, steps = steps)

        feature = np.array(tf.nn.l2_normalize(feature, axis=1)).astype(np.float16)

        return feature, vg_id

        # vg_id.tofile(f"{config['save_path']}VGid_{epoch}")

        # np.array(feature).astype(np.float16).tofile(f"{config['save_path']}VGfeature_{epoch}")


    def process_gallery(self, model, epoch, steps = None):

        feature, g_id = model.predict(self.g_data, verbose = 1, steps = steps)

        feature = np.array(tf.nn.l2_normalize(feature, axis=1)).astype(np.float16)

        return feature, g_id

        # g_id.tofile(f"{config['save_path']}Gid_{epoch}")

        # np.array(feature).astype(np.float16).tofile(f"{config['save_path']}Gfeature_{epoch}")


    def process_query(self, model, epoch, steps = None):

        feature, q_id = model.predict(self.q_data, verbose = 1, steps = steps)

        feature = np.array(tf.nn.l2_normalize(feature, axis=1)).astype(np.float16)

        return feature, q_id

        # q_id.tofile(f"{config['save_path']}Qid_{epoch}")

        # np.array(feature).astype(np.float16).tofile(f"{config['save_path']}Qfeature_{epoch}")


    def gen_sub(self, epoch):

        return os.popen(f"python3 generate_submit_csv.py --gallery {config['save_path']}G_{epoch}.csv --query {config['save_path']}Q_{epoch}.csv --submit {config['save_path']}S_{epoch}.csv").read()


    def on_epoch_end(self, epoch, logs={}):

        model = tf.keras.models.Model(inputs = self.model.inputs,

                                        outputs = [self.model.get_layer('feature').output, self.model.inputs[1]])


        with open(f"{config['save_path']}data_{epoch}", 'wb') as handle:

            pickle.dump([self.process_query(model, epoch), self.process_gallery(model, epoch), self.process_vgallery(model, epoch)], handle, protocol=pickle.HIGHEST_PROTOCOL)

        # self.gen_sub(epoch)


class SaveModel(tf.keras.callbacks.Callback):

    def __init__(self, path):

        self.path = path


    def on_epoch_end(self, epoch, logs={}):

        pass

        # self.model.save(self.path + "model.keras")

Run

DATA_FILENAMES = []

TEST_G_FILENAMES = []

TEST_Q_FILENAMES = []


for gcs_path in config["data_paths"]:

    DATA_FILENAMES += tf.io.gfile.glob(gcs_path + '/*BIO*.tfrec')

    TEST_G_FILENAMES += tf.io.gfile.glob(gcs_path + '/FGVC11*Test_G*.tfrec')

    TEST_Q_FILENAMES += tf.io.gfile.glob(gcs_path + '/FGVC11*Test_Q*.tfrec')


TRAINING_FILENAMES, VALIDATION_FILENAMES = train_test_split(DATA_FILENAMES, test_size=0.02, random_state = config["seed"])


sample_dataset = get_test_dataset(random.choices(DATA_FILENAMES, k = 1))

train_dataset = get_train_dataset(TRAINING_FILENAMES)

valid_dataset = get_valid_dataset(VALIDATION_FILENAMES)

valid_G_dataset = get_test_dataset(VALIDATION_FILENAMES)

test_G_dataset = get_test_dataset(sorted(TEST_G_FILENAMES))

test_Q_dataset = get_test_dataset(sorted(TEST_Q_FILENAMES))


evaluation = Evaluation(test_G_dataset, test_Q_dataset, valid_G_dataset)


counter = 0

plt.figure()

f, axarr = plt.subplots(4,4)


for data in train_dataset:

    for image in data[0]['image']:

        axarr[counter//4, counter%4].imshow(image/256)

        axarr[counter//4, counter%4].axis('off')


        counter += 1

        if (counter == 16):

            break

    if (counter == 16):

        break


plt.show()


with strategy.scope():

    model = model_factory(backbones = config["backbones"],

                          n_classes = config["n_classes"])


    optimizer = tf.keras.optimizers.Adam(learning_rate = config["lr"])

    model.compile(optimizer = optimizer,

                  loss = [tf.keras.losses.CategoricalCrossentropy()],

                  metrics = [tf.keras.metrics.CategoricalAccuracy(name = "ACC@1"),

                             tf.keras.metrics.TopKCategoricalAccuracy(k = 10, name = "ACC@10"),

                             tf.keras.metrics.TopKCategoricalAccuracy(k = 50, name = "ACC@50"),

                             ])

    savemodel = SaveModel(path = config['save_path'])


    # if os.path.isfile("model.keras"):

    #     model = tf.keras.models.load_model("model.keras", safe_mode=False, custom_objects={"Hash": Hash, "Margin": Margin})


H = model.fit(train_dataset, verbose = 1,

              validation_data = valid_dataset,

              callbacks = [savemodel, evaluation],

              steps_per_epoch = 1000,

              epochs = config["epochs"])