From 32f6e1fb234ca69ebab76d119f7488268fb36aa8 Mon Sep 17 00:00:00 2001 From: Nicola Thouliss Date: Sat, 29 May 2021 14:20:24 +1000 Subject: [PATCH] Converted /ml folder to typescript --- src/ml/{Model.js => Model.ts} | 12 +++++++++--- .../{GridSizeModel.js => GridSizeModel.ts} | 15 ++++++++++----- 2 files changed, 19 insertions(+), 8 deletions(-) rename src/ml/{Model.js => Model.ts} (69%) rename src/ml/gridSize/{GridSizeModel.js => GridSizeModel.ts} (67%) diff --git a/src/ml/Model.js b/src/ml/Model.ts similarity index 69% rename from src/ml/Model.js rename to src/ml/Model.ts index 3809bf4..3c554bd 100644 --- a/src/ml/Model.js +++ b/src/ml/Model.ts @@ -1,15 +1,21 @@ +import { ModelJSON, WeightsManifestConfig } from "@tensorflow/tfjs-core/dist/io/types"; import blobToBuffer from "../helpers/blobToBuffer"; class Model { - constructor(config, weightsMapping) { + config: ModelJSON; + weightsMapping: { [path: string]: string }; + constructor(config: ModelJSON, weightsMapping: { [path: string]: string }) { this.config = config; this.weightsMapping = weightsMapping; } async load() { // Load weights from the manifest then fetch them into an ArrayBuffer - let buffers = []; - const manifest = this.config.weightsManifest[0]; + let buffers: ArrayBuffer[] = []; + if (this.config === undefined) { + return; + } + const manifest = this.config?.weightsManifest[0]; for (let path of manifest.paths) { const url = this.weightsMapping[path]; const response = await fetch(url); diff --git a/src/ml/gridSize/GridSizeModel.js b/src/ml/gridSize/GridSizeModel.ts similarity index 67% rename from src/ml/gridSize/GridSizeModel.js rename to src/ml/gridSize/GridSizeModel.ts index 6775f75..25bde7e 100644 --- a/src/ml/gridSize/GridSizeModel.js +++ b/src/ml/gridSize/GridSizeModel.ts @@ -2,17 +2,21 @@ import Model from "../Model"; import config from "./model.json"; import weights from "./group1-shard1of1.bin"; +import { LayersModel } from "@tensorflow/tfjs"; +import { ModelJSON } from "@tensorflow/tfjs-core/dist/io/types"; class GridSizeModel extends Model { // Store model as static to prevent extra network requests - static model; + static model: LayersModel; // Load tensorflow dynamically - static tf; + + // TODO: find type for tf + static tf: any; constructor() { - super(config, { "group1-shard1of1.bin": weights }); + super(config as ModelJSON, { "group1-shard1of1.bin": weights }); } - async predict(imageData) { + async predict(imageData: ImageData) { if (!GridSizeModel.tf) { GridSizeModel.tf = await import("@tensorflow/tfjs"); } @@ -23,7 +27,8 @@ class GridSizeModel extends Model { } const model = GridSizeModel.model; - const prediction = tf.tidy(() => { + // TODO: check this mess -> changing type on prediction causes issues + const prediction: any = tf.tidy(() => { const image = tf.browser.fromPixels(imageData, 1).toFloat(); const normalized = image.div(tf.scalar(255.0)); const batched = tf.expandDims(normalized);