SpringBoot集成DJL實現(xiàn)圖片分類功能
1.什么是DJL?
DJL 是一個很新的項目,在2019年12月初的AWS re: invest大會上才正式的發(fā)布出來。。簡單來說,DJL是一個使用Java API簡化模型訓(xùn)練、測試、部署和使用深度學(xué)習(xí)模型進(jìn)行推理的開源庫深度學(xué)習(xí)工具包,開源的許可協(xié)議是Apache-2.0。對于Java開發(fā)者而言,可以在Java中開發(fā)及應(yīng)用原生的機(jī)器學(xué)習(xí)和深度學(xué)習(xí)模型,同時簡化了深度學(xué)習(xí)開發(fā)的難度。通過DJL提供的直觀的、高級的API,Java開發(fā)人員可以訓(xùn)練自己的模型,或者利用數(shù)據(jù)科學(xué)家用Python預(yù)先訓(xùn)練好的模型來進(jìn)行推理。如果您恰好是對學(xué)習(xí)深度學(xué)習(xí)感興趣的Java開發(fā)者,那么DJL無疑將是開始深度學(xué)習(xí)應(yīng)用的一個最好的起點。
2.數(shù)據(jù)準(zhǔn)備
下載訓(xùn)練集
wget https://vision.cs.utexas.edu/projects/finegrained/utzap50k/ut-zap50k-images-square.zip
解壓,方便后面訓(xùn)練模型使用
unzip ut-zap50k-images-square.zip
3.代碼工程
實驗?zāi)康?/h3>
基于djl實現(xiàn)圖片分類
基于djl實現(xiàn)圖片分類
pom.xml
<?xml version="1.0" encoding="UTF-8"?> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <parent> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-parent</artifactId> <version>3.2.1</version> </parent> <modelVersion>4.0.0</modelVersion> <artifactId>djl</artifactId> <properties> <maven.compiler.source>17</maven.compiler.source> <maven.compiler.target>17</maven.compiler.target> </properties> <dependencies> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> </dependency> <dependency> <groupId>org.projectlombok</groupId> <artifactId>lombok</artifactId> <optional>true</optional> </dependency> <!-- DJL --> <dependency> <groupId>ai.djl</groupId> <artifactId>api</artifactId> </dependency> <dependency> <groupId>ai.djl</groupId> <artifactId>basicdataset</artifactId> </dependency> <dependency> <groupId>ai.djl</groupId> <artifactId>model-zoo</artifactId> </dependency> <!-- pytorch-engine--> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-engine</artifactId> <scope>runtime</scope> </dependency> </dependencies> <profiles> <profile> <id>windows</id> <activation> <activeByDefault>true</activeByDefault> </activation> <dependencies> <!-- Windows CPU --> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-cpu</artifactId> <classifier>win-x86_64</classifier> <scope>runtime</scope> <version>2.0.1</version> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-jni</artifactId> <version>2.0.1-0.23.0</version> <scope>runtime</scope> </dependency> </dependencies> </profile> <profile> <id>centos7</id> <activation> <activeByDefault>false</activeByDefault> </activation> <dependencies> <!-- For Pre-CXX11 build (CentOS7)--> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-cpu-precxx11</artifactId> <classifier>linux-x86_64</classifier> <version>2.0.1</version> <scope>runtime</scope> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-jni</artifactId> <version>2.0.1-0.23.0</version> <scope>runtime</scope> </dependency> </dependencies> </profile> <profile> <id>linux</id> <activation> <activeByDefault>false</activeByDefault> </activation> <dependencies> <!-- Linux CPU --> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-cpu</artifactId> <classifier>linux-x86_64</classifier> <scope>runtime</scope> <version>2.0.1</version> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-jni</artifactId> <version>2.0.1-0.23.0</version> <scope>runtime</scope> </dependency> </dependencies> </profile> <profile> <id>aarch64</id> <activation> <activeByDefault>false</activeByDefault> </activation> <dependencies> <!-- For aarch64 build--> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-cpu-precxx11</artifactId> <classifier>linux-aarch64</classifier> <scope>runtime</scope> <version>2.0.1</version> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-jni</artifactId> <version>2.0.1-0.23.0</version> <scope>runtime</scope> </dependency> </dependencies> </profile> </profiles> <dependencyManagement> <dependencies> <dependency> <groupId>ai.djl</groupId> <artifactId>bom</artifactId> <version>0.23.0</version> <type>pom</type> <scope>import</scope> </dependency> </dependencies> </dependencyManagement> </project>
conotroller
package com.et.controller; import ai.djl.MalformedModelException; import ai.djl.translate.TranslateException; import com.et.service.ImageClassificationService; import lombok.RequiredArgsConstructor; import org.springframework.core.io.FileSystemResource; import org.springframework.core.io.Resource; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.*; import org.springframework.web.multipart.MultipartFile; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.List; import java.util.Random; import java.util.stream.Stream; @RestController @RequiredArgsConstructor public class ImageClassificationController { private final ImageClassificationService imageClassificationService; @PostMapping(path = "/analyze") public String predict(@RequestPart("image") MultipartFile image, @RequestParam(defaultValue = "/home/djl-test/models") String modePath) throws TranslateException, MalformedModelException, IOException { return imageClassificationService.predict(image, modePath); } @PostMapping(path = "/training") public String training(@RequestParam(defaultValue = "/home/djl-test/images-test") String datasetRoot, @RequestParam(defaultValue = "/home/djl-test/models") String modePath) throws TranslateException, IOException { return imageClassificationService.training(datasetRoot, modePath); } @GetMapping("/download") public ResponseEntity<Resource> downloadFile(@RequestParam(defaultValue = "/home/djl-test/images-test") String directoryPath) { List<String> imgPathList = new ArrayList<>(); try (Stream<Path> paths = Files.walk(Paths.get(directoryPath))) { // Filter only regular files (excluding directories) paths.filter(Files::isRegularFile) .forEach(c-> imgPathList.add(c.toString())); } catch (IOException e) { return ResponseEntity.status(500).build(); } Random random = new Random(); String filePath = imgPathList.get(random.nextInt(imgPathList.size())); Path file = Paths.get(filePath); Resource resource = new FileSystemResource(file.toFile()); if (!resource.exists()) { return ResponseEntity.notFound().build(); } HttpHeaders headers = new HttpHeaders(); headers.add(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=" + file.getFileName().toString()); headers.add(HttpHeaders.CONTENT_TYPE, MediaType.IMAGE_JPEG_VALUE); try { return ResponseEntity.ok() .headers(headers) .contentLength(resource.contentLength()) .body(resource); } catch (IOException e) { return ResponseEntity.status(500).build(); } } }
service
接口
package com.et.service; import ai.djl.MalformedModelException; import ai.djl.translate.TranslateException; import org.springframework.web.multipart.MultipartFile; import java.io.IOException; public interface ImageClassificationService { public String predict(MultipartFile image, String modePath) throws IOException, MalformedModelException, TranslateException; public String training(String datasetRoot, String modePath) throws TranslateException, IOException; }
實現(xiàn)類
package com.et.service; import ai.djl.MalformedModelException; import ai.djl.Model; import ai.djl.basicdataset.cv.classification.ImageFolder; import ai.djl.inference.Predictor; import ai.djl.metric.Metrics; import ai.djl.modality.Classifications; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; import ai.djl.modality.cv.transform.Resize; import ai.djl.modality.cv.transform.ToTensor; import ai.djl.modality.cv.translator.ImageClassificationTranslator; import ai.djl.ndarray.types.Shape; import ai.djl.training.*; import ai.djl.training.dataset.RandomAccessDataset; import ai.djl.training.evaluator.Accuracy; import ai.djl.training.listener.TrainingListener; import ai.djl.training.loss.Loss; import ai.djl.translate.TranslateException; import ai.djl.translate.Translator; import com.et.Models; import lombok.Cleanup; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import org.springframework.web.multipart.MultipartFile; import javax.imageio.ImageIO; import java.awt.image.BufferedImage; import java.io.IOException; import java.io.InputStream; import java.nio.file.Path; import java.nio.file.Paths; @Slf4j @Service public class ImageClassificationServiceImpl implements ImageClassificationService { // represents number of training samples processed before the model is updated private static final int BATCH_SIZE = 32; // the number of passes over the complete dataset private static final int EPOCHS = 2; //the number of classification labels: boots, sandals, shoes, slippers @Value("${djl.num-of-output:4}") public int numOfOutput; @Override public String predict(MultipartFile image, String modePath) throws IOException, MalformedModelException, TranslateException { @Cleanup InputStream is = image.getInputStream(); Path modelDir = Paths.get(modePath); BufferedImage bi = ImageIO.read(is); Image img = ImageFactory.getInstance().fromImage(bi); // empty model instance try (Model model = Models.getModel(numOfOutput)) { // load the model model.load(modelDir, Models.MODEL_NAME); // define a translator for pre and post processing // out of the box this translator converts images to ResNet friendly ResNet 18 shape Translator<Image, Classifications> translator = ImageClassificationTranslator.builder() .addTransform(new Resize(Models.IMAGE_WIDTH, Models.IMAGE_HEIGHT)) .addTransform(new ToTensor()) .optApplySoftmax(true) .build(); // run the inference using a Predictor try (Predictor<Image, Classifications> predictor = model.newPredictor(translator)) { // holds the probability score per label Classifications predictResult = predictor.predict(img); log.info("reusult={}",predictResult.toJson()); return predictResult.toJson(); } } } @Override public String training(String datasetRoot, String modePath) throws TranslateException, IOException { log.info("Image dataset training started...Image dataset address path:{}",datasetRoot); // the location to save the model Path modelDir = Paths.get(modePath); // create ImageFolder dataset from directory ImageFolder dataset = initDataset(datasetRoot); // Split the dataset set into training dataset and validate dataset RandomAccessDataset[] datasets = dataset.randomSplit(8, 2); // set loss function, which seeks to minimize errors // loss function evaluates model's predictions against the correct answer (during training) // higher numbers are bad - means model performed poorly; indicates more errors; want to // minimize errors (loss) Loss loss = Loss.softmaxCrossEntropyLoss(); // setting training parameters (ie hyperparameters) TrainingConfig config = setupTrainingConfig(loss); try (Model model = Models.getModel(numOfOutput); // empty model instance to hold patterns Trainer trainer = model.newTrainer(config)) { // metrics collect and report key performance indicators, like accuracy trainer.setMetrics(new Metrics()); Shape inputShape = new Shape(1, 3, Models.IMAGE_HEIGHT, Models.IMAGE_HEIGHT); // initialize trainer with proper input shape trainer.initialize(inputShape); // find the patterns in data EasyTrain.fit(trainer, EPOCHS, datasets[0], datasets[1]); // set model properties TrainingResult result = trainer.getTrainingResult(); model.setProperty("Epoch", String.valueOf(EPOCHS)); model.setProperty( "Accuracy", String.format("%.5f", result.getValidateEvaluation("Accuracy"))); model.setProperty("Loss", String.format("%.5f", result.getValidateLoss())); // save the model after done training for inference later // model saved as shoeclassifier-0000.params model.save(modelDir, Models.MODEL_NAME); // save labels into model directory Models.saveSynset(modelDir, dataset.getSynset()); log.info("Image dataset training completed......"); return String.join("\n", dataset.getSynset()); } } private ImageFolder initDataset(String datasetRoot) throws IOException, TranslateException { ImageFolder dataset = ImageFolder.builder() // retrieve the data .setRepositoryPath(Paths.get(datasetRoot)) .optMaxDepth(10) .addTransform(new Resize(Models.IMAGE_WIDTH, Models.IMAGE_HEIGHT)) .addTransform(new ToTensor()) // random sampling; don't process the data in order .setSampling(BATCH_SIZE, true) .build(); dataset.prepare(); return dataset; } private TrainingConfig setupTrainingConfig(Loss loss) { return new DefaultTrainingConfig(loss) .addEvaluator(new Accuracy()) .addTrainingListeners(TrainingListener.Defaults.logging()); } }
application.yaml
server: port: 8888 spring: application: name: djl-image-classification-demo servlet: multipart: max-file-size: 100MB max-request-size: 100MB mvc: pathmatch: matching-strategy: ant_path_matcher
啟動類
package com.et; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; @SpringBootApplication public class DemoApplication { public static void main(String[] args) { SpringApplication.run(DemoApplication.class, args); } }
以上只是一些關(guān)鍵代碼,所有代碼請參見下面代碼倉庫
代碼倉庫
4.測試
啟動Spring Boot應(yīng)用
訓(xùn)練模型
使用之前下載的數(shù)據(jù)集
控制臺輸出日志,如果沒有g(shù)pu的話,訓(xùn)練有點慢,估計要等一會
2024-10-11T21:00:05.407+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] c.e.s.ImageClassificationServiceImpl : Image dataset training started...Image dataset address path:/Users/liuhaihua/ai/ut-zap50k-images-square 2024-10-11T21:00:08.455+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.util.Platform : Ignore mismatching platform from: jar:file:/Users/liuhaihua/.m2/repository/ai/djl/pytorch/pytorch-native-cpu/2.0.1/pytorch-native-cpu-2.0.1-win-x86_64.jar!/native/lib/pytorch.properties 2024-10-11T21:00:09.240+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.pytorch.engine.PtEngine : PyTorch graph executor optimizer is enabled, this may impact your inference latency and throughput. See: https://docs.djl.ai/docs/development/inference_performance_optimization.html#graph-executor-optimization 2024-10-11T21:00:09.241+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.pytorch.engine.PtEngine : Number of inter-op threads is 4 2024-10-11T21:00:09.241+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.pytorch.engine.PtEngine : Number of intra-op threads is 4 2024-10-11T21:00:09.287+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Training on: cpu(). 2024-10-11T21:00:09.290+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Load PyTorch Engine Version 1.13.1 in 0.044 ms. Training: 100% |████████████████████████████████████████| Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.38 Validating: 100% |████████████████████████████████████████| 2024-10-11T22:42:48.142+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Epoch 1 finished. 2024-10-11T22:42:48.187+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Train: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.38 2024-10-11T22:42:48.189+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Validate: Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.24 Training: 5% |███ | Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.22
預(yù)測圖片分類
使用上一步訓(xùn)練出來的模型進(jìn)行預(yù)測
根據(jù)返回的結(jié)果看見鞋子的概率最高,由此可見該圖片所屬的鞋類為 Shoes
以上就是SpringBoot集成DJL實現(xiàn)圖片分類功能的詳細(xì)內(nèi)容,更多關(guān)于SpringBoot DJL圖片分類的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
Spring注解驅(qū)動開發(fā)實現(xiàn)屬性賦值
這篇文章主要介紹了Spring注解驅(qū)動開發(fā)實現(xiàn)屬性賦值,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2020-04-04IDEA?設(shè)置?SpringBoot?logback?彩色日志的解決方法?附配置文件
這篇文章主要介紹了IDEA?設(shè)置?SpringBoot?logback?彩色日志(附配置文件)的操作方法,本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下2023-12-12Spring Cloud 配置中心內(nèi)容加密的配置方法
這篇文章主要介紹了Spring Cloud 配置中心內(nèi)容加密的配置方法,非常不錯,具有一定的參考借鑒價值,需要的朋友可以參考下2018-06-06Spring Boot使用JSR-380進(jìn)行校驗的示例
這篇文章主要介紹了Spring Boot使用JSR-380進(jìn)行校驗,本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下2021-03-03mybatis學(xué)習(xí)筆記之mybatis注解配置詳解
本篇文章主要介紹了mybatis學(xué)習(xí)筆記之mybatis注解配置詳解,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2017-12-12Java中數(shù)組如何轉(zhuǎn)為字符串的幾種方法
數(shù)組是java中一個重要的類型,小伙伴們知道如何將數(shù)組轉(zhuǎn)為字符串嗎,這篇文章主要給大家介紹了關(guān)于Java中數(shù)組如何轉(zhuǎn)為字符串的幾種方法,需要的朋友可以參考下2024-03-03